Skip to content

Commit

Permalink
Merge pull request #133 from EnzymeAD/ap/view
Browse files Browse the repository at this point in the history
feat: robust handling of wrapped arrays of reactant arrays
  • Loading branch information
avik-pal authored Sep 30, 2024
2 parents e2ca620 + e8fe0c8 commit 831fbdc
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 51 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu
version = "0.2.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -14,13 +15,11 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
ReactantAdaptExt = "Adapt"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"
Expand Down
8 changes: 0 additions & 8 deletions ext/ReactantAdaptExt.jl

This file was deleted.

36 changes: 20 additions & 16 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
module ReactantNNlibExt

using NNlib
using Reactant
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array

for (jlop, hloop) in (
(:(NNlib.tanh_fast), :tanh),
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::Reactant.TracedRArray{T,0}) where {T}
return Reactant.TracedRArray{T,0}(
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
return TracedRArray{T,0}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
Expand All @@ -19,18 +19,16 @@ for (jlop, hloop) in (
end
end

NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T))
NNlib.relu(x::TracedRArray{T,0}) where {T} = max(x, zero(T))

function NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T}
function NNlib.gelu(x::TracedRArray{T,0}) where {T}
α = T(0.044715)
λλ = T((8 / π))
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
end

# TODO handle non finite cases
function NNlib.softmax!(
out::Reactant.TracedRArray{T,N}, x::AbstractArray; dims=1
) where {T,N}
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
max_ = NNlib.fast_maximum(x; dims)
#if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
Expand All @@ -43,8 +41,11 @@ function NNlib.softmax!(
end

function NNlib.conv(
x::Reactant.TracedRArray{T,N}, W::Reactant.TracedRArray{T}, cdims::DenseConvDims
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
) where {T,N}
x = materialize_traced_array(x)
W = materialize_traced_array(W)

kernel_size = NNlib.kernel_size(cdims)
padding = NNlib.padding(cdims)
stride = NNlib.stride(cdims)
Expand Down Expand Up @@ -119,10 +120,12 @@ function NNlib.conv(
batch_group_count=1,
)

return Reactant.TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
end

function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N}
function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
x = materialize_traced_array(x)

num_spatial_dims = N - 2
input_spatial_dims = 1:num_spatial_dims

Expand Down Expand Up @@ -185,21 +188,22 @@ function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N
body,
)

return Reactant.TracedRArray{T,N}(
(), Reactant.MLIR.IR.result(reduction), size(result_type)
)
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
end

function NNlib.maxpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
function NNlib.maxpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
return reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
)
end

function NNlib.meanpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
numel = prod(NNlib.kernel_size(pdims))
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
T(numel)
end

NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)

end # module ReactantNNlibExt
8 changes: 5 additions & 3 deletions ext/ReactantStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
module ReactantStatisticsExt

using Reactant: TracedRArray
using Reactant: AnyTracedRArray, materialize_traced_array
using Statistics: Statistics

function Statistics.mean(A::TracedRArray{T,N}; dims=:) where {T,N}
function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}
A = materialize_traced_array(A)
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
return mapreduce(identity, +, A; dims) / denom
end

function Statistics.var(
A::TracedRArray{T,N}; dims=:, mean=nothing, corrected=true
A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true
) where {T,N}
A = materialize_traced_array(A)
mean === nothing && (mean = Statistics.mean(A; dims))
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
return mapreduce(abs2, +, A .- mean; dims) / denom
Expand Down
2 changes: 2 additions & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ end

ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray{T,0}(data, ())

Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)

function ConcreteRArray(
data::Array{T,N}; client=XLA.default_backend[], idx=XLA.default_device_idx[]
) where {T,N}
Expand Down
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module Reactant

using Adapt: Adapt, WrappedArray

# auxiliary types and functions
include("OrderedIdDict.jl")

Expand Down
77 changes: 55 additions & 22 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,32 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
end
end

function Base.getindex(a::TracedRArray{T,0}) where {T}
return a
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
const AnyTracedRScalar{T} = AnyTracedRArray{T,0}
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}

materialize_traced_array(x::TracedRArray) = x
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]

get_mlir_data(x::TracedRArray) = x.mlir_data
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))

ancestor(x::TracedRArray) = x
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))

get_ancestor_indices(::TracedRArray, indices...) = indices
function get_ancestor_indices(
x::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...
) where {T,N}
return get_ancestor_indices(parent(x), Base.reindex(x.indices, indices)...)
end

function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Integer,N}) where {T,N}
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a

