From 31e652bdbd95ee46808445a1915ec52e09a097c2 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Fri, 29 Nov 2024 11:00:53 +0000 Subject: [PATCH] [RTG] Elaboration support for get_size operations --- include/circt/Dialect/RTG/IR/RTGVisitors.h | 10 +-- .../circt/Dialect/RTG/Transforms/RTGPasses.td | 1 + .../RTG/Transforms/ElaborationPass.cpp | 61 +++++++++++++++++++ test/Dialect/RTG/Transform/elaboration.mlir | 23 +++++++ tools/circt-opt/CMakeLists.txt | 1 + tools/circt-opt/circt-opt.cpp | 2 + 6 files changed, 94 insertions(+), 4 deletions(-) diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index f21b0a12c423..60b4341ecf3f 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -111,10 +111,10 @@ class RTGTypeVisitor { ResultType dispatchTypeVisitor(Type type, ExtraArgs... args) { auto *thisCast = static_cast(this); return TypeSwitch(type) - .template Case( - [&](auto expr) -> ResultType { - return thisCast->visitType(expr, args...); - }) + .template Case([&](auto expr) -> ResultType { + return thisCast->visitType(expr, args...); + }) .template Case( [&](auto expr) -> ResultType { return thisCast->visitContextResourceType(expr, args...); @@ -158,6 +158,8 @@ class RTGTypeVisitor { HANDLE(SetType, Unhandled); HANDLE(BagType, Unhandled); HANDLE(DictType, Unhandled); + HANDLE(IndexType, Unhandled); + HANDLE(IntegerType, Unhandled); #undef HANDLE }; diff --git a/include/circt/Dialect/RTG/Transforms/RTGPasses.td b/include/circt/Dialect/RTG/Transforms/RTGPasses.td index 6769924a4d81..6542ac032e5b 100644 --- a/include/circt/Dialect/RTG/Transforms/RTGPasses.td +++ b/include/circt/Dialect/RTG/Transforms/RTGPasses.td @@ -27,6 +27,7 @@ def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> { // Define a custom constructor to have more control over the pass options // (e.g., std::optional options are not handled very well). let constructor = "::circt::rtg::createElaborationPass()"; + let dependentDialects = ["mlir::arith::ArithDialect"]; } #endif // CIRCT_DIALECT_RTG_RTGPASSES_TD diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index e9df1962a8ef..956ab7f54f92 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -393,6 +393,43 @@ struct InternMapInfo : public DenseMapInfo { // Main Elaborator Implementation //===----------------------------------------------------------------------===// +/// Construct an SSA value from a given elaborated value. +class Materializer : public RTGTypeVisitor { +public: + using Base = RTGTypeVisitor; + using Base::visitType; + + Value visitUnhandledType(Type type, OpBuilder &builder, Location loc, + ElaboratorValue *val) { + return Value(); + } + + Value visitType(IndexType type, OpBuilder &builder, Location loc, + ElaboratorValue *val) { + auto res = builder.create( + loc, IntegerAttr::get(type, cast(val)->getInt())); + materializedValues[{val, builder.getBlock()}] = res; + return res; + } + + Value materialize(Block *block, Location loc, ElaboratorValue *val) { + if (val->isOpaqueValue()) + return val->getOpaqueValue(); + + auto iter = materializedValues.find({val, block}); + if (iter != materializedValues.end()) + return iter->second; + + OpBuilder builder = OpBuilder::atBlockBegin(block); + return dispatchTypeVisitor(val->getType(), builder, loc, val); + } + +private: + DenseMap, Value> materializedValues; +}; + /// Used to signal to the elaboration driver whether the operation should be /// removed. enum class DeletionKind { Keep, Delete }; @@ -444,6 +481,15 @@ class Elaborator FailureOr visitExternalOp(Operation *op, function_ref addToWorklist) { + for (auto &operand : op->getOpOperands()) { + auto val = materializer.materialize(op->getBlock(), op->getLoc(), + state.at(operand.get())); + if (!val) + return op->emitError("failed to materialize value for operand #") + << operand.getOperandNumber(); + operand.set(val); + } + // Treat values defined by external ops as opaque, non-elaborated values. for (auto res : op->getResults()) internalizeResult(res); @@ -607,6 +653,13 @@ class Elaborator return DeletionKind::Delete; } + FailureOr + visitOp(SetGetSizeOp op, function_ref addToWorklist) { + auto size = cast(state.at(op.getSet()))->getAsArrayRef().size(); + internalizeResult(op.getResult(), size); + return DeletionKind::Delete; + } + FailureOr visitOp(BagCreateOp op, function_ref addToWorklist) { DenseMap bag; @@ -749,6 +802,13 @@ class Elaborator return DeletionKind::Delete; } + FailureOr + visitOp(BagGetSizeOp op, function_ref addToWorklist) { + auto size = cast(state.at(op.getBag()))->getBag().size(); + internalizeResult(op.getResult(), size); + return DeletionKind::Delete; + } + FailureOr dispatchOpVisitor(Operation *op, function_ref addToWorklist) { @@ -858,6 +918,7 @@ class Elaborator // A map from SSA values to a pointer of an interned elaborator value. DenseMap state; + Materializer materializer; SymbolTable symTable; }; diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index c031e0937017..af1b05e616e5 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -62,6 +62,29 @@ rtg.test @bagOperations : !rtg.dict<> { rtg.invoke_sequence %seq2 } +// CHECK-LABEL: rtg.test @setSize +rtg.test @setSize : !rtg.dict<> { + // CHECK-NEXT: [[C:%.+]] = arith.constant 1 : index + // CHECK-NEXT: index.add [[C]], [[C]] + // CHECK-NEXT: } + %c5_i32 = arith.constant 5 : i32 + %set = rtg.set_create %c5_i32 : i32 + %size = rtg.set_get_size %set : !rtg.set + index.add %size, %size +} + +// CHECK-LABEL: rtg.test @bagSize +rtg.test @bagSize : !rtg.dict<> { + // CHECK-NEXT: [[C:%.+]] = arith.constant 1 : index + // CHECK-NEXT: index.add [[C]], [[C]] + // CHECK-NEXT: } + %c8 = arith.constant 8 : index + %c5_i32 = arith.constant 5 : i32 + %bag = rtg.bag_create (%c8 x %c5_i32) : i32 + %size = rtg.bag_get_size %bag : !rtg.bag + index.add %size, %size +} + // CHECK-LABEL: rtg.sequence @seq3 rtg.sequence @seq3 { ^bb0(%arg0: !rtg.set): diff --git a/tools/circt-opt/CMakeLists.txt b/tools/circt-opt/CMakeLists.txt index b933045537bd..aedd45fad84b 100644 --- a/tools/circt-opt/CMakeLists.txt +++ b/tools/circt-opt/CMakeLists.txt @@ -38,6 +38,7 @@ target_link_libraries(circt-opt MLIREmitCDialect MLIRFuncInlinerExtension MLIRVectorDialect + MLIRIndexDialect ) export_executable_symbols_for_plugins(circt-opt) diff --git a/tools/circt-opt/circt-opt.cpp b/tools/circt-opt/circt-opt.cpp index 9c4be5a329d1..d65d8c0015b5 100644 --- a/tools/circt-opt/circt-opt.cpp +++ b/tools/circt-opt/circt-opt.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -55,6 +56,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); circt::registerAllDialects(registry); circt::registerAllPasses();