Skip to content

Commit

Permalink
update solver docs strings
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilian-gelbrecht committed Aug 28, 2024
1 parent 55789ac commit 94ee943
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/NeuralDELux.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module NeuralDELux

using DocStringExtensions

include("gpu.jl")
include("neuralde.jl")
include("solver.jl")
include("training.jl")
include("utils.jl")
include("loss.jl")

export ADNeuralDE, SciMLNeuralDE, ADEulerStep, ADRK4Step
export ADNeuralDE, SciMLNeuralDE, ADEulerStep, ADRK4Step, MultiStepRK4

end
5 changes: 5 additions & 0 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ function solve(prob::AbstractDEProblem, solver::SciMLRK4Step; kwargs...)
@muladd cat(u, u + (dt/6) .* (k₁ + 2 .* (k₂ + k₃) + f(u + dt .* k₃, p, t + dt)), dims=ndims(u)+1)
end

"""
($TYPEDSIGNATURES)
Zygote and GPU compatible fixed step size RK4 solver. Needs to be called similar to the solvers of `OrdinaryDiffEq.jl`. `dt` is a mandatory kwarg, setting the step size.
"""
struct MultiStepRK4
end

Expand Down

0 comments on commit 94ee943

Please sign in to comment.