From 061625af3295046952f378967b48b212265bd935 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 Jan 2025 14:56:32 +0530 Subject: [PATCH] feat: add `similarto` keyword to `build_function` --- src/build_function.jl | 5 ++++- test/build_function.jl | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/build_function.jl b/src/build_function.jl index 93e6af290..7af31d3b9 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -288,6 +288,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; outputidxs=nothing, skipzeros = false, force_SA = false, + similarto = nothing, wrap_code = (nothing, nothing), fillzeros = skipzeros && !(rhss isa SparseMatrixCSC), states = LazyState(), @@ -301,7 +302,9 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...])) i = findfirst(x->x isa DestructuredArgs, dargs) - similarto = force_SA ? SArray : i === nothing ? Array : dargs[i].name + if similarto === nothing + similarto = force_SA ? SArray : i === nothing ? Array : dargs[i].name + end oop, iip = iip_config oop_body = if oop diff --git a/test/build_function.jl b/test/build_function.jl index 5e7266402..0c09fdd6d 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -281,3 +281,12 @@ end T = value(x .^ 2) @test_nowarn toexpr(T, NameState()) end + +@testset "`similarto` keyword argument" begin + @variables x[1:2] + T = collect(value(x .^ 2)) + fn = build_function(T, collect(x); expression = false)[1] + @test_throws MethodError fn((1.0, 2.0)) + fn = build_function(T, collect(x); similarto = Array, expression = false)[1] + @test fn((1.0, 2.0)) ≈ [1.0, 4.0] +end