Skip to content

Commit

Permalink
[RTG] Support partial sequence substitutions
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Feb 3, 2025
1 parent fafb31a commit 4849514
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 76 deletions.
52 changes: 34 additions & 18 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,22 @@ def SequenceOp : RTGOp<"sequence", [
let hasRegionVerifier = 1;
}

def SequenceClosureOp : RTGOp<"sequence_closure", [
def GetSequenceOp : RTGOp<"get_sequence", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "create a sequence closure with the provided arguments";
let summary = "create a sequence value";
let description = [{
This operation creates a closure object for the provided sequence and
arguments. This allows sequences to be passed around as an SSA value.
For example, it can be inserted into a set and selected at random which
is one of the main ways to do randomization. Not having to deal with
sequence arguments after randomly selecting a sequence simplifies the
problem of coming up with values to pass as arguments, but also provides a
way for the user to constrain the arguments at the location where they are
added to the set. In the future, it can also be possible to add sequence
handles directly to a set and randomly pick arguments at the invokation
site.
This operation creates a sequence value referring to the provided sequence
by symbol. It allows sequences to be passed around as an SSA value. For
example, it can be inserted into a set and selected at random which is one
of the main ways to do randomization.
}];

let arguments = (ins SymbolNameAttr:$sequence, Variadic<AnyType>:$args);
let results = (outs FullySubstitutedSequenceType:$ref);
let arguments = (ins SymbolNameAttr:$sequence);
let results = (outs SequenceType:$ref);

let assemblyFormat = [{
$sequence (`(` $args^ `:` qualified(type($args)) `)`)? attr-dict
}];
let assemblyFormat = "$sequence `:` qualified(type($ref)) attr-dict";
}

def RandomizeSequenceOp : RTGOp<"randomize_sequence", []> {
Expand All @@ -105,11 +97,35 @@ def RandomizeSequenceOp : RTGOp<"randomize_sequence", []> {
}];

let arguments = (ins FullySubstitutedSequenceType:$sequence);

let results = (outs RandomizedSequenceType:$randomizedSequence);

let assemblyFormat = "$sequence attr-dict";
}

def SubstituteSequenceOp : RTGOp<"substitute_sequence", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "partially substitute arguments of a sequence family";
let description = [{
This operation substitutes the first N of the M >= N arguments of the given
sequence family, where N is the size of provided argument substitution list.
A new sequence (if N == M) or sequence family with M-N will be returned.

Not having to deal with sequence arguments after randomly selecting a
sequence simplifies the problem of coming up with values to pass as
arguments, but also provides a way for the user to constrain the arguments
at the location where they are added to a set or bag.
}];

let arguments = (ins SequenceType:$sequence, Variadic<AnyType>:$replacements);
let results = (outs SequenceType:$result);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

