Skip to content

Commit

Permalink
[RTGTest] Add the last remaining ops for RV32I (#8142)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Feb 3, 2025
1 parent deb88ba commit 5aa1cc3
Show file tree
Hide file tree
Showing 12 changed files with 390 additions and 15 deletions.
16 changes: 16 additions & 0 deletions include/circt-c/Dialect/RTGTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ MLIR_CAPI_EXPORTED MlirType rtgtestIntegerRegisterTypeGet(MlirContext ctxt);
// Immediates.
//===----------------------------------------------------------------------===//

/// If the type is an RTGTest Imm5Type.
MLIR_CAPI_EXPORTED bool rtgtestTypeIsAImm5(MlirType type);

/// Creates an RTGTest Imm5 type in the context.
MLIR_CAPI_EXPORTED MlirType rtgtestImm5TypeGet(MlirContext ctxt);

/// If the type is an RTGTest Imm12Type.
MLIR_CAPI_EXPORTED bool rtgtestTypeIsAImm12(MlirType type);

Expand Down Expand Up @@ -276,6 +282,16 @@ MLIR_CAPI_EXPORTED MlirAttribute rtgtestRegT6AttrGet(MlirContext ctxt);
// Immediates.
//===----------------------------------------------------------------------===//

/// If the attribute is an RTGTest Imm5Attr.
MLIR_CAPI_EXPORTED bool rtgtestAttrIsAImm5(MlirAttribute attr);

/// Creates an RTGTest Imm5 attribute in the context.
MLIR_CAPI_EXPORTED MlirAttribute rtgtestImm5AttrGet(MlirContext ctxt,
unsigned value);

/// Returns the value represented by the Imm5 attribute.
MLIR_CAPI_EXPORTED unsigned rtgtestImm5AttrGetValue(MlirAttribute attr);

/// If the attribute is an RTGTest Imm12Attr.
MLIR_CAPI_EXPORTED bool rtgtestAttrIsAImm12(MlirAttribute attr);

Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/RTGTest/IR/RTGTestAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ImmediateAttrBase<int width> : RTGTestAttrDef<"Imm" # width, [
let genVerifyDecl = 1;
}

def Imm5 : ImmediateAttrBase<5>;
def Imm12 : ImmediateAttrBase<12>;
def Imm13 : ImmediateAttrBase<13>;
def Imm21 : ImmediateAttrBase<21>;
Expand Down
204 changes: 201 additions & 3 deletions include/circt/Dialect/RTGTest/IR/RTGTestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def ImmediateOp : RTGTestOp<"immediate", [
]> {
let summary = "declare an immediate value";

let arguments = (ins AnyAttrOf<[Imm12, Imm13, Imm21, Imm32]>:$imm);
let arguments = (ins AnyAttrOf<[Imm5, Imm12, Imm13, Imm21, Imm32]>:$imm);
let results = (outs AnyType:$result);

let assemblyFormat = "$imm attr-dict";
Expand Down Expand Up @@ -98,7 +98,8 @@ class InstFormatIOpBase<string mnemonic, int opcode7, int funct3>
<< cast<rtg::RegisterAttrInterface>(adaptor.getRd())
.getRegisterAssembly()
<< ", "
<< cast<Imm12Attr>(adaptor.getImm()).getValue()
// The assembler only accepts signed values here.
<< cast<Imm12Attr>(adaptor.getImm()).getAPInt().getSExtValue()
<< "("
<< cast<rtg::RegisterAttrInterface>(adaptor.getRs())
.getRegisterAssembly()
Expand Down Expand Up @@ -182,6 +183,7 @@ class InstFormatBOpBase<string mnemonic, int opcode7, int funct3>
return;
}

// The assembler is fine with unsigned and signed values here.
os << cast<Imm13Attr>(adaptor.getImm()).getValue();
}
}];
Expand Down Expand Up @@ -266,7 +268,8 @@ class InstFormatSOpBase<string mnemonic, int opcode7, int funct3>
<< cast<rtg::RegisterAttrInterface>(adaptor.getRs1())
.getRegisterAssembly()
<< ", "
<< cast<Imm12Attr>(adaptor.getImm()).getValue()
// The assembler only accepts signed values here.
<< cast<Imm12Attr>(adaptor.getImm()).getAPInt().getSExtValue()
<< "("
<< cast<rtg::RegisterAttrInterface>(adaptor.getRs2())
.getRegisterAssembly()
Expand All @@ -275,8 +278,192 @@ class InstFormatSOpBase<string mnemonic, int opcode7, int funct3>
}];
}

class InstFormatUOpBase<string mnemonic, int opcode7>
: RTGTestOp<"rv32i." # mnemonic, [InstructionOpAdaptor]> {

