Skip to content

Commit

Permalink
Add RIMU_VERSION_NUMBER, error checks and Bool parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Jan 21, 2025
1 parent 857dd93 commit 7bbb2c5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
14 changes: 11 additions & 3 deletions src/RimuIO/RimuIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using Rimu.DictVectors: PDVec, DVec, target_segment
using Rimu.Interfaces: Interfaces, localpart, storage
using Rimu.StochasticStyles: default_style, IsDynamicSemistochastic

import Tables, Arrow, Arrow.ArrowTypes
import Rimu, Tables, Arrow, Arrow.ArrowTypes

export save_df, load_df, save_state, load_state

Expand Down Expand Up @@ -48,6 +48,7 @@ function save_df(
if metadata === nothing
metadata = [key => string(val) for (key, val) in DataFrames.metadata(df)]
end
push!(metadata, "RIMU_PACKAGE_VERSION" => string(Rimu.PACKAGE_VERSION))
Arrow.write(filename, df; compress, metadata, kwargs...)
end

Expand Down Expand Up @@ -86,10 +87,11 @@ state is loaded.
See also [`load_state`](@ref).
"""
function save_state(args...; kwargs...)
new_kwargs = (; RIMU_PACKAGE_VERSION=Rimu.PACKAGE_VERSION, kwargs...)
if mpi_size() > 1
_save_state_mpi(args...; kwargs...)
_save_state_mpi(args...; new_kwargs...)
else
_save_state_serial(args...; kwargs...)
_save_state_serial(args...; new_kwargs...)
end
end

Expand Down Expand Up @@ -142,6 +144,9 @@ See also [`save_state`](@ref).
"""
function load_state(::Type{D}, filename; style=nothing, kwargs...) where {D}
tbl = Arrow.Table(filename)
if Tables.schema(tbl).names (:key, :value)
throw(ArgumentError("`$filename` is not a valid Rimu state file"))
end
K = eltype(tbl.key)
V = eltype(tbl.value)
if isnothing(style)
Expand All @@ -157,12 +162,15 @@ function load_state(::Type{D}, filename; style=nothing, kwargs...) where {D}
arrow_meta = Arrow.metadata(tbl)[]
if !isnothing(arrow_meta)
metadata_pairs = map(collect(arrow_meta)) do (k, v)
k == "RIMU_PACKAGE_VERSION" && return Symbol(k) => VersionNumber(v)
v_int = tryparse(Int, v)
!isnothing(v_int) && return Symbol(k) => v_int
v_float = tryparse(Float64, v)
!isnothing(v_float) && return Symbol(k) => v_float
v_cmp = tryparse(ComplexF64, v)
!isnothing(v_cmp) && return Symbol(k) => v_cmp
v_bool = tryparse(Bool, v)
!isnothing(v_bool) && return Symbol(k) => v_bool
Symbol(k) => v
end
metadata = NamedTuple(metadata_pairs)
Expand Down
21 changes: 19 additions & 2 deletions test/RimuIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Arrow
using Rimu: RimuIO
using DataFrames

tmpdir = mktempdir()
const tmpdir = mktempdir()

@testset "save_df, load_df" begin
file = joinpath(tmpdir, "tmp.arrow")
Expand Down Expand Up @@ -83,6 +83,18 @@ end

@testset "vectors" begin
ham = HubbardReal1D(BoseFS(1,1,1))
@testset "errors" begin
df1 = DataFrame(key=[1,2,3], value=[1,2,3], error=[0,0,0])
save_df(file, df1)
@test_throws ArgumentError load_state(file)

df2 = DataFrame(error=[0,0,0])
save_df(file, df2)
@test_throws ArgumentError load_state(file)

rm(file)
end

@testset "save DVec" begin
dvec = ham * DVec([BoseFS(1,1,1) => 1.0, BoseFS(2,1,0) => π])
save_state(file, dvec)
Expand Down Expand Up @@ -117,13 +129,18 @@ end

@testset "metadata" begin
dvec = DVec(BoseFS(1,1,1,1) => 1.0)
save_state(file, dvec; int=1, float=2.3, complex=1.2 + 3im, string="a string")
save_state(
file, dvec;
int=1, float=2.3, complex=1.2 + 3im, string="a string", bool=true
)
_, meta = load_state(file)

@test meta.int === 1
@test meta.float === 2.3
@test meta.complex === 1.2 + 3im
@test meta.bool === true
@test meta.string === "a string"
@test meta.RIMU_PACKAGE_VERSION == Rimu.PACKAGE_VERSION
rm(file)
end

Expand Down

0 comments on commit 7bbb2c5

Please sign in to comment.