Skip to content

Commit

Permalink
scimlneuralde with devicearray again
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilian-gelbrecht committed Jul 30, 2024
1 parent 5e7cc83 commit 1fb4208
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CUDA, LuxCUDA
using CUDA, Lux, LuxCUDA, LuxDeviceUtils

"""
DetermineDevice(; gpu::Union{Nothing, Bool}=nothing)
Expand All @@ -23,3 +23,7 @@ function DetermineDevice(x::AbstractArray)
error("Can't determine Device based on input array ")
end
end

DeviceArray(device::LuxCUDADevice{D}, x) where D = CuArray(x)
DeviceArray(device::LuxCPUDevice, x) = Array(x)
DeviceArray(device::LuxDeviceUtils.AbstractLuxDevice, x) = Array(x)
2 changes: 1 addition & 1 deletion src/neuralde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function (m::SciMLNeuralDE)(X, ps, st)

prob = ODEProblem{false}(ODEFunction{false}(rhs), x[..,1], (t[1],t[end]), ps)

m.device(solve(prob, m.alg; saveat=t, m.kwargs...)), st
DeviceArray(m.device, solve(prob, m.alg; saveat=t, m.kwargs...)), st
end

"""
Expand Down

0 comments on commit 1fb4208

Please sign in to comment.