diff --git a/lib/codecs/array_to_bytes.ml b/lib/codecs/array_to_bytes.ml index b0aeb216..98c33d70 100644 --- a/lib/codecs/array_to_bytes.ml +++ b/lib/codecs/array_to_bytes.ml @@ -2,6 +2,7 @@ open Array_to_array open Bytes_to_bytes open Util.Result_syntax open Codecs_intf +open Util module Ndarray = Owl.Dense.Ndarray.Generic @@ -43,7 +44,7 @@ module BytesCodec = struct string -> (a, b) Util.array_repr -> endianness -> - ((a, b) Ndarray.t, [> error]) result + ((a, b) Ndarray.t, [> `Store_read of string | error]) result = fun buf decoded t -> let open Bigarray in let open (val endian_module t) in @@ -99,7 +100,7 @@ module rec ArrayToBytes : sig array_tobytes -> ('a, 'b) Util.array_repr -> string -> - (('a, 'b) Ndarray.t, [> error]) result + (('a, 'b) Ndarray.t, [> `Store_read of string | error]) result val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result val to_yojson : array_tobytes -> Yojson.Safe.t end = struct @@ -126,11 +127,10 @@ end = struct | `ShardingIndexed c -> ShardingIndexedCodec.encode c x let decode : - type a b. array_tobytes -> - (a, b) Util.array_repr -> + ('a, 'b) Util.array_repr -> string -> - ((a, b) Ndarray.t, [> error]) result + (('a, 'b) Ndarray.t, [> error]) result = fun t repr b -> match t with | `Bytes endian -> BytesCodec.decode b repr endian @@ -142,8 +142,7 @@ end = struct let of_yojson x = match Util.get_name x with - | "bytes" -> - BytesCodec.of_yojson x >>| fun e -> `Bytes e + | "bytes" -> BytesCodec.of_yojson x >>| fun e -> `Bytes e | "sharding_indexed" -> ShardingIndexedCodec.of_yojson x >>| fun c -> `ShardingIndexed c | _ -> Error ("array->bytes codec not supported: ") @@ -151,16 +150,31 @@ end and ShardingIndexedCodec : sig type t = internal_shard_config - val parse : - t -> ('a, 'b) Util.array_repr -> (unit, [> error]) result + val parse : t -> ('a, 'b) Util.array_repr -> (unit, [> error]) result val compute_encoded_size : int -> t -> int - val encode : - t -> ('a, 'b) Ndarray.t -> (string, [> error]) result + val encode : t -> ('a, 'b) Ndarray.t -> (string, [> error]) result + val partial_encode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + partial_setter -> + int -> + ('a, 'b) Util.array_repr -> + (int array * 'a) list -> + (unit, 'c) result + val partial_decode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + int -> + ('a, 'b) Util.array_repr -> + (int * int array) list -> + ((int * 'a) list, 'c) result val decode : t -> ('a, 'b) Util.array_repr -> string -> - (('a, 'b) Ndarray.t, [> error]) result + (('a, 'b) Ndarray.t, [> `Store_read of string | error]) result val of_yojson : Yojson.Safe.t -> (t, string) result val to_yojson : t -> Yojson.Safe.t end = struct @@ -179,29 +193,27 @@ end = struct internal_shard_config -> ('a, 'b) Util.array_repr -> (unit, [> error]) result - = fun t repr -> - (match Array.(length repr.shape = length t.chunk_shape) with + = fun t r -> + (match Array.(length r.shape = length t.chunk_shape) with | true -> Ok () | false -> let msg = "sharding chunk_shape length must equal the dimensionality of the decoded representaton of a shard." in - Result.error @@ `Sharding (t.chunk_shape, repr.shape, msg)) + Result.error @@ `Sharding (t.chunk_shape, r.shape, msg)) >>= fun () -> (match - Array.for_all2 (fun x y -> (x mod y) = 0) repr.shape t.chunk_shape + Array.for_all2 (fun x y -> (x mod y) = 0) r.shape t.chunk_shape with | true -> Ok () | false -> let msg = "sharding chunk_shape must evenly divide the size of the shard shape." in - Result.error @@ `Sharding (t.chunk_shape, repr.shape, msg)) + Result.error @@ `Sharding (t.chunk_shape, r.shape, msg)) >>= fun () -> - parse_chain repr t.codecs >>= fun () -> - (* must add one dimension to the representation of the index array. *) - parse_chain - {repr with shape = Array.append repr.shape [|2|]} t.index_codecs + parse_chain r t.codecs >>= fun () -> + parse_chain {r with shape = Array.append r.shape [|2|]} t.index_codecs let compute_encoded_size input_size t = List.fold_left BytesToBytes.compute_encoded_size @@ -230,7 +242,6 @@ end = struct ('a, 'b) Ndarray.t -> (string, [> error]) result = fun t x -> - let open Util in let open Extensions in let shard_shape = Ndarray.shape x in let cps = Array.map2 (/) shard_shape t.chunk_shape in @@ -244,49 +255,39 @@ end = struct let k, c = RegularGrid.index_coord_pair grid coords.(i) in Arraytbl.add tbl k (c, y)) x; let kind = Ndarray.kind x in - let cindices = ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl in let buf = Buffer.create @@ Ndarray.size_in_bytes x in - let coord = idx_shp in - let len = Array.length shard_shape in - ArraySet.fold (fun idx acc -> - acc >>= fun offset -> - (* find_all returns bindings in reverse order. To restore the - * C-ordering of elements we must call List.rev. *) - let vals = - Array.of_list @@ - snd @@ - List.split @@ - List.rev @@ - Arraytbl.find_all tbl idx - in - let x' = Ndarray.of_array kind vals t.chunk_shape in - encode_chain t.codecs x' >>| fun b -> - Buffer.add_string buf b; - Array.blit idx 0 coord 0 len; - coord.(len) <- 0; - Ndarray.set shard_idx coord offset; - coord.(len) <- 1; - let nbytes = Int64.of_int @@ String.length b in - Ndarray.set shard_idx coord nbytes; - Int64.add offset nbytes) cindices (Ok 0L) + let icoords = Array.map (fun v -> [|v; v|]) idx_shp in + icoords.(Array.length shard_shape) <- [|0; 1|]; + ArraySet.fold + (fun idx acc -> + acc >>= fun offset -> + (* find_all returns bindings in reverse order. To restore the + * C-ordering of elements we must use List.rev. *) + let v = + Array.of_list @@ snd @@ List.split @@ List.rev @@ + Arraytbl.find_all tbl idx in + let x' = Ndarray.of_array kind v t.chunk_shape in + encode_chain t.codecs x' >>| fun b -> + Buffer.add_string buf b; + Array.iteri (fun i v -> icoords.(i).(0) <- v; icoords.(i).(1) <- v) idx; + let nb = Int64.of_int @@ String.length b in + Ndarray.set_index shard_idx icoords [|offset; nb|]; + Int64.add offset nb) + (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok 0L) >>= fun _ -> - (* convert t.index_codecs to a generic bytes-to-bytes chain. *) encode_chain (t.index_codecs :> bytestobytes internal_chain) shard_idx - >>| fun b' -> + >>| fun ib -> match t.index_location with - | Start -> b' ^ Buffer.contents buf - | End -> Buffer.(add_string buf b'; contents buf) + | Start -> ib ^ Buffer.contents buf + | End -> Buffer.(add_string buf ib; contents buf) let rec decode_chain : type a b. bytestobytes internal_chain -> - string -> (a, b) Util.array_repr -> + string -> ((a, b) Ndarray.t, [> error]) result - = fun t x repr -> - (* compute the last encoded representation of array->array codec chain. - This becomes the decoded representation of the array->bytes decode - procedure. *) + = fun t repr x -> List.fold_left (fun acc c -> acc >>= ArrayToArray.compute_encoded_representation c) @@ -300,13 +301,11 @@ end = struct t.a2a (ArrayToBytes.decode t.a2b repr' y) and decode_index : - string -> - int array -> internal_shard_config -> + int array -> + string -> ((int64, Bigarray.int64_elt) Ndarray.t * string, [> error]) result - = fun b shard_shape t -> - let open Util in - let cps = Array.map2 (/) shard_shape t.chunk_shape in + = fun t cps b -> let l = index_size t cps in let o = String.length b - l in let index_bytes, rest = @@ -316,58 +315,146 @@ end = struct in decode_chain (t.index_codecs :> bytestobytes internal_chain) - index_bytes {fill_value = Int64.max_int ;kind = Bigarray.Int64 ;shape = Array.append cps [|2|]} + index_bytes >>| fun decoded -> decoded, rest and index_size t cps = compute_encoded_size (16 * Util.prod cps) t + and partial_encode : + internal_shard_config -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ]) result) -> + partial_setter -> + int -> + ('a, 'b) Util.array_repr -> + (int array * 'a) list -> + (unit, [> `Store_read of string | error ]) result + = fun t get_partial set_partial bytesize repr pairs -> + let open Extensions in + let cps = Array.map2 (/) repr.shape t.chunk_shape in + let is = index_size t cps in + let ibytes, pad = + match t.index_location with + | Start -> get_partial [0, Some is], is + | End -> get_partial [bytesize - is, None], 0 in + ibytes >>| + List.hd >>= + decode_index t cps >>= fun (idx_arr, _) -> + let tbl = Arraytbl.create @@ List.length pairs in + RegularGrid.create ~array_shape:repr.shape t.chunk_shape >>= fun grid -> + List.iter + (fun (c, v) -> + let id, co = RegularGrid.index_coord_pair grid c in + Arraytbl.add tbl id (co, v)) pairs; + let inner = {repr with shape = t.chunk_shape} in + let icoords = Array.map (fun v -> [|v; v|]) @@ Ndarray.shape idx_arr in + icoords.(Array.length t.chunk_shape) <- [|0; 1|]; + ArraySet.fold + (fun idx acc -> + acc >>= fun (bsize, shard_idx) -> + let z = Arraytbl.find_all tbl idx in + let p = Bigarray.array1_of_genarray @@ Ndarray.slice_left shard_idx idx in + let offset = Int64.to_int p.{0} in + let nb = Int64.to_int p.{1} in + get_partial [pad + offset, Some nb] >>| + List.hd >>= + decode_chain t.codecs inner >>= fun arr -> + List.iter (fun (c, v) -> Ndarray.set arr c v) z; + encode_chain t.codecs arr >>| fun s -> + let nb' = String.length s in + (* if codec chain doesnt contain compressions so update chunk in-place *) + if nb' = nb then + (set_partial [pad + offset, s]; bsize, shard_idx) + else + (Array.iteri + (fun i v -> icoords.(i).(0) <- v; icoords.(i).(1) <- v) idx; + Ndarray.set_index + shard_idx icoords Int64.[|of_int bsize; of_int nb'|]; + set_partial ~append:true [bsize, s]; + bsize + nb', shard_idx)) + (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok (bytesize - pad, idx_arr)) + >>= fun (bytesize, shard_idx) -> + encode_chain (t.index_codecs :> bytestobytes internal_chain) shard_idx + >>| fun ib -> + match t.index_location with + | Start -> set_partial [0, ib] + | End -> set_partial ~append:true [bytesize, ib] + and decode : t -> ('a, 'b) Util.array_repr -> string -> - (('a, 'b) Ndarray.t, [> error]) result + (('a, 'b) Ndarray.t, [> `Store_read of string | error]) result = fun t repr b -> - let open Util in let open Extensions in - let open Util.Result_syntax in - decode_index b repr.shape t >>= fun (shard_idx, b') -> - if Ndarray.equal_scalar shard_idx Int64.max_int then - Ok (Ndarray.create repr.kind repr.shape repr.fill_value) - else - RegularGrid.create ~array_shape:repr.shape t.chunk_shape >>= fun sg -> - let slice = Array.make (Array.length repr.shape) @@ Owl_types.R [] in - (* pair (i, c) is a pair of shard chunk index (i) and shard coordinate c *) - let pair = - Array.map - (RegularGrid.index_coord_pair sg) - (Indexing.coords_of_slice slice repr.shape) in - let tbl = Arraytbl.create @@ Array.length pair in - let inner = - {kind = repr.kind - ;shape = t.chunk_shape - ;fill_value = repr.fill_value} - in - Array.fold_right (fun (idx, coord) acc -> - acc >>= fun l -> - match Arraytbl.find_opt tbl idx with - | Some arr -> - Ok (Ndarray.get arr coord :: l) - | None -> - match Ndarray.(slice_left shard_idx idx) with - | pair when Ndarray.equal_scalar pair Int64.max_int -> - Ok (inner.fill_value :: l) - | pair -> - let p = Bigarray.array1_of_genarray pair in - let c = String.sub b' (Int64.to_int p.{0}) (Int64.to_int p.{1}) in - decode_chain t.codecs c inner >>= fun x -> - Arraytbl.add tbl idx x; - Ok (Ndarray.get x coord :: l)) pair (Ok []) - >>| Array.of_list >>| fun vals -> - Ndarray.of_array inner.kind vals repr.shape + let cps = Array.map2 (/) repr.shape t.chunk_shape in + decode_index t cps b >>= fun (shard_idx, b') -> + RegularGrid.create ~array_shape:repr.shape t.chunk_shape >>= fun grid -> + let slice = Array.make (Array.length repr.shape) @@ Owl_types.R [] in + let coords = Indexing.coords_of_slice slice repr.shape in + let tbl = Arraytbl.create @@ Array.length coords in + Array.iteri + (fun i y -> + let k, c = RegularGrid.index_coord_pair grid y in + Arraytbl.add tbl k (i, c)) coords; + let inner = {repr with shape = t.chunk_shape} in + ArraySet.fold + (fun idx acc -> + acc >>= fun xs -> + let pairs = Arraytbl.find_all tbl idx in + let p = Bigarray.array1_of_genarray @@ Ndarray.slice_left shard_idx idx in + let c = String.sub b' (Int64.to_int p.{0}) (Int64.to_int p.{1}) in + decode_chain t.codecs inner c >>| fun arr -> + List.fold_left + (fun a (i, c) -> (i, Ndarray.get arr c) :: a) xs pairs) + (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok []) + >>| fun pairs -> + let v = + Array.of_list @@ snd @@ List.split @@ + List.fast_sort (fun (x, _) (y, _) -> Int.compare x y) pairs in + Ndarray.of_array inner.kind v repr.shape + + and partial_decode : + internal_shard_config -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ]) result) -> + int -> + ('a, 'b) Util.array_repr -> + (int * int array) list -> + ((int * 'a) list, [> `Store_read of string | error ]) result + = fun t get_partial bsize repr pairs -> + let open Extensions in + let cps = Array.map2 (/) repr.shape t.chunk_shape in + let is = index_size t cps in + let ibytes, pad = + match t.index_location with + | Start -> get_partial [0, Some is], is + | End -> get_partial [bsize - is, None], 0 in + ibytes >>| + List.hd >>= + decode_index t cps >>= fun (idx_arr, _) -> + RegularGrid.create ~array_shape:repr.shape t.chunk_shape >>= fun grid -> + let tbl = Arraytbl.create @@ List.length pairs in + List.iter + (fun (i, y) -> + let id, c = RegularGrid.index_coord_pair grid y in + Arraytbl.add tbl id (i, c)) pairs; + let inner = {repr with shape = t.chunk_shape} in + ArraySet.fold + (fun idx acc -> + acc >>= fun (shard_idx, xs) -> + let z = Arraytbl.find_all tbl idx in + let p = Bigarray.array1_of_genarray @@ Ndarray.slice_left shard_idx idx in + get_partial Int64.[pad + to_int p.{0}, Some (to_int p.{1})] >>| + List.hd >>= + decode_chain t.codecs inner >>| fun arr -> + shard_idx, xs @ List.map (fun (i, c) -> (i, Ndarray.get arr c)) z) + (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok (idx_arr, [])) + >>| snd let rec chain_to_yojson chain = `List diff --git a/lib/codecs/array_to_bytes.mli b/lib/codecs/array_to_bytes.mli index e6a0cf14..15185900 100644 --- a/lib/codecs/array_to_bytes.mli +++ b/lib/codecs/array_to_bytes.mli @@ -15,7 +15,29 @@ module ArrayToBytes : sig array_tobytes -> ('a, 'b) Util.array_repr -> string -> - (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t, [> error]) result + (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t + ,[> `Store_read of string | error]) result val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result val to_yojson : array_tobytes -> Yojson.Safe.t end + +module ShardingIndexedCodec : sig + type t = internal_shard_config + val partial_encode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + partial_setter -> + int -> + ('a, 'b) Util.array_repr -> + (int array * 'a) list -> + (unit, 'c) result + val partial_decode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + int -> + ('a, 'b) Util.array_repr -> + (int * int array) list -> + ((int * 'a) list, 'c) result +end diff --git a/lib/codecs/codecs.ml b/lib/codecs/codecs.ml index 6dd7feba..20d3f4de 100644 --- a/lib/codecs/codecs.ml +++ b/lib/codecs/codecs.ml @@ -93,11 +93,45 @@ module Chain = struct (fun acc c -> acc >>= BytesToBytes.encode c) (ArrayToBytes.encode t.a2b y) t.b2b + let is_just_sharding : t -> bool = function + | {a2a = []; a2b = `ShardingIndexed _; b2b = []} -> true + | _ -> false + + let partial_encode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error] as 'c) result) -> + partial_setter -> + int -> + ('a, 'b) Util.array_repr -> + (int array * 'a) list -> + (unit, 'c) result + = fun t f g bsize repr pairs -> + match t.a2b with + | `ShardingIndexed c -> + ShardingIndexedCodec.partial_encode c f g bsize repr pairs + | `Bytes _ -> failwith "bytes codec does not support partial encoding." + + let partial_decode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + int -> + ('a, 'b) Util.array_repr -> + (int * int array) list -> + ((int * 'a) list, 'c) result + = fun t f s repr pairs -> + match t.a2b with + | `ShardingIndexed c -> + ShardingIndexedCodec.partial_decode c f s repr pairs + | `Bytes _ -> failwith "bytes codec does not support partial decoding." + let decode : t -> ('a, 'b) Util.array_repr -> string -> - (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t, [> error]) result + (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t + ,[> `Store_read of string | error]) result = fun t repr x -> List.fold_right (fun c acc -> acc >>= BytesToBytes.decode c) t.b2b (Ok x) diff --git a/lib/codecs/codecs.mli b/lib/codecs/codecs.mli index c200e2f7..55afae02 100644 --- a/lib/codecs/codecs.mli +++ b/lib/codecs/codecs.mli @@ -21,6 +21,10 @@ module Chain : sig the required codecs as defined in the Zarr Version 3 specification. *) val default : t + (** [is_just_sharding t] is [true] if the codec chain [t] contains only + the [sharding_indexed] codec. *) + val is_just_sharding : t -> bool + (** [encode t x] computes the encoded byte string representation of array chunk [x]. Returns an error upon failure. *) val encode : @@ -34,7 +38,27 @@ module Chain : sig t -> ('a, 'b) Util.array_repr -> string -> - (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t, [> error ]) result + (('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t + ,[> `Store_read of string | error ]) result + + val partial_encode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + partial_setter -> + int -> + ('a, 'b) Util.array_repr -> + (int array * 'a) list -> + (unit, 'c) result + + val partial_decode : + t -> + ((int * int option) list -> + (string list, [> `Store_read of string | error ] as 'c) result) -> + int -> + ('a, 'b) Util.array_repr -> + (int * int array) list -> + ((int * 'a) list, 'c) result (** [x = y] returns true if chain [x] is equal to chain [y], and false otherwise. *) diff --git a/lib/codecs/codecs_intf.ml b/lib/codecs/codecs_intf.ml index 4b45497e..0a7f9c9d 100644 --- a/lib/codecs/codecs_intf.ml +++ b/lib/codecs/codecs_intf.ml @@ -39,6 +39,7 @@ type error = | `CodecChain of string | `Sharding of int array * int array * string ] +type partial_setter = ?append:bool -> (int * string) list -> unit module type Interface = sig (** The type of [array -> array] codecs. *) @@ -102,4 +103,6 @@ module type Interface = sig | `Transpose_order of int array * string | `CodecChain of string | `Sharding of int array * int array * string ] + + type partial_setter = ?append:bool -> (int * string) list -> unit end diff --git a/lib/storage/filesystem.ml b/lib/storage/filesystem.ml index 8c182301..b315f4fb 100644 --- a/lib/storage/filesystem.ml +++ b/lib/storage/filesystem.ml @@ -28,6 +28,29 @@ module Impl = struct with | Sys_error _ -> Error (`Store_read fpath) + let get_partial_values t key ranges = + let open Util.Result_syntax in + In_channel.with_open_gen + In_channel.[Open_rdonly] + t.file_perm + (key_to_fspath t key) + (fun ic -> + let size = In_channel.length ic |> Int64.to_int in + List.fold_right + (fun (rs, len) acc -> + acc >>= fun xs -> + let len' = + match len with + | Some l -> l + | None -> size - rs + in + In_channel.seek ic @@ Int64.of_int rs; + match In_channel.really_input_string ic len' with + | Some s -> Ok (s :: xs) + | None -> + Error (`Store_read "EOF reached before all bytes are read.")) + ranges (Ok [])) + let set t key value = let filename = key_to_fspath t key in create_parent_dir filename t.file_perm; @@ -35,7 +58,19 @@ module Impl = struct Out_channel.[Open_wronly; Open_trunc; Open_creat] t.file_perm filename - (fun oc -> Out_channel.output_string oc value) + (fun oc -> Out_channel.output_string oc value; Out_channel.flush oc) + + let set_partial_values t key ?(append=false) rvs = + let open Out_channel in + Out_channel.with_open_gen + [if append then Open_append else Open_wronly] + t.file_perm + (key_to_fspath t key) + (fun oc -> + List.iter + (fun (rs, value) -> + Out_channel.seek oc @@ Int64.of_int rs; + Out_channel.output_string oc value) rvs; Out_channel.flush oc) let list t = let module StrSet = Storage_intf.Base.StrSet in @@ -68,23 +103,18 @@ module Impl = struct try Sys.remove @@ key_to_fspath t key with | Sys_error _ -> () - let get_partial_values t kr_pairs = - Storage_intf.Base.get_partial_values - ~get_fn:get t kr_pairs - - let set_partial_values t krv_triplet = - Storage_intf.Base.set_partial_values - ~set_fn:set ~get_fn:get t krv_triplet - - let erase_values t keys = - Storage_intf.Base.erase_values - ~erase_fn:erase t keys + let size t key = + In_channel.with_open_gen + In_channel.[Open_rdonly] + t.file_perm + (key_to_fspath t key) + (fun ic -> In_channel.length ic |> Int64.to_int) let erase_prefix t pre = Storage_intf.Base.erase_prefix ~list_fn:list ~erase_fn:erase t pre - let list_prefix pre t = + let list_prefix t pre = Storage_intf.Base.list_prefix ~list_fn:list t pre let list_dir t pre = diff --git a/lib/storage/memory.ml b/lib/storage/memory.ml index 263935a2..f77eab3e 100644 --- a/lib/storage/memory.ml +++ b/lib/storage/memory.ml @@ -20,23 +20,23 @@ module Impl = struct let erase = StrMap.remove + let size t key = + get t key |> Result.get_ok |> String.length + let erase_prefix t pre = StrMap.filter_map_inplace (fun k v -> if String.starts_with ~prefix:pre k then None else Some v) t - let get_partial_values t kr_pairs = + let get_partial_values t key ranges = Storage_intf.Base.get_partial_values - ~get_fn:get t kr_pairs + ~get_fn:get t key ranges - let set_partial_values t krv_triplet = + let set_partial_values t key ?(append=false) rv = Storage_intf.Base.set_partial_values - ~set_fn:set ~get_fn:get t krv_triplet - - let erase_values t keys = - Storage_intf.Base.erase_values ~erase_fn:erase t keys + ~set_fn:set ~get_fn:get t key append rv - let list_prefix pre t = + let list_prefix t pre = Storage_intf.Base.list_prefix ~list_fn:list t pre let list_dir t pre = diff --git a/lib/storage/storage.ml b/lib/storage/storage.ml index e2397521..9816afee 100644 --- a/lib/storage/storage.ml +++ b/lib/storage/storage.ml @@ -1,5 +1,6 @@ include Storage_intf +open Util open Util.Result_syntax open Node @@ -43,7 +44,6 @@ module Make (M : STORE) : S with type t = M.t = struct node t = - let open Util in let repr = {kind; fill_value; shape = chunks} in (match codecs with | Some c -> Codecs.Chain.create repr c @@ -89,7 +89,7 @@ module Make (M : STORE) : S with type t = M.t = struct (Result.get_ok @@ ArrayNode.of_path p) :: l, r else l, (Result.get_ok @@ GroupNode.of_path p) :: r - else acc) ([], []) (list_prefix "" t) + else acc) ([], []) (list_prefix t "") with | [], [] as xs -> xs | l, r -> l, GroupNode.root :: r @@ -102,71 +102,68 @@ module Make (M : STORE) : S with type t = M.t = struct let erase_all_nodes t = erase_prefix t "" - let set_array - : type a b. + let set_array : ArrayNode.t -> Owl_types.index array -> - (a, b) Ndarray.t -> + ('a, 'b) Ndarray.t -> t -> (unit, [> error ]) result = fun node slice x t -> - let open Util in get t @@ ArrayNode.to_metakey node >>= fun bytes -> AM.decode bytes >>? (fun msg -> `Store_write msg) >>= fun meta -> - (if Ndarray.shape x = Indexing.slice_shape slice @@ AM.shape meta then + let arr_shape = AM.shape meta in + (if Ndarray.shape x = Indexing.slice_shape slice arr_shape then Ok () else Error (`Store_write "slice and input array shapes are unequal.")) >>= fun () -> - (if AM.is_valid_kind meta @@ Ndarray.kind x then + let kind = Ndarray.kind x in + (if AM.is_valid_kind meta kind then Ok () else Result.error @@ `Store_write ( "input array's kind is not compatible with node's data type.")) >>= fun () -> - let coords = Indexing.coords_of_slice slice @@ AM.shape meta in - let tbl = Arraytbl.create @@ Array.length coords - in + let coords = Indexing.coords_of_slice slice arr_shape in + let tbl = Arraytbl.create @@ Array.length coords in Ndarray.iteri (fun i y -> let k, c = AM.index_coord_pair meta coords.(i) in Arraytbl.add tbl k (c, y)) x; - let repr = - {kind = Ndarray.kind x - ;shape = AM.chunk_shape meta - ;fill_value = AM.fillvalue_of_kind meta @@ Ndarray.kind x} - in - let codecs = AM.codecs meta in + let fill_value = AM.fillvalue_of_kind meta kind in + let repr = {kind; fill_value; shape = AM.chunk_shape meta} in let prefix = ArrayNode.to_key node ^ "/" in - let cindices = ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl in - ArraySet.fold (fun idx acc -> - acc >>= fun () -> - let chunkkey = prefix ^ AM.chunk_key meta idx in - (match get t chunkkey with - | Ok b -> - Codecs.Chain.decode codecs repr b - | Error _ -> - Ok (Ndarray.create repr.kind repr.shape repr.fill_value)) - >>= fun arr -> - (* NOTE: Ndarray.set_fancy* functions unfortunately don't work for array - kinds other than Float32, Float64, Complex32 and Complex64. - See: https://github.com/owlbarn/owl/issues/671 . As a workaround - we manually set each coordinate one-at-time using the basic - set function which does not suffer from this bug. It is likely - much slower for large Zarr chunks but necessary for usability.*) - List.iter - (fun (c, v) -> Ndarray.set arr c v) @@ Arraytbl.find_all tbl idx; - Codecs.Chain.encode codecs arr >>| set t chunkkey) cindices (Ok ()) - - let get_array - : type a b. + let chain = AM.codecs meta in + ArraySet.fold + (fun idx acc -> + acc >>= fun () -> + let pairs = Arraytbl.find_all tbl idx in + let ckey = prefix ^ AM.chunk_key meta idx in + if Codecs.Chain.is_just_sharding chain && is_member t ckey then + Codecs.Chain.partial_encode + chain + (get_partial_values t ckey) + (set_partial_values t ckey) + (size t ckey) + repr + pairs + else + (match get t ckey with + | Ok b -> Codecs.Chain.decode chain repr b + | Error `Store_read _ -> + Result.ok @@ Ndarray.create repr.kind repr.shape repr.fill_value) + >>= fun arr -> + List.iter (fun (c, v) -> Ndarray.set arr c v) pairs; + Codecs.Chain.encode chain arr >>| set t ckey) + (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok ()) + + let get_array : ArrayNode.t -> Owl_types.index array -> - (a, b) Bigarray.kind -> + ('a, 'b) Bigarray.kind -> t -> - ((a, b) Ndarray.t, [> error]) result + (('a, 'b) Ndarray.t, [> error]) result = fun node slice kind t -> - let open Util in get t @@ ArrayNode.to_metakey node >>= fun bytes -> AM.decode bytes >>? (fun msg -> `Store_read msg) >>= fun meta -> (if AM.is_valid_kind meta kind then @@ -175,68 +172,64 @@ module Make (M : STORE) : S with type t = M.t = struct Result.error @@ `Store_read ("input kind is not compatible with node's data type.")) >>= fun () -> - (try - Ok (Indexing.slice_shape slice @@ AM.shape meta) - with + let arr_shape = AM.shape meta in + (try Result.ok @@ Indexing.slice_shape slice arr_shape with | Assert_failure _ -> Result.error @@ `Store_read "slice shape is not compatible with node's shape.") - >>= fun sshape -> - let coords = Indexing.coords_of_slice slice @@ AM.shape meta in + >>= fun slice_shape -> + let coords = Indexing.coords_of_slice slice arr_shape in let tbl = Arraytbl.create @@ Array.length coords in Array.iteri (fun i y -> let k, c = AM.index_coord_pair meta y in Arraytbl.add tbl k (i, c)) coords; - let prefix = ArrayNode.to_key node ^ "/" in let chain = AM.codecs meta in - let repr = - {kind - ;shape = AM.chunk_shape meta - ;fill_value = AM.fillvalue_of_kind meta kind} - in + let prefix = ArrayNode.to_key node ^ "/" in + let fill_value = AM.fillvalue_of_kind meta kind in + let repr = {kind; fill_value; shape = AM.chunk_shape meta} in ArraySet.fold (fun idx acc -> - acc >>= fun l -> - let xs = Arraytbl.find_all tbl idx in - match get t @@ prefix ^ AM.chunk_key meta idx with - | Ok b -> - Codecs.Chain.decode chain repr b >>| fun arr -> - List.fold_left - (fun accu (i, c) -> (i, Ndarray.get arr c) :: accu) l xs - | Error _ -> - Result.ok @@ - List.fold_left - (fun accu (i, _) -> (i, repr.fill_value) :: accu) l xs) - (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok []) + acc >>= fun xs -> + let pairs = Arraytbl.find_all tbl idx in + let ckey = prefix ^ AM.chunk_key meta idx in + if Codecs.Chain.is_just_sharding chain && is_member t ckey then + Codecs.Chain.partial_decode + chain (get_partial_values t ckey) (size t ckey) repr pairs >>| + List.append xs + else + match get t ckey with + | Ok b -> + Codecs.Chain.decode chain repr b >>| fun arr -> + List.fold_left + (fun a (i, c) -> (i, Ndarray.get arr c) :: a) xs pairs + | Error `Store_read _ -> + Result.ok @@ + List.fold_left (fun a (i, _) -> (i, fill_value) :: a) xs pairs) + (ArraySet.of_seq @@ Arraytbl.to_seq_keys tbl) (Ok []) >>| fun pairs -> - let vals = - Array.of_list @@ - snd @@ - List.split @@ - List.fast_sort (fun (x, _) (y, _) -> Int.compare x y) pairs - in - Ndarray.of_array kind vals sshape - - let reshape t node shape = + (* sorting restores the C-order of the decoded array coordinates. *) + let v = + Array.of_list @@ snd @@ List.split @@ + List.fast_sort (fun (x, _) (y, _) -> Int.compare x y) pairs in + Ndarray.of_array kind v slice_shape + + let reshape t node newshape = let mkey = ArrayNode.to_metakey node in get t mkey >>= fun bytes -> - AM.decode bytes >>? (fun msg -> `Store_write msg) - >>= fun meta -> - (if Array.length shape = Array.length @@ AM.shape meta then + AM.decode bytes >>? (fun msg -> `Store_write msg) >>= fun meta -> + let oldshape = AM.shape meta in + (if Array.length newshape = Array.length oldshape then Ok () else Error (`Store_write "new shape must have same number of dimensions.")) >>| fun () -> let pre = ArrayNode.to_key node ^ "/" in - let s = - ArraySet.of_list @@ AM.chunk_indices meta @@ AM.shape meta in - let s' = - ArraySet.of_list @@ AM.chunk_indices meta shape in + let s = ArraySet.of_list @@ AM.chunk_indices meta oldshape in + let s' = ArraySet.of_list @@ AM.chunk_indices meta newshape in ArraySet.iter - (fun v -> erase t @@ pre ^ AM.chunk_key meta v) - ArraySet.(diff s s'); - set t mkey @@ AM.encode @@ AM.update_shape meta shape + (fun v -> erase t @@ pre ^ AM.chunk_key meta v) ArraySet.(diff s s'); + set t mkey @@ AM.encode @@ AM.update_shape meta newshape end module MemoryStore = struct diff --git a/lib/storage/storage_intf.ml b/lib/storage/storage_intf.ml index c88b98bb..374f48f0 100644 --- a/lib/storage/storage_intf.ml +++ b/lib/storage/storage_intf.ml @@ -3,7 +3,7 @@ open Node type key = string -type range = ByteRange of int * int option +type range = int * int option type error = [ `Store_read of string @@ -27,15 +27,16 @@ module type STORE = sig in keys and ending with a trailing / character. *) type t - val get : t -> key -> (string, [> error]) result - val get_partial_values : t -> (key * range) list -> (string list, [> error ]) result + val size : t -> key -> int + val get : t -> key -> (string, [> `Store_read of string ]) result + val get_partial_values : + t -> key -> range list -> (string list, [> `Store_read of string ]) result val set : t -> key -> string -> unit - val set_partial_values : t -> (key * int * string) list -> (unit, [> error]) result + val set_partial_values : t -> key -> ?append:bool -> (int * string) list -> unit val erase : t -> key -> unit - val erase_values : t -> key list -> unit val erase_prefix : t -> key -> unit val list : t -> key list - val list_prefix : key -> t -> key list + val list_prefix : t -> key -> key list val list_dir : t -> key -> key list * string list val is_member : t -> key -> bool end @@ -188,11 +189,8 @@ module Base = struct (String.starts_with ~prefix:pre) (list_fn t) - let erase_values ~erase_fn t keys = - StrSet.iter (erase_fn t) @@ StrSet.of_list keys - let erase_prefix ~list_fn ~erase_fn t pre = - erase_values ~erase_fn t @@ list_prefix ~list_fn t pre + List.iter (erase_fn t) @@ list_prefix ~list_fn t pre let list_dir ~list_fn t pre = let n = String.length pre in @@ -214,29 +212,25 @@ module Base = struct in StrSet.(elements keys, elements prefixes) - let get_partial_values ~get_fn t kr_pairs = + let get_partial_values ~get_fn t key ranges = let open Util.Result_syntax in List.fold_right - (fun (k, ByteRange (rs, len)) acc -> + (fun (rs, len) acc -> acc >>= fun xs -> - get_fn t k >>| fun v -> + get_fn t key >>| fun v -> (match len with | None -> String.sub v rs @@ String.length v - rs - | Some l -> String.sub v rs l) :: xs) kr_pairs (Ok []) - - let set_partial_values ~set_fn ~get_fn t krv = - let open Util.Result_syntax in - let module StrMap = Util.StrMap in - let tbl = StrMap.create @@ List.length krv in - List.fold_right - (fun (k, rs, v) acc -> - acc >>= fun () -> - (match StrMap.find_opt tbl k with - | None -> - get_fn t k - | Some ov -> Ok ov) - >>| fun ov -> - let ov' = Bytes.of_string ov in - String.(length v |> blit v 0 ov' rs); - set_fn t k @@ Bytes.to_string ov') krv (Ok ()) + | Some l -> String.sub v rs l) :: xs) ranges (Ok []) + + let set_partial_values ~set_fn ~get_fn t key append rv = + List.iter + (fun (rs, v) -> + let ov = get_fn t key |> Result.get_ok in + if append then + let ov' = ov ^ v in + set_fn t key ov' + else + let ov' = Bytes.of_string ov in + String.(length v |> Bytes.blit_string v 0 ov' rs); + set_fn t key @@ Bytes.to_string ov') rv end diff --git a/test/test_storage.ml b/test/test_storage.ml index 0ef0a7b3..13320be9 100644 --- a/test/test_storage.ml +++ b/test/test_storage.ml @@ -2,6 +2,7 @@ open OUnit2 open Zarr open Zarr.Node open Zarr.Storage +open Zarr.Codecs module Ndarray = Owl.Dense.Ndarray.Generic @@ -55,14 +56,25 @@ let test_store have metadata with said values."); let fake = ArrayNode.(gnode / "non-member") |> Result.get_ok in - assert_equal - ~printer:string_of_bool - false @@ - M.array_exists store fake; + assert_equal ~printer:string_of_bool false @@ M.array_exists store fake; + let nested_sharding = + `ShardingIndexed { + chunk_shape = [|5; 1; 5|]; + index_location = Start; + index_codecs = [`Bytes Big]; + codecs = [`Bytes Little; `Gzip L2]} + in + let cfg = + {chunk_shape = [|2; 5; 5|] + ;index_location = End + ;index_codecs = [`Bytes Little; `Crc32c] + ;codecs = [`Transpose [|2; 0; 1|]; nested_sharding; `Gzip L1]} in let anode = ArrayNode.(gnode / "arrnode") |> Result.get_ok in let r = M.create_array + ~sep:`Dot + ~codecs:[`ShardingIndexed cfg] ~shape:[|100; 100; 50|] ~chunks:[|10; 15; 20|] Bigarray.Complex64 @@ -71,38 +83,37 @@ let test_store store in assert_equal (Ok ()) r; - (* should work with a custom chain too *) + let slice = Owl_types.[|R [0; 20]; I 10; R []|] in + let r = M.get_array anode slice Bigarray.Complex64 store in + assert_bool "" @@ Result.is_ok r; + let expected = Ndarray.create Bigarray.Complex64 [|21; 1; 50|] Complex.one in + let r = M.set_array anode slice expected store in + assert_equal (Ok ()) r; + let got = Result.get_ok @@ M.get_array anode slice Bigarray.Complex64 store in + assert_equal ~printer:Owl_pretty.dsnda_to_string expected got; + let x' = Ndarray.map (fun v -> Complex.(add v one)) got in + let r = M.set_array anode slice x' store in + assert_equal (Ok ()) r; + let r = M.get_array anode slice Bigarray.Complex64 store in + assert_bool "" @@ Result.is_ok r; + let r = M.create_array - ~sep:`Dot ~codecs:[`Bytes Big] ~shape:[|100; 100; 50|] ~chunks:[|10; 15; 20|] Bigarray.Complex64 Complex.zero anode - store - in + store in assert_equal (Ok ()) r; - - let slice = Owl_types.[|R [0; 20]; I 10; R []|] in - let expected = - Ndarray.create Bigarray.Complex64 [|21; 1; 50|] Complex.zero in - let got = - Result.get_ok @@ - M.get_array anode slice Bigarray.Complex64 store in - assert_equal - ~printer:Owl_pretty.dsnda_to_string - expected - got; - - let x' = Ndarray.map (fun _ -> Complex.one) got in + let expected = Ndarray.create Bigarray.Complex64 [|21; 1; 50|] Complex.zero in + let got = Result.get_ok @@ M.get_array anode slice Bigarray.Complex64 store in + assert_equal ~printer:Owl_pretty.dsnda_to_string expected got; + let x' = Ndarray.map (fun v -> Complex.(add v one)) got in let r = M.set_array anode slice x' store in assert_equal (Ok ()) r; - let got = - Result.get_ok @@ - M.get_array anode slice Bigarray.Complex64 store - in + let got = Result.get_ok @@ M.get_array anode slice Bigarray.Complex64 store in assert_equal ~printer:Owl_pretty.dsnda_to_string x' got; assert_bool "get_array can only work with the correct array kind" @@