From ab62dbea5df2e0cba97c5d74b68fafea69c24cff Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 4 Feb 2025 16:34:57 +0530 Subject: [PATCH 1/2] test: test cse with arrayops --- test/build_function.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/build_function.jl b/test/build_function.jl index f7df7b763..fe0e122fa 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -1,7 +1,7 @@ using Symbolics, SparseArrays, LinearAlgebra, Test using ReferenceTests using Symbolics: value -using SymbolicUtils.Code: DestructuredArgs, Func, NameState +using SymbolicUtils.Code: DestructuredArgs, Func, NameState, Let, cse @variables a b c1 c2 c3 d e g oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true) @test all(isnan, eval(oop)([-1, Inf])) @@ -301,3 +301,12 @@ end @test buf ≈ ones(2) end end + +@testset "cse with arrayops" begin + @variables x[1:3] y f(..) + t = x .+ y + t = t .* f(t) + res = cse(value(t)) + @test res isa Let + @test !isempty(res.pairs) +end From 7ca5fb51b85aa75641a21590ea2eefed6364508b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 4 Feb 2025 17:53:40 +0530 Subject: [PATCH 2/2] fix: fix `CallWithMetadata` inside `DestructuredArgs` with `create_bindings = false` --- src/variable.jl | 10 ++++++++++ test/build_function.jl | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/src/variable.jl b/src/variable.jl index 9d9cc260a..c74793856 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -291,6 +291,16 @@ CallWithMetadata(f) = CallWithMetadata(f, nothing) SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic() +# HACK: +# A `DestructuredArgs` with `create_bindings = false` doesn't create a `Let` block, and +# instead adds the assignments to the rewrites dictionary. This is problematic, because +# if the `DestructuredArgs` contains a `CallWithMetadata` the key in the `Dict` will be +# a `CallWithMetadata` which won't match against the operation of the called symbolic. +# This is the _only_ hook we have and relies on the `DestructuredArgs` being converted +# into a list of `Assignment`s before being addded to the `Dict` inside `toexpr(::Let, st)`. +# The callable symbolic is unwrapped so it matches the operation of the called version. +SymbolicUtils.Code.Assignment(f::CallWithMetadata, x) = SymbolicUtils.Code.Assignment(f.f, x) + function Base.show(io::IO, c::CallWithMetadata) show(io, c.f) print(io, "⋆") diff --git a/test/build_function.jl b/test/build_function.jl index fe0e122fa..8e9cdd0a3 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -310,3 +310,9 @@ end @test res isa Let @test !isempty(res.pairs) end + +@testset "`CallWithMetadata` in `DestructuredArgs` with `create_bindings = false`" begin + @variables x f(..) + fn = build_function(f(x), DestructuredArgs([f]; create_bindings = false), x; expression = Val{false}) + @test fn([isodd], 3) +end