-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathh5ds.py
143 lines (134 loc) · 4.97 KB
/
h5ds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import h5py
import numpy as np
import torch
from assert_eq import assert_eq
class H5DS:
def __init__(self, name, dtype, shape=(1,), extensible=False):
assert isinstance(name, str)
assert isinstance(shape, tuple)
assert all([isinstance(s, int) for s in shape])
assert isinstance(extensible, bool)
self.name = name
self.dtype = dtype
self.shape = shape
self.extensible = extensible
def create(self, h5file, value=None):
assert isinstance(h5file, h5py.File)
assert h5file
assert not self.exists(h5file)
if self.extensible:
if value is None:
shape = (0,) + self.shape
else:
if self.shape == (1,):
value = np.array([value], dtype=self.dtype)
value = value[np.newaxis, :]
shape = (1,) + self.shape
maxshape = (None,) + self.shape
chunks = (1,) + self.shape
else:
assert value is not None
shape = self.shape
maxshape = self.shape
chunks = self.shape
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
elif self.shape == (1,):
value = np.array([value], dtype=self.dtype)
assert isinstance(value, np.ndarray)
assert_eq(value.dtype, self.dtype)
assert_eq(value.shape, self.shape)
ds = h5file.create_dataset(
name=self.name,
shape=shape,
maxshape=maxshape,
chunks=chunks,
dtype=self.dtype,
compression="gzip",
compression_opts=9,
)
if value is not None:
ds[...] = value
def read(self, h5file, index=None):
assert isinstance(h5file, h5py.File)
assert h5file
assert self.exists(h5file)
ds = h5file[self.name]
if self.extensible:
assert isinstance(index, int)
if index < 0 or index >= ds.shape[0]:
raise StopIteration()
if self.shape == (1,):
return ds[index, 0]
return np.array(ds[index])
else:
assert index is None
if self.shape == (1,):
return ds[0]
return np.array(ds)
def write(self, h5file, value, index=None):
assert isinstance(h5file, h5py.File)
assert h5file
if self.shape == (1,):
assert isinstance(value, int) or isinstance(value, float)
value = np.array([value], dtype=self.dtype)
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
assert isinstance(value, np.ndarray)
assert_eq(value.dtype, self.dtype)
assert_eq(value.shape, self.shape)
assert self.exists(h5file)
ds = h5file[self.name]
if self.extensible:
assert isinstance(index, int)
if index < 0 or index >= ds.shape[0]:
raise StopIteration()
ds[index] = value
else:
assert index is None
ds[...] = value
def append(self, h5file, value):
assert isinstance(h5file, h5py.File)
assert h5file
if self.shape == (1,):
assert isinstance(value, int) or isinstance(value, float)
value = np.array([value], dtype=self.dtype)
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
assert isinstance(value, np.ndarray)
assert_eq(value.dtype, self.dtype)
assert_eq(value.shape, self.shape)
assert self.extensible
assert self.exists(h5file)
ds = h5file[self.name]
N = ds.shape[0]
ds.resize(size=(N + 1), axis=0)
ds[-1] = value
def count(self, h5file):
assert isinstance(h5file, h5py.File)
assert h5file
assert self.extensible
ds = h5file[self.name]
return ds.shape[0]
def exists(self, h5file):
assert isinstance(h5file, h5py.File)
assert h5file
if self.name not in h5file.keys():
return False
ds = h5file[self.name]
if ds.dtype != self.dtype:
raise Exception(
f"Incorrect dtype found in HDF5 dataset. Expected '{self.name}' to have type {self.dtype} but found {ds.dtype} instead."
)
if self.extensible:
N = ds.shape[0]
if ds.shape[1:] != self.shape:
raise Exception(
f"Incorrect shape found in HDF5 dataset. Expected '{self.name}' to have extensible shape N*{'*'.join(self.shape)} but found {'*'.join(ds.shape)} instead."
)
else:
if ds.shape != self.shape:
raise Exception(
f"Incorrect shape found in HDF5 dataset. Expected '{self.name}' to have inextensible shape {'*'.join(self.shape)} but found {'*'.join(ds.shape)} instead."
)
return True