Skip to content

Commit

Permalink
Merge pull request #1426 from AayushSabharwal/as/cse-arrayop
Browse files Browse the repository at this point in the history
test: fix array codegen tests
  • Loading branch information
ChrisRackauckas authored Feb 5, 2025
2 parents 4726cae + eced900 commit 914806f
Show file tree
Hide file tree
Showing 18 changed files with 162 additions and 183 deletions.
81 changes: 30 additions & 51 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Symbolics: symtype, shape, wrap, unwrap, Unknown, Arr, array_term, jacobia
using Base: Slice
using SymbolicUtils: Sym, term, operation
import LinearAlgebra: dot
import ..limit2

struct TestMetaT end
Symbolics.option_to_metadata_type(::Val{:test_meta}) = TestMetaT
Expand Down Expand Up @@ -185,15 +186,10 @@ end
n = 2
A = randn(n, n)
foo(x) = A * x # a function to represent symbolically, note, if this function is defined inside the testset, it's not found by the function fun_eval = eval(fun_ex)
function Symbolics.propagate_ndims(::typeof(foo), x)
ndims(x)
end
function Symbolics.propagate_shape(::typeof(foo), x)
shape(x)
end
@wrapped function foo(x::AbstractVector)
t = array_term(foo, x)
setmetadata(t, Symbolics.ScalarizeCache, Ref{Any}(nothing))
@register_array_symbolic foo(x::Vector{Real}) begin
size = (n,)
eltype = eltype(x)
ndims = 1
end

#=
Expand All @@ -203,33 +199,17 @@ The following two testsets test jacobians for symbolic functions of symbolic arr
@testset "Functions and Jacobians using @syms" begin
@variables x[1:n]

function symbolic_call(x)
@syms foo(x::Symbolics.Arr{Num,1})::Symbolics.Arr{Num,1} # symbolic foo can not be created in global scope due to conflict with function foo
foo(x) # return a symbolic call to foo
end

x0 = randn(n)
@test foo(x0) == A * x0
ex = symbolic_call(x)
ex = foo(x)

fun_genf = build_function(ex, x, expression=Val{false})
@test_broken fun_genf(x0) == A * x0# UndefVarError: foo not defined
fun_oop, fun_iip = build_function(ex, x, expression=Val{false})
@test fun_oop(x0) == A * x0# UndefVarError: foo not defined

# Generate an expression instead and eval it manually
fun_ex = build_function(ex, x, expression=Val{true})
fun_eval = eval(fun_ex)
fun_ex_oop, fun_ex_iip = build_function(ex, x, expression=Val{true})
fun_eval = eval(fun_ex_oop)
@test fun_eval(x0) == foo(x0)

# Try to provide the hidden argument `expression_module` to solve the scoping issue
@test_skip begin
fun_genf = build_function(ex, x, expression=Val{false}, expression_module=Main) # UndefVarError: #_RGF_ModTag not defined
fun_genf(x0) == A * x0
end

## Jacobians
@test_broken Symbolics.value.(Symbolics.jacobian(foo(x), x)) == A
@test_throws ErrorException Symbolics.value.(Symbolics.jacobian(ex , x))

end


Expand All @@ -242,11 +222,11 @@ end

@test shape(ex) == shape(x)

fun_iip, fun_genf = build_function(ex, x, expression=Val{false})
@test fun_genf(x0) == A * x0
fun_oop, fun_iif = build_function(ex, x, expression=Val{false})
@test fun_oop(x0) == A * x0

# Generate an expression instead and eval it manually
fun_ex_ip, fun_ex_oop = build_function(ex, x, expression=Val{true})
fun_ex_oop, fun_ex_ip = build_function(ex, x, expression=Val{true})
fun_eval = eval(fun_ex_oop)
@test fun_eval(x0) == foo(x0)

Expand Down Expand Up @@ -357,29 +337,28 @@ end
A = 3.4
alpha = 10.0

