Skip to content

Commit

Permalink
Added new momentum-sampling strategy
Browse files Browse the repository at this point in the history
Now supports sampling momentum with quasi-monte-carlo sequence (low-discrepancy sequence).
  • Loading branch information
B1inkFox committed Nov 3, 2024
1 parent 5d56902 commit e88634a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 3 deletions.
18 changes: 18 additions & 0 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,21 @@ refresh(
z.θ,
ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ),
)



include("quasi_MC.jl")

"Quasi-random momentum refreshment."
struct QuasiRandomMomentumRefreshment <: AbstractMomentumRefreshment
quasi_seed::Quasi_MC_seed
end

function refresh(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
ref::QuasiRandomMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
)
return phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ, ref.quasi_seed))
end
42 changes: 39 additions & 3 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ $(TYPEDEF)
Abstract type for preconditioning metrics.
"""

include("quasi_MC.jl")

abstract type AbstractMetric end

_string_M⁻¹(mat::AbstractMatrix, n_chars::Int = 32) = _string_M⁻¹(diag(mat), n_chars)
Expand Down Expand Up @@ -104,17 +107,27 @@ function _rand(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::UnitEuclideanMetric{T},
kinetic::GaussianKinetic,
quasi_seed::Union{Nothing, Quasi_MC_seed} = nothing
) where {T}
r = randn(rng, T, size(metric)...)
if isnothing(quasi_seed)
r = randn(rng, T, size(metric)...)
else
r = get_next_vector(quasi_seed)
end
return r
end

function _rand(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DiagEuclideanMetric{T},
kinetic::GaussianKinetic,
quasi_seed::Union{Nothing, Quasi_MC_seed} = nothing
) where {T}
r = randn(rng, T, size(metric)...)
if isnothing(quasi_seed)
r = randn(rng, T, size(metric)...)
else
r = get_next_vector(quasi_seed)
end
r ./= metric.sqrtM⁻¹
return r
end
Expand All @@ -123,8 +136,13 @@ function _rand(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DenseEuclideanMetric{T},
kinetic::GaussianKinetic,
quasi_seed::Union{Nothing, Quasi_MC_seed} = nothing
) where {T}
r = randn(rng, T, size(metric)...)
if isnothing(quasi_seed)
r = randn(rng, T, size(metric)...)
else
r = get_next_vector(quasi_seed)
end
ldiv!(metric.cholM⁻¹, r)
return r
end
Expand Down Expand Up @@ -153,5 +171,23 @@ Base.rand(
kinetic::AbstractKinetic,
θ::AbstractVecOrMat,
) = rand(rng, metric, kinetic)


Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) =
rand(metric, kinetic)


Base.rand(
rng::AbstractRNG,
metric::AbstractMetric,
kinetic::AbstractKinetic,
θ::AbstractVecOrMat,
quasi_mc::Quasi_MC_seed
) = _rand(rng, metric, kinetic, quasi_mc)
Base.rand(
rng::AbstractVector{<:AbstractRNG},
metric::AbstractMetric,
kinetic::AbstractKinetic,
θ::AbstractVecOrMat,
quasi_mc::Quasi_MC_seed
) = _rand(rng, metric, kinetic, quasi_mc)
26 changes: 26 additions & 0 deletions src/quasi_MC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using QuasiMonteCarlo, Distributions

mutable struct Quasi_MC_seed
array::AbstractVecOrMat{<:Real}
counter::Int

# Custom constructor with a default value for `counter`
function Quasi_MC_seed(array::AbstractVecOrMat{<:Real}, counter::Int = 1)
new(array, counter)
end
end

function get_next_vector(q::Quasi_MC_seed; normalized = true)
val = q.array[:, q.counter]
q.counter += 1
if normalized
return uniform_to_normal(val)
else
return val
end
end

function uniform_to_normal(uniform_vector::Vector{Float64})::Vector{Float64}
normal_dist = Normal(0, 1)
return [quantile(normal_dist, u) for u in uniform_vector]
end

0 comments on commit e88634a

Please sign in to comment.