Skip to content

Commit

Permalink
fix: correct dims handling in mapreducedim! (#728)
Browse files Browse the repository at this point in the history
* feat: add sign dispatches

* fix: correct dims handling in mapreducedim!

* Update test/basic.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
avik-pal and github-actions[bot] authored Feb 11, 2025
1 parent b70d614 commit 0760f4b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,14 @@ function Base.mapreducedim!(
@nospecialize(R::TracedRArray),
A::Base.AbstractArrayOrBroadcasted,
)
tmp = TracedUtils.broadcast_to_size(
Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)
)
R.mlir_data = broadcast(op, R, tmp).mlir_data
@assert length(size(R)) == length(size(A))
dims = map(enumerate(zip(size(R), size(A)))) do (i, (sR, sA))
sR == sA && return nothing
@assert sR == 1
return i
end
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
set_mlir_data!(R, get_mlir_data(tmp))
return R
end

Expand Down
27 changes: 27 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,30 @@ end

@test @jit(fn(x)) fn(Array(x))
end

function fntest1(x)
y = similar(x, 1, 1, 8)
sum!(y, x)
return y
end

function fntest2(x)
y = similar(x, 2, 1, 8)
sum!(y, x)
return y
end

function fntest3(x)
y = similar(x, 2, 1, 1)
sum!(abs2, y, x)
return y
end

@testset "mapreducedim!" begin
x = reshape(collect(Float32, 1:64), 2, 4, 8) ./ 64
x_ra = Reactant.to_rarray(x)

@test Array(@jit(fntest1(x_ra))) fntest1(x)
@test Array(@jit(fntest2(x_ra))) fntest2(x)
@test Array(@jit(fntest3(x_ra))) fntest3(x)
end

0 comments on commit 0760f4b

Please sign in to comment.