Skip to content

Commit

Permalink
Merge branch 'master' of github.com:keras-team/keras
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Mar 16, 2024
2 parents c1dfba3 + f3a01a7 commit 838c7da
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 87 deletions.
1 change: 1 addition & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ jobs:
env_vars: PYTHON,KERAS_HOME
flags: keras,keras-${{ matrix.backend }}
files: core-coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false

format:
Expand Down
50 changes: 50 additions & 0 deletions keras/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from keras.api_export import keras_export
from keras.utils import tree

NUM_BATCHES_FOR_TENSOR_SPEC = 2


@keras_export("keras.utils.unpack_x_y_sample_weight")
def unpack_x_y_sample_weight(data):
Expand Down Expand Up @@ -125,6 +127,54 @@ def class_weight_to_sample_weights(y, class_weight):
return sample_weight


def get_tensor_spec(batches):
"""Return the common tensor spec for a list of batches.
Args:
batches: list of structures of tensors. The structures must be
identical, but the shape at each leaf may be different.
Returns: the common tensor spec for all the batches.
"""
from keras.utils.module_utils import tensorflow as tf

def get_single_tensor_spec(*tensors):
x = tensors[0]
rank = len(x.shape)
if rank < 1:
raise ValueError(
"When passing a dataset to a Keras model, the arrays must "
f"be at least rank 1. Received: {x} of rank {len(x.shape)}."
)
for t in tensors:
if len(t.shape) != rank:
raise ValueError(
"When passing a dataset to a Keras model, the "
"corresponding arrays in each batch must have the same "
f"rank. Received: {x} and {t}"
)
shape = []
# Merge shapes: go through each dimension one by one and keep the
# common values
for dims in zip(*[list(x.shape) for x in tensors]):
dims_set = set(dims)
shape.append(dims_set.pop() if len(dims_set) == 1 else None)
shape[0] = None # batch size may not be static

dtype = backend.standardize_dtype(x.dtype)
if isinstance(x, tf.RaggedTensor):
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
if (
isinstance(x, tf.SparseTensor)
or is_scipy_sparse(x)
or is_jax_sparse(x)
):
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
else:
return tf.TensorSpec(shape=shape, dtype=dtype)

return tree.map_structure(get_single_tensor_spec, *batches)


def get_jax_iterator(iterable):
from keras.backend.jax.core import convert_to_tensor

Expand Down
51 changes: 13 additions & 38 deletions keras/trainers/data_adapters/generator_data_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools

from keras import backend
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.data_adapters.data_adapter import DataAdapter
from keras.utils import tree
Expand All @@ -10,49 +9,19 @@ class GeneratorDataAdapter(DataAdapter):
"""Adapter for Python generators."""

def __init__(self, generator):
first_batch, generator = peek_and_restore(generator)
first_batches, generator = peek_and_restore(generator)
self.generator = generator
self._first_batch = first_batch
self._first_batches = first_batches
self._output_signature = None
if not isinstance(first_batch, tuple):
if not isinstance(first_batches[0], tuple):
raise ValueError(
"When passing a Python generator to a Keras model, "
"the generator must return a tuple, either "
"(input,) or (inputs, targets) or "
"(inputs, targets, sample_weights). "
f"Received: {first_batch}"
f"Received: {first_batches[0]}"
)

def _set_tf_output_signature(self):
from keras.utils.module_utils import tensorflow as tf

def get_tensor_spec(x):
shape = x.shape
if len(shape) < 1:
raise ValueError(
"When passing a Python generator to a Keras model, "
"the arrays returned by the generator "
"must be at least rank 1. Received: "
f"{x} of rank {len(x.shape)}"
)
shape = list(shape)
shape[0] = None # The batch size is not guaranteed to be static.
dtype = backend.standardize_dtype(x.dtype)
if isinstance(x, tf.RaggedTensor):
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
if (
isinstance(x, tf.SparseTensor)
or data_adapter_utils.is_scipy_sparse(x)
or data_adapter_utils.is_jax_sparse(x)
):
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
else:
return tf.TensorSpec(shape=shape, dtype=dtype)

