Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for :foreigncall to transition to GC safe automatically #49933

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3409,7 +3409,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, sstate:
abstract_eval_value(interp, x, sstate, sv)
end
cconv = e.args[5]
if isa(cconv, QuoteNode) && (v = cconv.value; isa(v, Tuple{Symbol, UInt16}))
if isa(cconv, QuoteNode) && (v = cconv.value; isa(v, Tuple{Symbol, UInt16, Bool}))
override = decode_effects_override(v[2])
effects = override_effects(effects, override)
end
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
:meta => 0:typemax(Int),
:global => 1:1,
:globaldecl => 1:2,
:foreigncall => 5:typemax(Int), # name, RT, AT, nreq, (cconv, effects), args..., roots...
:foreigncall => 5:typemax(Int), # name, RT, AT, nreq, (cconv, effects, gc_safe), args..., roots...
:cfunction => 5:5,
:isdefined => 1:2,
:code_coverage_effect => 0:0,
Expand Down
50 changes: 43 additions & 7 deletions base/c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,31 @@ The above input outputs this:

(:printf, :Cvoid, [:Cstring, :Cuint], ["%d", :value])
"""
function ccall_macro_parse(expr::Expr)
function ccall_macro_parse(exprs)
gc_safe = false
expr = nothing
if exprs isa Expr
expr = exprs
elseif length(exprs) == 1
expr = exprs[1]
elseif length(exprs) == 2
gc_expr = exprs[1]
expr = exprs[2]
if gc_expr.head == :(=) && gc_expr.args[1] == :gc_safe
if gc_expr.args[2] == true
gc_safe = true
elseif gc_expr.args[2] == false
gc_safe = false
else
throw(ArgumentError("gc_safe must be true or false"))
end
else
throw(ArgumentError("@ccall option must be `gc_safe=true` or `gc_safe=false`"))
end
else
throw(ArgumentError("@ccall needs a function signature with a return type"))
end

# setup and check for errors
if !isexpr(expr, :(::))
throw(ArgumentError("@ccall needs a function signature with a return type"))
Expand Down Expand Up @@ -328,12 +352,11 @@ function ccall_macro_parse(expr::Expr)
pusharg!(a)
end
end

return func, rettype, types, args, nreq
return func, rettype, types, args, gc_safe, nreq
end


function ccall_macro_lower(convention, func, rettype, types, args, nreq)
function ccall_macro_lower(convention, func, rettype, types, args, gc_safe, nreq)
statements = []

# if interpolation was used, ensure the value is a function pointer at runtime.
Expand All @@ -351,9 +374,15 @@ function ccall_macro_lower(convention, func, rettype, types, args, nreq)
else
func = esc(func)
end
cconv = nothing
if convention isa Tuple
cconv = Expr(:cconv, (convention..., gc_safe), nreq)
else
cconv = Expr(:cconv, (convention, UInt16(0), gc_safe), nreq)
end

return Expr(:block, statements...,
Expr(:call, :ccall, func, Expr(:cconv, convention, nreq), esc(rettype),
Expr(:call, :ccall, func, cconv, esc(rettype),
Expr(:tuple, map(esc, types)...), map(esc, args)...))
end

Expand Down Expand Up @@ -404,9 +433,16 @@ Example using an external library:

The string literal could also be used directly before the function
name, if desired `"libglib-2.0".g_uri_escape_string(...`

It's possible to declare the ccall as `gc_safe` by using the `gc_safe = true` option:
@ccall gc_safe=true strlen(s::Cstring)::Csize_t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚲 If gc_safe=false is the default, perhaps marking a ccall as gc_safe could just be spelled like:

@ccall :gc_safe strlen(s::Cstring)::Csize_t

(similar to how the @atomic and @spawn and @assume_effects macros take Symbols to control behaviour)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hope to be able to flip the default.

This allows the garbage collector to run concurrently with the ccall, which can be useful whenever
the `ccall` may block outside of julia.
WARNING: This option should be used with caution, as it can lead to undefined behavior if the ccall
calls back into the julia runtime. (`@cfunction`/`@ccallables` are safe however)
"""
macro ccall(expr)
return ccall_macro_lower(:ccall, ccall_macro_parse(expr)...)
macro ccall(exprs...)
return ccall_macro_lower((:ccall), ccall_macro_parse(exprs)...)
end

