From 9843a0f8c14232a98385d3a88c4622d5f8081dcd Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 26 May 2021 14:06:26 -0400 Subject: [PATCH 1/3] Use at register for `sign` --- src/extra_functions.jl | 4 ++-- test/overloads.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/extra_functions.jl b/src/extra_functions.jl index 34d101458..8b38a23cb 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -2,9 +2,9 @@ @register Base.getindex(x,i) # define one and only one promotion rule @register Base.binomial(n,k) -Base.sign(x::Symbolic) = Term{Int}(sign, [x]) -Base.sign(x::Num) = Num(sign(value(x))) +@register Base.sign(x)::Int derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0 + @register Base.signbit(x)::Bool derivative(::typeof(signbit), args::NTuple{1,Any}, ::Val{1}) = 0 derivative(::typeof(abs), args::NTuple{1,Any}, ::Val{1}) = IfElse.ifelse(signbit(args[1]),-one(args[1]),one(args[1])) diff --git a/test/overloads.jl b/test/overloads.jl index db1fa1019..713f174b4 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -168,7 +168,7 @@ eqs = [ @test [2 1 -1; -3 1 -1; 0 1 -5] * Symbolics.solve_for(eqs, [x, y, z]) == [2; -2; -2] @test isequal(Symbolics.solve_for(2//1*x + y - 2//1*z ~ 9//1*x, 1//1*x), 1//7*y - 2//7*z) -@test isequal(sign(x), Num(SymbolicUtils.Term{Int}(sign, [x]))) +@test isequal(sign(x), Num(SymbolicUtils.Term{Int}(sign, [Symbolics.value(x)]))) @test isequal(sign(Num(1)), Num(1)) @test isequal(sign(Num(-1)), Num(-1)) From fd6250cc807ddef930da11cd5d688263f163b112 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 26 May 2021 16:15:06 -0400 Subject: [PATCH 2/3] Eagerly evaluate registered function if arguments are not symbolic --- src/register.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/register.jl b/src/register.jl index f387caeec..d85f64c80 100644 --- a/src/register.jl +++ b/src/register.jl @@ -1,3 +1,4 @@ +using SymbolicUtils: Symbolic """ @register(expr, define_promotion = true, Ts = [Num, Symbolic, Real]) @@ -50,7 +51,12 @@ macro register(expr, define_promotion = true, Ts = [Num, Symbolic, Real]) push!(ex.args, quote function $f($(setinds(args, symbolic_args, ts)...)) wrap = any(x->typeof(x) <: $Num, tuple($(setinds(args, symbolic_args, ts)...),)) ? $Num : $identity - wrap($Term{$ret_type}($f, [$(map(name, args)...)])) + args = ($(map(name, args)...),) + if all(arg -> !(arg isa $Symbolic), args) + $f(args...,) + else + wrap($Term{$ret_type}($f, collect(args))) + end end end) end From 40b9447f4c185cc2f860c4b64de715cdaba5038c Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 26 May 2021 16:26:12 -0400 Subject: [PATCH 3/3] Remember to wrap --- src/register.jl | 2 +- test/overloads.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/register.jl b/src/register.jl index d85f64c80..1570e438d 100644 --- a/src/register.jl +++ b/src/register.jl @@ -53,7 +53,7 @@ macro register(expr, define_promotion = true, Ts = [Num, Symbolic, Real]) wrap = any(x->typeof(x) <: $Num, tuple($(setinds(args, symbolic_args, ts)...),)) ? $Num : $identity args = ($(map(name, args)...),) if all(arg -> !(arg isa $Symbolic), args) - $f(args...,) + wrap($f(args...,)) else wrap($Term{$ret_type}($f, collect(args))) end diff --git a/test/overloads.jl b/test/overloads.jl index 713f174b4..5c3e6fcee 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -169,6 +169,7 @@ eqs = [ @test isequal(Symbolics.solve_for(2//1*x + y - 2//1*z ~ 9//1*x, 1//1*x), 1//7*y - 2//7*z) @test isequal(sign(x), Num(SymbolicUtils.Term{Int}(sign, [Symbolics.value(x)]))) +@test sign(Num(1)) isa Num @test isequal(sign(Num(1)), Num(1)) @test isequal(sign(Num(-1)), Num(-1))