From f34ec8f1a5038614675d08d1a70fbfea80f3f124 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 7 Nov 2024 17:06:23 -0500 Subject: [PATCH] feat: support setindex with views (#240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support setindex with views * Update test/basic.jl Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --------- Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/TracedRArray.jl | 30 +++++++++++++++++++++++++++--- test/basic.jl | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 1bd4caee9..48bb7074f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -35,6 +35,19 @@ materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) +function set_mlir_data!(x::TracedRArray, data) + x.mlir_data = data + return x +end +function set_mlir_data!(x::AnyTracedRArray, data) + data_type = MLIR.IR.type(data) + data = TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( + (), data, size(data_type) + ) + setindex!(x, data, axes(x)...) + return x +end + ancestor(x::TracedRArray) = x ancestor(x::WrappedTracedRArray) = ancestor(parent(x)) @@ -115,12 +128,23 @@ function Base.setindex!( i in indices ] res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dynamic_update_slice(a.mlir_data, v.mlir_data, indices), 1 + MLIR.Dialects.stablehlo.dynamic_update_slice( + a.mlir_data, get_mlir_data(v), indices + ), + 1, ) a.mlir_data = res return v end +function Base.setindex!( + a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} +) where {T,N} + ancestor_indices = get_ancestor_indices(a, indices...) + setindex!(ancestor(a), v, ancestor_indices...) + return a +end + Base.size(x::TracedRArray) = x.shape Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A)) @@ -727,7 +751,7 @@ function broadcast_to_size_internal(x::TracedRArray, rsize) ) end -function _copyto!(dest::TracedRArray, bc::Broadcasted) +function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest @@ -736,7 +760,7 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted) args = (broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) res = elem_apply(bc.f, args...) - dest.mlir_data = res.mlir_data + set_mlir_data!(dest, res.mlir_data) return dest end diff --git a/test/basic.jl b/test/basic.jl index e3b049a37..737097e8c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -396,6 +396,43 @@ end # get_view_compiled = @compile get_view(x_concrete) end +function masking(x) + y = similar(x) + y[1:2, :] .= 0 + y[3:4, :] .= 1 + return y +end + +function masking!(x) + x[1:2, :] .= 0 + x[3:4, :] .= 1 + return x +end + +@testset "setindex! with views" begin + x = rand(4, 4) .+ 2.0 + x_ra = Reactant.to_rarray(x) + + y = masking(x) + y_ra = @jit(masking(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test !(any(iszero, x_ra_array[1, :])) + @test !(any(iszero, x_ra_array[2, :])) + @test !(any(isone, x_ra_array[3, :])) + @test !(any(isone, x_ra_array[4, :])) + + y_ra = @jit(masking!(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test all(iszero, x_ra_array[1, :]) + @test all(iszero, x_ra_array[2, :]) + @test all(isone, x_ra_array[3, :]) + @test all(isone, x_ra_array[4, :]) +end + tuple_byref(x) = (; a=(; b=x)) tuple_byref2(x) = abs2.(x), tuple_byref2(x)