From 51898944a35da8c6c230eff614ae9824c0a96aff Mon Sep 17 00:00:00 2001 From: Zolisa Bleki Date: Sat, 13 Jul 2024 14:48:15 +0200 Subject: [PATCH] Use list literals to specify shard config's codec chain. --- README.md | 10 ++---- lib/codecs/array_to_bytes.ml | 45 +++++++++++++------------- lib/codecs/array_to_bytes.mli | 14 ++++----- lib/codecs/codecs.ml | 59 ++++++++++++++++++++++++++++------- lib/codecs/codecs_intf.ml | 40 +++++++++++------------- test/test_codecs.ml | 17 ++++------ 6 files changed, 104 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index fb353eba..b5ba6cd6 100644 --- a/README.md +++ b/README.md @@ -75,14 +75,8 @@ R[73,1] -INF -INF -INF -INF -INF -INF *) ```ocaml let config = {chunk_shape = [|5; 3; 5|] - ;codecs = - {a2a = [`Transpose [|2; 0; 1|]] - ;a2b = `Bytes Little - ;b2b = [`Gzip L5]} - ;index_codecs = - {a2a = [] - ;a2b = `Bytes Big - ;b2b = [`Crc32c]} + ;codecs = [`Transpose [|2; 0; 1|]; `Bytes Little; `Gzip L5] + ;index_codecs = [`Bytes Big; `Crc32c] ;index_location = Start};; let shard_node = Result.get_ok @@ ArrayNode.(group_node / "another");; diff --git a/lib/codecs/array_to_bytes.ml b/lib/codecs/array_to_bytes.ml index 49d4faa0..ca2667e5 100644 --- a/lib/codecs/array_to_bytes.ml +++ b/lib/codecs/array_to_bytes.ml @@ -87,21 +87,21 @@ end module rec ArrayToBytes : sig val parse : ('a, 'b) Util.array_repr -> - arraytobytes -> + array_tobytes -> (unit, [> error]) result - val compute_encoded_size : int -> arraytobytes -> int - val default : arraytobytes + val compute_encoded_size : int -> array_tobytes -> int + val default : array_tobytes val encode : ('a, 'b) Ndarray.t -> - arraytobytes -> + array_tobytes -> (string, [> error]) result val decode : string -> ('a, 'b) Util.array_repr -> - arraytobytes -> + array_tobytes -> (('a, 'b) Ndarray.t, [> error]) result - val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result - val to_yojson : arraytobytes -> Yojson.Safe.t + val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result + val to_yojson : array_tobytes -> Yojson.Safe.t end = struct let default = `Bytes Little @@ -120,7 +120,7 @@ end = struct let encode : type a b. (a, b) Ndarray.t -> - arraytobytes -> + array_tobytes -> (string, [> error]) result = fun x -> function | `Bytes endian -> BytesCodec.encode x endian @@ -130,7 +130,7 @@ end = struct : type a b. string -> (a, b) Util.array_repr -> - arraytobytes -> + array_tobytes -> ((a, b) Ndarray.t, [> error]) result = fun b repr -> function | `Bytes endian -> BytesCodec.decode b repr endian @@ -150,7 +150,7 @@ end = struct end and ShardingIndexedCodec : sig - type t = shard_config + type t = internal_shard_config val parse : ('a, 'b) Util.array_repr -> t -> @@ -169,7 +169,7 @@ and ShardingIndexedCodec : sig val to_yojson : t -> Yojson.Safe.t end = struct - type t = shard_config + type t = internal_shard_config let parse_chain repr chain = List.fold_left @@ -183,7 +183,7 @@ end = struct let parse : type a b. (a, b) Util.array_repr -> - shard_config -> + internal_shard_config -> (unit, [> error]) result = fun repr t -> (match Array.(length repr.shape = length t.chunk_shape) with @@ -220,7 +220,7 @@ end = struct let rec encode_chain : type a b. - bytestobytes shard_chain -> + bytestobytes internal_chain -> (a, b) Ndarray.t -> (string, [> error]) result = fun t x -> @@ -234,7 +234,7 @@ end = struct and encode : type a b. (a, b) Ndarray.t -> - shard_config -> + internal_shard_config -> (string, [> error]) result = fun x t -> let open Util in @@ -293,7 +293,7 @@ end = struct offset := Int64.add !offset nbytes) cindices (Ok ()) >>= fun () -> (* convert t.index_codecs to a generic bytes-to-bytes chain. *) - encode_chain (t.index_codecs :> bytestobytes shard_chain) shard_idx + encode_chain (t.index_codecs :> bytestobytes internal_chain) shard_idx >>| fun b' -> match t.index_location with | Start -> @@ -307,7 +307,7 @@ end = struct let rec decode_chain : type a b. - bytestobytes shard_chain -> + bytestobytes internal_chain -> string -> (a, b) Util.array_repr -> ((a, b) Ndarray.t, [> error]) result @@ -330,7 +330,7 @@ end = struct and decode_index : string -> int array -> - shard_config -> + internal_shard_config -> ((int64, Bigarray.int64_elt) Ndarray.t * string, [> error]) result = fun b shard_shape t -> let open Util in @@ -347,10 +347,9 @@ end = struct ;kind = Bigarray.Int64 ;shape = Array.append cps [|2|]} in - decode_chain - (t.index_codecs : fixed_bytestobytes shard_chain :> bytestobytes shard_chain) - b' repr >>| fun decoded -> - (decoded, rest) + decode_chain (t.index_codecs :> bytestobytes internal_chain) b' repr + >>= fun decoded -> + Ok (decoded, rest) and index_size t cps = compute_encoded_size (16 * Util.prod cps) t @@ -413,7 +412,7 @@ end = struct and to_yojson t = let codecs = chain_to_yojson t.codecs in let index_codecs = - chain_to_yojson (t.index_codecs :> bytestobytes shard_chain) + chain_to_yojson (t.index_codecs :> bytestobytes internal_chain) in let index_location = match t.index_location with @@ -435,7 +434,7 @@ end = struct ("codecs", codecs)])] let rec chain_of_yojson : - Yojson.Safe.t list -> (bytestobytes shard_chain, string) result + Yojson.Safe.t list -> (bytestobytes internal_chain, string) result = fun codecs -> let filter_partition f encoded = List.fold_right (fun c (l, r) -> diff --git a/lib/codecs/array_to_bytes.mli b/lib/codecs/array_to_bytes.mli index 91ddf1f9..7c3befdb 100644 --- a/lib/codecs/array_to_bytes.mli +++ b/lib/codecs/array_to_bytes.mli @@ -5,19 +5,19 @@ module Ndarray = Owl.Dense.Ndarray.Generic module ArrayToBytes : sig val parse : ('a, 'b) Util.array_repr -> - arraytobytes -> + array_tobytes -> (unit, [> error]) result - val compute_encoded_size : int -> arraytobytes -> int - val default : arraytobytes + val compute_encoded_size : int -> array_tobytes -> int + val default : array_tobytes val encode : ('a, 'b) Ndarray.t -> - arraytobytes -> + array_tobytes -> (string, [> error]) result val decode : string -> ('a, 'b) Util.array_repr -> - arraytobytes -> + array_tobytes -> (('a, 'b) Ndarray.t, [> error]) result - val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result - val to_yojson : arraytobytes -> Yojson.Safe.t + val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result + val to_yojson : array_tobytes -> Yojson.Safe.t end diff --git a/lib/codecs/codecs.ml b/lib/codecs/codecs.ml index df64d656..77a17dc3 100644 --- a/lib/codecs/codecs.ml +++ b/lib/codecs/codecs.ml @@ -5,15 +5,31 @@ open Util.Result_syntax include Codecs_intf -type internal_chain = - {a2a : arraytoarray list - ;a2b : arraytobytes - ;b2b : bytestobytes list} +type arraytobytes = + [ `Bytes of endianness + | `ShardingIndexed of shard_config ] + +and shard_config = + {chunk_shape : int array + ;codecs : + [ arraytoarray + | `Bytes of endianness + | `ShardingIndexed of shard_config + | bytestobytes ] list + ;index_codecs : + [ arraytoarray + | `Bytes of endianness + | `ShardingIndexed of shard_config + | fixed_bytestobytes ] list + ;index_location : loc} + +type codec_chain = + [ arraytoarray | arraytobytes | bytestobytes ] list module Chain = struct - type t = internal_chain + type t = bytestobytes internal_chain - let create : + let rec create : type a b. (a, b) Util.array_repr -> codec_chain -> (t, [> error ]) result = fun repr cc -> let a2a, rest = List.partition_map (function @@ -21,11 +37,32 @@ module Chain = struct | #arraytobytes as c -> Either.right c | #bytestobytes as c -> Either.right c) cc in - (match - List.partition_map (function - | #arraytobytes as c -> Either.left c - | #bytestobytes as c -> Either.right c) rest - with + List.fold_right + (fun c acc -> + acc >>= fun (l, r) -> + match c with + | `Bytes e -> Ok (`Bytes e :: l, r) + | `ShardingIndexed cfg -> + create repr cfg.codecs + >>= fun codecs -> + create + {repr with shape = Array.append repr.shape [|2|]} + (cfg.index_codecs :> codec_chain) + >>= fun index_codecs -> + (* convert to a fixed_bytestobytes internal_chain type *) + let b2b = + List.partition_map (function + | `Crc32c -> Either.left `Crc32c + | c -> Either.right c) index_codecs.b2b |> fst in + let cfg' : internal_shard_config = + {chunk_shape = cfg.chunk_shape; + index_location = cfg.index_location; + index_codecs = {index_codecs with b2b}; + codecs} in + Ok (`ShardingIndexed cfg' :: l, r) + | #bytestobytes as c -> Ok (l, c :: r)) rest (Ok ([], [])) + >>= fun result -> + (match result with | [x], rest -> Ok (x, rest) | _ -> Result.error @@ `CodecChain "Must be exactly one array->bytes codec.") diff --git a/lib/codecs/codecs_intf.ml b/lib/codecs/codecs_intf.ml index 11a9e0c3..59907a5e 100644 --- a/lib/codecs/codecs_intf.ml +++ b/lib/codecs/codecs_intf.ml @@ -17,23 +17,20 @@ type endianness = Little | Big type loc = Start | End -type arraytobytes = +type array_tobytes = [ `Bytes of endianness - | `ShardingIndexed of shard_config ] + | `ShardingIndexed of internal_shard_config ] -and shard_config = +and internal_shard_config = {chunk_shape : int array - ;codecs : bytestobytes shard_chain - ;index_codecs : fixed_bytestobytes shard_chain + ;codecs : bytestobytes internal_chain + ;index_codecs : fixed_bytestobytes internal_chain ;index_location : loc} -and 'a shard_chain = - {a2a: arraytoarray list - ;a2b: arraytobytes - ;b2b: 'a list} - -type codec_chain = - [ arraytoarray | arraytobytes | bytestobytes ] list +and 'a internal_chain = + {a2a : arraytoarray list + ;a2b : array_tobytes + ;b2b : 'a list} type error = [ `Extension of string @@ -82,17 +79,18 @@ module type Interface = sig (** A type representing the Sharding indexed codec's configuration parameters. *) and shard_config = {chunk_shape : int array - ;codecs : bytestobytes shard_chain - ;index_codecs : fixed_bytestobytes shard_chain + ;codecs : + [ arraytoarray + | `Bytes of endianness + | `ShardingIndexed of shard_config + | bytestobytes ] list + ;index_codecs : + [ arraytoarray + | `Bytes of endianness + | `ShardingIndexed of shard_config + | fixed_bytestobytes ] list ;index_location : loc} - (** A type representing the chain of codecs used to encode/decode - a shard's bytes and its index array. *) - and 'a shard_chain = - {a2a: arraytoarray list - ;a2b: arraytobytes - ;b2b: 'a list} - (** A type used to build a user-defined chain of codecs when creating a Zarr array. *) type codec_chain = [ arraytoarray | arraytobytes | bytestobytes ] list diff --git a/test/test_codecs.ml b/test/test_codecs.ml index ac3e65ab..5d439caf 100644 --- a/test/test_codecs.ml +++ b/test/test_codecs.ml @@ -48,8 +48,8 @@ let tests = [ let shard_cfg = {chunk_shape = [|2; 5; 5|] ;index_location = End - ;index_codecs = {a2a = []; a2b = `Bytes Little; b2b = [`Crc32c]} - ;codecs = {a2a = [`Transpose [|0; 1; 2|]]; a2b = `Bytes Big; b2b = [`Gzip L1]}} + ;index_codecs = [`Bytes Little; `Crc32c] + ;codecs = [`Transpose [|0; 1; 2|]; `Bytes Big; `Gzip L1]} in let chain = [`Transpose [|2; 1; 0; 3|]; `ShardingIndexed shard_cfg; `Crc32c; `Gzip L9] @@ -323,11 +323,8 @@ let tests = [ let cfg = {chunk_shape = [|3; 5; 5|] ;index_location = Start - ;index_codecs = - {a2a = [] - ;a2b = `Bytes Little - ;b2b = [`Crc32c]} - ;codecs = {a2a = []; a2b = `Bytes Big; b2b = []}} + ;index_codecs = [`Bytes Little; `Crc32c] + ;codecs = [`Bytes Big]} in let chain = [`ShardingIndexed cfg] in (*test failure for chunk shape not evenly dividing shard. *) @@ -368,11 +365,9 @@ let tests = [ (* test if including a transpose codec for index_codec chain results in a failure. *) let chain' = - [`ShardingIndexed - {cfg with + [`ShardingIndexed {cfg with chunk_shape = [|5; 3; 5|] - ;index_codecs = - {cfg.index_codecs with a2a = [`Transpose [|0; 3; 1; 2|]]}}] + ;index_codecs = `Transpose [|0; 3; 1; 2|] :: cfg.index_codecs}] in let cc = Chain.create decoded_repr chain' |> Result.get_ok in assert_bool