let arguments = (ins IntegerRegisterType:$rd,
AnyTypeOf<[Imm32Type, LabelType]>:$imm);

let assemblyFormat = "$rd `,` $imm `:` type($imm) attr-dict";

let extraClassDefinition = [{
void $cppClass::printInstructionBinary(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
assert (isa<Imm32Attr>(adaptor.getImm()) &&
"binary of labels not supported");

auto rd = cast<rtg::RegisterAttrInterface>(adaptor.getRd());
auto imm = cast<Imm32Attr>(adaptor.getImm()).getAPInt();

auto binary = imm.extractBits(20, 12)
.concat(llvm::APInt(5, rd.getClassIndex()))
.concat(llvm::APInt(7, }] # opcode7 # [{));

SmallVector<char> str;
binary.toStringUnsigned(str, 16);
os << str;
}

void $cppClass::printInstructionAssembly(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
os << getOperationName().rsplit('.').second
<< " "
<< cast<rtg::RegisterAttrInterface>(adaptor.getRd())
.getRegisterAssembly()
<< ", ";

if (auto label = dyn_cast<StringAttr>(adaptor.getImm())) {
os << label.getValue();
return;
}

// The assembler wants an unsigned value here.
os << cast<Imm32Attr>(adaptor.getImm()).getValue();
}
}];
}

class InstFormatJOpBase<string mnemonic, int opcode7>
: RTGTestOp<"rv32i." # mnemonic, [InstructionOpAdaptor]> {

let arguments = (ins IntegerRegisterType:$rd,
AnyTypeOf<[Imm21Type, LabelType]>:$imm);

let assemblyFormat = "$rd `,` $imm `:` type($imm) attr-dict";

let extraClassDefinition = [{
void $cppClass::printInstructionBinary(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
assert (isa<Imm21Attr>(adaptor.getImm()) &&
"binary of labels not supported");

auto rd = cast<rtg::RegisterAttrInterface>(adaptor.getRd());
auto imm = cast<Imm21Attr>(adaptor.getImm()).getAPInt();

auto binary = imm.extractBits(1, 20)
.concat(imm.extractBits(10, 1))
.concat(imm.extractBits(1, 1))
.concat(imm.extractBits(8, 12))
.concat(llvm::APInt(5, rd.getClassIndex()))
.concat(llvm::APInt(7, }] # opcode7 # [{));

SmallVector<char> str;
binary.toStringUnsigned(str, 16);
os << str;
}

void $cppClass::printInstructionAssembly(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
os << getOperationName().rsplit('.').second
<< " "
<< cast<rtg::RegisterAttrInterface>(adaptor.getRd())
.getRegisterAssembly()
<< ", ";

if (auto label = dyn_cast<StringAttr>(adaptor.getImm())) {
os << label.getValue();
return;
}

// The assembler is fine with signed and unsigned values here.
os << cast<Imm21Attr>(adaptor.getImm()).getAPInt().getSExtValue();
}
}];
}

class InstFormatIAOpBase<string mnemonic, int opcode7, int funct3>
: RTGTestOp<"rv32i." # mnemonic, [InstructionOpAdaptor]> {

let arguments = (ins IntegerRegisterType:$rd,
IntegerRegisterType:$rs,
Imm12Type:$imm);

let assemblyFormat = "$rd `,` $rs `,` $imm attr-dict";

let extraClassDefinition = [{
void $cppClass::printInstructionBinary(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
auto rd = cast<rtg::RegisterAttrInterface>(adaptor.getRd());
auto rs = cast<rtg::RegisterAttrInterface>(adaptor.getRs());
auto imm = cast<Imm12Attr>(adaptor.getImm()).getAPInt();

auto binary = imm
.concat(llvm::APInt(5, rs.getClassIndex()))
.concat(llvm::APInt(3, }] # funct3 # [{))
.concat(llvm::APInt(5, rd.getClassIndex()))
.concat(llvm::APInt(7, }] # opcode7 # [{));

SmallVector<char> str;
binary.toStringUnsigned(str, 16);
os << str;
}

void $cppClass::printInstructionAssembly(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
os << getOperationName().rsplit('.').second
<< " "
<< cast<rtg::RegisterAttrInterface>(adaptor.getRd())
.getRegisterAssembly()
<< ", "
<< cast<rtg::RegisterAttrInterface>(adaptor.getRs())
.getRegisterAssembly()
<< ", "
// The assembler only accepts signed values here.
<< cast<Imm12Attr>(adaptor.getImm()).getAPInt().getSExtValue();
}
}];
}

