Skip to content

Commit

Permalink
[SYCL][Matrix] Extend W/A for more corner cases of AccessChain usage (#…
Browse files Browse the repository at this point in the history
…16370)

The new corner case is: AccessChain is used on arrays of Joint Matrices
Fix for CMPLRLLVM-64465
  • Loading branch information
YuriPlyakhin authored Dec 24, 2024
1 parent a813b55 commit aeddb8d
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 29 deletions.
132 changes: 113 additions & 19 deletions llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,74 @@ namespace {
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";

Type *getInnermostType(Type *Ty) {
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
Ty = ArrayTy->getElementType();
return Ty;
}

Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
if (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
return ArrayType::get(
replaceInnermostType(ArrayTy->getElementType(), NewInnermostTy),
ArrayTy->getNumElements());
return NewInnermostTy;
}

// This function is a copy of stripPointerCastsAndOffsets from Value.cpp,
// simplified and modified to strip non-zero GEP indices as well and also
// find nearest GEP instruction.
Value *stripPointerCastsAndOffsets(Value *V, bool StopOnGEP = false) {
if (!V->getType()->isPointerTy())
return V;

// Even though we don't look through PHI nodes, we could be called on an
// instruction in an unreachable block, which may be on a cycle.
SmallPtrSet<Value *, 4> Visited;

Visited.insert(V);
do {
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
if (StopOnGEP && isa<GetElementPtrInst>(GEP))
return V;
V = GEP->getPointerOperand();
} else if (auto *BC = dyn_cast<BitCastOperator>(V)) {
Value *NewV = BC->getOperand(0);
if (!NewV->getType()->isPointerTy())
return V;
V = NewV;
} else if (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V)) {
V = ASC->getOperand(0);
} else {
if (auto *Call = dyn_cast<CallBase>(V)) {
if (Value *RV = Call->getReturnedArgOperand()) {
V = RV;
// Strip the call instruction, since callee returns its RV
// argument as return value. So, we need to continue stripping.
continue;
}
}
return V;
}
assert(V->getType()->isPointerTy() && "Unexpected operand type!");
} while (Visited.insert(V).second);

return V;
}

TargetExtType *extractMatrixType(StructType *WrapperMatrixTy) {
if (!WrapperMatrixTy)
return nullptr;
TargetExtType *MatrixTy =
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));

if (!MatrixTy)
return nullptr;
if (MatrixTy->getName() != MATRIX_TYPE)
return nullptr;
return MatrixTy;
}

// This function finds all calls to __spirv_AccessChain function and transforms
// its users and operands to make LLVM IR more SPIR-V friendly.
bool transformAccessChain(Function *F) {
Expand Down Expand Up @@ -60,33 +128,59 @@ bool transformAccessChain(Function *F) {
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
// function call. It's necessary because otherwise OpAccessChain indices
// would be wrong.
Instruction *Ptr =
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
Instruction *Ptr = dyn_cast<Instruction>(
stripPointerCastsAndOffsets(CI->getArgOperand(0)));
if (!Ptr || !isa<AllocaInst>(Ptr))
continue;
StructType *WrapperMatrixTy =
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
if (!WrapperMatrixTy)
continue;
TargetExtType *MatrixTy =
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
if (!MatrixTy)

Type *AllocaTy = cast<AllocaInst>(Ptr)->getAllocatedType();
// It may happen that sycl::joint_matrix class object is wrapped into
// nested arrays. We need to find the innermost type to extract
if (StructType *WrapperMatrixTy =
dyn_cast<StructType>(getInnermostType(AllocaTy))) {
TargetExtType *MatrixTy = extractMatrixType(WrapperMatrixTy);
if (!MatrixTy)
continue;

AllocaInst *Alloca = nullptr;
{
IRBuilder Builder(CI);
IRBuilderBase::InsertPointGuard IG(Builder);
Builder.SetInsertPointPastAllocas(CI->getFunction());
Alloca = Builder.CreateAlloca(replaceInnermostType(AllocaTy, MatrixTy));
Alloca->takeName(Ptr);
}
Ptr->replaceAllUsesWith(Alloca);
Ptr->dropAllReferences();
Ptr->eraseFromParent();
ModuleChanged = true;
}

// In case spirv.CooperativeMatrixKHR is used in arrays, we also need to
// insert GEP to get pointer to target exention type and use it instead of
// pointer to sycl::joint_matrix class object when it is passed to
// __spirv_AccessChain
// First we check if the argument came from a GEP instruction
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(
stripPointerCastsAndOffsets(CI->getArgOperand(0), /*StopOnGEP=*/true));
if (!GEP)
continue;
StringRef Name = MatrixTy->getName();
if (Name != MATRIX_TYPE)

// Check if GEP return type is a pointer to sycl::joint_matrix class object
StructType *WrapperMatrixTy =
dyn_cast<StructType>(GEP->getResultElementType());
if (!extractMatrixType(WrapperMatrixTy))
continue;

AllocaInst *Alloca = nullptr;
// Insert GEP right before the __spirv_AccessChain call
{
IRBuilder Builder(CI);
IRBuilderBase::InsertPointGuard IG(Builder);
Builder.SetInsertPointPastAllocas(CI->getFunction());
Alloca = Builder.CreateAlloca(MatrixTy);
Value *NewGEP =
Builder.CreateInBoundsGEP(WrapperMatrixTy, CI->getArgOperand(0),
{Builder.getInt64(0), Builder.getInt32(0)});
CI->setArgOperand(0, NewGEP);
ModuleChanged = true;
}
Ptr->replaceAllUsesWith(Alloca);
Ptr->dropAllReferences();
Ptr->eraseFromParent();
ModuleChanged = true;
}
return ModuleChanged;
}
Expand Down
65 changes: 55 additions & 10 deletions llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,69 @@

; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s

; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
target triple = "spir64-unknown-unknown"

%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
%"struct.sycl::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
%"struct.sycl::_V1::long" = type { i64 }

define weak_odr dso_local spir_kernel void @test(i64 %ind) {
; CHECK-LABEL: define weak_odr dso_local spir_kernel void @test(
; CHECK-SAME: i64 [[IND:%.*]]) {

; non-matrix alloca not touched
; CHECK: [[NOT_MATR:%.*]] = alloca [2 x [4 x %"struct.sycl::_V1::long"]]
; both matrix-related allocas updated to use target extension types
; CHECK-NEXT: [[MATR:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK-NEXT: [[MATR_ARR:%.*]] = alloca [2 x [4 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]

; CHECK-NEXT: [[ASCAST:%.*]] = addrspacecast ptr [[MATR]] to ptr addrspace(4)
; no gep inserted, since not needed
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST]], i64 noundef 0)

; CHECK: [[GEP:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr [[MATR_ARR]], i64 0, i64 [[IND]], i64 [[IND]]
; CHECK-NEXT: [[ASCAST_1:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
; CHECK-NEXT: [[ASCAST_2:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
; gep is inserted for each of the accesschain calls to extract target extension type
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_1]], i64 0, i32 0
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP2]], i64 noundef 0)
; CHECK: [[TMP5:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_2]], i64 0, i32 0
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)

; negative test - not touching non-matrix code
; CHECK: [[GEP_1:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr [[NOT_MATR]], i64 0, i64 [[IND]], i64 [[IND]]
; CHECK-NEXT: [[ASCAST_3:%.*]] = addrspacecast ptr [[GEP_1]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST_3]], i64 noundef 0)

define weak_odr dso_local spir_kernel void @test() {
entry:
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
%1 = addrspacecast ptr %0 to ptr addrspace(4)
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
; allocas
%matr = alloca %"struct.sycl::joint_matrix", align 8
%matr.arr = alloca [2 x [4 x %"struct.sycl::joint_matrix"]], align 8
%not.matr = alloca [2 x [4 x %"struct.sycl::_V1::long"]], align 8

; simple case
%ascast = addrspacecast ptr %matr to ptr addrspace(4)
%0 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast, i64 noundef 0)
%1 = load i8, ptr addrspace(4) %0

; gep with non-zero inidices and multiple access chains per 1 alloca
%gep = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr %matr.arr, i64 0, i64 %ind, i64 %ind
%ascast.1 = addrspacecast ptr %gep to ptr addrspace(4)
%ascast.2 = addrspacecast ptr %gep to ptr addrspace(4)
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.1, i64 noundef 0)
%3 = load i8, ptr addrspace(4) %2
%4 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.2, i64 noundef 0)
%5 = load i8, ptr addrspace(4) %4

; negative test - not touching non-matrix code
%gep.1 = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr %not.matr, i64 0, i64 %ind, i64 %ind
%ascast.3 = addrspacecast ptr %gep.1 to ptr addrspace(4)
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.3, i64 noundef 0)
%7 = load i8, ptr addrspace(4) %6

ret void
}

declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef, i64 noundef)

0 comments on commit aeddb8d

Please sign in to comment.