Skip to content

Commit

Permalink
[clang][SME] Ignore flatten/clang::always_inline statements for calle…
Browse files Browse the repository at this point in the history
…es with mismatched streaming attributes (llvm#116391)

If `__attribute__((flatten))` is used on a function, or
`[[clang::always_inline]]` on a statement, don't inline any callees with
incompatible streaming attributes. Without this check, clang may produce
incorrect code when these attributes are used in code with streaming
functions.

Note: The docs for flatten say it can be ignored when inlining is
impossible: "causes calls within the attributed function to be inlined
unless it is impossible to do so".

Similarly, the (clang-only) `[[clang::always_inline]]` statement
attribute is more relaxed than the GNU `__attribute__((always_inline))`
(which says it should error it if it can't inline), saying only "If a
statement is marked [[clang::always_inline]] and contains calls, the
compiler attempts to inline those calls.". The docs also go on to show
an example of where `[[clang::always_inline]]` has no effect.
  • Loading branch information
MacDue authored Nov 26, 2024
1 parent 624e52b commit db6f627
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 18 deletions.
16 changes: 11 additions & 5 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5111,9 +5111,10 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,

// Some architectures (such as x86-64) have the ABI changed based on
// attribute-target/features. Give them a chance to diagnose.
CGM.getTargetCodeGenInfo().checkFunctionCallABI(
CGM, Loc, dyn_cast_or_null<FunctionDecl>(CurCodeDecl),
dyn_cast_or_null<FunctionDecl>(TargetDecl), CallArgs, RetTy);
const FunctionDecl *CallerDecl = dyn_cast_or_null<FunctionDecl>(CurCodeDecl);
const FunctionDecl *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl);
CGM.getTargetCodeGenInfo().checkFunctionCallABI(CGM, Loc, CallerDecl,
CalleeDecl, CallArgs, RetTy);

// 1. Set up the arguments.

Expand Down Expand Up @@ -5688,7 +5689,10 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline);

// Add call-site always_inline attribute if exists.
if (InAlwaysInlineAttributedStmt)
// Note: This corresponds to the [[clang::always_inline]] statement attribute.
if (InAlwaysInlineAttributedStmt &&
!CGM.getTargetCodeGenInfo().wouldInliningViolateFunctionCallABI(
CallerDecl, CalleeDecl))
Attrs =
Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::AlwaysInline);

Expand All @@ -5704,7 +5708,9 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
// FIXME: should this really take priority over __try, below?
if (CurCodeDecl && CurCodeDecl->hasAttr<FlattenAttr>() &&
!InNoInlineAttributedStmt &&
!(TargetDecl && TargetDecl->hasAttr<NoInlineAttr>())) {
!(TargetDecl && TargetDecl->hasAttr<NoInlineAttr>()) &&
!CGM.getTargetCodeGenInfo().wouldInliningViolateFunctionCallABI(
CallerDecl, CalleeDecl)) {
Attrs =
Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::AlwaysInline);
}
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/CodeGen/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ class TargetCodeGenInfo {
const CallArgList &Args,
QualType ReturnType) const {}

/// Returns true if inlining the function call would produce incorrect code
/// for the current target and should be ignored (even with the always_inline
/// or flatten attributes).
///
/// Note: This probably should be handled in LLVM. However, the LLVM
/// `alwaysinline` attribute currently means the inliner will ignore
/// mismatched attributes (which sometimes can generate invalid code). So,
/// this hook allows targets to avoid adding the LLVM `alwaysinline` attribute
/// based on C/C++ attributes or other target-specific reasons.
///
/// See previous discussion here:
/// https://discourse.llvm.org/t/rfc-avoid-inlining-alwaysinline-functions-when-they-cannot-be-inlined/79528
virtual bool
wouldInliningViolateFunctionCallABI(const FunctionDecl *Caller,
const FunctionDecl *Callee) const {
return false;
}

/// Determines the size of struct _Unwind_Exception on this platform,
/// in 8-bit units. The Itanium ABI defines this as:
/// struct _Unwind_Exception {
Expand Down
72 changes: 59 additions & 13 deletions clang/lib/CodeGen/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ class AArch64TargetCodeGenInfo : public TargetCodeGenInfo {
const FunctionDecl *Callee, const CallArgList &Args,
QualType ReturnType) const override;