def EmbedSequenceOp : RTGOp<"embed_sequence", []> {
let summary = "embed a sequence of instructions into another sequence";
let description = [{
Expand Down Expand Up @@ -417,7 +433,7 @@ def TestOp : RTGOp<"test", [
`rtg.sequence`'s body with the exception of the block arguments.
The arguments must match the fields of the dict type in the target attribute
exactly. The test must not have any additional arguments and cannot be
referenced by an `rtg.sequence_closure` operation.
referenced by an `rtg.get_sequence` operation.
}];

let arguments = (ins SymbolNameAttr:$sym_name,
Expand Down
6 changes: 4 additions & 2 deletions include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class RTGOpVisitor {
// RTG tests
TestOp, TargetOp, YieldOp,
// Sequences
SequenceOp, SequenceClosureOp, RandomizeSequenceOp, EmbedSequenceOp,
SequenceOp, GetSequenceOp, SubstituteSequenceOp,
RandomizeSequenceOp, EmbedSequenceOp,
// Sets
SetCreateOp, SetSelectRandomOp, SetDifferenceOp, SetUnionOp,
SetSizeOp>([&](auto expr) -> ResultType {
Expand Down Expand Up @@ -87,7 +88,8 @@ class RTGOpVisitor {
}

HANDLE(SequenceOp, Unhandled);
HANDLE(SequenceClosureOp, Unhandled);
HANDLE(GetSequenceOp, Unhandled);
HANDLE(SubstituteSequenceOp, Unhandled);
HANDLE(RandomizeSequenceOp, Unhandled);
HANDLE(EmbedSequenceOp, Unhandled);
HANDLE(SetCreateOp, Unhandled);
Expand Down
9 changes: 5 additions & 4 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,18 @@
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('sequence_name',
TypeAttr.get(rtg.SequenceType.get([])))
seq = rtg.SequenceOp('sequence_name', TypeAttr.get(rtg.SequenceType.get()))
Block.create_at_start(seq.bodyRegion, [])

test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(test.bodyRegion, [])
with InsertionPoint(block):
seq_closure = rtg.SequenceClosureOp('sequence_name', [])
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
rtg.RandomizeSequenceOp(seq_get)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK-NEXT: rtg.sequence_closure
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
# CHECK-NEXT: }
print(m)

Expand Down
93 changes: 87 additions & 6 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,107 @@ void SequenceOp::print(OpAsmPrinter &p) {
}

//===----------------------------------------------------------------------===//
// SequenceClosureOp
// GetSequenceOp
//===----------------------------------------------------------------------===//

LogicalResult
SequenceClosureOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
SequenceOp seq =
symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
if (!seq)
return emitOpError()
<< "'" << getSequence()
<< "' does not reference a valid 'rtg.sequence' operation";

if (TypeRange(seq.getSequenceType().getElementTypes()) !=
getArgs().getTypes())
return emitOpError("referenced 'rtg.sequence' op's argument types must "
"match 'args' types");
if (seq.getSequenceType() != getType())
return emitOpError("referenced 'rtg.sequence' op's type does not match");

return success();
}

//===----------------------------------------------------------------------===//
// SubstituteSequenceOp
//===----------------------------------------------------------------------===//

LogicalResult SubstituteSequenceOp::verify() {
if (getReplacements().empty())
return emitOpError("must at least have one replacement value");

if (getReplacements().size() >
getSequence().getType().getElementTypes().size())
return emitOpError(
"must not have more replacement values than sequence arguments");

if (getReplacements().getTypes() !=
getSequence().getType().getElementTypes().take_front(
getReplacements().size()))
return emitOpError("replacement types must match the same number of "
"sequence argument types from the front");

return success();
}

LogicalResult SubstituteSequenceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
ArrayRef<Type> argTypes =
cast<SequenceType>(operands[0].getType()).getElementTypes();
auto seqType =
SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
inferredReturnTypes.push_back(seqType);
return success();
}

ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
OpAsmParser::UnresolvedOperand sequenceRawOperand;
SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
Type sequenceRawType;

if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
return failure();

auto replacementsOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
parser.parseColon() || parser.parseType(sequenceRawType) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();

if (!isa<SequenceType>(sequenceRawType))
return parser.emitError(parser.getNameLoc())
<< "'sequence' must be handle to a sequence or sequence family, but "
"got "
<< sequenceRawType;

if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
result.operands))
return failure();

if (parser.resolveOperands(replacementsOperands,
cast<SequenceType>(sequenceRawType)
.getElementTypes()
.take_front(replacementsOperands.size()),
replacementsOperandsLoc, result.operands))
return failure();

SmallVector<Type> inferredReturnTypes;
if (failed(inferReturnTypes(
parser.getContext(), result.location, result.operands,
result.attributes.getDictionary(parser.getContext()),
result.getRawProperties(), result.regions, inferredReturnTypes)))
return failure();

result.addTypes(inferredReturnTypes);
return success();
}

void SubstituteSequenceOp::print(OpAsmPrinter &p) {
p << ' ' << getSequence() << "(" << getReplacements()
<< ") : " << getSequence().getType();
p.printOptionalAttrDict((*this)->getAttrs(), {});
}

//===----------------------------------------------------------------------===//
// SetCreateOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 18 additions & 8 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,8 @@ class Materializer {
std::queue<RandomizedSequenceStorage *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
elabRequests.push(val);
Value seq = builder.create<SequenceClosureOp>(
loc, SequenceType::get(builder.getContext(), {}), val->name,
ValueRange{});
Value seq = builder.create<GetSequenceOp>(
loc, SequenceType::get(builder.getContext(), {}), val->name);
return builder.create<RandomizeSequenceOp>(loc, seq);
}

Expand Down Expand Up @@ -795,14 +794,25 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return visitUnhandledOp(op);
}

