From 48495144adba9f28467fcec35955d3e671baf16f Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Mon, 27 Jan 2025 16:28:33 +0000 Subject: [PATCH] [RTG] Support partial sequence substitutions --- include/circt/Dialect/RTG/IR/RTGOps.td | 52 +++++++---- include/circt/Dialect/RTG/IR/RTGVisitors.h | 6 +- .../Bindings/Python/dialects/rtg.py | 9 +- lib/Dialect/RTG/IR/RTGOps.cpp | 93 +++++++++++++++++-- .../RTG/Transforms/ElaborationPass.cpp | 26 ++++-- test/CAPI/rtg-pipelines.c | 11 ++- test/Dialect/RTG/IR/basic.mlir | 22 +++-- test/Dialect/RTG/IR/errors.mlir | 53 ++++++++++- test/Dialect/RTG/Transform/elaboration.mlir | 44 +++++---- 9 files changed, 240 insertions(+), 76 deletions(-) diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index c9aa3c0e390a..0de59484fcd9 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -59,30 +59,22 @@ def SequenceOp : RTGOp<"sequence", [ let hasRegionVerifier = 1; } -def SequenceClosureOp : RTGOp<"sequence_closure", [ +def GetSequenceOp : RTGOp<"get_sequence", [ Pure, DeclareOpInterfaceMethods ]> { - let summary = "create a sequence closure with the provided arguments"; + let summary = "create a sequence value"; let description = [{ - This operation creates a closure object for the provided sequence and - arguments. This allows sequences to be passed around as an SSA value. - For example, it can be inserted into a set and selected at random which - is one of the main ways to do randomization. Not having to deal with - sequence arguments after randomly selecting a sequence simplifies the - problem of coming up with values to pass as arguments, but also provides a - way for the user to constrain the arguments at the location where they are - added to the set. In the future, it can also be possible to add sequence - handles directly to a set and randomly pick arguments at the invokation - site. + This operation creates a sequence value referring to the provided sequence + by symbol. It allows sequences to be passed around as an SSA value. For + example, it can be inserted into a set and selected at random which is one + of the main ways to do randomization. }]; - let arguments = (ins SymbolNameAttr:$sequence, Variadic:$args); - let results = (outs FullySubstitutedSequenceType:$ref); + let arguments = (ins SymbolNameAttr:$sequence); + let results = (outs SequenceType:$ref); - let assemblyFormat = [{ - $sequence (`(` $args^ `:` qualified(type($args)) `)`)? attr-dict - }]; + let assemblyFormat = "$sequence `:` qualified(type($ref)) attr-dict"; } def RandomizeSequenceOp : RTGOp<"randomize_sequence", []> { @@ -105,11 +97,35 @@ def RandomizeSequenceOp : RTGOp<"randomize_sequence", []> { }]; let arguments = (ins FullySubstitutedSequenceType:$sequence); + let results = (outs RandomizedSequenceType:$randomizedSequence); let assemblyFormat = "$sequence attr-dict"; } +def SubstituteSequenceOp : RTGOp<"substitute_sequence", [ + Pure, + DeclareOpInterfaceMethods, +]> { + let summary = "partially substitute arguments of a sequence family"; + let description = [{ + This operation substitutes the first N of the M >= N arguments of the given + sequence family, where N is the size of provided argument substitution list. + A new sequence (if N == M) or sequence family with M-N will be returned. + + Not having to deal with sequence arguments after randomly selecting a + sequence simplifies the problem of coming up with values to pass as + arguments, but also provides a way for the user to constrain the arguments + at the location where they are added to a set or bag. + }]; + + let arguments = (ins SequenceType:$sequence, Variadic:$replacements); + let results = (outs SequenceType:$result); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + def EmbedSequenceOp : RTGOp<"embed_sequence", []> { let summary = "embed a sequence of instructions into another sequence"; let description = [{ @@ -417,7 +433,7 @@ def TestOp : RTGOp<"test", [ `rtg.sequence`'s body with the exception of the block arguments. The arguments must match the fields of the dict type in the target attribute exactly. The test must not have any additional arguments and cannot be - referenced by an `rtg.sequence_closure` operation. + referenced by an `rtg.get_sequence` operation. }]; let arguments = (ins SymbolNameAttr:$sym_name, diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index abe12e7cbe1e..13cb2b1a2fbe 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -42,7 +42,8 @@ class RTGOpVisitor { // RTG tests TestOp, TargetOp, YieldOp, // Sequences - SequenceOp, SequenceClosureOp, RandomizeSequenceOp, EmbedSequenceOp, + SequenceOp, GetSequenceOp, SubstituteSequenceOp, + RandomizeSequenceOp, EmbedSequenceOp, // Sets SetCreateOp, SetSelectRandomOp, SetDifferenceOp, SetUnionOp, SetSizeOp>([&](auto expr) -> ResultType { @@ -87,7 +88,8 @@ class RTGOpVisitor { } HANDLE(SequenceOp, Unhandled); - HANDLE(SequenceClosureOp, Unhandled); + HANDLE(GetSequenceOp, Unhandled); + HANDLE(SubstituteSequenceOp, Unhandled); HANDLE(RandomizeSequenceOp, Unhandled); HANDLE(EmbedSequenceOp, Unhandled); HANDLE(SetCreateOp, Unhandled); diff --git a/integration_test/Bindings/Python/dialects/rtg.py b/integration_test/Bindings/Python/dialects/rtg.py index 35827450ce11..355731a48869 100644 --- a/integration_test/Bindings/Python/dialects/rtg.py +++ b/integration_test/Bindings/Python/dialects/rtg.py @@ -53,17 +53,18 @@ circt.register_dialects(ctx) m = Module.create() with InsertionPoint(m.body): - seq = rtg.SequenceOp('sequence_name', - TypeAttr.get(rtg.SequenceType.get([]))) + seq = rtg.SequenceOp('sequence_name', TypeAttr.get(rtg.SequenceType.get())) Block.create_at_start(seq.bodyRegion, []) test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get())) block = Block.create_at_start(test.bodyRegion, []) with InsertionPoint(block): - seq_closure = rtg.SequenceClosureOp('sequence_name', []) + seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name') + rtg.RandomizeSequenceOp(seq_get) # CHECK: rtg.test @test_name : !rtg.dict<> { - # CHECK-NEXT: rtg.sequence_closure + # CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name + # CHECK-NEXT: rtg.randomize_sequence [[SEQ]] # CHECK-NEXT: } print(m) diff --git a/lib/Dialect/RTG/IR/RTGOps.cpp b/lib/Dialect/RTG/IR/RTGOps.cpp index 86e38dfc5553..4f3b8113ed8e 100644 --- a/lib/Dialect/RTG/IR/RTGOps.cpp +++ b/lib/Dialect/RTG/IR/RTGOps.cpp @@ -90,11 +90,11 @@ void SequenceOp::print(OpAsmPrinter &p) { } //===----------------------------------------------------------------------===// -// SequenceClosureOp +// GetSequenceOp //===----------------------------------------------------------------------===// LogicalResult -SequenceClosureOp::verifySymbolUses(SymbolTableCollection &symbolTable) { +GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { SequenceOp seq = symbolTable.lookupNearestSymbolFrom(*this, getSequenceAttr()); if (!seq) @@ -102,14 +102,95 @@ SequenceClosureOp::verifySymbolUses(SymbolTableCollection &symbolTable) { << "'" << getSequence() << "' does not reference a valid 'rtg.sequence' operation"; - if (TypeRange(seq.getSequenceType().getElementTypes()) != - getArgs().getTypes()) - return emitOpError("referenced 'rtg.sequence' op's argument types must " - "match 'args' types"); + if (seq.getSequenceType() != getType()) + return emitOpError("referenced 'rtg.sequence' op's type does not match"); return success(); } +//===----------------------------------------------------------------------===// +// SubstituteSequenceOp +//===----------------------------------------------------------------------===// + +LogicalResult SubstituteSequenceOp::verify() { + if (getReplacements().empty()) + return emitOpError("must at least have one replacement value"); + + if (getReplacements().size() > + getSequence().getType().getElementTypes().size()) + return emitOpError( + "must not have more replacement values than sequence arguments"); + + if (getReplacements().getTypes() != + getSequence().getType().getElementTypes().take_front( + getReplacements().size())) + return emitOpError("replacement types must match the same number of " + "sequence argument types from the front"); + + return success(); +} + +LogicalResult SubstituteSequenceOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + ArrayRef argTypes = + cast(operands[0].getType()).getElementTypes(); + auto seqType = + SequenceType::get(context, argTypes.drop_front(operands.size() - 1)); + inferredReturnTypes.push_back(seqType); + return success(); +} + +ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + OpAsmParser::UnresolvedOperand sequenceRawOperand; + SmallVector replacementsOperands; + Type sequenceRawType; + + if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen()) + return failure(); + + auto replacementsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() || + parser.parseColon() || parser.parseType(sequenceRawType) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (!isa(sequenceRawType)) + return parser.emitError(parser.getNameLoc()) + << "'sequence' must be handle to a sequence or sequence family, but " + "got " + << sequenceRawType; + + if (parser.resolveOperand(sequenceRawOperand, sequenceRawType, + result.operands)) + return failure(); + + if (parser.resolveOperands(replacementsOperands, + cast(sequenceRawType) + .getElementTypes() + .take_front(replacementsOperands.size()), + replacementsOperandsLoc, result.operands)) + return failure(); + + SmallVector inferredReturnTypes; + if (failed(inferReturnTypes( + parser.getContext(), result.location, result.operands, + result.attributes.getDictionary(parser.getContext()), + result.getRawProperties(), result.regions, inferredReturnTypes))) + return failure(); + + result.addTypes(inferredReturnTypes); + return success(); +} + +void SubstituteSequenceOp::print(OpAsmPrinter &p) { + p << ' ' << getSequence() << "(" << getReplacements() + << ") : " << getSequence().getType(); + p.printOptionalAttrDict((*this)->getAttrs(), {}); +} + //===----------------------------------------------------------------------===// // SetCreateOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index e491b2d17a4e..be939b0d2c7f 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -668,9 +668,8 @@ class Materializer { std::queue &elabRequests, function_ref emitError) { elabRequests.push(val); - Value seq = builder.create( - loc, SequenceType::get(builder.getContext(), {}), val->name, - ValueRange{}); + Value seq = builder.create( + loc, SequenceType::get(builder.getContext(), {}), val->name); return builder.create(loc, seq); } @@ -795,14 +794,25 @@ class Elaborator : public RTGOpVisitor> { return visitUnhandledOp(op); } - FailureOr visitOp(SequenceClosureOp op) { + FailureOr visitOp(GetSequenceOp op) { SmallVector replacements; - for (auto replacement : op.getArgs()) + state[op.getResult()] = + sharedState.internalizer.internalize( + op.getSequenceAttr(), std::move(replacements)); + return DeletionKind::Delete; + } + + FailureOr visitOp(SubstituteSequenceOp op) { + auto *seq = get(op.getSequence()); + + SmallVector replacements(seq->args); + for (auto replacement : op.getReplacements()) replacements.push_back(state.at(replacement)); state[op.getResult()] = sharedState.internalizer.internalize( - op.getSequenceAttr(), std::move(replacements)); + seq->familyName, std::move(replacements)); + return DeletionKind::Delete; } @@ -1334,10 +1344,10 @@ LogicalResult ElaborationPass::inlineSequences(TestOp testOp, if (!randSeqOp) return embedOp->emitError("sequence operand not directly defined by " "'rtg.randomize_sequence' op"); - auto getSeqOp = randSeqOp.getSequence().getDefiningOp(); + auto getSeqOp = randSeqOp.getSequence().getDefiningOp(); if (!getSeqOp) return randSeqOp->emitError( - "sequence operand not directly defined by 'rtg.sequence_closure' op"); + "sequence operand not directly defined by 'rtg.get_sequence' op"); auto seqOp = table.lookup(getSeqOp.getSequenceAttr()); diff --git a/test/CAPI/rtg-pipelines.c b/test/CAPI/rtg-pipelines.c index 052c818434a8..adc0bdeba984 100644 --- a/test/CAPI/rtg-pipelines.c +++ b/test/CAPI/rtg-pipelines.c @@ -18,11 +18,12 @@ int main(int argc, char **argv) { mlirDialectHandleRegisterDialect(mlirGetDialectHandle__rtg__(), ctx); MlirModule moduleOp = mlirModuleCreateParse( - ctx, mlirStringRefCreateFromCString("rtg.sequence @seq() {\n" - "}\n" - "rtg.test @test : !rtg.dict<> {\n" - " %0 = rtg.sequence_closure @seq\n" - "}\n")); + ctx, mlirStringRefCreateFromCString( + "rtg.sequence @seq() {\n" + "}\n" + "rtg.test @test : !rtg.dict<> {\n" + " %0 = rtg.get_sequence @seq : !rtg.sequence\n" + "}\n")); if (mlirModuleIsNull(moduleOp)) { printf("ERROR: Could not parse.\n"); mlirContextDestroy(ctx); diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index e53edd58469b..cc4b4d18bd37 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -29,20 +29,22 @@ rtg.sequence @seq1(%arg0: i32, %arg1: !rtg.sequence) { } // CHECK-LABEL: rtg.sequence @seqRandomizationAndEmbedding rtg.sequence @seqRandomizationAndEmbedding() { - // CHECK: [[V0:%.+]] = rtg.sequence_closure @seq0 + // CHECK: [[V0:%.+]] = rtg.get_sequence @seq0 // CHECK: [[C0:%.+]] = arith.constant 0 : i32 - // CHECK: [[V1:%.+]] = rtg.sequence_closure @seq1([[C0]], [[V0]] : i32, !rtg.sequence) - // CHECK: [[V2:%.+]] = rtg.randomize_sequence [[V0]] - // CHECK: [[V3:%.+]] = rtg.randomize_sequence [[V1]] - // CHECK: rtg.embed_sequence [[V2]] + // CHECK: [[V1:%.+]] = rtg.get_sequence @seq1 + // CHECK: [[V2:%.+]] = rtg.substitute_sequence [[V1]]([[C0]], [[V0]]) : !rtg.sequence + // CHECK: [[V3:%.+]] = rtg.randomize_sequence [[V0]] + // CHECK: [[V4:%.+]] = rtg.randomize_sequence [[V2]] // CHECK: rtg.embed_sequence [[V3]] - %0 = rtg.sequence_closure @seq0 + // CHECK: rtg.embed_sequence [[V4]] + %0 = rtg.get_sequence @seq0 : !rtg.sequence %c0_i32 = arith.constant 0 : i32 - %1 = rtg.sequence_closure @seq1(%c0_i32, %0 : i32, !rtg.sequence) - %2 = rtg.randomize_sequence %0 - %3 = rtg.randomize_sequence %1 - rtg.embed_sequence %2 + %1 = rtg.get_sequence @seq1 : !rtg.sequence + %2 = rtg.substitute_sequence %1(%c0_i32, %0) : !rtg.sequence + %3 = rtg.randomize_sequence %0 + %4 = rtg.randomize_sequence %2 rtg.embed_sequence %3 + rtg.embed_sequence %4 } // CHECK-LABEL: @sets diff --git a/test/Dialect/RTG/IR/errors.mlir b/test/Dialect/RTG/IR/errors.mlir index b3b1be9047db..09be5cb5ce1a 100644 --- a/test/Dialect/RTG/IR/errors.mlir +++ b/test/Dialect/RTG/IR/errors.mlir @@ -6,14 +6,59 @@ func.func @seq0() { } // expected-error @below {{'seq0' does not reference a valid 'rtg.sequence' operation}} -rtg.sequence_closure @seq0 +rtg.get_sequence @seq0 : !rtg.sequence // ----- -rtg.sequence @seq0(%arg0: i32) { } +rtg.sequence @seq0(%arg0: index) { } -// expected-error @below {{referenced 'rtg.sequence' op's argument types must match 'args' types}} -rtg.sequence_closure @seq0 +// expected-error @below {{referenced 'rtg.sequence' op's type does not match}} +"rtg.get_sequence"() <{sequence="seq0"}> : () -> !rtg.sequence + +// ----- + +rtg.sequence @seq0(%arg0: index) { } + +%0 = rtg.get_sequence @seq0 : !rtg.sequence +// expected-error @below {{must at least have one replacement value}} +rtg.substitute_sequence %0() : !rtg.sequence + +// ----- + +rtg.sequence @seq0(%arg0: index) { } + +%c = index.constant 0 +%0 = rtg.get_sequence @seq0 : !rtg.sequence +// expected-error @below {{number of operands and types do not match: got 2 operands and 1 types}} +rtg.substitute_sequence %0(%c, %c) : !rtg.sequence + +// ----- + +rtg.sequence @seq0(%arg0: index) { } + +// expected-note @below {{prior use here}} +%c = index.bool.constant true +%0 = rtg.get_sequence @seq0 : !rtg.sequence +// expected-error @below {{use of value '%c' expects different type than prior uses: 'index' vs 'i1'}} +rtg.substitute_sequence %0(%c) : !rtg.sequence + +// ----- + +rtg.sequence @seq0(%arg0: index) { } + +%c = index.constant 0 +%0 = rtg.get_sequence @seq0 : !rtg.sequence +// expected-error @below {{must not have more replacement values than sequence arguments}} +"rtg.substitute_sequence"(%0, %c, %c) : (!rtg.sequence, index, index) -> !rtg.sequence + +// ----- + +rtg.sequence @seq0(%arg0: index) { } + +%c = index.bool.constant true +%0 = rtg.get_sequence @seq0 : !rtg.sequence +// expected-error @below {{replacement types must match the same number of sequence argument types from the front}} +"rtg.substitute_sequence"(%0, %c) : (!rtg.sequence, i1) -> !rtg.sequence // ----- diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index e750a8bfc6f9..26cddbdd92fe 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -115,10 +115,11 @@ rtg.sequence @seq0(%arg0: index) { } rtg.sequence @seq1(%arg0: index) { - %0 = rtg.sequence_closure @seq0(%arg0 : index) - %1 = rtg.randomize_sequence %0 + %0 = rtg.get_sequence @seq0 : !rtg.sequence + %1 = rtg.substitute_sequence %0(%arg0) : !rtg.sequence + %2 = rtg.randomize_sequence %1 func.call @dummy2(%arg0) : (index) -> () - rtg.embed_sequence %1 + rtg.embed_sequence %2 func.call @dummy2(%arg0) : (index) -> () } @@ -129,9 +130,10 @@ rtg.test @nestedSequences : !rtg.dict<> { // CHECK: func.call @dummy2 // CHECK: func.call @dummy2 %0 = index.constant 0 - %1 = rtg.sequence_closure @seq1(%0 : index) - %2 = rtg.randomize_sequence %1 - rtg.embed_sequence %2 + %1 = rtg.get_sequence @seq1 : !rtg.sequence + %2 = rtg.substitute_sequence %1(%0) : !rtg.sequence + %3 = rtg.randomize_sequence %2 + rtg.embed_sequence %3 } rtg.sequence @seq2(%arg0: index) { @@ -146,12 +148,14 @@ rtg.test @sameSequenceDifferentArgs : !rtg.dict<> { // CHECK: func.call @dummy2([[C1]]) %0 = index.constant 0 %1 = index.constant 1 - %2 = rtg.sequence_closure @seq2(%0 : index) - %3 = rtg.randomize_sequence %2 - %4 = rtg.sequence_closure @seq2(%1 : index) - %5 = rtg.randomize_sequence %4 - rtg.embed_sequence %3 - rtg.embed_sequence %5 + %2 = rtg.get_sequence @seq2 : !rtg.sequence + %3 = rtg.substitute_sequence %2(%0) : !rtg.sequence + %4 = rtg.randomize_sequence %3 + %5 = rtg.get_sequence @seq2 : !rtg.sequence + %6 = rtg.substitute_sequence %5(%1) : !rtg.sequence + %7 = rtg.randomize_sequence %6 + rtg.embed_sequence %4 + rtg.embed_sequence %7 } rtg.sequence @seq3(%arg0: !rtg.set) { @@ -169,13 +173,15 @@ rtg.test @sequenceClosureFixesRandomization : !rtg.dict<> { %0 = index.constant 0 %1 = index.constant 1 %2 = rtg.set_create %0, %1 : index - %3 = rtg.sequence_closure @seq3(%2 : !rtg.set) - %4 = rtg.randomize_sequence %3 - %5 = rtg.sequence_closure @seq3(%2 : !rtg.set) - %6 = rtg.randomize_sequence %5 - rtg.embed_sequence %4 - rtg.embed_sequence %6 - rtg.embed_sequence %4 + %3 = rtg.get_sequence @seq3 : !rtg.sequence> + %4 = rtg.substitute_sequence %3(%2) : !rtg.sequence> + %5 = rtg.randomize_sequence %4 + %6 = rtg.get_sequence @seq3 : !rtg.sequence> + %7 = rtg.substitute_sequence %6(%2) : !rtg.sequence> + %8 = rtg.randomize_sequence %7 + rtg.embed_sequence %5 + rtg.embed_sequence %8 + rtg.embed_sequence %5 } // CHECK-LABEL: @indexOps