class InstFormatShiftOpBase<string mnemonic, int opcode7,
int funct3, int funct7>
: RTGTestOp<"rv32i." # mnemonic, [InstructionOpAdaptor]> {

let arguments = (ins IntegerRegisterType:$rd,
IntegerRegisterType:$rs,
Imm5Type:$imm);

let assemblyFormat = "$rd `,` $rs `,` $imm attr-dict";

let extraClassDefinition = [{
void $cppClass::printInstructionBinary(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
auto rd = cast<rtg::RegisterAttrInterface>(adaptor.getRd());
auto rs = cast<rtg::RegisterAttrInterface>(adaptor.getRs());
auto imm = cast<Imm5Attr>(adaptor.getImm()).getAPInt();

auto binary = llvm::APInt(7, }] # funct7 # [{)
.concat(imm.extractBits(5, 0))
.concat(llvm::APInt(5, rs.getClassIndex()))
.concat(llvm::APInt(3, }] # funct3 # [{))
.concat(llvm::APInt(5, rd.getClassIndex()))
.concat(llvm::APInt(7, }] # opcode7 # [{));

SmallVector<char> str;
binary.toStringUnsigned(str, 16);
os << str;
}

void $cppClass::printInstructionAssembly(llvm::raw_ostream &os,
FoldAdaptor adaptor) {
os << getOperationName().rsplit('.').second
<< " "
<< cast<rtg::RegisterAttrInterface>(adaptor.getRd())
.getRegisterAssembly()
<< ", "
<< cast<rtg::RegisterAttrInterface>(adaptor.getRs())
.getRegisterAssembly()
<< ", "
// The assembler only accepts an unsigned value here.
<< cast<Imm5Attr>(adaptor.getImm()).getValue();
}
}];
}

//===- Instructions -------------------------------------------------------===//

def RV32I_LUI : InstFormatUOpBase<"lui", 0b0110111>;
def RV32I_AUIPC : InstFormatUOpBase<"auipc", 0b0010111>;
def RV32I_JAL : InstFormatJOpBase<"jal", 0b1101111>;
def RV32I_JALROp : InstFormatIOpBase<"jalr", 0b1100111, 0b000>;

def RV32I_BEQ : InstFormatBOpBase<"beq", 0b1100011, 0b000>;
Expand All @@ -296,6 +483,17 @@ def RV32I_SB : InstFormatSOpBase<"sb", 0b0100011, 0b000>;
def RV32I_SH : InstFormatSOpBase<"sh", 0b0100011, 0b001>;
def RV32I_SW : InstFormatSOpBase<"sw", 0b0100011, 0b010>;

def RV32I_ADDI : InstFormatIAOpBase<"addi", 0b0010011, 0b000>;
def RV32I_SLTI : InstFormatIAOpBase<"slti", 0b0010011, 0b010>;
def RV32I_SLTIU : InstFormatIAOpBase<"sltiu", 0b0010011, 0b011>;
def RV32I_XORI : InstFormatIAOpBase<"xori", 0b0010011, 0b100>;
def RV32I_ORI : InstFormatIAOpBase<"ori", 0b0010011, 0b110>;
def RV32I_ANDI : InstFormatIAOpBase<"andi", 0b0010011, 0b111>;

def RV32I_SLLI : InstFormatShiftOpBase<"slli", 0b0010011, 0b001, 0b0000000>;
def RV32I_SRLI : InstFormatShiftOpBase<"srli", 0b0010011, 0b101, 0b0000000>;
def RV32I_SRAI : InstFormatShiftOpBase<"srai", 0b0010011, 0b101, 0b0100000>;

def RV32I_ADD : InstFormatROpBase<"add", 0b110011, 0b000, 0b0000000>;
def RV32I_SUB : InstFormatROpBase<"sub", 0b110011, 0b000, 0b0100000>;
def RV32I_SLL : InstFormatROpBase<"sll", 0b110011, 0b001, 0b0000000>;
Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/RTGTest/IR/RTGTestTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ImmTypeBase<int width> : TypeDef<RTGTestDialect, "Imm" # width, []> {
let mnemonic = "imm" # width;
}

def Imm5Type : ImmTypeBase<5>;
def Imm12Type : ImmTypeBase<12>;
def Imm13Type : ImmTypeBase<13>;
def Imm21Type : ImmTypeBase<21>;
Expand Down
2 changes: 2 additions & 0 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
# CHECK: rtgtest.immediate #rtgtest.imm5<3> : !rtgtest.imm5
rtgtest.ImmediateOp(rtgtest.Imm5Attr.get(3))
# CHECK: rtgtest.immediate #rtgtest.imm12<3> : !rtgtest.imm12
rtgtest.ImmediateOp(rtgtest.Imm12Attr.get(3))
# CHECK: rtgtest.immediate #rtgtest.imm13<3> : !rtgtest.imm13
Expand Down
19 changes: 19 additions & 0 deletions lib/Bindings/Python/RTGTestModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ void circt::python::populateDialectRTGTestSubmodule(nb::module_ &m) {
},
nb::arg("self"), nb::arg("ctxt") = nullptr);