FailureOr<DeletionKind> visitOp(SequenceClosureOp op) {
FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
SmallVector<ElaboratorValue> replacements;
for (auto replacement : op.getArgs())
state[op.getResult()] =
sharedState.internalizer.internalize<SequenceStorage>(
op.getSequenceAttr(), std::move(replacements));
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
auto *seq = get<SequenceStorage *>(op.getSequence());

SmallVector<ElaboratorValue> replacements(seq->args);
for (auto replacement : op.getReplacements())
replacements.push_back(state.at(replacement));

state[op.getResult()] =
sharedState.internalizer.internalize<SequenceStorage>(
op.getSequenceAttr(), std::move(replacements));
seq->familyName, std::move(replacements));

return DeletionKind::Delete;
}

Expand Down Expand Up @@ -1334,10 +1344,10 @@ LogicalResult ElaborationPass::inlineSequences(TestOp testOp,
if (!randSeqOp)
return embedOp->emitError("sequence operand not directly defined by "
"'rtg.randomize_sequence' op");
auto getSeqOp = randSeqOp.getSequence().getDefiningOp<SequenceClosureOp>();
auto getSeqOp = randSeqOp.getSequence().getDefiningOp<GetSequenceOp>();
if (!getSeqOp)
return randSeqOp->emitError(
"sequence operand not directly defined by 'rtg.sequence_closure' op");
"sequence operand not directly defined by 'rtg.get_sequence' op");

auto seqOp = table.lookup<SequenceOp>(getSeqOp.getSequenceAttr());

Expand Down
11 changes: 6 additions & 5 deletions test/CAPI/rtg-pipelines.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ int main(int argc, char **argv) {
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__rtg__(), ctx);

MlirModule moduleOp = mlirModuleCreateParse(
ctx, mlirStringRefCreateFromCString("rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
" %0 = rtg.sequence_closure @seq\n"
"}\n"));
ctx, mlirStringRefCreateFromCString(
"rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
" %0 = rtg.get_sequence @seq : !rtg.sequence\n"
"}\n"));
if (mlirModuleIsNull(moduleOp)) {
printf("ERROR: Could not parse.\n");
mlirContextDestroy(ctx);
Expand Down
22 changes: 12 additions & 10 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,22 @@ rtg.sequence @seq1(%arg0: i32, %arg1: !rtg.sequence) { }

// CHECK-LABEL: rtg.sequence @seqRandomizationAndEmbedding
rtg.sequence @seqRandomizationAndEmbedding() {
// CHECK: [[V0:%.+]] = rtg.sequence_closure @seq0
// CHECK: [[V0:%.+]] = rtg.get_sequence @seq0
// CHECK: [[C0:%.+]] = arith.constant 0 : i32
// CHECK: [[V1:%.+]] = rtg.sequence_closure @seq1([[C0]], [[V0]] : i32, !rtg.sequence)
// CHECK: [[V2:%.+]] = rtg.randomize_sequence [[V0]]
// CHECK: [[V3:%.+]] = rtg.randomize_sequence [[V1]]
// CHECK: rtg.embed_sequence [[V2]]
// CHECK: [[V1:%.+]] = rtg.get_sequence @seq1
// CHECK: [[V2:%.+]] = rtg.substitute_sequence [[V1]]([[C0]], [[V0]]) : !rtg.sequence<i32, !rtg.sequence>
// CHECK: [[V3:%.+]] = rtg.randomize_sequence [[V0]]
// CHECK: [[V4:%.+]] = rtg.randomize_sequence [[V2]]
// CHECK: rtg.embed_sequence [[V3]]
%0 = rtg.sequence_closure @seq0
// CHECK: rtg.embed_sequence [[V4]]
%0 = rtg.get_sequence @seq0 : !rtg.sequence
%c0_i32 = arith.constant 0 : i32
%1 = rtg.sequence_closure @seq1(%c0_i32, %0 : i32, !rtg.sequence)
%2 = rtg.randomize_sequence %0
%3 = rtg.randomize_sequence %1
rtg.embed_sequence %2
%1 = rtg.get_sequence @seq1 : !rtg.sequence<i32, !rtg.sequence>
%2 = rtg.substitute_sequence %1(%c0_i32, %0) : !rtg.sequence<i32, !rtg.sequence>
%3 = rtg.randomize_sequence %0
%4 = rtg.randomize_sequence %2
rtg.embed_sequence %3
rtg.embed_sequence %4
}

// CHECK-LABEL: @sets
Expand Down
Loading

0 comments on commit 4849514

Please sign in to comment.