Skip to content

Commit

Permalink
Use list literals to specify shard config's codec chain.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed Jul 13, 2024
1 parent 71216f8 commit 5189894
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 81 deletions.
10 changes: 2 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,8 @@ R[73,1] -INF -INF -INF -INF -INF -INF *)
```ocaml
let config =
{chunk_shape = [|5; 3; 5|]
;codecs =
{a2a = [`Transpose [|2; 0; 1|]]
;a2b = `Bytes Little
;b2b = [`Gzip L5]}
;index_codecs =
{a2a = []
;a2b = `Bytes Big
;b2b = [`Crc32c]}
;codecs = [`Transpose [|2; 0; 1|]; `Bytes Little; `Gzip L5]
;index_codecs = [`Bytes Big; `Crc32c]
;index_location = Start};;
let shard_node = Result.get_ok @@ ArrayNode.(group_node / "another");;
Expand Down
45 changes: 22 additions & 23 deletions lib/codecs/array_to_bytes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ end
module rec ArrayToBytes : sig
val parse
: ('a, 'b) Util.array_repr ->
arraytobytes ->
array_tobytes ->
(unit, [> error]) result
val compute_encoded_size : int -> arraytobytes -> int
val default : arraytobytes
val compute_encoded_size : int -> array_tobytes -> int
val default : array_tobytes
val encode
: ('a, 'b) Ndarray.t ->
arraytobytes ->
array_tobytes ->
(string, [> error]) result
val decode
: string ->
('a, 'b) Util.array_repr ->
arraytobytes ->
array_tobytes ->
(('a, 'b) Ndarray.t, [> error]) result
val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result
val to_yojson : arraytobytes -> Yojson.Safe.t
val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result
val to_yojson : array_tobytes -> Yojson.Safe.t
end = struct

let default = `Bytes Little
Expand All @@ -120,7 +120,7 @@ end = struct
let encode
: type a b.
(a, b) Ndarray.t ->
arraytobytes ->
array_tobytes ->
(string, [> error]) result
= fun x -> function
| `Bytes endian -> BytesCodec.encode x endian
Expand All @@ -130,7 +130,7 @@ end = struct
: type a b.
string ->
(a, b) Util.array_repr ->
arraytobytes ->
array_tobytes ->
((a, b) Ndarray.t, [> error]) result
= fun b repr -> function
| `Bytes endian -> BytesCodec.decode b repr endian
Expand All @@ -150,7 +150,7 @@ end = struct
end

and ShardingIndexedCodec : sig
type t = shard_config
type t = internal_shard_config
val parse
: ('a, 'b) Util.array_repr ->
t ->
Expand All @@ -169,7 +169,7 @@ and ShardingIndexedCodec : sig
val to_yojson : t -> Yojson.Safe.t
end = struct

type t = shard_config
type t = internal_shard_config

let parse_chain repr chain =
List.fold_left
Expand All @@ -183,7 +183,7 @@ end = struct
let parse
: type a b.
(a, b) Util.array_repr ->
shard_config ->
internal_shard_config ->
(unit, [> error]) result
= fun repr t ->
(match Array.(length repr.shape = length t.chunk_shape) with
Expand Down Expand Up @@ -220,7 +220,7 @@ end = struct

let rec encode_chain
: type a b.
bytestobytes shard_chain ->
bytestobytes internal_chain ->
(a, b) Ndarray.t ->
(string, [> error]) result
= fun t x ->
Expand All @@ -234,7 +234,7 @@ end = struct
and encode
: type a b.
(a, b) Ndarray.t ->
shard_config ->
internal_shard_config ->
(string, [> error]) result
= fun x t ->
let open Util in
Expand Down Expand Up @@ -293,7 +293,7 @@ end = struct
offset := Int64.add !offset nbytes) cindices (Ok ())
>>= fun () ->
(* convert t.index_codecs to a generic bytes-to-bytes chain. *)
encode_chain (t.index_codecs :> bytestobytes shard_chain) shard_idx
encode_chain (t.index_codecs :> bytestobytes internal_chain) shard_idx
>>| fun b' ->
match t.index_location with
| Start ->
Expand All @@ -307,7 +307,7 @@ end = struct

let rec decode_chain
: type a b.
bytestobytes shard_chain ->
bytestobytes internal_chain ->
string ->
(a, b) Util.array_repr ->
((a, b) Ndarray.t, [> error]) result
Expand All @@ -330,7 +330,7 @@ end = struct
and decode_index
: string ->
int array ->
shard_config ->
internal_shard_config ->
((int64, Bigarray.int64_elt) Ndarray.t * string, [> error]) result
= fun b shard_shape t ->
let open Util in
Expand All @@ -347,10 +347,9 @@ end = struct
;kind = Bigarray.Int64
;shape = Array.append cps [|2|]}
in
decode_chain
(t.index_codecs : fixed_bytestobytes shard_chain :> bytestobytes shard_chain)
b' repr >>| fun decoded ->
(decoded, rest)
decode_chain (t.index_codecs :> bytestobytes internal_chain) b' repr
>>= fun decoded ->
Ok (decoded, rest)