mlir_type_subclass(m, "Imm5Type", rtgtestTypeIsAImm5)
.def_classmethod(
"get",
[](nb::object cls, MlirContext ctxt) {
return cls(rtgtestImm5TypeGet(ctxt));
},
nb::arg("self"), nb::arg("ctxt") = nullptr);

mlir_type_subclass(m, "Imm12Type", rtgtestTypeIsAImm12)
.def_classmethod(
"get",
Expand Down Expand Up @@ -336,6 +344,17 @@ void circt::python::populateDialectRTGTestSubmodule(nb::module_ &m) {
},
nb::arg("self"), nb::arg("ctxt") = nullptr);

mlir_attribute_subclass(m, "Imm5Attr", rtgtestAttrIsAImm5)
.def_classmethod(
"get",
[](nb::object cls, unsigned value, MlirContext ctxt) {
return cls(rtgtestImm5AttrGet(ctxt, value));
},
nb::arg("self"), nb::arg("value"), nb::arg("ctxt") = nullptr)
.def_property_readonly("value", [](MlirAttribute self) {
return rtgtestImm5AttrGetValue(self);
});

mlir_attribute_subclass(m, "Imm12Attr", rtgtestAttrIsAImm12)
.def_classmethod(
"get",
Expand Down
18 changes: 18 additions & 0 deletions lib/CAPI/Dialect/RTGTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ MlirType rtgtestIntegerRegisterTypeGet(MlirContext ctxt) {
// Immediates.
//===----------------------------------------------------------------------===//

bool rtgtestTypeIsAImm5(MlirType type) { return isa<Imm5Type>(unwrap(type)); }

MlirType rtgtestImm5TypeGet(MlirContext ctxt) {
return wrap(Imm5Type::get(unwrap(ctxt)));
}

bool rtgtestTypeIsAImm12(MlirType type) { return isa<Imm12Type>(unwrap(type)); }

MlirType rtgtestImm12TypeGet(MlirContext ctxt) {
Expand Down Expand Up @@ -345,6 +351,18 @@ MlirAttribute rtgtestRegT6AttrGet(MlirContext ctxt) {
// Immediates.
//===----------------------------------------------------------------------===//

bool rtgtestAttrIsAImm5(MlirAttribute attr) {
return isa<Imm5Attr>(unwrap(attr));
}

MlirAttribute rtgtestImm5AttrGet(MlirContext ctxt, unsigned value) {
return wrap(Imm5Attr::get(unwrap(ctxt), value));
}

unsigned rtgtestImm5AttrGetValue(MlirAttribute attr) {
return cast<Imm5Attr>(unwrap(attr)).getValue();
}

bool rtgtestAttrIsAImm12(MlirAttribute attr) {
return isa<Imm12Attr>(unwrap(attr));
}
Expand Down
14 changes: 14 additions & 0 deletions test/CAPI/rtgtest.c
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ static void testRegisters(MlirContext ctx) {
}

static void testImmediates(MlirContext ctx) {
MlirType imm5Type = rtgtestImm5TypeGet(ctx);
// CHECK: is_imm5
fprintf(stderr, rtgtestTypeIsAImm5(imm5Type) ? "is_imm5\n" : "isnot_imm5\n");
// CHECK: !rtgtest.imm5
mlirTypeDump(imm5Type);

MlirType imm12Type = rtgtestImm12TypeGet(ctx);
// CHECK: is_imm12
fprintf(stderr,
Expand Down Expand Up @@ -299,6 +305,14 @@ static void testImmediates(MlirContext ctx) {
// CHECK: !rtgtest.imm32
mlirTypeDump(imm32Type);

MlirAttribute imm5Attr = rtgtestImm5AttrGet(ctx, 3);
// CHECK: is_imm5
fprintf(stderr, rtgtestAttrIsAImm5(imm5Attr) ? "is_imm5\n" : "isnot_imm5\n");
// CHECK: 3
fprintf(stderr, "%u\n", rtgtestImm5AttrGetValue(imm5Attr));
// CHECK: #rtgtest.imm5<3>
mlirAttributeDump(imm5Attr);

MlirAttribute imm12Attr = rtgtestImm12AttrGet(ctx, 3);
// CHECK: is_imm12
fprintf(stderr,
Expand Down
Loading

0 comments on commit 5aa1cc3

Please sign in to comment.