Skip to content

Commit

Permalink
[RTG][Elaboration] Add support for context operations
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Jan 30, 2025
1 parent 0a2e772 commit 1cafc54
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 12 deletions.
135 changes: 123 additions & 12 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/RTG/IR/RTGAttributes.h"
#include "circt/Dialect/RTG/IR/RTGOps.h"
#include "circt/Dialect/RTG/IR/RTGVisitors.h"
#include "circt/Dialect/RTG/Transforms/RTGPasses.h"
Expand Down Expand Up @@ -309,12 +310,16 @@ struct SequenceStorage {

/// Storage object for an '!rtg.randomized_sequence'.
struct RandomizedSequenceStorage {
RandomizedSequenceStorage(StringRef name, SequenceStorage *sequence)
: hashcode(llvm::hash_combine(name, sequence)), name(name),
sequence(sequence) {}
RandomizedSequenceStorage(StringRef name,
ContextResourceAttrInterface context,
StringAttr test, SequenceStorage *sequence)
: hashcode(llvm::hash_combine(name, context, test, sequence)), name(name),
context(context), test(test), sequence(sequence) {}

bool isEqual(const RandomizedSequenceStorage *other) const {
return hashcode == other->hashcode && sequence == other->sequence;
return hashcode == other->hashcode && name == other->name &&
context == other->context && test == other->test &&
sequence == other->sequence;
}

// The cached hashcode to avoid repeated computations.
Expand All @@ -323,6 +328,12 @@ struct RandomizedSequenceStorage {
// The name of this fully substituted and elaborated sequence.
const StringRef name;

// The context under which this sequence is placed.
const ContextResourceAttrInterface context;

// The test in which this sequence is placed.
const StringAttr test;

const SequenceStorage *sequence;
};

Expand Down Expand Up @@ -420,7 +431,8 @@ static void print(SequenceStorage *val, llvm::raw_ostream &os) {

static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
os << "<randomized-sequence @" << val->name << " derived from @"
<< val->sequence->familyName.getValue() << "(";
<< val->sequence->familyName.getValue() << " under context "
<< val->context << " in test " << val->test << "(";
llvm::interleaveComma(val->sequence->args, os,
[&](const ElaboratorValue &val) { os << val; });
os << ") at " << val << ">";
Expand Down Expand Up @@ -546,6 +558,11 @@ class Materializer {
op->erase();
}

template <typename OpTy, typename... Args>
OpTy create(Location location, Args &&...args) {
return builder.create<OpTy>(location, std::forward<Args>(args)...);
}

private:
void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
auto ip = builder.getInsertionPoint();
Expand Down Expand Up @@ -720,14 +737,29 @@ struct ElaboratorSharedState {
std::queue<RandomizedSequenceStorage *> worklist;
};

/// A collection of state per RTG test.
struct TestState {
/// The name of the test.
StringAttr name;

/// The context switches registered for this test.
MapVector<
std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
SequenceStorage *>
contextSwitches;
};

/// Interprets the IR to perform and lower the represented randomizations.
class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
public:
using RTGBase = RTGOpVisitor<Elaborator, FailureOr<DeletionKind>>;
using RTGBase::visitOp;

Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer)
: sharedState(sharedState), materializer(materializer) {}
Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
Materializer &materializer,
ContextResourceAttrInterface currentContext = {})
: sharedState(sharedState), testState(testState),
materializer(materializer), currentContext(currentContext) {}

template <typename ValueTy>
inline ValueTy get(Value val) const {
Expand Down Expand Up @@ -804,12 +836,26 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {

auto name = sharedState.names.newName(seq->familyName.getValue());
state[op.getResult()] =
sharedState.internalizer.internalize<RandomizedSequenceStorage>(name,
seq);
sharedState.internalizer.internalize<RandomizedSequenceStorage>(
name, currentContext, testState.name, seq);
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
auto *seq = get<RandomizedSequenceStorage *>(op.getSequence());
if (seq->context != currentContext) {
auto err = op->emitError("attempting to place sequence ")
<< seq->name << " derived from "
<< seq->sequence->familyName.getValue() << " under context "
<< currentContext
<< ", but it was previously randomized for context ";
if (seq->context)
err << seq->context;
else
err << "'default'";
return err;
}

return DeletionKind::Keep;
}

Expand Down Expand Up @@ -1019,6 +1065,60 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(OnContextOp op) {
ContextResourceAttrInterface from = currentContext,
to = cast<ContextResourceAttrInterface>(
get<TypedAttr>(op.getContext()));
if (!currentContext)
from = DefaultContextAttr::get(op->getContext(), to.getType());

auto emitError = [&]() {
auto diag = op.emitError();
diag.attachNote(op.getLoc())
<< "while materializing value for context switching for " << op;
return diag;
};

if (from == to) {
Value seqVal = materializer.materialize(
get<SequenceStorage *>(op.getSequence()), op.getLoc(),
sharedState.worklist, emitError);
Value randSeqVal =
materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
return DeletionKind::Delete;
}

// Switch to the desired context.
auto *iter = testState.contextSwitches.find({from, to});
// NOTE: we could think about supporting context switching via intermediate
// context, i.e., treat it as a transitive relation.
if (iter == testState.contextSwitches.end())
return op->emitError("no context transition registered to switch from ")
<< from << " to " << to;

auto familyName = iter->second->familyName;
SmallVector<ElaboratorValue> args{from, to,
get<SequenceStorage *>(op.getSequence())};
auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
familyName, std::move(args));
auto *randSeq =
sharedState.internalizer.internalize<RandomizedSequenceStorage>(
sharedState.names.newName(familyName.getValue()), to,
testState.name, seq);
Value seqVal = materializer.materialize(randSeq, op.getLoc(),
sharedState.worklist, emitError);
materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);

return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
get<SequenceStorage *>(op.getSequence());
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(scf::IfOp op) {
bool cond = get<bool>(op.getCondition());
auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
Expand Down Expand Up @@ -1176,13 +1276,19 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
// State to be shared between all elaborator instances.
ElaboratorSharedState &sharedState;

// State to a specific RTG test and the sequences placed within it.
TestState &testState;

// Allows us to materialize ElaboratorValues to the IR operations necessary to
// obtain an SSA value representing that elaborated value.
Materializer &materializer;

// A map from SSA values to a pointer of an interned elaborator value.
DenseMap<Value, ElaboratorValue> state;

// The current context we are elaborating under.
ContextResourceAttrInterface currentContext;

uint64_t virtualRegisterID = 0;
uint64_t uniqueLabelID = 1;
};
Expand Down Expand Up @@ -1268,11 +1374,14 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,

// Initialize the worklist with the test ops since they cannot be placed by
// other ops.
DenseMap<StringAttr, TestState> testStates;
for (auto testOp : moduleOp.getOps<TestOp>()) {
LLVM_DEBUG(llvm::dbgs()
<< "\n=== Elaborating test @" << testOp.getSymName() << "\n\n");
Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
Elaborator elaborator(state, materializer);
testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
materializer);
if (failed(elaborator.elaborate(testOp.getBodyRegion())))
return failure();

Expand Down Expand Up @@ -1300,10 +1409,12 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,

LLVM_DEBUG(llvm::dbgs()
<< "\n=== Elaborating sequence family @" << familyOp.getSymName()
<< " into @" << seqOp.getSymName() << "\n\n");
<< " into @" << seqOp.getSymName() << " under context "
<< curr->context << "\n\n");

Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
Elaborator elaborator(state, materializer);
Elaborator elaborator(state, testStates[curr->test], materializer,
curr->context);
if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
curr->sequence->args)))
return failure();
Expand Down
116 changes: 116 additions & 0 deletions test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,78 @@ rtg.test @randomIntegers : !rtg.dict<> {
func.call @dummy2(%1) : (index) -> ()
}

