Skip to content

Commit

Permalink
[AArch64] SME implementation for agnostic-ZA functions (#120150)
Browse files Browse the repository at this point in the history
This implements the lowering of calls from agnostic-ZA functions to
non-agnostic-ZA functions, using the ABI routines
`__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`.

This implements the proposal described in the following PRs:
* ARM-software/acle#336
* ARM-software/abi-aa#264
  • Loading branch information
sdesmalen-arm authored Dec 23, 2024
1 parent d8e7929 commit 2ce168b
Show file tree
Hide file tree
Showing 13 changed files with 486 additions and 41 deletions.
24 changes: 14 additions & 10 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2268,19 +2268,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
Attrs.hasFnAttr("aarch64_inout_za") +
Attrs.hasFnAttr("aarch64_out_za") +
Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
Attrs.hasFnAttr("aarch64_preserves_za") +
Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
"'aarch64_inout_za', 'aarch64_preserves_za' and "
"'aarch64_za_state_agnostic' are mutually exclusive",
V);

Check(
(Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
Attrs.hasFnAttr("aarch64_inout_zt0") +
Attrs.hasFnAttr("aarch64_out_zt0") +
Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
V);
Check((Attrs.hasFnAttr("aarch64_new_zt0") +
Attrs.hasFnAttr("aarch64_in_zt0") +
Attrs.hasFnAttr("aarch64_inout_zt0") +
Attrs.hasFnAttr("aarch64_out_zt0") +
Attrs.hasFnAttr("aarch64_preserves_zt0") +
Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
"'aarch64_za_state_agnostic' are mutually exclusive",
V);

if (Attrs.hasFnAttr(Attribute::JumpTable)) {
const GlobalValue *GV = cast<GlobalValue>(V);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
SMEAttrs CallerAttrs(*FuncInfo.Fn);
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface())
CallerAttrs.hasStreamingCompatibleInterface() ||
CallerAttrs.hasAgnosticZAInterface())
return nullptr;
return new AArch64FastISel(FuncInfo, LibInfo);
}
134 changes: 132 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
break;
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
MAKE_CASE(AArch64ISD::VG_SAVE)
MAKE_CASE(AArch64ISD::VG_RESTORE)
Expand Down Expand Up @@ -3230,6 +3232,64 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
return BB;
}

// TODO: Find a way to merge this with EmitAllocateZABuffer.
MachineBasicBlock *
AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
MachineBasicBlock *BB) const {
MachineFunction *MF = BB->getParent();
MachineFrameInfo &MFI = MF->getFrameInfo();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
"Lazy ZA save is not yet supported on Windows");

const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->isSMESaveBufferUsed()) {
// Allocate a buffer object of the size given by MI.getOperand(1).
auto Size = MI.getOperand(1).getReg();
auto Dest = MI.getOperand(0).getReg();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP)
.addReg(AArch64::SP)
.addReg(Size)
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest)
.addReg(AArch64::SP);

// We have just allocated a variable sized object, tell this to PEI.
MFI.CreateVariableSizedObject(Align(16), nullptr);
} else
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
MI.getOperand(0).getReg());

BB->remove_instr(&MI);
return BB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
MachineBasicBlock *BB) const {
// If the buffer is used, emit a call to __arm_sme_state_size()
MachineFunction *MF = BB->getParent();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->isSMESaveBufferUsed()) {
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
.addExternalSymbol("__arm_sme_state_size")
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::X0);
} else
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::XZR);
BB->remove_instr(&MI);
return BB;
}

MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {

Expand Down Expand Up @@ -3264,6 +3324,10 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitInitTPIDR2Object(MI, BB);
case AArch64::AllocateZABuffer:
return EmitAllocateZABuffer(MI, BB);
case AArch64::AllocateSMESaveBuffer:
return EmitAllocateSMESaveBuffer(MI, BB);
case AArch64::GetSMESaveSize:
return EmitGetSMESaveSize(MI, BB);
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
Expand Down Expand Up @@ -7663,6 +7727,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return CC_AArch64_AAPCS;
case CallingConv::ARM64EC_Thunk_X64:
Expand Down Expand Up @@ -8122,6 +8187,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
DAG.getVTList(MVT::i64, MVT::Other), Chain);
Chain = BufferSize.getValue(1);

SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer =
DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
} else {
// Allocate space dynamically.
Buffer = DAG.getNode(
ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
{Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}

// Copy the value to a virtual register, and save that in FuncInfo.
Register BufferPtr =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
FuncInfo->setSMESaveBufferAddr(BufferPtr);
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}

if (CallConv == CallingConv::PreserveNone) {
Expand Down Expand Up @@ -8410,6 +8500,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
CallerAttrs.hasStreamingBody())
return false;

Expand Down Expand Up @@ -8734,6 +8825,33 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
SelectionDAG &DAG,
AArch64FunctionInfo *Info, SDLoc DL,
SDValue Chain, bool IsSave) {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setSMESaveBufferUsed();

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
Entry.Node =
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
Args.push_back(Entry);

SDValue Callee =
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
TLI.getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
Callee, std::move(Args));
return TLI.LowerCallTo(CLI).second;
}

static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
const SMEAttrs &CalleeAttrs) {
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
Expand Down Expand Up @@ -8894,6 +9012,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};

bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
bool RequiresSaveAllZA =
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
Expand All @@ -8920,6 +9040,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
&MF.getFunction());
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
} else if (RequiresSaveAllZA) {
assert(!CalleeAttrs.hasSharedZAInterface() &&
"Cannot share state that may not exist");
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/true);
}

SDValue PStateSM;
Expand Down Expand Up @@ -9467,9 +9592,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
} else if (RequiresSaveAllZA) {
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/false);
}

if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
RequiresSaveAllZA) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
Expand Down Expand Up @@ -28084,7 +28213,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
auto CalleeAttrs = SMEAttrs(*Base);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs))
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
return true;
}
return false;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,10 @@ enum NodeType : unsigned {
ALLOCATE_ZA_BUFFER,
INIT_TPIDR2OBJ,

// Needed for __arm_agnostic("sme_za_state")
GET_SME_SAVE_SIZE,
ALLOC_SME_SAVE_BUFFER,

// Asserts that a function argument (i32) is zero-extended to i8 by
// the caller
ASSERT_ZEXT_BOOL,
Expand Down Expand Up @@ -667,6 +671,10 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI,
MachineBasicBlock *BB) const;

MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;

// Holds a pointer to a buffer that is large enough to represent
// all SME ZA state and any additional state required by the
// __arm_sme_save/restore support routines.
Register SMESaveBufferAddr = MCRegister::NoRegister;

// true if SMESaveBufferAddr is used.
bool SMESaveBufferUsed = false;

// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
Expand All @@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return PredicateRegForFillSpill;
}

Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };

unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; };
void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };

Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
}

// Nodes to allocate a save buffer for SME.
def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
[SDTCisInt<0>]>, [SDNPHasChain]>;
let usesCustomInserter = 1, Defs = [X0] in {
def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
}
def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;

def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
[SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
let usesCustomInserter = 1, Defs = [SP] in {
def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
}
def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
(AllocateSMESaveBuffer $size)>;

//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,

if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
isPreservesZT0())) &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");

assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
"Function cannot have a shared-ZA interface and an agnostic-ZA "
"interface");
}

SMEAttrs::SMEAttrs(const CallBase &CB) {
Expand All @@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
FuncName == "__arm_sme_state_size")
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Expand All @@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Compatible;
if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
Bitmask |= SM_Body;
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
Bitmask |= ZA_State_Agnostic;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
Expand Down
Loading

0 comments on commit 2ce168b

Please sign in to comment.