macro ccall_effects(effects::UInt16, expr)
Expand Down
2 changes: 1 addition & 1 deletion base/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ function _partially_inline!(@nospecialize(x), slot_replacements::Vector{Any},
elseif i == 4
@assert isa(x.args[4], Int)
elseif i == 5
@assert isa((x.args[5]::QuoteNode).value, Union{Symbol, Tuple{Symbol, UInt8}})
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
@assert isa((x.args[5]::QuoteNode).value, Union{Symbol, Tuple{Symbol, UInt16, Bool}})
else
x.args[i] = _partially_inline!(x.args[i], slot_replacements,
type_signature, static_param_values,
Expand Down
2 changes: 1 addition & 1 deletion base/strings/string.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end
# but the macro is not available at this time in bootstrap, so we write it manually.
const _string_n_override = 0x04ee
@eval _string_n(n::Integer) = $(Expr(:foreigncall, QuoteNode(:jl_alloc_string), Ref{String},
:(Core.svec(Csize_t)), 1, QuoteNode((:ccall, _string_n_override)), :(convert(Csize_t, n))))
:(Core.svec(Csize_t)), 1, QuoteNode((:ccall, _string_n_override, false)), :(convert(Csize_t, n))))

"""
String(s::AbstractString)
Expand Down
4 changes: 2 additions & 2 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,9 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.

The number of required arguments for a varargs function definition.

* `args[5]::QuoteNode{<:Union{Symbol,Tuple{Symbol,UInt16}}`: calling convention
* `args[5]::QuoteNode{<:Union{Symbol,Tuple{Symbol,UInt16}, Tuple{Symbol,UInt16,Bool}}`: calling convention

The calling convention for the call, optionally with effects.
The calling convention for the call, optionally with effects, and `gc_safe` (safe to execute concurrently to GC.).

* `args[6:5+length(args[3])]` : arguments

Expand Down
14 changes: 10 additions & 4 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ class function_sig_t {
AttributeList attributes; // vector of function call site attributes
Type *lrt; // input parameter of the llvm return type (from julia_struct_to_llvm)
bool retboxed; // input parameter indicating whether lrt is jl_value_t*
bool gc_safe; // input parameter indicating whether the call is safe to execute concurrently to GC
Type *prt; // out parameter of the llvm return type for the function signature
int sret; // out parameter for indicating whether return value has been moved to the first argument position
std::string err_msg;
Expand All @@ -1146,8 +1147,8 @@ class function_sig_t {
size_t nreqargs; // number of required arguments in ccall function definition
jl_codegen_params_t *ctx;

function_sig_t(const char *fname, Type *lrt, jl_value_t *rt, bool retboxed, jl_svec_t *at, jl_unionall_t *unionall_env, size_t nreqargs, CallingConv::ID cc, bool llvmcall, jl_codegen_params_t *ctx)
: lrt(lrt), retboxed(retboxed),
function_sig_t(const char *fname, Type *lrt, jl_value_t *rt, bool retboxed, bool gc_safe, jl_svec_t *at, jl_unionall_t *unionall_env, size_t nreqargs, CallingConv::ID cc, bool llvmcall, jl_codegen_params_t *ctx)
: lrt(lrt), retboxed(retboxed), gc_safe(gc_safe),
prt(NULL), sret(0), cc(cc), llvmcall(llvmcall),
at(at), rt(rt), unionall_env(unionall_env),
nccallargs(jl_svec_len(at)), nreqargs(nreqargs),
Expand Down Expand Up @@ -1295,6 +1296,9 @@ std::string generate_func_sig(const char *fname)
RetAttrs = RetAttrs.addAttribute(LLVMCtx, Attribute::NonNull);
if (rt == jl_bottom_type)
FnAttrs = FnAttrs.addAttribute(LLVMCtx, Attribute::NoReturn);
if (gc_safe)
FnAttrs = FnAttrs.addAttribute(LLVMCtx, "julia.gc_safe");

assert(attributes.isEmpty());
attributes = AttributeList::get(LLVMCtx, FnAttrs, RetAttrs, paramattrs);
return "";
Expand Down Expand Up @@ -1412,7 +1416,7 @@ static const std::string verify_ccall_sig(jl_value_t *&rt, jl_value_t *at,

const int fc_args_start = 6;

// Expr(:foreigncall, pointer, rettype, (argtypes...), nreq, [cconv | (cconv, effects)], args..., roots...)
// Expr(:foreigncall, pointer, rettype, (argtypes...), nreq, gc_safe, [cconv | (cconv, effects)], args..., roots...)
static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
{
JL_NARGSV(ccall, 5);
Expand All @@ -1424,11 +1428,13 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
assert(jl_is_quotenode(args[5]));
jl_value_t *jlcc = jl_quotenode_value(args[5]);
jl_sym_t *cc_sym = NULL;
bool gc_safe = false;
if (jl_is_symbol(jlcc)) {
cc_sym = (jl_sym_t*)jlcc;
}
else if (jl_is_tuple(jlcc)) {
cc_sym = (jl_sym_t*)jl_get_nth_field_noalloc(jlcc, 0);
gc_safe = jl_unbox_bool(jl_get_nth_field_checked(jlcc, 2));
}
assert(jl_is_symbol(cc_sym));
native_sym_arg_t symarg = {};
Expand Down Expand Up @@ -1547,7 +1553,7 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
}
if (rt != args[2] && rt != (jl_value_t*)jl_any_type)
jl_temporary_root(ctx, rt);
function_sig_t sig("ccall", lrt, rt, retboxed,
function_sig_t sig("ccall", lrt, rt, retboxed, gc_safe,
(jl_svec_t*)at, unionall, nreqargs,
cc, llvmcall, &ctx.emission_context);
for (size_t i = 0; i < nccallargs; i++) {
Expand Down
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8065,7 +8065,7 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
if (rt != declrt && rt != (jl_value_t*)jl_any_type)
jl_temporary_root(ctx, rt);

function_sig_t sig("cfunction", lrt, rt, retboxed, argt, unionall_env, false, CallingConv::C, false, &ctx.emission_context);
function_sig_t sig("cfunction", lrt, rt, retboxed, false, argt, unionall_env, false, CallingConv::C, false, &ctx.emission_context);
assert(sig.fargt.size() + sig.sret == sig.fargt_sig.size());
if (!sig.err_msg.empty()) {
emit_error(ctx, sig.err_msg);
Expand Down Expand Up @@ -8205,7 +8205,7 @@ const char *jl_generate_ccallable(Module *llvmmod, void *sysimg_handle, jl_value
}
jl_value_t *err;
{ // scope block for sig
function_sig_t sig("cfunction", lcrt, crt, toboxed,
function_sig_t sig("cfunction", lcrt, crt, toboxed, false,
argtypes, NULL, false, CallingConv::C, false, &params);
if (sig.err_msg.empty()) {
if (sysimg_handle) {
Expand Down
14 changes: 5 additions & 9 deletions src/llvm-codegen-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,17 @@ static inline llvm::Value *emit_gc_state_set(llvm::IRBuilder<> &builder, llvm::T
unsigned offset = offsetof(jl_tls_states_t, gc_state);
Value *gc_state = builder.CreateConstInBoundsGEP1_32(T_int8, ptls, offset, "gc_state");
if (old_state == nullptr) {
old_state = builder.CreateLoad(T_int8, gc_state);
old_state = builder.CreateLoad(T_int8, gc_state, "old_state");
cast<LoadInst>(old_state)->setOrdering(AtomicOrdering::Monotonic);
}
builder.CreateAlignedStore(state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
if (auto *C = dyn_cast<ConstantInt>(old_state))
if (C->isZero())
return old_state;
if (auto *C = dyn_cast<ConstantInt>(state))
if (!C->isZero())
return old_state;
if (auto *C2 = dyn_cast<ConstantInt>(state))
if (C->getZExtValue() == C2->getZExtValue())
return old_state;
BasicBlock *passBB = BasicBlock::Create(builder.getContext(), "safepoint", builder.GetInsertBlock()->getParent());
BasicBlock *exitBB = BasicBlock::Create(builder.getContext(), "after_safepoint", builder.GetInsertBlock()->getParent());
Constant *zero8 = ConstantInt::get(T_int8, 0);
builder.CreateCondBr(builder.CreateOr(builder.CreateICmpEQ(old_state, zero8), // if (!old_state || !state)
builder.CreateICmpEQ(state, zero8)),
builder.CreateCondBr(builder.CreateICmpEQ(old_state, state, "is_new_state"), // Safepoint whenever we change the GC state
passBB, exitBB);
builder.SetInsertPoint(passBB);
MDNode *tbaa = get_tbaa_const(builder.getContext());
Expand Down
41 changes: 32 additions & 9 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2181,16 +2181,39 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S, bool *CFGModified) {
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
} else if (CI->arg_size() == CI->getNumOperands()) {
/* No operand bundle to lower */
++it;
continue;
} else {
CallInst *NewCall = CallInst::Create(CI, None, CI);
NewCall->takeName(CI);
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
if (CI->getFnAttr("julia.gc_safe").isValid()) {
// Insert the operations to switch to gc_safe if necessary.
IRBuilder<> builder(CI);
Value *pgcstack = getOrAddPGCstack(F);
assert(pgcstack);
// We dont use emit_state_set here because safepoints are unconditional for any code that reaches this
// We are basically guaranteed to go from gc_unsafe to gc_safe and back, and both transitions need a safepoint
// We also can't add any BBs here, so just avoiding the branches is good
Value *ptls = get_current_ptls_from_task(builder, get_current_task_from_pgcstack(builder, pgcstack), tbaa_gcframe);
unsigned offset = offsetof(jl_tls_states_t, gc_state);
Value *gc_state = builder.CreateConstInBoundsGEP1_32(Type::getInt8Ty(builder.getContext()), ptls, offset, "gc_state");
LoadInst *last_gc_state = builder.CreateAlignedLoad(Type::getInt8Ty(builder.getContext()), gc_state, Align(sizeof(void*)));
last_gc_state->setOrdering(AtomicOrdering::Monotonic);
builder.CreateAlignedStore(builder.getInt8(JL_GC_STATE_SAFE), gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
MDNode *tbaa = get_tbaa_const(builder.getContext());
emit_gc_safepoint(builder, T_size, ptls, tbaa, false);
builder.SetInsertPoint(CI->getNextNode());
builder.CreateAlignedStore(last_gc_state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
emit_gc_safepoint(builder, T_size, ptls, tbaa, false);
}
if (CI->arg_size() == CI->getNumOperands()) {
/* No operand bundle to lower */
++it;
continue;
} else {
// remove operand bundle
CallInst *NewCall = CallInst::Create(CI, None, CI);
NewCall->takeName(CI);
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
}
}
if (!CI->use_empty()) {
CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
Expand Down
21 changes: 21 additions & 0 deletions src/llvm-pass-helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ llvm::CallInst *JuliaPassContext::getPGCstack(llvm::Function &F) const
return nullptr;
}

llvm::CallInst *JuliaPassContext::getOrAddPGCstack(llvm::Function &F)
{
if (pgcstack_getter || adoptthread_func)
for (auto &I : F.getEntryBlock()) {
if (CallInst *callInst = dyn_cast<CallInst>(&I)) {
Value *callee = callInst->getCalledOperand();
if ((pgcstack_getter && callee == pgcstack_getter) ||
(adoptthread_func && callee == adoptthread_func)) {
return callInst;
}
}
}
IRBuilder<> builder(&F.getEntryBlock().front());
if (pgcstack_getter)
return builder.CreateCall(pgcstack_getter);
auto FT = FunctionType::get(PointerType::get(F.getContext(), 0), false);
auto F2 = Function::Create(FT, Function::ExternalLinkage, "julia.get_pgcstack", F.getParent());
pgcstack_getter = F2;
return builder.CreateCall( F2);
}

llvm::Function *JuliaPassContext::getOrNull(
const jl_intrinsics::IntrinsicDescription &desc) const
{
Expand Down
5 changes: 4 additions & 1 deletion src/llvm-pass-helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ struct JuliaPassContext {
// point of the given function, if there exists such a call.
// Otherwise, `nullptr` is returned.
llvm::CallInst *getPGCstack(llvm::Function &F) const;

// Gets a call to the `julia.get_pgcstack' intrinsic in the entry
// point of the given function, if there exists such a call.
// Otherwise, creates a new call to the intrinsic
llvm::CallInst *getOrAddPGCstack(llvm::Function &F);
// Gets the intrinsic or well-known function that conforms to
// the given description if it exists in the module. If not,
// `nullptr` is returned.
Expand Down
7 changes: 6 additions & 1 deletion src/llvm-ptls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,13 @@ void LowerPTLS::fix_pgcstack_use(CallInst *pgcstack, Function *pgcstack_getter,
last_gc_state->addIncoming(prior, fastTerm->getParent());
for (auto &BB : *pgcstack->getParent()->getParent()) {
if (isa<ReturnInst>(BB.getTerminator())) {
// Don't use emit_gc_safe_leave here, as that introduces a new BB while iterating BBs
builder.SetInsertPoint(BB.getTerminator());
emit_gc_unsafe_leave(builder, T_size, get_current_ptls_from_task(builder, get_current_task_from_pgcstack(builder, phi), tbaa), last_gc_state, true);
Value *ptls = get_current_ptls_from_task(builder, get_current_task_from_pgcstack(builder, phi), tbaa_gcframe);
unsigned offset = offsetof(jl_tls_states_t, gc_state);
Value *gc_state = builder.CreateConstInBoundsGEP1_32(Type::getInt8Ty(builder.getContext()), ptls, offset, "gc_state");
builder.CreateAlignedStore(last_gc_state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
emit_gc_safepoint(builder, T_size, ptls, tbaa, true);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ static jl_value_t *resolve_definition_effects(jl_value_t *expr, jl_module_t *mod
return expr;
}
if (e->head == jl_foreigncall_sym) {
JL_NARGSV(ccall method definition, 5); // (fptr, rt, at, nreq, (cc, effects))
JL_NARGSV(ccall method definition, 5); // (fptr, rt, at, nreq, (cc, effects, gc_safe))
jl_task_t *ct = jl_current_task;
jl_value_t *rt = jl_exprarg(e, 1);
jl_value_t *at = jl_exprarg(e, 2);
Expand Down Expand Up @@ -172,11 +172,12 @@ static jl_value_t *resolve_definition_effects(jl_value_t *expr, jl_module_t *mod
jl_value_t *cc = jl_quotenode_value(jl_exprarg(e, 4));
if (!jl_is_symbol(cc)) {
JL_TYPECHK(ccall method definition, tuple, cc);
if (jl_nfields(cc) != 2) {
if (jl_nfields(cc) != 3) {
jl_error("In ccall calling convention, expected two argument tuple or symbol.");
}
JL_TYPECHK(ccall method definition, symbol, jl_get_nth_field(cc, 0));
JL_TYPECHK(ccall method definition, uint16, jl_get_nth_field(cc, 1));
JL_TYPECHK(ccall method definition, bool, jl_get_nth_field(cc, 2));
}
}
if (e->head == jl_call_sym && nargs > 0 &&
Expand Down
Loading