bool wouldInliningViolateFunctionCallABI(
const FunctionDecl *Caller, const FunctionDecl *Callee) const override;

private:
// Diagnose calls between functions with incompatible Streaming SVE
// attributes.
Expand Down Expand Up @@ -1143,30 +1146,67 @@ void AArch64TargetCodeGenInfo::checkFunctionABI(
}
}

void AArch64TargetCodeGenInfo::checkFunctionCallABIStreaming(
CodeGenModule &CGM, SourceLocation CallLoc, const FunctionDecl *Caller,
const FunctionDecl *Callee) const {
if (!Caller || !Callee || !Callee->hasAttr<AlwaysInlineAttr>())
return;
enum class ArmSMEInlinability : uint8_t {
Ok = 0,
ErrorCalleeRequiresNewZA = 1 << 0,
WarnIncompatibleStreamingModes = 1 << 1,
ErrorIncompatibleStreamingModes = 1 << 2,

IncompatibleStreamingModes =
WarnIncompatibleStreamingModes | ErrorIncompatibleStreamingModes,

LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/ErrorIncompatibleStreamingModes),
};

/// Determines if there are any Arm SME ABI issues with inlining \p Callee into
/// \p Caller. Returns the issue (if any) in the ArmSMEInlinability bit enum.
static ArmSMEInlinability GetArmSMEInlinability(const FunctionDecl *Caller,
const FunctionDecl *Callee) {
bool CallerIsStreaming =
IsArmStreamingFunction(Caller, /*IncludeLocallyStreaming=*/true);
bool CalleeIsStreaming =
IsArmStreamingFunction(Callee, /*IncludeLocallyStreaming=*/true);
bool CallerIsStreamingCompatible = isStreamingCompatible(Caller);
bool CalleeIsStreamingCompatible = isStreamingCompatible(Callee);

ArmSMEInlinability Inlinability = ArmSMEInlinability::Ok;

if (!CalleeIsStreamingCompatible &&
(CallerIsStreaming != CalleeIsStreaming || CallerIsStreamingCompatible))
CGM.getDiags().Report(
CallLoc, CalleeIsStreaming
? diag::err_function_always_inline_attribute_mismatch
: diag::warn_function_always_inline_attribute_mismatch)
<< Caller->getDeclName() << Callee->getDeclName() << "streaming";
(CallerIsStreaming != CalleeIsStreaming || CallerIsStreamingCompatible)) {
if (CalleeIsStreaming)
Inlinability |= ArmSMEInlinability::ErrorIncompatibleStreamingModes;
else
Inlinability |= ArmSMEInlinability::WarnIncompatibleStreamingModes;
}
if (auto *NewAttr = Callee->getAttr<ArmNewAttr>())
if (NewAttr->isNewZA())
CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_new_za)
<< Callee->getDeclName();
Inlinability |= ArmSMEInlinability::ErrorCalleeRequiresNewZA;

return Inlinability;
}

void AArch64TargetCodeGenInfo::checkFunctionCallABIStreaming(
CodeGenModule &CGM, SourceLocation CallLoc, const FunctionDecl *Caller,
const FunctionDecl *Callee) const {
if (!Caller || !Callee || !Callee->hasAttr<AlwaysInlineAttr>())
return;

ArmSMEInlinability Inlinability = GetArmSMEInlinability(Caller, Callee);

if ((Inlinability & ArmSMEInlinability::IncompatibleStreamingModes) !=
ArmSMEInlinability::Ok)
CGM.getDiags().Report(
CallLoc,
(Inlinability & ArmSMEInlinability::ErrorIncompatibleStreamingModes) ==
ArmSMEInlinability::ErrorIncompatibleStreamingModes
? diag::err_function_always_inline_attribute_mismatch
: diag::warn_function_always_inline_attribute_mismatch)
<< Caller->getDeclName() << Callee->getDeclName() << "streaming";

if ((Inlinability & ArmSMEInlinability::ErrorCalleeRequiresNewZA) ==
ArmSMEInlinability::ErrorCalleeRequiresNewZA)
CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_new_za)
<< Callee->getDeclName();
}