// CHECK-LABEL: rtg.test @contexts_contextCpu
rtg.test @contexts : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
^bb0(%cpu0: !rtgtest.cpu, %cpu1: !rtgtest.cpu):
// CHECK-NEXT: rtg.label_decl "label0"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label5"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label2"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label7"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label4"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label8"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label3"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label6"
// CHECK-NEXT: rtg.label
// CHECK-NEXT: rtg.label_decl "label1"
// CHECK-NEXT: rtg.label
%0 = rtg.get_sequence @cpuSeq : !rtg.sequence<!rtgtest.cpu>
%1 = rtg.substitute_sequence %0(%cpu1) : !rtg.sequence<!rtgtest.cpu>
%l0 = rtg.label_decl "label0"
rtg.label local %l0
rtg.on_context %cpu0, %1 : !rtgtest.cpu
%l1 = rtg.label_decl "label1"
rtg.label local %l1
}

rtg.target @contextCpu : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
%cpu0 = rtgtest.cpu_decl <0>
%cpu1 = rtgtest.cpu_decl <1>
%0 = rtg.get_sequence @switchCpuSeq : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
%1 = rtg.get_sequence @switchNestedCpuSeq : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
rtg.context_switch #rtg.default : !rtgtest.cpu -> #rtgtest.cpu<0> : !rtgtest.cpu, %0 : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
rtg.context_switch #rtgtest.cpu<0> : !rtgtest.cpu -> #rtgtest.cpu<1> : !rtgtest.cpu, %1 : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
rtg.yield %cpu0, %cpu1 : !rtgtest.cpu, !rtgtest.cpu
}

