Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch writer and reader support #15

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/data/io/test_batch_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from yann.data.io.batch_files import BatchWriter
import torch
import numpy as np
import pytest


def test_batch_write_kwargs():
with BatchWriter('asfasd') as write:
for n in range(20):
write.batch(
ids=list(range(10)),
targets=np.random.randn(10),
outputs=torch.rand(10, 12),
)


def test_batch_write_args():
with BatchWriter('asfasd', names=('id', 'target', 'output')) as write:
for n in range(20):
write.batch(
list(range(10)),
np.random.randn(10),
torch.rand(10, 12),
)

with pytest.raises(ValueError, 'names and encoders must be same length'):
bw = BatchWriter('asfsd', names=(1,2,3), encoders=(1,2))


def test_meta():
BatchWriter(path=lambda x: 'foo', meta={
'checkpoint_id': 'asfads',
'dataset': 'MNIST'
})
Empty file added tests/data/storage/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions tests/data/storage/test_batch_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from yann.data.storage.batch_files import BatchWriter, PartitionedBatchWriter, BatchReader
import numpy as np
import torch
import pathlib

import yann

def same_type(*items):
return all(type(x) == type(items[0]) for x in items)


def test_pickle(tmpdir: pathlib.Path):
path = tmpdir / 'batches.pkl'
w = BatchWriter(path, names=('ids', 'targets', 'outputs', 'paths'))

batches = []

for i in range(10):
batches.append((
list(range(10)),
torch.zeros(10, 12),
torch.rand(10, 12),
[f"{i}-{n}.jpg" for n in range(10)]
))

w.batch(*batches[-1])

w.close()

assert path.exists()
# assert path.stat().st_size > 400

assert w.meta_path.exists()

assert w.path == path

loaded_batches = yann.load(w.path)

assert len(loaded_batches) == 10




def test_use_case(tmpdir):
model = torch.nn.Module()

w = BatchWriter(tmpdir / 'MNIST-preds.pkl')

iw = BatchWriter(tmpdir/'inputs.pkl')

for inputs, targets in iw.through(yann.batches('MNIST', size=32, workers=10, transform=())):
preds = model(inputs)
w.batch(
targets=targets,
preds=preds
)
w.close()

processed = 0
correct = 0
r = BatchReader(w.path)
for batch in r.batches():
processed += len(batch['targets'])
correct += sum(batch['targets'] == batch['preds'])


16 changes: 16 additions & 0 deletions tests/data/storage/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from yann.data.storage.parquet import BatchParquetFileWriter

import torch



def test_parquet_batch_writer(tmpdir):
path = tmpdir / 'test.parquet'
with BatchParquetFileWriter(path) as write:
for i in range(10):
write.batch(
ids=list(range(10)),
labels=torch.ones(10, 12)
)

assert path.exists()
59 changes: 40 additions & 19 deletions yann/data/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,57 @@
import csv
import gzip
from pathlib import Path

import torch


class Loader:
"""

gs://bucket/file.th
./foo/**/*.jpg

Args:
path:

Returns:
"""
def __call__(self, path, **kwargs):

def __call__(self, path, format=None, deserialize=None, filesystem=None, **kwargs):
path = Path(path)
if hasattr(self, path.suffix):
return getattr(self, path.suffix)(**kwargs)
format = format or path.suffix[1:]
if hasattr(self, format):
return getattr(self, format)(str(path), **kwargs)
raise ValueError(f'File format not supported ({format})')

def th(self, path, **kwargs):
return torch.load(path, **kwargs)

def csv(self):
pass
def json(self, path, **kwargs):
return load_json(path, **kwargs)

def json(self):
pass
def pickle(self, path, **kwargs):
return load_pickle(path, **kwargs)

def jsonlines(self):
pass
pkl = pickle


load = Loader()


class Saver:
def __call__(self, x, path, **kwargs):
pass
def __call__(
self, x, path, format=None, serialize=None, filesystem=None, **kwargs
):
path = Path(path)
format = format or path.suffix[1:]
if hasattr(self, format):
return getattr(self, format)(x, path, **kwargs)
raise ValueError(f'File format not supported ({format})')

def th(self, x, path, **kwargs):
return torch.save(x, path, **kwargs)

def json(self, x, path, **kwargs):
return save_json(x, path, **kwargs)

def pickle(self, x, path, **kwargs):
return save_pickle(x, path, **kwargs)

pkl = pickle


save = Saver()
Expand Down Expand Up @@ -95,6 +111,12 @@ def untar(path):
tar.extractall()


