diff --git a/src/variable.jl b/src/variable.jl index 9d9cc260a..c74793856 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -291,6 +291,16 @@ CallWithMetadata(f) = CallWithMetadata(f, nothing) SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic() +# HACK: +# A `DestructuredArgs` with `create_bindings = false` doesn't create a `Let` block, and +# instead adds the assignments to the rewrites dictionary. This is problematic, because +# if the `DestructuredArgs` contains a `CallWithMetadata` the key in the `Dict` will be +# a `CallWithMetadata` which won't match against the operation of the called symbolic. +# This is the _only_ hook we have and relies on the `DestructuredArgs` being converted +# into a list of `Assignment`s before being addded to the `Dict` inside `toexpr(::Let, st)`. +# The callable symbolic is unwrapped so it matches the operation of the called version. +SymbolicUtils.Code.Assignment(f::CallWithMetadata, x) = SymbolicUtils.Code.Assignment(f.f, x) + function Base.show(io::IO, c::CallWithMetadata) show(io, c.f) print(io, "⋆") diff --git a/test/build_function.jl b/test/build_function.jl index f7df7b763..8e9cdd0a3 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -1,7 +1,7 @@ using Symbolics, SparseArrays, LinearAlgebra, Test using ReferenceTests using Symbolics: value -using SymbolicUtils.Code: DestructuredArgs, Func, NameState +using SymbolicUtils.Code: DestructuredArgs, Func, NameState, Let, cse @variables a b c1 c2 c3 d e g oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true) @test all(isnan, eval(oop)([-1, Inf])) @@ -301,3 +301,18 @@ end @test buf ≈ ones(2) end end + +@testset "cse with arrayops" begin + @variables x[1:3] y f(..) + t = x .+ y + t = t .* f(t) + res = cse(value(t)) + @test res isa Let + @test !isempty(res.pairs) +end + +@testset "`CallWithMetadata` in `DestructuredArgs` with `create_bindings = false`" begin + @variables x f(..) + fn = build_function(f(x), DestructuredArgs([f]; create_bindings = false), x; expression = Val{false}) + @test fn([isodd], 3) +end