rtg.sequence @cpuSeq(%cpu: !rtgtest.cpu) {
%l2 = rtg.label_decl "label2"
rtg.label local %l2
%0 = rtg.get_sequence @nestedCpuSeq : !rtg.sequence
rtg.on_context %cpu, %0 : !rtgtest.cpu
%l3 = rtg.label_decl "label3"
rtg.label local %l3
}

rtg.sequence @nestedCpuSeq() {
%l4 = rtg.label_decl "label4"
rtg.label local %l4
}

rtg.sequence @switchCpuSeq(%parent: !rtgtest.cpu, %child: !rtgtest.cpu, %seq: !rtg.sequence) {
%l5 = rtg.label_decl "label5"
rtg.label local %l5
%0 = rtg.randomize_sequence %seq
rtg.embed_sequence %0
%l6 = rtg.label_decl "label6"
rtg.label local %l6
}

rtg.sequence @switchNestedCpuSeq(%parent: !rtgtest.cpu, %child: !rtgtest.cpu, %seq: !rtg.sequence) {
%l7 = rtg.label_decl "label7"
rtg.label local %l7
%0 = rtg.randomize_sequence %seq
rtg.embed_sequence %0
%l8 = rtg.label_decl "label8"
rtg.label local %l8
}

// -----

rtg.test @nestedRegionsNotSupported : !rtg.dict<> {
Expand Down Expand Up @@ -424,3 +496,47 @@ rtg.test @randomIntegers : !rtg.dict<> {
%0 = rtg.random_number_in_range [%c5, %c5)
func.call @dummy2(%0) : (index) -> ()
}

// -----

rtg.sequence @seq0(%seq: !rtg.randomized_sequence) {
// expected-error @below {{attempting to place sequence seq1_0 derived from seq1 under context #rtgtest.cpu<0> : !rtgtest.cpu, but it was previously randomized for context 'default'}}
rtg.embed_sequence %seq
}
rtg.sequence @seq1() { }
rtg.sequence @seq(%arg0: !rtgtest.cpu, %arg1: !rtgtest.cpu, %seq: !rtg.sequence) {
%0 = rtg.randomize_sequence %seq
rtg.embed_sequence %0
}

rtg.target @invalidRandomizationTarget : !rtg.dict<cpu: !rtgtest.cpu> {
%0 = rtg.get_sequence @seq : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
rtg.context_switch #rtg.default : !rtgtest.cpu -> #rtgtest.cpu<0>, %0 : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
%1 = rtgtest.cpu_decl <0>
rtg.yield %1 : !rtgtest.cpu
}

rtg.test @invalidRandomization : !rtg.dict<cpu: !rtgtest.cpu> {
^bb0(%cpu: !rtgtest.cpu):
%0 = rtg.get_sequence @seq1 : !rtg.sequence
%1 = rtg.randomize_sequence %0
%2 = rtg.get_sequence @seq0 : !rtg.sequence<!rtg.randomized_sequence>
%3 = rtg.substitute_sequence %2(%1) : !rtg.sequence<!rtg.randomized_sequence>
rtg.on_context %cpu, %3 : !rtgtest.cpu
}

// -----

rtg.sequence @seq() {}

rtg.target @target : !rtg.dict<cpu: !rtgtest.cpu> {
%0 = rtgtest.cpu_decl <0>
rtg.yield %0 : !rtgtest.cpu
}

rtg.test @contextSwitchNotAvailable : !rtg.dict<cpu: !rtgtest.cpu> {
^bb0(%cpu: !rtgtest.cpu):
%0 = rtg.get_sequence @seq : !rtg.sequence
// expected-error @below {{no context transition registered to switch from #rtg.default : !rtgtest.cpu to #rtgtest.cpu<0> : !rtgtest.cpu}}
rtg.on_context %cpu, %0 : !rtgtest.cpu
}

0 comments on commit 1cafc54

Please sign in to comment.