limit = Main.limit
dtu = @arrayop (i, j) alpha * (u[limit(i - 1, n), j] +
u[limit(i + 1, n), j] +
u[i, limit(j + 1, n)] +
u[i, limit(j - 1, n)] -
dtu = @arrayop (i, j) alpha * (u[limit2(i - 1, n), j] +
u[limit2(i + 1, n), j] +
u[i, limit2(j + 1, n)] +
u[i, limit2(j - 1, n)] -
4u[i, j]) +
1.0 + u[i, j]^2 * v[i, j] - (A + 1) *
u[i, j] + brusselator_f(x[i], y[j], t) i in 1:n j in 1:n
dtv = @arrayop (i, j) alpha * (v[limit(i - 1, n), j] +
v[limit(i + 1, n), j] +
v[i, limit(j + 1, n)] +
v[i, limit(j - 1, n)] -
dtv = @arrayop (i, j) alpha * (v[limit2(i - 1, n), j] +
v[limit2(i + 1, n), j] +
v[i, limit2(j + 1, n)] +
v[i, limit2(j - 1, n)] -
4v[i, j]) -
u[i, j]^2 * v[i, j] + A * u[i, j] i in 1:n j in 1:n
lapu = @arrayop (i, j) (u[limit(i - 1, n), j] +
u[limit(i + 1, n), j] +
u[i, limit(j + 1, n)] +
u[i, limit(j - 1, n)] -
lapu = @arrayop (i, j) (u[limit2(i - 1, n), j] +
u[limit2(i + 1, n), j] +
u[i, limit2(j + 1, n)] +
u[i, limit2(j - 1, n)] -
4u[i, j]) i in 1:n j in 1:n
lapv = @arrayop (i, j) (v[limit(i - 1, n), j] +
v[limit(i + 1, n), j] +
v[i, limit(j + 1, n)] +
v[i, limit(j - 1, n)] -
lapv = @arrayop (i, j) (v[limit2(i - 1, n), j] +
v[limit2(i + 1, n), j] +
v[i, limit2(j + 1, n)] +
v[i, limit2(j - 1, n)] -
4v[i, j]) i in 1:n j in 1:n
s = brusselator_f.(x, y', t)

Expand All @@ -388,7 +367,7 @@ end
lapu = wrap(lapu)
lapv = wrap(lapv)

f, g = build_function(dtu, u, v, t, expression=Val{false}, nanmath = false)
g, f = build_function(dtu, u, v, t, expression=Val{false}, nanmath = false)
du = zeros(Num, 8, 8)
f(du, u,v,t)
@test isequal(collect(du), collect(dtu))
Expand Down
20 changes: 10 additions & 10 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ expr = toexpr(Func([value(D(x))], [], value(D(x))))

a = rand(4)
@variables x[1:4]
@test eval(build_function(sin.(cos.(x)), cos.(x))[2])(a) == sin.(a)
@test eval(build_function(sin.(cos.(x)), cos.(x))[1])(a) == sin.(a)

# more skipzeros
@variables x,y
Expand Down Expand Up @@ -272,7 +272,7 @@ end
let #658
using Symbolics
@variables a, X1[1:3], X2[1:3]
k = eval(build_function(a * X1 + X2, X1, X2, a)[2])
k = eval(build_function(a * X1 + X2, X1, X2, a)[1])
@test k(ones(3), ones(3), 1.5) == [2.5, 2.5, 2.5]
end

Expand Down Expand Up @@ -302,14 +302,14 @@ end
end
end

@testset "cse with arrayops" begin
@variables x[1:3] y f(..)
t = x .+ y
t = t .* f(t)
res = cse(value(t))
@test res isa Let
@test !isempty(res.pairs)
end
# @testset "cse with arrayops" begin
# @variables x[1:3] y f(..)
# t = x .+ y
# t = t .* f(t)
# res = cse(value(t))
# @test res isa Let
# @test !isempty(res.pairs)
# end

@testset "`CallWithMetadata` in `DestructuredArgs` with `create_bindings = false`" begin
@variables x f(..)
Expand Down
27 changes: 12 additions & 15 deletions test/build_function_tests/intermediate-exprs-inplace.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
:(function (u,)
let ˍ₋out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
ˍ₋out_input_1 = let _out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
:(function (ˍ₋out, u)
begin
ˍ₋out_input_1 = let _out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5)))
end
end
_out
end
for (j, j′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
for (i, i′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(ˍ₋out_input_1, j, i))
end
_out
end
for (j, j′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
for (i, i′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(ˍ₋out_input_1, j, i))
end
end
ˍ₋out
end
end)
27 changes: 15 additions & 12 deletions test/build_function_tests/intermediate-exprs-outplace.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
:(function (ˍ₋out, u)
begin
ˍ₋out_input_1 = let _out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
:(function (u,)
let ˍ₋out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
ˍ₋out_input_1 = let _out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5)))
end
end
end
_out
end
for (j, j′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
for (i, i′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(ˍ₋out_input_1, j, i))
end
_out
end
for (j, j′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
for (i, i′) = zip(Base.OneTo(5), reset_to_one(Base.OneTo(5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(ˍ₋out_input_1, j, i))
end
end
ˍ₋out
end
end)
13 changes: 5 additions & 8 deletions test/build_function_tests/manual-limits-inplace.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
:(function (u,)
let ˍ₋out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
:(function (ˍ₋out, u)
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5)))
end
end
ˍ₋out
end
end)
13 changes: 8 additions & 5 deletions test/build_function_tests/manual-limits-outplace.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
:(function (ˍ₋out, u)
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
:(function (u,)
let ˍ₋out = zeros(Float64, map(length, (Base.OneTo(5), Base.OneTo(5))))
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5)))
end
end
end
ˍ₋out
end
end)
21 changes: 9 additions & 12 deletions test/build_function_tests/stencil-broadcast-inplace.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
:(function (x,)
let ˍ₋out = zeros(Float64, map(length, (1:6, 1:6)))
begin
ˍ₋out_2_input_1 = (broadcast)(+, x, (adjoint)(x))
ˍ₋out_1 = (view)(ˍ₋out, 1:6, 1:6)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:5, 2:5)
for (j, j′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
for (i, i′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (+)(1, (getindex)(ˍ₋out_2_input_1, i, j)))
end
:(function (ˍ₋out, x)
begin
ˍ₋out_2_input_1 = (broadcast)(+, x, (adjoint)(x))
ˍ₋out_1 = (view)(ˍ₋out, 1:6, 1:6)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:5, 2:5)
for (j, j′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
for (i, i′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (+)(1, (getindex)(ˍ₋out_2_input_1, i, j)))
end
end
ˍ₋out
end
end)
21 changes: 12 additions & 9 deletions test/build_function_tests/stencil-broadcast-outplace.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
:(function (ˍ₋out, x)
begin
ˍ₋out_2_input_1 = (broadcast)(+, x, (adjoint)(x))
ˍ₋out_1 = (view)(ˍ₋out, 1:6, 1:6)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:5, 2:5)
for (j, j′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
for (i, i′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (+)(1, (getindex)(ˍ₋out_2_input_1, i, j)))
:(function (x,)
let ˍ₋out = zeros(Float64, map(length, (1:6, 1:6)))
begin
ˍ₋out_2_input_1 = (broadcast)(+, x, (adjoint)(x))
ˍ₋out_1 = (view)(ˍ₋out, 1:6, 1:6)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:5, 2:5)
for (j, j′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
for (i, i′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (+)(1, (getindex)(ˍ₋out_2_input_1, i, j)))
end
end
end
ˍ₋out
end
end)
19 changes: 8 additions & 11 deletions test/build_function_tests/stencil-extents-inplace.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
:(function (x,)
let ˍ₋out = zeros(Float64, map(length, (1:5, 1:5)))
begin
ˍ₋out_1 = (view)(ˍ₋out, 1:5, 1:5)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:4, 2:4)
for (j, j′) = zip(2:4, reset_to_one(2:4))
for (i, i′) = zip(2:4, reset_to_one(2:4))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (*)(1//2, (+)((+)((+)((getindex)(x, (+)(-1, i), j), (getindex)(x, (+)(1, i), j)), (getindex)(x, i, (+)(-1, j))), (getindex)(x, i, (+)(1, j)))))
end
:(function (ˍ₋out, x)
begin
ˍ₋out_1 = (view)(ˍ₋out, 1:5, 1:5)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:4, 2:4)
for (j, j′) = zip(2:4, reset_to_one(2:4))
for (i, i′) = zip(2:4, reset_to_one(2:4))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (*)(1//2, (+)((+)((+)((getindex)(x, (+)(-1, i), j), (getindex)(x, (+)(1, i), j)), (getindex)(x, i, (+)(-1, j))), (getindex)(x, i, (+)(1, j)))))
end
end
ˍ₋out
end
end)
19 changes: 11 additions & 8 deletions test/build_function_tests/stencil-extents-outplace.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
:(function (ˍ₋out, x)
begin
ˍ₋out_1 = (view)(ˍ₋out, 1:5, 1:5)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:4, 2:4)
for (j, j′) = zip(2:4, reset_to_one(2:4))
for (i, i′) = zip(2:4, reset_to_one(2:4))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (*)(1//2, (+)((+)((+)((getindex)(x, (+)(-1, i), j), (getindex)(x, (+)(1, i), j)), (getindex)(x, i, (+)(-1, j))), (getindex)(x, i, (+)(1, j)))))
:(function (x,)
let ˍ₋out = zeros(Float64, map(length, (1:5, 1:5)))
begin
ˍ₋out_1 = (view)(ˍ₋out, 1:5, 1:5)
ˍ₋out_1 .= 0
ˍ₋out_2 = (view)(ˍ₋out, 2:4, 2:4)
for (j, j′) = zip(2:4, reset_to_one(2:4))
for (i, i′) = zip(2:4, reset_to_one(2:4))
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (*)(1//2, (+)((+)((+)((getindex)(x, (+)(-1, i), j), (getindex)(x, (+)(1, i), j)), (getindex)(x, i, (+)(-1, j))), (getindex)(x, i, (+)(1, j)))))
end
end
end
ˍ₋out
end
end)
Loading

0 comments on commit 914806f

Please sign in to comment.