// If the target does not have floating-point registers, but we are using a
Expand Down Expand Up @@ -1200,6 +1240,12 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(CodeGenModule &CGM,
checkFunctionCallABISoftFloat(CGM, CallLoc, Caller, Callee, Args, ReturnType);
}

bool AArch64TargetCodeGenInfo::wouldInliningViolateFunctionCallABI(
const FunctionDecl *Caller, const FunctionDecl *Callee) const {
return Caller && Callee &&
GetArmSMEInlinability(Caller, Callee) != ArmSMEInlinability::Ok;
}

void AArch64ABIInfo::appendAttributeMangling(TargetClonesAttr *Attr,
unsigned Index,
raw_ostream &Out) const {
Expand Down
84 changes: 84 additions & 0 deletions clang/test/CodeGen/AArch64/sme-inline-callees-streaming-attrs.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -emit-llvm -target-feature +sme %s -DUSE_FLATTEN -o - | FileCheck %s
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -emit-llvm -target-feature +sme %s -DUSE_ALWAYS_INLINE_STMT -o - | FileCheck %s

// REQUIRES: aarch64-registered-target

extern void was_inlined(void);

#if defined(USE_FLATTEN)
#define FN_ATTR __attribute__((flatten))
#define STMT_ATTR
#elif defined(USE_ALWAYS_INLINE_STMT)
#define FN_ATTR
#define STMT_ATTR [[clang::always_inline]]
#else
#error Expected USE_FLATTEN or USE_ALWAYS_INLINE_STMT to be defined.
#endif

void fn(void) { was_inlined(); }
void fn_streaming_compatible(void) __arm_streaming_compatible { was_inlined(); }
void fn_streaming(void) __arm_streaming { was_inlined(); }
__arm_locally_streaming void fn_locally_streaming(void) { was_inlined(); }
__arm_new("za") void fn_streaming_new_za(void) __arm_streaming { was_inlined(); }

FN_ATTR
void caller(void) {
STMT_ATTR fn();
STMT_ATTR fn_streaming_compatible();
STMT_ATTR fn_streaming();
STMT_ATTR fn_locally_streaming();
STMT_ATTR fn_streaming_new_za();
}
// CHECK-LABEL: void @caller()
// CHECK-NEXT: entry:
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @fn_streaming
// CHECK-NEXT: call void @fn_locally_streaming
// CHECK-NEXT: call void @fn_streaming_new_za

FN_ATTR void caller_streaming_compatible(void) __arm_streaming_compatible {
STMT_ATTR fn();
STMT_ATTR fn_streaming_compatible();
STMT_ATTR fn_streaming();
STMT_ATTR fn_locally_streaming();
STMT_ATTR fn_streaming_new_za();
}
// CHECK-LABEL: void @caller_streaming_compatible()
// CHECK-NEXT: entry:
// CHECK-NEXT: call void @fn
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @fn_streaming
// CHECK-NEXT: call void @fn_locally_streaming
// CHECK-NEXT: call void @fn_streaming_new_za

FN_ATTR void caller_streaming(void) __arm_streaming {
STMT_ATTR fn();
STMT_ATTR fn_streaming_compatible();
STMT_ATTR fn_streaming();
STMT_ATTR fn_locally_streaming();
STMT_ATTR fn_streaming_new_za();
}
// CHECK-LABEL: void @caller_streaming()
// CHECK-NEXT: entry:
// CHECK-NEXT: call void @fn
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @fn_streaming_new_za

FN_ATTR __arm_locally_streaming
void caller_locally_streaming(void) {
STMT_ATTR fn();
STMT_ATTR fn_streaming_compatible();
STMT_ATTR fn_streaming();
STMT_ATTR fn_locally_streaming();
STMT_ATTR fn_streaming_new_za();
}
// CHECK-LABEL: void @caller_locally_streaming()
// CHECK-NEXT: entry:
// CHECK-NEXT: call void @fn
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @was_inlined
// CHECK-NEXT: call void @fn_streaming_new_za

0 comments on commit db6f627

Please sign in to comment.