Skip to content

Commit

Permalink
[core] Simplify autogenerated code. (#2083)
Browse files Browse the repository at this point in the history
Use the cc.offsetof operation to reduce the size of the code.
  • Loading branch information
schweitzpgi authored Aug 14, 2024
1 parent d117ad7 commit a817c65
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 37 deletions.
34 changes: 31 additions & 3 deletions lib/Optimizer/CodeGen/CCToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,33 @@ class SizeOfOpPattern : public ConvertOpToLLVMPattern<cudaq::cc::SizeOfOp> {
}
};

class OffsetOfOpPattern : public ConvertOpToLLVMPattern<cudaq::cc::OffsetOfOp> {
public:
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

// Use the GEP approach for now. LLVM is planning to remove support for this
// at some point. See: https://github.com/llvm/llvm-project/issues/71507
LogicalResult
matchAndRewrite(cudaq::cc::OffsetOfOp offsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputTy = offsetOp.getInputType();
SmallVector<cudaq::cc::ComputePtrArg> args;
for (std::int32_t i : offsetOp.getConstantIndices())
args.push_back(i);
auto resultTy = offsetOp.getType();
auto loc = offsetOp.getLoc();
// TODO: replace this with some target-specific memory layout computation
// when we upgrade to a newer MLIR.
auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
auto ptrTy = cudaq::cc::PointerType::get(inputTy);
auto nul = rewriter.create<cudaq::cc::CastOp>(loc, ptrTy, zero);
Value nextPtr =
rewriter.create<cudaq::cc::ComputePtrOp>(loc, ptrTy, nul, args);
rewriter.replaceOpWithNewOp<cudaq::cc::CastOp>(offsetOp, resultTy, nextPtr);
return success();
}
};

class StdvecDataOpPattern
: public ConvertOpToLLVMPattern<cudaq::cc::StdvecDataOp> {
public:
Expand Down Expand Up @@ -647,7 +674,8 @@ void cudaq::opt::populateCCToLLVMPatterns(LLVMTypeConverter &typeConverter,
ComputePtrOpPattern, CreateStringLiteralOpPattern,
ExtractValueOpPattern, FuncToPtrOpPattern, GlobalOpPattern,
InsertValueOpPattern, InstantiateCallableOpPattern,
LoadOpPattern, PoisonOpPattern, SizeOfOpPattern,
StdvecDataOpPattern, StdvecInitOpPattern, StdvecSizeOpPattern,
StoreOpPattern, UndefOpPattern>(typeConverter);
LoadOpPattern, OffsetOfOpPattern, PoisonOpPattern,
SizeOfOpPattern, StdvecDataOpPattern, StdvecInitOpPattern,
StdvecSizeOpPattern, StoreOpPattern, UndefOpPattern>(
typeConverter);
}
22 changes: 6 additions & 16 deletions lib/Optimizer/Transforms/GenKernelExecution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,13 @@ class GenerateKernelExecution

Value genComputeReturnOffset(Location loc, OpBuilder &builder,
FunctionType funcTy,
cudaq::cc::StructType msgStructTy,
Value nullSt) {
auto i64Ty = builder.getI64Type();
cudaq::cc::StructType msgStructTy) {
if (funcTy.getNumResults() == 0)
return builder.create<arith::ConstantIntOp>(loc, NoResultOffset, 64);
auto members = msgStructTy.getMembers();
std::int32_t numKernelArgs = funcTy.getNumInputs();
auto resTy = cudaq::cc::PointerType::get(members[numKernelArgs]);
auto gep = builder.create<cudaq::cc::ComputePtrOp>(
loc, resTy, nullSt,
SmallVector<cudaq::cc::ComputePtrArg>{numKernelArgs});
return builder.create<cudaq::cc::CastOp>(loc, i64Ty, gep);
auto i64Ty = builder.getI64Type();
return builder.create<cudaq::cc::OffsetOfOp>(
loc, i64Ty, msgStructTy, ArrayRef<std::int32_t>{numKernelArgs});
}

/// Create a function that determines the return value offset in the message
Expand All @@ -315,11 +310,8 @@ class GenerateKernelExecution
OpBuilder::InsertionGuard guard(builder);
auto *entry = returnOffsetFunc.addEntryBlock();
builder.setInsertionPointToStart(entry);
auto ptrTy = cudaq::cc::PointerType::get(msgStructTy);
auto zero = builder.create<arith::ConstantIntOp>(loc, 0, 64);
auto basePtr = builder.create<cudaq::cc::CastOp>(loc, ptrTy, zero);
auto result =
genComputeReturnOffset(loc, builder, devKernelTy, msgStructTy, basePtr);
genComputeReturnOffset(loc, builder, devKernelTy, msgStructTy);
builder.create<func::ReturnOp>(loc, result);
}

