From ee90e9f6507378d350ad71a1ac4af99b64974ee3 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Thu, 30 Jan 2025 12:18:43 +0000 Subject: [PATCH] [RTG] Add operation to get a random number within a range --- include/circt/Dialect/RTG/IR/RTGOps.td | 18 +++++++++++++ include/circt/Dialect/RTG/IR/RTGVisitors.h | 3 +++ .../RTG/Transforms/ElaborationPass.cpp | 19 ++++++++++++++ test/Dialect/RTG/IR/basic.mlir | 6 +++++ test/Dialect/RTG/Transform/elaboration.mlir | 26 +++++++++++++++++++ 5 files changed, 72 insertions(+) diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index 0de59484fcd9..66968002f0c8 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -367,6 +367,24 @@ def BagUniqueSizeOp : RTGOp<"bag_unique_size", [Pure]> { }]; } +//===- Integer Operations -------------------------------------------------===// + +def RandomNumberInRangeOp : RTGOp<"random_number_in_range", []> { + let summary = "returns a number uniformly at random within the given range"; + let description = [{ + This operation computes a random number based on a uniform distribution + within the given range. The lower bound is inclusive while the upper bound + is exclusive. If the range is empty, compilation will fail. + This is (obviously) more performant than inserting all legal numbers into a + set and using 'set_select_random', but yields the same behavior. + }]; + + let arguments = (ins Index:$lowerBound, Index:$upperBound); + let results = (outs Index:$result); + + let assemblyFormat = "` ` `[` $lowerBound `,` $upperBound `)` attr-dict"; +} + //===- ISA Register Handling Operations -----------------------------------===// def FixedRegisterOp : RTGOp<"fixed_reg", [ diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index 13cb2b1a2fbe..f7ad6bfb29f5 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -41,6 +41,8 @@ class RTGOpVisitor { FixedRegisterOp, VirtualRegisterOp, // RTG tests TestOp, TargetOp, YieldOp, + // Integers + RandomNumberInRangeOp, // Sequences SequenceOp, GetSequenceOp, SubstituteSequenceOp, RandomizeSequenceOp, EmbedSequenceOp, @@ -92,6 +94,7 @@ class RTGOpVisitor { HANDLE(SubstituteSequenceOp, Unhandled); HANDLE(RandomizeSequenceOp, Unhandled); HANDLE(EmbedSequenceOp, Unhandled); + HANDLE(RandomNumberInRangeOp, Unhandled); HANDLE(SetCreateOp, Unhandled); HANDLE(SetSelectRandomOp, Unhandled); HANDLE(SetDifferenceOp, Unhandled); diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index be939b0d2c7f..ba79b7d4adb4 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -1017,6 +1017,25 @@ class Elaborator : public RTGOpVisitor> { FailureOr visitOp(LabelOp op) { return DeletionKind::Keep; } + FailureOr visitOp(RandomNumberInRangeOp op) { + size_t lower = get(op.getLowerBound()); + size_t upper = get(op.getUpperBound()) - 1; + if (lower > upper) + return op->emitError("cannot select a number from an empty range"); + + if (auto intAttr = + op->getAttrOfType("rtg.elaboration_custom_seed")) { + std::mt19937 customRng(intAttr.getInt()); + state[op.getResult()] = + size_t(getUniformlyInRange(customRng, lower, upper)); + } else { + state[op.getResult()] = + size_t(getUniformlyInRange(sharedState.rng, lower, upper)); + } + + return DeletionKind::Delete; + } + FailureOr visitOp(scf::IfOp op) { bool cond = get(op.getCondition()); auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion(); diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index cc4b4d18bd37..9ff816dca067 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -106,3 +106,9 @@ rtg.target @target : !rtg.dict { rtg.test @test : !rtg.dict { ^bb0(%arg0: i32, %arg1: i32): } + +// CHECK-LABEL: rtg.sequence @integerHandlingOps +rtg.sequence @integerHandlingOps(%arg0: index, %arg1: index) { + // CHECK: rtg.random_number_in_range [%arg0, %arg1) + rtg.random_number_in_range [%arg0, %arg1) +} diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index 26cddbdd92fe..427e2aa65158 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -374,6 +374,21 @@ rtg.test @labels : !rtg.dict<> { rtg.label local %l4 } +// CHECK-LABEL: rtg.test @randomIntegers +rtg.test @randomIntegers : !rtg.dict<> { + %lower = index.constant 5 + %upper = index.constant 10 + %0 = rtg.random_number_in_range [%lower, %upper) {rtg.elaboration_custom_seed=0} + // CHECK-NEXT: [[V0:%.+]] = index.constant 5 + // CHECK-NEXT: func.call @dummy2([[V0]]) + func.call @dummy2(%0) : (index) -> () + + %1 = rtg.random_number_in_range [%lower, %upper) {rtg.elaboration_custom_seed=3} + // CHECK-NEXT: [[V1:%.+]] = index.constant 8 + // CHECK-NEXT: func.call @dummy2([[V1]]) + func.call @dummy2(%1) : (index) -> () +} + // ----- rtg.test @nestedRegionsNotSupported : !rtg.dict<> { @@ -398,3 +413,14 @@ rtg.test @untypedAttributes : !rtg.dict<> { // expected-note @below {{while materializing value for operand#0}} func.call @dummy(%0) : (index) -> () } + +// ----- + +func.func @dummy2(%arg0: index) -> () {return} + +rtg.test @randomIntegers : !rtg.dict<> { + %c5 = index.constant 5 + // expected-error @below {{cannot select a number from an empty range}} + %0 = rtg.random_number_in_range [%c5, %c5) + func.call @dummy2(%0) : (index) -> () +}