self._output_signature = tree.map_structure(
get_tensor_spec, self._first_batch
)

def get_numpy_iterator(self):
return data_adapter_utils.get_numpy_iterator(self.generator)

Expand Down Expand Up @@ -85,7 +54,9 @@ def get_tf_iterator():
yield batch

if self._output_signature is None:
self._set_tf_output_signature()
self._output_signature = data_adapter_utils.get_tensor_spec(
self._first_batches
)
ds = tf.data.Dataset.from_generator(
get_tf_iterator,
output_signature=self._output_signature,
Expand All @@ -106,5 +77,9 @@ def batch_size(self):


def peek_and_restore(generator):
element = next(generator)
return element, itertools.chain([element], generator)
batches = list(
itertools.islice(
generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
)
)
return batches, itertools.chain(batches, generator)
35 changes: 35 additions & 0 deletions keras/trainers/data_adapters/generator_data_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,41 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type):
sample_order.append(by[i, 0])
self.assertAllClose(sample_order, list(range(34)))

@parameterized.named_parameters(
named_product(iterator_type=["np", "tf", "jax", "torch"])
)
def test_with_different_shapes(self, iterator_type):
def generator():
yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32")
yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32")
yield np.ones([2, 6], "float32"), np.ones([2, 2], "float32")

adapter = generator_data_adapter.GeneratorDataAdapter(generator())

if iterator_type == "np":
it = adapter.get_numpy_iterator()
elif iterator_type == "tf":
it = adapter.get_tf_dataset()
elif iterator_type == "jax":
it = adapter.get_jax_iterator()
elif iterator_type == "torch":
it = adapter.get_torch_dataloader()

for i, batch in enumerate(it):
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertEqual(bx.dtype, by.dtype)
self.assertContainsExactSubsequence(str(bx.dtype), "float32")
if i == 0:
self.assertEqual(bx.shape, (16, 4))
self.assertEqual(by.shape, (16, 2))
elif i == 1:
self.assertEqual(bx.shape, (16, 5))
self.assertEqual(by.shape, (16, 2))
else:
self.assertEqual(bx.shape, (2, 6))
self.assertEqual(by.shape, (2, 2))