Expand Down Expand Up @@ -1272,7 +1264,6 @@ class GenerateKernelExecution

// Compute the struct size without the trailing bytes, structSize, and
// with the trailing bytes, extendedStructSize.
auto nullSt = builder.create<cudaq::cc::CastOp>(loc, structPtrTy, zero);
Value structSize =
builder.create<cudaq::cc::SizeOfOp>(loc, i64Ty, structTy);
extendedStructSize =
Expand Down Expand Up @@ -1332,8 +1323,7 @@ class GenerateKernelExecution
castLoadThunk =
builder.create<cudaq::cc::FuncToPtrOp>(loc, ptrI8Ty, loadThunk);
castTemp = builder.create<cudaq::cc::CastOp>(loc, ptrI8Ty, temp);
resultOffset =
genComputeReturnOffset(loc, builder, devFuncTy, structTy, nullSt);
resultOffset = genComputeReturnOffset(loc, builder, devFuncTy, structTy);
}

Value vecArgPtrs;
Expand Down
16 changes: 6 additions & 10 deletions test/Quake/kernel_exec-1.qke
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ module attributes {quake.mangled_name_map = {

// CHECK-LABEL: func.func @_ZN3ghzclEi(
// CHECK-SAME: %[[VAL_0:.*]]: !cc.ptr<i8>, %[[VAL_1:.*]]: i32) -> f64 {
// CHECK: %[[VAL_2:.*]] = cc.undef !cc.struct<{i32, f64}>
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_2:.*]] = cc.undef !cc.struct<{i32, f64}>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_4:.*]] = cc.insert_value %[[VAL_1]], %[[VAL_2]][0] : (!cc.struct<{i32, f64}>, i32) -> !cc.struct<{i32, f64}>
// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_3]] : (i64) -> !cc.ptr<!cc.struct<{i32, f64}>>
// CHECK: %[[VAL_7:.*]] = cc.sizeof !cc.struct<{i32, f64}> : i64
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : i64
// CHECK: %[[VAL_9:.*]] = cc.alloca i8[%[[VAL_8]] : i64]
Expand All @@ -100,8 +99,7 @@ module attributes {quake.mangled_name_map = {
// CHECK: %[[VAL_13:.*]] = constant @ghz.thunk : (!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>
// CHECK: %[[VAL_15:.*]] = cc.func_ptr %[[VAL_13]] : ((!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>) -> !cc.ptr<i8>
// CHECK: %[[VAL_16:.*]] = cc.cast %[[VAL_11]] : (!cc.ptr<!cc.array<!cc.struct<{i32, f64}> x ?>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_17:.*]] = cc.compute_ptr %[[VAL_5]][1] : (!cc.ptr<!cc.struct<{i32, f64}>>) -> !cc.ptr<f64>
// CHECK: %[[VAL_18:.*]] = cc.cast %[[VAL_17]] : (!cc.ptr<f64>) -> i64
// CHECK: %[[VAL_18:.*]] = cc.offsetof !cc.struct<{i32, f64}> [1] : i64
// CHECK: %[[VAL_12:.*]] = llvm.mlir.addressof @ghz.kernelName : !llvm.ptr<array<4 x i8>>
// CHECK: %[[VAL_14:.*]] = cc.cast %[[VAL_12]] : (!llvm.ptr<array<4 x i8>>) -> !cc.ptr<i8>
// CHECK: call @altLaunchKernel(%[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_8]], %[[VAL_18]]) : (!cc.ptr<i8>, !cc.ptr<i8>, !cc.ptr<i8>, i64, i64) -> ()
Expand Down Expand Up @@ -132,8 +130,8 @@ module attributes {quake.mangled_name_map = {
// CHECK: }

// CHECK-LABEL: func.func @ghz.argsCreator(
// CHECK-SAME: %[[VAL_0:.*]]: !cc.ptr<!cc.ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !cc.ptr<!cc.ptr<i8>>) -> i64 {
// CHECK-SAME: %[[VAL_0:.*]]: !cc.ptr<!cc.ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !cc.ptr<!cc.ptr<i8>>) -> i64 {
// CHECK: %[[VAL_2:.*]] = cc.undef !cc.struct<{i32, f64}>
// CHECK: %[[VAL_14:.*]] = cc.cast %[[VAL_0]] : (!cc.ptr<!cc.ptr<i8>>) -> !cc.ptr<!cc.array<!cc.ptr<i8> x ?>>
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64
Expand Down Expand Up @@ -194,7 +192,6 @@ module attributes {quake.mangled_name_map = {
// HYBRID: %[[VAL_2:.*]] = cc.undef !cc.struct<{i32, f64}>
// HYBRID: %[[VAL_3:.*]] = arith.constant 0 : i64
// HYBRID: %[[VAL_4:.*]] = cc.insert_value %[[VAL_1]], %[[VAL_2]][0] : (!cc.struct<{i32, f64}>, i32) -> !cc.struct<{i32, f64}>
// HYBRID: %[[VAL_5:.*]] = cc.cast %[[VAL_3]] : (i64) -> !cc.ptr<!cc.struct<{i32, f64}>>
// HYBRID: %[[VAL_6:.*]] = cc.sizeof !cc.struct<{i32, f64}> : i64
// HYBRID: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_3]] : i64
// HYBRID: %[[VAL_8:.*]] = cc.alloca i8{{\[}}%[[VAL_7]] : i64]
Expand All @@ -204,8 +201,7 @@ module attributes {quake.mangled_name_map = {
// HYBRID: %[[VAL_11:.*]] = constant @ghz.thunk : (!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>
// HYBRID: %[[VAL_12:.*]] = cc.func_ptr %[[VAL_11]] : ((!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>) -> !cc.ptr<i8>
// HYBRID: %[[VAL_13:.*]] = cc.cast %[[VAL_10]] : (!cc.ptr<!cc.array<!cc.struct<{i32, f64}> x ?>>) -> !cc.ptr<i8>
// HYBRID: %[[VAL_14:.*]] = cc.compute_ptr %[[VAL_5]][1] : (!cc.ptr<!cc.struct<{i32, f64}>>) -> !cc.ptr<f64>
// HYBRID: %[[VAL_15:.*]] = cc.cast %[[VAL_14]] : (!cc.ptr<f64>) -> i64
// HYBRID: %[[VAL_15:.*]] = cc.offsetof !cc.struct<{i32, f64}> [1] : i64
// HYBRID: %[[VAL_16:.*]] = cc.alloca !cc.struct<{!cc.ptr<!cc.ptr<i8>>, !cc.ptr<!cc.ptr<i8>>, !cc.ptr<!cc.ptr<i8>>}>
// HYBRID: %[[VAL_17:.*]] = cc.alloca !cc.array<!cc.ptr<i8> x 1>
// HYBRID: %[[VAL_18:.*]] = cc.sizeof !cc.array<!cc.ptr<i8> x 1> : i64
Expand Down
10 changes: 2 additions & 8 deletions test/Quake/return_vector.qke
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,15 @@ func.func @test_0(%0: !cc.ptr<!cc.struct<{!cc.ptr<i32>, !cc.ptr<i32>, !cc.ptr<i3
// CHECK-SAME: %[[VAL_1:.*]]: !cc.ptr<i8>, %[[VAL_2:.*]]: i32) {
// CHECK: %[[VAL_3:.*]] = arith.constant 4 : i64
// CHECK: %[[VAL_4:.*]] = constant @test_0.thunk : (!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_6:.*]] = cc.undef !cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>
// CHECK: %[[VAL_7:.*]] = cc.insert_value %[[VAL_2]], %[[VAL_6]][0] : (!cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>, i32) -> !cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>
// CHECK: %[[VAL_8:.*]] = cc.cast %[[VAL_5]] : (i64) -> !cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>>
// CHECK: %[[VAL_9:.*]] = cc.sizeof !cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}> : i64
// CHECK: %[[VAL_10:.*]] = cc.alloca i8{{\[}}%[[VAL_9]] : i64]
// CHECK: %[[VAL_11:.*]] = cc.cast %[[VAL_10]] : (!cc.ptr<!cc.array<i8 x ?>>) -> !cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>>
// CHECK: cc.store %[[VAL_7]], %[[VAL_11]] : !cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>>
// CHECK: %[[VAL_14:.*]] = cc.func_ptr %[[VAL_4]] : ((!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>) -> !cc.ptr<i8>
// CHECK: %[[VAL_15:.*]] = cc.cast %[[VAL_10]] : (!cc.ptr<!cc.array<i8 x ?>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_16:.*]] = cc.compute_ptr %[[VAL_8]][1] : (!cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}>>) -> !cc.ptr<!cc.struct<{!cc.ptr<i32>, i64}>>
// CHECK: %[[VAL_17:.*]] = cc.cast %[[VAL_16]] : (!cc.ptr<!cc.struct<{!cc.ptr<i32>, i64}>>) -> i64
// CHECK: %[[VAL_17:.*]] = cc.offsetof !cc.struct<{i32, !cc.struct<{!cc.ptr<i32>, i64}>}> [1] : i64
// CHECK: %[[VAL_12:.*]] = llvm.mlir.addressof @test_0.kernelName : !llvm.ptr<array<7 x i8>>
// CHECK: %[[VAL_13:.*]] = cc.cast %[[VAL_12]] : (!llvm.ptr<array<7 x i8>>) -> !cc.ptr<i8>
// CHECK: call @altLaunchKernel(%[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_9]], %[[VAL_17]]) : (!cc.ptr<i8>, !cc.ptr<i8>, !cc.ptr<i8>, i64, i64) -> ()
Expand Down Expand Up @@ -113,18 +110,15 @@ func.func @test_1(%0: !cc.ptr<!cc.struct<{!cc.ptr<f64>, !cc.ptr<f64>, !cc.ptr<f6
// CHECK-SAME: %[[VAL_1:.*]]: !cc.ptr<i8>, %[[VAL_2:.*]]: i32) {
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : i64
// CHECK: %[[VAL_4:.*]] = constant @test_1.thunk : (!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_6:.*]] = cc.undef !cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>
// CHECK: %[[VAL_7:.*]] = cc.insert_value %[[VAL_2]], %[[VAL_6]][0] : (!cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>, i32) -> !cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>
// CHECK: %[[VAL_8:.*]] = cc.cast %[[VAL_5]] : (i64) -> !cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>>
// CHECK: %[[VAL_9:.*]] = cc.sizeof !cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}> : i64
// CHECK: %[[VAL_10:.*]] = cc.alloca i8{{\[}}%[[VAL_9]] : i64]
// CHECK: %[[VAL_11:.*]] = cc.cast %[[VAL_10]] : (!cc.ptr<!cc.array<i8 x ?>>) -> !cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>>
// CHECK: cc.store %[[VAL_7]], %[[VAL_11]] : !cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>>
// CHECK: %[[VAL_14:.*]] = cc.func_ptr %[[VAL_4]] : ((!cc.ptr<i8>, i1) -> !cc.struct<{!cc.ptr<i8>, i64}>) -> !cc.ptr<i8>
// CHECK: %[[VAL_15:.*]] = cc.cast %[[VAL_10]] : (!cc.ptr<!cc.array<i8 x ?>>) -> !cc.ptr<i8>
// CHECK: %[[VAL_16:.*]] = cc.compute_ptr %[[VAL_8]][1] : (!cc.ptr<!cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}>>) -> !cc.ptr<!cc.struct<{!cc.ptr<f64>, i64}>>
// CHECK: %[[VAL_17:.*]] = cc.cast %[[VAL_16]] : (!cc.ptr<!cc.struct<{!cc.ptr<f64>, i64}>>) -> i64
// CHECK: %[[VAL_17:.*]] = cc.offsetof !cc.struct<{i32, !cc.struct<{!cc.ptr<f64>, i64}>}> [1] : i64
// CHECK: %[[VAL_12:.*]] = llvm.mlir.addressof @test_1.kernelName : !llvm.ptr<array<7 x i8>>
// CHECK: %[[VAL_13:.*]] = cc.cast %[[VAL_12]] : (!llvm.ptr<array<7 x i8>>) -> !cc.ptr<i8>
// CHECK: call @altLaunchKernel(%[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_9]], %[[VAL_17]]) : (!cc.ptr<i8>, !cc.ptr<i8>, !cc.ptr<i8>, i64, i64) -> ()
Expand Down

0 comments on commit a817c65

Please sign in to comment.