Skip to content

Commit

Permalink
move code into module, remove unnecessary dependency, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Nov 25, 2024
1 parent e279c9d commit 074a05e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 121 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.13.2-dev"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Expand Down Expand Up @@ -55,7 +54,6 @@ KrylovKitExt = "KrylovKit"
[compat]
Arpack = "0.5"
Arrow = "1.5, 2"
BSON = "0.3"
Combinatorics = "1"
CommonSolve = "0.2.4"
ConsoleProgressMonitor = "0.1"
Expand Down
105 changes: 93 additions & 12 deletions src/RimuIO/RimuIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,27 @@ Provides convenience functions:
"""
module RimuIO

using Arrow: Arrow, ArrowTypes
using BSON: BSON, bson
using DataFrames: DataFrames, DataFrame, metadata!
using StaticArrays: StaticArrays, SVector

using Rimu: mpi_size, mpi_rank, mpi_barrier
using Rimu.BitStringAddresses: BitStringAddresses, BitString, BoseFS,
CompositeFS, FermiFS, SortedParticleList,
num_modes, num_particles
using Rimu.DictVectors: PDVec, DVec, target_segment
using Rimu.Interfaces: Interfaces, localpart, storage
using Rimu.StochasticStyles: default_style, IsDynamicSemistochastic

import Tables, Arrow, Arrow.ArrowTypes

export save_df, load_df

include("tables.jl")
include("arrowtypes.jl")

"""
RimuIO.save_df(filename, df::DataFrame; kwargs...)
save_df(filename, df::DataFrame; kwargs...)
Save dataframe in Arrow format.
Keyword arguments are passed on to
Expand All @@ -50,7 +52,8 @@ function save_df(
end

"""
RimuIO.load_df(filename; propagate_metadata = true, add_filename = true) -> DataFrame
load_df(filename; propagate_metadata = true, add_filename = true) -> DataFrame
Load Arrow file into `DataFrame`. Optionally propagate metadata to `DataFrame` and
add the file name as metadata.
Expand All @@ -69,15 +72,76 @@ function load_df(filename; propagate_metadata = true, add_filename = true)
return df
end

function save_state(filename, vector; kwargs...)
"""
save_state(filename, vector; io, kwargs...)
Save [`PDVec`](@ref) or [`DVec`](@ref) `vector` to an arrow file `filename`.
`io` determines the output stream to write progress to. Defaults to `stderr` when MPI is enabled and `devnull` otherwise.
All other `kwargs` are saved as strings to the arrow file and will be parsed back when the
state is loaded.
See also [`load_state`](@ref).
"""
function save_state(args...; kwargs...)
if mpi_size() > 1
_save_state_mpi(args...; kwargs...)
else
_save_state_serial(args...; kwargs...)
end
end

function _save_state_serial(filename, vector; io=devnull, kwargs...)
metadata = [string(k) => string(v) for (k, v) in kwargs]
Arrow.write(
filename, (; keys=collect(keys(vector)), values=collect(values(vector)));
compress=:zstd, metadata,
)
print(io, "saving vector...")
time = @elapsed Arrow.write(filename, Tables.table(vector); compress=:zstd, metadata)
println(io, " done in $(round(time, sigdigits=3)) s")
end

function _save_state_mpi(filename, vector; io=stderr, kwargs...)
# First rank creates the file and saves metadata.
total_time = @elapsed begin
if mpi_rank() == 0
println(io, "saving vector...")
metadata = [string(k) => string(v) for (k, v) in kwargs]
time = @elapsed begin
Arrow.write(
filename, Tables.table(vector);
compress=:zstd, metadata, file=false
)
end
println(io, " rank 0: $(round(time, sigdigits=3)) s")
end
# Other ranks append their data to the file in order.
for rank in 1:(mpi_size() - 1)
mpi_barrier()
if mpi_rank() == rank
time = @elapsed Arrow.append(filename, Tables.table(vector))
println(io, " rank $rank: $(round(time, sigdigits=3)) s")
end
end
end
if io devnull
mpi_barrier()
end
if mpi_rank() == 0
println(io, "done in $(round(total_time, sigdigits=3)) s")
end
end

function load_state(filename; style=nothing, kwargs...)
"""
load_state(filename; kwargs...)
load_state(PDVec, filename; kwargs...)
load_state(DVec, filename; kwargs...)
Load the state saved in the Arrow file `filename`. `kwargs` are passed to the constructor of
`PDVec`. Any metadata stored in the file will be parsed, evaluated and returned alongside
the vector in a `NamedTuple`.
See also [`save_state`](@ref).
"""
function load_state(::Type{D}, filename; style=nothing, kwargs...) where {D}
tbl = Arrow.Table(filename)
K = eltype(tbl.key)
V = eltype(tbl.value)
Expand All @@ -88,19 +152,36 @@ function load_state(filename; style=nothing, kwargs...)
style = default_style(V)
end
end
vector = PDVec{K,V}(; style, kwargs...)
vector = D{K,V}(; style, kwargs...)
fill_vector!(vector, tbl.key, tbl.value)

arrow_meta = Arrow.metadata(tbl)[]
if !isnothing(arrow_meta)
metadata = NamedTuple(Symbol(k) => eval(Meta.parse(v)) for (k, v) in arrow_meta)
metadata = NamedTuple(Symbol(k) => try_parse_eval(v) for (k, v) in arrow_meta)
else
metadata = (;)
end

return vector, metadata
end

function load_state(filename; kwargs...)
if Threads.nthreads() == 1
return load_state(DVec, filename; kwargs...)
else
return load_state(PDVec, filename; kwargs...)
end
end

function try_parse_eval(string)
try
return eval(Meta.parse(string))
catch e
return string
end
end

# TODO: move me to (P)DVec?
function fill_vector!(vector::PDVec, keys, vals)
Threads.@threads for seg_id in eachindex(vector.segments)
seg = vector.segments[seg_id]
Expand Down
108 changes: 1 addition & 107 deletions src/RimuIO/save-state.jl
Original file line number Diff line number Diff line change
@@ -1,107 +1 @@
using Rimu, Arrow, Tables, MPI, KrylovKit

struct DVecTable{K,V}
dict::Dict{K,V}
end
function Base.iterate(tbl::DVecTable, st=0)
itr = iterate(tbl.dict, st)
if !isnothing(itr)
pair, st = itr
return (; key=pair[1], value=pair[2]), st
else
return nothing
end
end

Base.length(tbl::DVecTable) = length(tbl.dict)

function Base.show(io::IO, tbl::DVecTable{K,V}) where {K,V}
print(io, length(tbl), "-row DVecTable{$K,$V}")
end

Tables.table(dvec::DVec) = DVecTable(dvec.storage)
Tables.istable(::Type{<:DVecTable}) = true
Tables.rowaccess(::Type{<:DVecTable}) = true
Tables.schema(tbl::DVecTable{K,V}) where {K,V} = Tables.Schema((:key, :value), (K, V))
Tables.rows(tbl::DVecTable) = tbl

struct PDVecTable{K,V,N}
segments::NTuple{N,Dict{K,V}}
end
function Base.iterate(tbl::PDVecTable, (st,i)=(0, 1))
if i > length(tbl.segments)
return nothing
end

itr = iterate(tbl.segments[i], st)
if !isnothing(itr)
pair, st = itr
return (; key=pair[1], value=pair[2]), (st, i)
else
return iterate(tbl, (0, i+1))
end
end

Base.length(tbl::PDVecTable) = sum(length, tbl.segments)

function Base.show(io::IO, tbl::PDVecTable{K,V,N}) where {K,V,N}
print(io, length(tbl), "-row PDVecTable{$K,$V,$N}")
end

Tables.table(pdvec::PDVec) = PDVecTable(pdvec.segments)
Tables.istable(::Type{<:PDVecTable}) = true
Tables.rowaccess(::Type{<:PDVecTable}) = true
Tables.rows(tbl::PDVecTable) = tbl
Tables.schema(tbl::PDVecTable{K,V}) where {K,V} = Tables.Schema((:key, :value), (K, V))
Tables.partitions(tbl::PDVecTable) = map(DVecTable, tbl.segments)

function save_state(args...; kwargs...)
comm = MPI.COMM_WORLD
if MPI.Comm_size(comm) > 1
save_state_mpi(args...; kwargs...)
else
save_state_serial(args...; kwargs...)
end
end

function save_state_serial(filename, vector; io=devnull, kwargs...)
metadata = [string(k) => string(v) for (k, v) in kwargs]
print(io, "saving vector...")
time = @elapsed Arrow.write(filename, Tables.table(vector); compress=:zstd, metadata)
println(io, "done in $(round(time, sigdigits=3)) s")
end

using MPI

function save_state_mpi(filename, vector; io=stderr, kwargs...)
comm = MPI.COMM_WORLD

# First rank creates the file and saves metadata.
total_time = @elapsed begin
if MPI.Comm_rank(comm) == 0
println(io, "saving vector...")
metadata = [string(k) => string(v) for (k, v) in kwargs]
time = @elapsed begin
Arrow.write(
filename, Tables.table(vector);
compress=:zstd, metadata, file=false
)
end
println(io, " rank 0: $(round(time, sigdigits=3)) s")
end
# Other ranks save their chunks in order.
for rank in 1:(MPI.Comm_size(comm) - 1)
MPI.Barrier(comm)
if MPI.Comm_rank(comm) == rank
time = @elapsed Arrow.append(filename, Tables.table(vector))
println(io, " rank $rank: $(round(time, sigdigits=3)) s")
end
end
end
if io devnull
MPI.Barrier(comm)
end
if MPI.Comm_rank(comm) == 0
println(io, "done in $(round(total_time, sigdigits=3)) s")
end
end
using Rimu, Arrow, Tables, KrylovKit

0 comments on commit 074a05e

Please sign in to comment.