Skip to content

Commit

Permalink
feat: add similarto keyword to build_function
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 29, 2025
1 parent 4d6af94 commit 061625a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
outputidxs=nothing,
skipzeros = false,
force_SA = false,
similarto = nothing,
wrap_code = (nothing, nothing),
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
states = LazyState(),
Expand All @@ -301,7 +302,9 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))
i = findfirst(x->x isa DestructuredArgs, dargs)
similarto = force_SA ? SArray : i === nothing ? Array : dargs[i].name
if similarto === nothing
similarto = force_SA ? SArray : i === nothing ? Array : dargs[i].name
end

oop, iip = iip_config
oop_body = if oop
Expand Down
9 changes: 9 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,12 @@ end
T = value(x .^ 2)
@test_nowarn toexpr(T, NameState())
end

@testset "`similarto` keyword argument" begin
@variables x[1:2]
T = collect(value(x .^ 2))
fn = build_function(T, collect(x); expression = false)[1]
@test_throws MethodError fn((1.0, 2.0))
fn = build_function(T, collect(x); similarto = Array, expression = false)[1]
@test fn((1.0, 2.0)) [1.0, 4.0]
end

0 comments on commit 061625a

Please sign in to comment.