def unzip(zip, dest):
import zipfile
with zipfile.ZipFile(zip, 'r') as f:
f.extractall(dest)


def iter_csv(path, header=True, tuples=True, sep=',', quote='"', **kwargs):
with open(path) as f:
reader = csv.reader(f, delimiter=sep, quotechar=quote, **kwargs)
Expand Down Expand Up @@ -122,4 +144,3 @@ def write_csv(data, path, header=None):
writer.writerow(header)
for row in data:
writer.writerow(row)

183 changes: 183 additions & 0 deletions yann/data/storage/batch_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from collections import defaultdict
import os.path
from pathlib import Path
import datetime
import yann

import torch
from ...utils import fully_qulified_name, timestr
from ...utils.ids import memorable_id

"""
TODO:
- pickle
- parquet
- numpy
- csv
- json lines
- hdf5
- tfrecords
- lmdb
"""


class BatchWriter:
def __init__(self, path, encoders=None, names=None, meta=None):
self.path = path

if isinstance(encoders, (list, tuple)):
if not names:
raise ValueError('Names must be provided if encoders are a tuple')
if len(encoders) != len(names):
raise ValueError('names and encoders must be the same length if provided as tuples')
encoders = dict(zip(names, encoders))

self.encoders = encoders or {}
self.names = names
self.writer = None
self.buffers = defaultdict(list)

self.meta = meta or {}
self.meta['encoders'] = {
k: {
'path': fully_qulified_name(v),
'name': getattr(v, '__name__', None)
} for k, v in self.encoders.items()
}
self.meta['time_created'] = timestr()
self.meta['write_id'] = memorable_id()
self.meta['path'] = str(self.path)
self.save_meta()

def encode_batch(self, *args, **kwargs):
if args:
items = zip(self.names, args)
else:
items = kwargs.items()

data = {}
for k, v in items:
if self.encoders and k in self.encoders:
v = self.encoders[k][v]
elif torch.is_tensor(v):
v = v.detach().cpu()

data[k] = v
return data

def batch(self, *args, **kwargs):
data = self.encode_batch(*args, **kwargs)
for k, v in data.items():
self.buffers[k].append(v)

def through(self, batches):
for b in batches:
if isinstance(b, (tuple, list)):
self.batch(*b)
else:
self.batch(**b)
yield b

def all(self, batches):
for b in batches:
if isinstance(b, (tuple, list)):
self.batch(*b)
else:
self.batch(**b)

def collate(self, buffers):
return buffers

def flush(self):
self._write()
self._wipe_buffers()

@property
def meta_path(self) -> Path:
return Path(self.path).parent / f"writer-meta.json"

def save_meta(self):
yann.save(self.meta, self.meta_path)

def close(self):
self.flush()
if self.writer and hasattr(self.writer, 'close'):
self.writer.close()

def __enter__(self):
return self

def __exit__(self):
self.close()

def _wipe_buffers(self):
for k in self.buffers:
self.buffers[k] = []

def _write(self):
Path(self.path).parent.mkdir(parents=True, exist_ok=True)
collated = self.collate(self.buffers)
self._save(dict(collated), self.path)

def _save(self, data, path):
yann.save(data, path)

def _num_buffered_batches(self):
return len(next(iter(self.buffers.values())))


class BatchStreamWriter(BatchWriter):
pass

class PartitionedBatchWriter(BatchWriter):
def __init__(self, path, batches_per_file=256, encoders=None, names=None, meta=None):
super().__init__(path, encoders=encoders, names=names, meta=meta)

self.part = 0
self.batches_per_file = batches_per_file

def batch(self, *args, **kwargs):
super().batch(*args, **kwargs)
if self._num_buffered_batches() >= self.batches_per_file:
self.flush()

def get_part_path(self, part):
if callable(self.path):
return self.path(part=part, batches=self.buffers)
elif '{' in self.path and '}' in self.path:
return self.path.format(
part=part,
time=datetime.datetime.utcnow()
)
else:
name, ext = os.path.splitext(self.path)
return f"{name}-{part}{ext}"

def _write(self):
path = self.get_part_path(self.part)
Path(path).parent.mkdir(parents=True, exist_ok=True)
collated = self.collate(self.buffers)
self._save(dict(collated), path)



class BatchReader:
def __init__(self, path):
pass

def batches(self):
pass

def samples(self):
pass

def __iter__(self):
return self.batches()


def writer() -> BatchWriter:
raise NotImplementedError()


def reader() -> BatchReader:
raise NotImplementedError()
Loading