diff --git a/src/RimuIO/RimuIO.jl b/src/RimuIO/RimuIO.jl index 05d15bad9..6f292e267 100644 --- a/src/RimuIO/RimuIO.jl +++ b/src/RimuIO/RimuIO.jl @@ -18,7 +18,7 @@ using Rimu.DictVectors: PDVec, DVec, target_segment using Rimu.Interfaces: Interfaces, localpart, storage using Rimu.StochasticStyles: default_style, IsDynamicSemistochastic -import Tables, Arrow, Arrow.ArrowTypes +import Rimu, Tables, Arrow, Arrow.ArrowTypes export save_df, load_df, save_state, load_state @@ -48,6 +48,7 @@ function save_df( if metadata === nothing metadata = [key => string(val) for (key, val) in DataFrames.metadata(df)] end + push!(metadata, "RIMU_PACKAGE_VERSION" => string(Rimu.PACKAGE_VERSION)) Arrow.write(filename, df; compress, metadata, kwargs...) end @@ -86,10 +87,11 @@ state is loaded. See also [`load_state`](@ref). """ function save_state(args...; kwargs...) + new_kwargs = (; RIMU_PACKAGE_VERSION=Rimu.PACKAGE_VERSION, kwargs...) if mpi_size() > 1 - _save_state_mpi(args...; kwargs...) + _save_state_mpi(args...; new_kwargs...) else - _save_state_serial(args...; kwargs...) + _save_state_serial(args...; new_kwargs...) end end @@ -142,6 +144,9 @@ See also [`save_state`](@ref). """ function load_state(::Type{D}, filename; style=nothing, kwargs...) where {D} tbl = Arrow.Table(filename) + if Tables.schema(tbl).names ≠ (:key, :value) + throw(ArgumentError("`$filename` is not a valid Rimu state file")) + end K = eltype(tbl.key) V = eltype(tbl.value) if isnothing(style) @@ -157,12 +162,15 @@ function load_state(::Type{D}, filename; style=nothing, kwargs...) where {D} arrow_meta = Arrow.metadata(tbl)[] if !isnothing(arrow_meta) metadata_pairs = map(collect(arrow_meta)) do (k, v) + k == "RIMU_PACKAGE_VERSION" && return Symbol(k) => VersionNumber(v) v_int = tryparse(Int, v) !isnothing(v_int) && return Symbol(k) => v_int v_float = tryparse(Float64, v) !isnothing(v_float) && return Symbol(k) => v_float v_cmp = tryparse(ComplexF64, v) !isnothing(v_cmp) && return Symbol(k) => v_cmp + v_bool = tryparse(Bool, v) + !isnothing(v_bool) && return Symbol(k) => v_bool Symbol(k) => v end metadata = NamedTuple(metadata_pairs) diff --git a/test/RimuIO.jl b/test/RimuIO.jl index 20fed50a2..913cdda51 100644 --- a/test/RimuIO.jl +++ b/test/RimuIO.jl @@ -4,7 +4,7 @@ using Arrow using Rimu: RimuIO using DataFrames -tmpdir = mktempdir() +const tmpdir = mktempdir() @testset "save_df, load_df" begin file = joinpath(tmpdir, "tmp.arrow") @@ -83,6 +83,18 @@ end @testset "vectors" begin ham = HubbardReal1D(BoseFS(1,1,1)) + @testset "errors" begin + df1 = DataFrame(key=[1,2,3], value=[1,2,3], error=[0,0,0]) + save_df(file, df1) + @test_throws ArgumentError load_state(file) + + df2 = DataFrame(error=[0,0,0]) + save_df(file, df2) + @test_throws ArgumentError load_state(file) + + rm(file) + end + @testset "save DVec" begin dvec = ham * DVec([BoseFS(1,1,1) => 1.0, BoseFS(2,1,0) => π]) save_state(file, dvec) @@ -117,13 +129,18 @@ end @testset "metadata" begin dvec = DVec(BoseFS(1,1,1,1) => 1.0) - save_state(file, dvec; int=1, float=2.3, complex=1.2 + 3im, string="a string") + save_state( + file, dvec; + int=1, float=2.3, complex=1.2 + 3im, string="a string", bool=true + ) _, meta = load_state(file) @test meta.int === 1 @test meta.float === 2.3 @test meta.complex === 1.2 + 3im + @test meta.bool === true @test meta.string === "a string" + @test meta.RIMU_PACKAGE_VERSION == Rimu.PACKAGE_VERSION rm(file) end