@parameterized.named_parameters(
named_product(
generator_type=["tf", "jax", "scipy"], iterator_type=["tf", "jax"]
Expand Down
34 changes: 9 additions & 25 deletions keras/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@

import numpy as np

from keras import backend
from keras.api_export import keras_export
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.data_adapters.data_adapter import DataAdapter
from keras.utils import tree


@keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"])
Expand Down Expand Up @@ -188,28 +186,6 @@ def __init__(
self.shuffle = shuffle
self._output_signature = None

def _set_tf_output_signature(self):
from keras.utils.module_utils import tensorflow as tf

def get_tensor_spec(x):
shape = x.shape
if len(shape) < 1:
raise ValueError(
"The arrays returned by PyDataset.__getitem__() "
"must be at least rank 1. Received: "
f"{x} of rank {len(x.shape)}"
)
shape = list(shape)
shape[0] = None # The batch size is not guaranteed to be static.
dtype = backend.standardize_dtype(x.dtype)
return tf.TensorSpec(shape=shape, dtype=dtype)

# Grab the first example
batch = self.py_dataset[0]
# Run checks on it and format it
batch = self._standardize_batch(batch)
self._output_signature = tree.map_structure(get_tensor_spec, batch)

def _standardize_batch(self, batch):
if isinstance(batch, dict):
return batch
Expand Down Expand Up @@ -287,7 +263,15 @@ def get_tf_dataset(self):
from keras.utils.module_utils import tensorflow as tf

if self._output_signature is None:
self._set_tf_output_signature()
num_samples = min(
data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
len(self.py_dataset),
)
batches = [
self._standardize_batch(self.py_dataset[i])
for i in range(num_samples)
]
self._output_signature = data_adapter_utils.get_tensor_spec(batches)

ds = tf.data.Dataset.from_generator(
self._get_iterator,
Expand Down
51 changes: 51 additions & 0 deletions keras/trainers/data_adapters/py_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,54 @@ def test_dict_inputs(self):
self.assertEqual(bx.dtype, by.dtype)
self.assertEqual(tuple(bx.shape), (4, 4))
self.assertEqual(tuple(by.shape), (4, 2))

@parameterized.named_parameters(
named_product(iterator_type=["np", "tf", "jax", "torch"])
)
def test_with_different_shapes(self, iterator_type):

class TestPyDataset(py_dataset_adapter.PyDataset):
def __len__(self):
return 3

def __getitem__(self, idx):
if idx == 0:
return np.ones([16, 4], "float32"), np.ones(
[16, 2], "float32"
)
if idx == 1:
return np.ones([16, 5], "float32"), np.ones(
[16, 2], "float32"
)
else:
return np.ones([2, 6], "float32"), np.ones(
[2, 2], "float32"
)

adapter = py_dataset_adapter.PyDatasetAdapter(
TestPyDataset(), shuffle=False
)

if iterator_type == "np":
it = adapter.get_numpy_iterator()
elif iterator_type == "tf":
it = adapter.get_tf_dataset()
elif iterator_type == "jax":
it = adapter.get_jax_iterator()
elif iterator_type == "torch":
it = adapter.get_torch_dataloader()

for i, batch in enumerate(it):
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertEqual(bx.dtype, by.dtype)
self.assertContainsExactSubsequence(str(bx.dtype), "float32")
if i == 0:
self.assertEqual(bx.shape, (16, 4))
self.assertEqual(by.shape, (16, 2))
elif i == 1:
self.assertEqual(bx.shape, (16, 5))
self.assertEqual(by.shape, (16, 2))
else:
self.assertEqual(bx.shape, (2, 6))
self.assertEqual(by.shape, (2, 2))
38 changes: 14 additions & 24 deletions keras/trainers/data_adapters/torch_data_loader_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools

import numpy as np

from keras import backend
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.data_adapters.data_adapter import DataAdapter
from keras.utils import tree
Expand All @@ -19,6 +20,7 @@ def __init__(self, dataloader):
)

self._dataloader = dataloader
self._output_signature = None
self._batch_size = dataloader.batch_size
self._num_batches = None
self._partial_batch_size = None
Expand All @@ -44,36 +46,24 @@ def get_jax_iterator(self):
def get_tf_dataset(self):
from keras.utils.module_utils import tensorflow as tf

output_signature = self.peek_and_get_tensor_spec()
if self._output_signature is None:
batches = list(
itertools.islice(
self._dataloader,
data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
)
)
self._output_signature = tuple(
data_adapter_utils.get_tensor_spec(batches)
)
return tf.data.Dataset.from_generator(
self.get_numpy_iterator,
output_signature=output_signature,
output_signature=self._output_signature,
)

def get_torch_dataloader(self):
return self._dataloader

def peek_and_get_tensor_spec(self):
from keras.utils.module_utils import tensorflow as tf

batch_data = next(iter(self._dataloader))

def get_tensor_spec(x):
shape = x.shape
if len(shape) < 1:
raise ValueError(
"When passing a Pytorch DataLoader to a Keras model, "
"the arrays returned by the generator "
"must be at least rank 1. Received: "
f"{x} of rank {len(x.shape)}"
)
shape = list(shape)
shape[0] = None # The batch size is not guaranteed to be static.
dtype = backend.standardize_dtype(x.dtype)
return tf.TensorSpec(shape=shape, dtype=dtype)

return tuple(tree.map_structure(get_tensor_spec, batch_data))

@property
def num_batches(self):
return self._num_batches
Expand Down
Loading

0 comments on commit 838c7da

Please sign in to comment.