Skip to content

Commit

Permalink
Add fallback gpu setindex (#137)
Browse files Browse the repository at this point in the history
* Add fallback gpu setindex

* Update ConcreteRArray.jl

* Update ConcreteRArray.jl

* Update ConcreteRArray.jl

* Update ConcreteRArray.jl

* Update ConcreteRArray.jl

* Update ConcreteRArray.jl

* Update ConcreteRArray.jl
  • Loading branch information
wsmoses authored Oct 3, 2024
1 parent f3c65db commit 9698fa8
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ function Base.show(io::IO, X::ConcreteRArray)
return Base.show(io, convert(Array, X))
end

const getindex_warned = Ref(false)
function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
if a.data == XLA.AsyncEmptyBuffer
throw("Cannot getindex from empty buffer")
Expand All @@ -143,9 +144,26 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
return unsafe_load(ptr, start)
end
end
if !getindex_warned[]
@warn(
"""Performing scalar get-indexing on task $(current_task()).
Invocation resulted in scalar indexing of a ConcreteRArray.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided."""
)
getindex_warned[] = true
end
return convert(Array, a)[args...]
end

function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
Base.setindex!(a, v, args...)
nothing
end

const setindex_warned = Ref(false)

function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N}
if a.data == XLA.AsyncEmptyBuffer
throw("Cannot setindex! to empty buffer")
Expand All @@ -167,7 +185,22 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
end
return a
end
throw("Cannot setindex! to non-CPU buffer")
if !setindex_warned[]
@warn(
"""Performing scalar set-indexing on task $(current_task()).
Invocation resulted in scalar indexing of a ConcreteRArray.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided.
This error message will only be printed for the first invocation for brevity.
"""
)
setindex_warned[] = true
end
fn = Reactant.compile(mysetindex!, (a, v, args...))
fn(a, v, args...)
return a
end

# TODO is there any way to allocate an uninitialized buffer in XLA?
Expand Down

1 comment on commit 9698fa8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 9698fa8 Previous: f3c65db Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1570954734 ns 1458455508 ns 1.08
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 235920037 ns 251236706 ns 0.94
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5154736106 ns 5585368365 ns 0.92
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 19482335000 ns 18510942516 ns 1.05
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1454393866 ns 1416635330 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 9344541.5 ns 8999966.5 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1654229125 ns 1626174510 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2474213271 ns 2507000840 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1427339607 ns 1464452980 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 88867803.5 ns 87282999 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2208391050 ns 2196952516 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 4720676124 ns 5380501243 ns 0.88
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1356859332 ns 1539579056 ns 0.88
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7831299 ns 7997757 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1464435737 ns 1478484824 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1568895073 ns 1648231085 ns 0.95
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1298556694.5 ns 1433497500 ns 0.91
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 86786255 ns 79774892 ns 1.09
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1773492797 ns 1780124071 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 2387188682 ns 2802545314 ns 0.85
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1337103435 ns 1366374529 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 105193558 ns 87273858 ns 1.21
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2258439114 ns 2290981311 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 4730673199 ns 4409558714 ns 1.07
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1379404181 ns 1481595475 ns 0.93
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 116791463.5 ns 117880246.5 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3050409788 ns 3059007138 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 17858116236 ns 8537922961 ns 2.09
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1299115963 ns 1475455846 ns 0.88
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 136028347 ns 137789966 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3226928458 ns 3144114552 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 9095328291 ns 11348055235 ns 0.80
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1375900813 ns 1431172027 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 86587539 ns 84404630.5 ns 1.03
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1902300110 ns 1904321145.5 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 2520101863 ns 2628751204 ns 0.96

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.