Skip to content

Commit

Permalink
fix: bypass segfault with fill complex (#245)
Browse files Browse the repository at this point in the history
* fix: bypass segfault with fill complex

* test: simpler promote_to testing

* fix: `MLIR.IR.DenseElementsAttribute` definition for bool arrays
  • Loading branch information
avik-pal authored Nov 8, 2024
1 parent f34ec8f commit 1ff11c9
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,7 @@ function broadcast_to_size(arg::Base.RefValue, rsize)
end

function broadcast_to_size(arg::T, rsize) where {T<:Number}
TT = MLIR.IR.TensorType([Int64(s) for s in rsize], MLIR.IR.Type(typeof(arg)))
attr = Base.fill(arg, TT)
attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize)))
return arg = TracedRArray{T,length(rsize)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize
)
Expand Down
2 changes: 1 addition & 1 deletion src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
)
end
if isa(rhs, Number)
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T}))
attr = MLIR.IR.DenseElementsAttribute(fill(T(rhs)))
return TracedRNumber{T}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
)
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/IR/Attribute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ end
Creates a dense elements attribute with the given shaped type from elements of a specific type. Expects the element type of the shaped type to match the data element type.
"""
function DenseElementsAttribute(values::AbstractVector{Bool})
function DenseElementsAttribute(values::AbstractArray{Bool})
shaped_type = TensorType(size(values), Type(Bool))
return Attribute(
API.mlirDenseElementsAttrBoolGet(shaped_type, length(values), pointer(values))
Expand Down
11 changes: 11 additions & 0 deletions test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,14 @@ end
x_concrete = Reactant.to_rarray(x)
@test @jit(abs.(x_concrete)) abs.(x)
end

@testset "promote_to Complex" begin
x = 1.0 + 2.0im
y = Reactant.ConcreteRNumber(x)

f = Reactant.compile((y,)) do z
z + Reactant.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im)
end

@test isapprox(f(y), 2.0 - 1.0im)
end
5 changes: 5 additions & 0 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,9 @@ end

pad_fn2 = Base.Fix2(NNlib.pad_constant, (1, 0, 1, 3))
@test @jit(∇sumabs2(pad_fn2, x_ra)) ∇sumabs2(pad_fn2, x)

x = rand(ComplexF32, 4, 4)
x_ra = Reactant.ConcreteRArray(x)

@test @jit(NNlib.pad_constant(x_ra, (1, 1))) NNlib.pad_constant(x, (1, 1))
end

1 comment on commit 1ff11c9

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 1ff11c9 Previous: f34ec8f Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 6384860404 ns 7027955167 ns 0.91
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5301328245 ns 5911790203 ns 0.90
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5203462205 ns 5039003346 ns 1.03
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7305272771 ns 7620308614 ns 0.96
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 34987991089 ns 31357074001 ns 1.12
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1561591277 ns 1572787384 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1557307113 ns 1548995246 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1542360845 ns 1555509740 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3316796244 ns 3467552492 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 3043968768 ns 3419731143 ns 0.89
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2146877682 ns 2180418814 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2131581565 ns 2173460673 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2138865967 ns 2158353524 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3910063623 ns 3927830322 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5708530964 ns 6253378741.5 ns 0.91
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1416785120 ns 1440134479 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1421169200 ns 1421557207 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1406671878 ns 1428360591 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3159231592 ns 3208488833 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1143922805 ns 1139085595.5 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1719831635 ns 1712651271 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1705729225 ns 1706763760 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1700558831 ns 1704312368 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3451683364 ns 3479746822 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3140401890 ns 3154626911 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2158380749 ns 2175602283 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2164881935 ns 2162683110 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2126575115 ns 2187544254 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3935604871 ns 3928792736 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5762722471.5 ns 6172504320 ns 0.93
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3016029772 ns 2997646442 ns 1.01
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 2981006759 ns 2957532725 ns 1.01
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2989746439 ns 2960803098 ns 1.01
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4846949677 ns 4863944327 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 10864111732 ns 22191688400 ns 0.49
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3166559857 ns 3157390991 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3144854123 ns 3212937029 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3165468463 ns 3419812778 ns 0.93
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5000879641 ns 5010112071 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 15188174420 ns 10940648930 ns 1.39
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1819847550 ns 1843742771 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1826616736 ns 1827524417 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1850176920 ns 1842645194 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3562720585 ns 3593953850 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 4485724051 ns 3544018917 ns 1.27

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.