Skip to content

Commit

Permalink
add saving/loading vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Nov 25, 2024
1 parent 1a780bc commit e279c9d
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
52 changes: 51 additions & 1 deletion src/RimuIO/RimuIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ using StaticArrays: StaticArrays, SVector
using Rimu.BitStringAddresses: BitStringAddresses, BitString, BoseFS,
CompositeFS, FermiFS, SortedParticleList,
num_modes, num_particles
using Rimu.DictVectors: DictVectors
using Rimu.DictVectors: PDVec, DVec, target_segment
using Rimu.Interfaces: Interfaces, localpart, storage
using Rimu.StochasticStyles: default_style, IsDynamicSemistochastic


export save_df, load_df
Expand Down Expand Up @@ -68,4 +69,53 @@ function load_df(filename; propagate_metadata = true, add_filename = true)
return df
end

function save_state(filename, vector; kwargs...)
metadata = [string(k) => string(v) for (k, v) in kwargs]
Arrow.write(
filename, (; keys=collect(keys(vector)), values=collect(values(vector)));
compress=:zstd, metadata,
)
end

function load_state(filename; style=nothing, kwargs...)
tbl = Arrow.Table(filename)
K = eltype(tbl.key)
V = eltype(tbl.value)
if isnothing(style)
if V <: AbstractFloat
style = IsDynamicSemistochastic()
else
style = default_style(V)
end
end
vector = PDVec{K,V}(; style, kwargs...)
fill_vector!(vector, tbl.key, tbl.value)

arrow_meta = Arrow.metadata(tbl)[]
if !isnothing(arrow_meta)
metadata = NamedTuple(Symbol(k) => eval(Meta.parse(v)) for (k, v) in arrow_meta)
else
metadata = (;)
end

return vector, metadata
end

function fill_vector!(vector::PDVec, keys, vals)
Threads.@threads for seg_id in eachindex(vector.segments)
seg = vector.segments[seg_id]
sizehint!(seg, length(keys) ÷ length(vector.segments))
for (k, v) in zip(keys, vals)
if target_segment(vector, k) == (seg_id, true)
seg[k] = v
end
end
end
end
function fill_vector!(vector::DVec, keys, vals)
for (k, v) in zip(keys, vals)
vector[k] = v
end
end

end # module RimuIO
107 changes: 107 additions & 0 deletions src/RimuIO/save-state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using Rimu, Arrow, Tables, MPI, KrylovKit

struct DVecTable{K,V}
dict::Dict{K,V}
end
function Base.iterate(tbl::DVecTable, st=0)
itr = iterate(tbl.dict, st)
if !isnothing(itr)
pair, st = itr
return (; key=pair[1], value=pair[2]), st
else
return nothing
end
end

Base.length(tbl::DVecTable) = length(tbl.dict)

function Base.show(io::IO, tbl::DVecTable{K,V}) where {K,V}
print(io, length(tbl), "-row DVecTable{$K,$V}")
end

Tables.table(dvec::DVec) = DVecTable(dvec.storage)
Tables.istable(::Type{<:DVecTable}) = true
Tables.rowaccess(::Type{<:DVecTable}) = true
Tables.schema(tbl::DVecTable{K,V}) where {K,V} = Tables.Schema((:key, :value), (K, V))
Tables.rows(tbl::DVecTable) = tbl

struct PDVecTable{K,V,N}
segments::NTuple{N,Dict{K,V}}
end
function Base.iterate(tbl::PDVecTable, (st,i)=(0, 1))
if i > length(tbl.segments)
return nothing
end

itr = iterate(tbl.segments[i], st)
if !isnothing(itr)
pair, st = itr
return (; key=pair[1], value=pair[2]), (st, i)
else
return iterate(tbl, (0, i+1))
end
end

Base.length(tbl::PDVecTable) = sum(length, tbl.segments)

function Base.show(io::IO, tbl::PDVecTable{K,V,N}) where {K,V,N}
print(io, length(tbl), "-row PDVecTable{$K,$V,$N}")
end

Tables.table(pdvec::PDVec) = PDVecTable(pdvec.segments)
Tables.istable(::Type{<:PDVecTable}) = true
Tables.rowaccess(::Type{<:PDVecTable}) = true
Tables.rows(tbl::PDVecTable) = tbl
Tables.schema(tbl::PDVecTable{K,V}) where {K,V} = Tables.Schema((:key, :value), (K, V))
Tables.partitions(tbl::PDVecTable) = map(DVecTable, tbl.segments)

function save_state(args...; kwargs...)
comm = MPI.COMM_WORLD
if MPI.Comm_size(comm) > 1
save_state_mpi(args...; kwargs...)
else
save_state_serial(args...; kwargs...)
end
end

function save_state_serial(filename, vector; io=devnull, kwargs...)
metadata = [string(k) => string(v) for (k, v) in kwargs]
print(io, "saving vector...")
time = @elapsed Arrow.write(filename, Tables.table(vector); compress=:zstd, metadata)
println(io, "done in $(round(time, sigdigits=3)) s")
end

using MPI

function save_state_mpi(filename, vector; io=stderr, kwargs...)
comm = MPI.COMM_WORLD

# First rank creates the file and saves metadata.
total_time = @elapsed begin
if MPI.Comm_rank(comm) == 0
println(io, "saving vector...")
metadata = [string(k) => string(v) for (k, v) in kwargs]
time = @elapsed begin
Arrow.write(
filename, Tables.table(vector);
compress=:zstd, metadata, file=false
)
end
println(io, " rank 0: $(round(time, sigdigits=3)) s")
end
# Other ranks save their chunks in order.
for rank in 1:(MPI.Comm_size(comm) - 1)
MPI.Barrier(comm)
if MPI.Comm_rank(comm) == rank
time = @elapsed Arrow.append(filename, Tables.table(vector))
println(io, " rank $rank: $(round(time, sigdigits=3)) s")
end
end
end
if io devnull
MPI.Barrier(comm)
end
if MPI.Comm_rank(comm) == 0
println(io, "done in $(round(total_time, sigdigits=3)) s")
end
end

0 comments on commit e279c9d

Please sign in to comment.