diff --git a/include/circt/Dialect/RTG/IR/ArithVisitors.h b/include/circt/Dialect/RTG/IR/ArithVisitors.h new file mode 100644 index 000000000000..7accac4a47a3 --- /dev/null +++ b/include/circt/Dialect/RTG/IR/ArithVisitors.h @@ -0,0 +1,126 @@ +//===- ArithVisitors.h - Arith Dialect Visitors -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines visitors that make it easier to work with Arith Ops. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_RTG_IR_ARITHVISITORS_H +#define CIRCT_DIALECT_RTG_IR_ARITHVISITORS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace arith { + +/// This helps visit TypeOp nodes. +template +class ArithOpVisitor { +public: + ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(op) + .template Case( + [&](auto expr) -> ResultType { + return thisCast->visitOp(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + if (op->getDialect() == + op->getContext()->getLoadedDialect()) { + return visitInvalidTypeOp(op, args...); + } + return thisCast->visitExternalOp(op, args...); + }); + } + + /// This callback is invoked on any RTG operations not handled properly by the + /// TypeSwitch. + ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args) { + op->emitOpError("Unknown Arith operation: ") << op->getName(); + abort(); + } + + /// This callback is invoked on any operations that are not + /// handled by the concrete visitor. + ResultType visitUnhandledOp(Operation *op, ExtraArgs... args); + + ResultType visitExternalOp(Operation *op, ExtraArgs... args) { + return ResultType(); + } + +#define HANDLE(OPTYPE, OPKIND) \ + ResultType visitOp(OPTYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##OPKIND##Op(op, args...); \ + } + + HANDLE(ConstantOp, Unhandled); + HANDLE(AddIOp, Unhandled); + HANDLE(AddUIExtendedOp, Unhandled); + HANDLE(SubIOp, Unhandled); + HANDLE(MulIOp, Unhandled); + HANDLE(MulSIExtendedOp, Unhandled); + HANDLE(MulUIExtendedOp, Unhandled); + HANDLE(DivUIOp, Unhandled); + HANDLE(DivSIOp, Unhandled); + HANDLE(CeilDivUIOp, Unhandled); + HANDLE(CeilDivSIOp, Unhandled); + HANDLE(FloorDivSIOp, Unhandled); + HANDLE(RemUIOp, Unhandled); + HANDLE(RemSIOp, Unhandled); + HANDLE(AndIOp, Unhandled); + HANDLE(OrIOp, Unhandled); + HANDLE(XOrIOp, Unhandled); + HANDLE(ShLIOp, Unhandled); + HANDLE(ShRUIOp, Unhandled); + HANDLE(ShRSIOp, Unhandled); + HANDLE(NegFOp, Unhandled); + HANDLE(AddFOp, Unhandled); + HANDLE(SubFOp, Unhandled); + HANDLE(MaximumFOp, Unhandled); + HANDLE(MaxNumFOp, Unhandled); + HANDLE(MaxSIOp, Unhandled); + HANDLE(MaxUIOp, Unhandled); + HANDLE(MinimumFOp, Unhandled); + HANDLE(MinNumFOp, Unhandled); + HANDLE(MinSIOp, Unhandled); + HANDLE(MinUIOp, Unhandled); + HANDLE(MulFOp, Unhandled); + HANDLE(DivFOp, Unhandled); + HANDLE(RemFOp, Unhandled); + HANDLE(ExtUIOp, Unhandled); + HANDLE(ExtSIOp, Unhandled); + HANDLE(ExtFOp, Unhandled); + HANDLE(TruncIOp, Unhandled); + HANDLE(TruncFOp, Unhandled); + HANDLE(UIToFPOp, Unhandled); + HANDLE(SIToFPOp, Unhandled); + HANDLE(FPToUIOp, Unhandled); + HANDLE(FPToSIOp, Unhandled); + HANDLE(IndexCastOp, Unhandled); + HANDLE(IndexCastUIOp, Unhandled); + HANDLE(BitcastOp, Unhandled); + HANDLE(CmpIOp, Unhandled); + HANDLE(CmpFOp, Unhandled); + HANDLE(SelectOp, Unhandled); +#undef HANDLE +}; + +} // namespace arith +} // namespace mlir + +#endif // CIRCT_DIALECT_RTG_IR_ARITHVISITORS_H diff --git a/lib/Dialect/RTG/Transforms/CMakeLists.txt b/lib/Dialect/RTG/Transforms/CMakeLists.txt index cd363a383510..70dcd46cb5f9 100644 --- a/lib/Dialect/RTG/Transforms/CMakeLists.txt +++ b/lib/Dialect/RTG/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ add_circt_dialect_library(CIRCTRTGTransforms LINK_LIBS PRIVATE CIRCTRTGDialect + MLIRArithDialect MLIRIR MLIRPass ) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 4dc6e5b0d664..d0991bd29fe0 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// +#include "circt/Dialect/RTG/IR/ArithVisitors.h" #include "circt/Dialect/RTG/IR/RTGOps.h" #include "circt/Dialect/RTG/IR/RTGVisitors.h" #include "circt/Dialect/RTG/Transforms/RTGPasses.h" @@ -228,6 +229,53 @@ class SequenceClosureValue : public ElaboratorValue { SmallVector args; }; +/// Holds an evaluated value of an `IndexType` or `IntegerType`'d value. +/// TODO: support integers with more than 64 bits +class IntegerValue : public ElaboratorValue { +public: + IntegerValue(Value value, uint64_t integer) + : ElaboratorValue(value, false), integer(integer) { + assert((isa(value.getType()) && + value.getType().getIntOrFloatBitWidth() <= 64) || + isa(value.getType())); + } + + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue *val) { + return !val->isOpaqueValue() && + (IndexType::classof(val->getType()) || + (IntegerType::classof(val->getType()) && + val->getType().getIntOrFloatBitWidth() <= 64)); + } + + bool containsOpaqueValue() const override { return false; } + + llvm::hash_code getHashValue() const override { + return llvm::hash_combine(integer, getType()); + } + + bool isEqual(const ElaboratorValue &other) const override { + auto *intVal = dyn_cast(&other); + if (!intVal) + return false; + + return integer == intVal->integer && getType() == intVal->getType(); + } + + std::string toString() const override { + std::string out; + llvm::raw_string_ostream stream(out); + stream << ""; + return out; + } + + uint64_t getInt() const { return integer; } + +private: + uint64_t integer; +}; + //===----------------------------------------------------------------------===// // Hash Map Helpers //===----------------------------------------------------------------------===// @@ -269,11 +317,19 @@ struct InternMapInfo : public DenseMapInfo { enum class DeletionKind { Keep, Delete }; /// Interprets the IR to perform and lower the represented randomizations. -class Elaborator : public RTGOpVisitor, - function_ref> { +class Elaborator + : public RTGOpVisitor, + function_ref>, + public mlir::arith::ArithOpVisitor, + function_ref> { public: - using RTGOpVisitor, - function_ref>::visitOp; + using RTGBase = RTGOpVisitor, + function_ref>; + using ArithBase = ArithOpVisitor, + function_ref>; + + using ArithBase::visitOp; + using RTGBase::visitOp; Elaborator(SymbolTable &table, const ElaborationOptions &options) : options(options), symTable(table) { @@ -313,6 +369,20 @@ class Elaborator : public RTGOpVisitor, return DeletionKind::Keep; } + FailureOr + visitOp(arith::ConstantOp op, function_ref addToWorklist) { + if (auto val = dyn_cast(op.getValue())) { + if (val.getValue().getBitWidth() <= 64 && + !val.getType().isSignedInteger()) { + internalizeResult(op.getResult(), + val.getValue().getZExtValue()); + return DeletionKind::Delete; + } + } + + return visitExternalOp(op, addToWorklist); + } + FailureOr visitOp(SequenceClosureOp op, function_ref addToWorklist) { SmallVector args; @@ -431,6 +501,15 @@ class Elaborator : public RTGOpVisitor, return DeletionKind::Delete; } + FailureOr + dispatchOpVisitor(Operation *op, + function_ref addToWorklist) { + if (op->getDialect() == op->getContext()->getLoadedDialect()) + return RTGBase::dispatchOpVisitor(op, addToWorklist); + + return ArithBase::dispatchOpVisitor(op, addToWorklist); + } + LogicalResult elaborate(TestOp testOp) { LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating Test @" << testOp.getSymName() << "\n\n"); diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index 362ab1ba8313..1c38be9c4d9d 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: rtg.sequence @seq0 rtg.sequence @seq0 { - %2 = arith.constant 2 : i32 + %2 = hw.constant 2 : i32 } // CHECK-LABEL: rtg.sequence @seq2 @@ -19,8 +19,8 @@ rtg.sequence @seq2 { // Test the set operations and passing a sequence ot another one via argument // CHECK-LABEL: rtg.test @setOperations rtg.test @setOperations : !rtg.dict<> { - // CHECK-NEXT: arith.constant 2 : i32 - // CHECK-NEXT: arith.constant 2 : i32 + // CHECK-NEXT: hw.constant 2 : i32 + // CHECK-NEXT: hw.constant 2 : i32 // CHECK-NEXT: } %0 = rtg.sequence_closure @seq0 %1 = rtg.sequence_closure @seq2(%0 : !rtg.sequence) @@ -41,8 +41,8 @@ rtg.sequence @seq3 { // CHECK-LABEL: rtg.test @setArguments rtg.test @setArguments : !rtg.dict<> { - // CHECK-NEXT: arith.constant 2 : i32 - // CHECK-NEXT: arith.constant 2 : i32 + // CHECK-NEXT: hw.constant 2 : i32 + // CHECK-NEXT: hw.constant 2 : i32 // CHECK-NEXT: } %0 = rtg.sequence_closure @seq0 %1 = rtg.sequence_closure @seq2(%0 : !rtg.sequence) @@ -70,37 +70,44 @@ rtg.test @noNullOperands : !rtg.dict<> { } rtg.target @target0 : !rtg.dict { - %0 = arith.constant 0 : i32 + %0 = hw.constant 0 : i32 rtg.yield %0 : i32 } rtg.target @target1 : !rtg.dict { - %0 = arith.constant 1 : i32 + %0 = hw.constant 1 : i32 rtg.yield %0 : i32 } // CHECK-LABEL: @targetTest_target0 -// CHECK: [[V0:%.+]] = arith.constant 0 -// CHECK: arith.addi [[V0]], [[V0]] +// CHECK: [[V0:%.+]] = hw.constant 0 +// CHECK: comb.add [[V0]], [[V0]] // CHECK-LABEL: @targetTest_target1 -// CHECK: [[V0:%.+]] = arith.constant 1 -// CHECK: arith.addi [[V0]], [[V0]] +// CHECK: [[V0:%.+]] = hw.constant 1 +// CHECK: comb.add [[V0]], [[V0]] rtg.test @targetTest : !rtg.dict { ^bb0(%arg0: i32): - arith.addi %arg0, %arg0 : i32 + comb.add %arg0, %arg0 : i32 } // CHECK-NOT: @unmatchedTest rtg.test @unmatchedTest : !rtg.dict { ^bb0(%arg0: i64): - arith.addi %arg0, %arg0 : i64 + comb.add %arg0, %arg0 : i64 +} + +// CHECK-LABEL: rtg.test @arithConstant +rtg.test @arithConstant : !rtg.dict<> { + %0 = arith.constant 2 : index + %1 = arith.constant 2 : i32 + // CHECK-NEXT: } } // ----- rtg.test @opaqueValuesAndSets : !rtg.dict<> { - %0 = arith.constant 2 : i32 + %0 = hw.constant 2 : i32 // expected-error @below {{cannot create a set of opaque values because they cannot be reliably uniqued}} %1 = rtg.set_create %0 : i32 } @@ -108,7 +115,7 @@ rtg.test @opaqueValuesAndSets : !rtg.dict<> { // ----- rtg.sequence @seq0 { - %2 = arith.constant 2 : i32 + %2 = hw.constant 2 : i32 } // Test that the elaborator value interning works as intended and exercise 'set_select_random' error messages.