and index_size t cps =
compute_encoded_size (16 * Util.prod cps) t
Expand Down Expand Up @@ -413,7 +412,7 @@ end = struct
and to_yojson t =
let codecs = chain_to_yojson t.codecs in
let index_codecs =
chain_to_yojson (t.index_codecs :> bytestobytes shard_chain)
chain_to_yojson (t.index_codecs :> bytestobytes internal_chain)
in
let index_location =
match t.index_location with
Expand All @@ -435,7 +434,7 @@ end = struct
("codecs", codecs)])]

let rec chain_of_yojson :
Yojson.Safe.t list -> (bytestobytes shard_chain, string) result
Yojson.Safe.t list -> (bytestobytes internal_chain, string) result
= fun codecs ->
let filter_partition f encoded =
List.fold_right (fun c (l, r) ->
Expand Down
14 changes: 7 additions & 7 deletions lib/codecs/array_to_bytes.mli
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ module Ndarray = Owl.Dense.Ndarray.Generic
module ArrayToBytes : sig
val parse
: ('a, 'b) Util.array_repr ->
arraytobytes ->
array_tobytes ->
(unit, [> error]) result
val compute_encoded_size : int -> arraytobytes -> int
val default : arraytobytes
val compute_encoded_size : int -> array_tobytes -> int
val default : array_tobytes
val encode
: ('a, 'b) Ndarray.t ->
arraytobytes ->
array_tobytes ->
(string, [> error]) result
val decode
: string ->
('a, 'b) Util.array_repr ->
arraytobytes ->
array_tobytes ->
(('a, 'b) Ndarray.t, [> error]) result
val of_yojson : Yojson.Safe.t -> (arraytobytes, string) result
val to_yojson : arraytobytes -> Yojson.Safe.t
val of_yojson : Yojson.Safe.t -> (array_tobytes, string) result
val to_yojson : array_tobytes -> Yojson.Safe.t
end
59 changes: 48 additions & 11 deletions lib/codecs/codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,64 @@ open Util.Result_syntax

include Codecs_intf

type internal_chain =
{a2a : arraytoarray list
;a2b : arraytobytes
;b2b : bytestobytes list}
type arraytobytes =
[ `Bytes of endianness
| `ShardingIndexed of shard_config ]

and shard_config =
{chunk_shape : int array
;codecs :
[ arraytoarray
| `Bytes of endianness
| `ShardingIndexed of shard_config
| bytestobytes ] list
;index_codecs :
[ arraytoarray
| `Bytes of endianness
| `ShardingIndexed of shard_config
| fixed_bytestobytes ] list
;index_location : loc}

type codec_chain =
[ arraytoarray | arraytobytes | bytestobytes ] list

module Chain = struct
type t = internal_chain
type t = bytestobytes internal_chain

let create :
let rec create :
type a b. (a, b) Util.array_repr -> codec_chain -> (t, [> error ]) result
= fun repr cc ->
let a2a, rest = List.partition_map (function
| #arraytoarray as c -> Either.left c
| #arraytobytes as c -> Either.right c
| #bytestobytes as c -> Either.right c) cc
in
(match
List.partition_map (function
| #arraytobytes as c -> Either.left c
| #bytestobytes as c -> Either.right c) rest
with
List.fold_right
(fun c acc ->
acc >>= fun (l, r) ->
match c with
| `Bytes e -> Ok (`Bytes e :: l, r)
| `ShardingIndexed cfg ->
create repr cfg.codecs
>>= fun codecs ->
create
{repr with shape = Array.append repr.shape [|2|]}
(cfg.index_codecs :> codec_chain)
>>= fun index_codecs ->
(* convert to a fixed_bytestobytes internal_chain type *)
let b2b =
List.partition_map (function
| `Crc32c -> Either.left `Crc32c
| c -> Either.right c) index_codecs.b2b |> fst in

Check warning on line 56 in lib/codecs/codecs.ml

View check run for this annotation

Codecov / codecov/patch

lib/codecs/codecs.ml#L56

Added line #L56 was not covered by tests
let cfg' : internal_shard_config =
{chunk_shape = cfg.chunk_shape;
index_location = cfg.index_location;
index_codecs = {index_codecs with b2b};
codecs} in
Ok (`ShardingIndexed cfg' :: l, r)
| #bytestobytes as c -> Ok (l, c :: r)) rest (Ok ([], []))
>>= fun result ->
(match result with
| [x], rest -> Ok (x, rest)
| _ ->
Result.error @@ `CodecChain "Must be exactly one array->bytes codec.")
Expand Down
40 changes: 19 additions & 21 deletions lib/codecs/codecs_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,20 @@ type endianness = Little | Big

type loc = Start | End

type arraytobytes =
type array_tobytes =
[ `Bytes of endianness
| `ShardingIndexed of shard_config ]
| `ShardingIndexed of internal_shard_config ]

and shard_config =
and internal_shard_config =
{chunk_shape : int array
;codecs : bytestobytes shard_chain
;index_codecs : fixed_bytestobytes shard_chain
;codecs : bytestobytes internal_chain
;index_codecs : fixed_bytestobytes internal_chain
;index_location : loc}

and 'a shard_chain =
{a2a: arraytoarray list
;a2b: arraytobytes
;b2b: 'a list}

type codec_chain =
[ arraytoarray | arraytobytes | bytestobytes ] list
and 'a internal_chain =
{a2a : arraytoarray list
;a2b : array_tobytes
;b2b : 'a list}

type error =
[ `Extension of string
Expand Down Expand Up @@ -82,17 +79,18 @@ module type Interface = sig
(** A type representing the Sharding indexed codec's configuration parameters. *)
and shard_config =
{chunk_shape : int array
;codecs : bytestobytes shard_chain
;index_codecs : fixed_bytestobytes shard_chain
;codecs :
[ arraytoarray
| `Bytes of endianness
| `ShardingIndexed of shard_config
| bytestobytes ] list
;index_codecs :
[ arraytoarray
| `Bytes of endianness
| `ShardingIndexed of shard_config
| fixed_bytestobytes ] list
;index_location : loc}

(** A type representing the chain of codecs used to encode/decode
a shard's bytes and its index array. *)
and 'a shard_chain =
{a2a: arraytoarray list
;a2b: arraytobytes
;b2b: 'a list}

(** A type used to build a user-defined chain of codecs when creating a Zarr array. *)
type codec_chain =
[ arraytoarray | arraytobytes | bytestobytes ] list
Expand Down
17 changes: 6 additions & 11 deletions test/test_codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ let tests = [
let shard_cfg =
{chunk_shape = [|2; 5; 5|]
;index_location = End
;index_codecs = {a2a = []; a2b = `Bytes Little; b2b = [`Crc32c]}
;codecs = {a2a = [`Transpose [|0; 1; 2|]]; a2b = `Bytes Big; b2b = [`Gzip L1]}}
;index_codecs = [`Bytes Little; `Crc32c]
;codecs = [`Transpose [|0; 1; 2|]; `Bytes Big; `Gzip L1]}
in
let chain =
[`Transpose [|2; 1; 0; 3|]; `ShardingIndexed shard_cfg; `Crc32c; `Gzip L9]
Expand Down Expand Up @@ -323,11 +323,8 @@ let tests = [
let cfg =
{chunk_shape = [|3; 5; 5|]
;index_location = Start
;index_codecs =
{a2a = []
;a2b = `Bytes Little
;b2b = [`Crc32c]}
;codecs = {a2a = []; a2b = `Bytes Big; b2b = []}}
;index_codecs = [`Bytes Little; `Crc32c]
;codecs = [`Bytes Big]}
in
let chain = [`ShardingIndexed cfg] in
(*test failure for chunk shape not evenly dividing shard. *)
Expand Down Expand Up @@ -368,11 +365,9 @@ let tests = [
(* test if including a transpose codec for index_codec chain results in
a failure. *)
let chain' =
[`ShardingIndexed
{cfg with
[`ShardingIndexed {cfg with
chunk_shape = [|5; 3; 5|]
;index_codecs =
{cfg.index_codecs with a2a = [`Transpose [|0; 3; 1; 2|]]}}]
;index_codecs = `Transpose [|0; 3; 1; 2|] :: cfg.index_codecs}]
in
let cc = Chain.create decoded_repr chain' |> Result.get_ok in
assert_bool
Expand Down

0 comments on commit 5189894

Please sign in to comment.