diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e4a7847d5..4f25707a8 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -847,3 +847,11 @@ function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} dims ≤ N && return x return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) end + +for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) + @eval function Base.clamp!(x::TracedRArray{T}, min::$(minT), max::$(maxT)) where {T} + y = clamp.(x, min, max) + x.mlir_data = y.mlir_data + return x + end +end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 530d8c1a9..ef8261ff6 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -270,6 +270,19 @@ Base.abs2(x::TracedRNumber{<:Real}) = x^2 Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) +for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) + @eval function Base.clamp(x::TracedRNumber{T}, min::$(minT), max::$(maxT)) where {T} + min = promote_to(TracedRNumber{T}, min) + max = promote_to(TracedRNumber{T}, max) + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.clamp(min.mlir_data, x.mlir_data, max.mlir_data), 1 + ), + ) + end +end + struct TypeCast{T<:ReactantPrimitive} <: Function end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) diff --git a/test/basic.jl b/test/basic.jl index 737097e8c..5137379bc 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -538,3 +538,22 @@ end @test Int(x) isa Int @test float(x) isa ConcreteRNumber{Float64} end + +@testset "clamp" begin + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) + + y = @jit(clamp!(x_ra, 0.0, 0.25)) + @test maximum(y) ≤ 0.25 + @test minimum(y) ≥ 0.0 + @test maximum(x_ra) == maximum(y) + @test minimum(x_ra) == minimum(y) + + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) + + y = @jit(clamp.(x_ra, 0.0, 0.25)) + @test maximum(y) ≤ 0.25 + @test minimum(y) ≥ 0.0 + @test x_ra ≈ x +end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 91a1a8d33..65240ed88 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -13,7 +13,7 @@ using NNlib, Reactant, Enzyme x_act_ca = Reactant.ConcreteRArray(x_act) @testset "Activation: $act" for act in ( - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 + identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 ) f_compile = Reactant.compile(sumabs2, (act, x_act))