function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
@warn(
"""Performing scalar indexing on task $(current_task()).
Invocation resulted in scalar indexing of a TracedRArray.
Expand All @@ -47,9 +68,7 @@ and require expensive copies and synchronization each time and therefore should
return TracedRArray{T,0}((), res2, ())
end

function Base.getindex(
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
) where {T,N}
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.slice(
Expand All @@ -62,14 +81,19 @@ function Base.getindex(
),
1,
)
return TracedRArray{T,N}((), res, Tuple(length.(indices)))
x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
ddims = findall(x -> x isa Integer, indices)
!isempty(ddims) && return dropdims(x; dims=Tuple(ddims))
return x
end

function Base.view(
a::TracedRArray{T,N}, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
) where {T,N}
# TODO: Implement before merging the PR
return error("view is not supported yet")
# Prevent ambiguity
function Base.getindex(a::WrappedTracedRArray, index::Int...)
return getindex(ancestor(a), get_ancestor_indices(a, index...)...)
end

function Base.getindex(a::WrappedTracedRArray, indices...)
return getindex(ancestor(a), get_ancestor_indices(a, indices...)...)
end

function Base.setindex!(
Expand Down Expand Up @@ -101,15 +125,15 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
# return print(io, X.mlir_data, ")")
end

Base.only(A::TracedRArray{T,0}) where {T} = A
Base.only(A::AnyTracedRScalar{T}) where {T} = A

function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A)))

# HLO reshape semantics collapse the opposite way
res1 = MLIR.IR.result(
MLIR.Dialects.stablehlo.transpose(
A.mlir_data;
get_mlir_data(A);
permutation=MLIR.IR.DenseArrayAttribute([Int64(N - 1 - i) for i in 0:(N - 1)]),
),
1,
Expand Down Expand Up @@ -137,12 +161,12 @@ function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
return TracedRArray{T,NT}((), res3, dims)
end

function Base.permutedims(A::TracedRArray{T,N}, perm) where {T,N}
function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.transpose(
A.mlir_data;
get_mlir_data(A);
permutation=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in perm]),
),
1,
Expand All @@ -151,13 +175,19 @@ function Base.permutedims(A::TracedRArray{T,N}, perm) where {T,N}
)
end

function Base.transpose(A::AnyTracedRVecOrMat)
A = ndims(A) == 1 ? reshape(A, :, 1) : A
return permutedims(A, (2, 1))
end
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)

function Base.promote_rule(
::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}}
) where {T,S,N}
return TracedRArray{Base.promote_type(T, S),N}
end

function Base.promote_rule(A::Type{T}, B::Type{TracedRArray{S,N}}) where {T,S,N}
function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
return TracedRArray{Base.promote_type(T, S),N}
end

Expand Down Expand Up @@ -194,7 +224,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
)
end

function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N}
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
return promote_to(TracedRArray{T,N}, rhs)
end

Expand Down Expand Up @@ -668,6 +698,7 @@ function Base.mapreducedim!(
end

struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end

AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}()

Expand All @@ -678,7 +709,9 @@ AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle
# copy(inst)
# end

BroadcastStyle(::Type{T}) where {T<:TracedRArray} = AbstractReactantArrayStyle{ndims(T)}()
function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N}
return AbstractReactantArrayStyle{N}()
end

function Base.similar(
bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
Expand Down Expand Up @@ -746,8 +779,8 @@ function broadcast_to_size(arg::AbstractArray, rsize)
return arg
end

function broadcast_to_size(arg::TracedRArray, rsize)
return arg
function broadcast_to_size(arg::AnyTracedRArray, rsize)
return materialize_traced_array(arg)
end

function broadcast_to_size(arg::Base.RefValue, rsize)
Expand Down
12 changes: 12 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,15 @@ tuple_byref2(x) = abs2.(x), tuple_byref2(x)
# @test r2[2].a.b.data === x.data
# @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0])
end

sum_xxᵀ(x) = sum(x .* x')

@testset "sum(x .* x')" begin
@testset "size(x): $(size(x))" for x in (rand(4, 4), rand(4))
x_ca = Reactant.to_rarray(x)

sum_xxᵀ_compiled = @compile sum_xxᵀ(x_ca)

@test sum_xxᵀ_compiled(x_ca) sum_xxᵀ(x)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ end
@safetestset "Closure" include("closure.jl")
@safetestset "Compile" include("compile.jl")
@safetestset "Buffer Donation" include("buffer_donation.jl")
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")

@testset "Neural Networks" begin
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
Expand Down
Loading

0 comments on commit 831fbdc

Please sign in to comment.