Skip to content

Commit

Permalink
Ignore file indexing for native data loader. (#215)
Browse files Browse the repository at this point in the history
* Ignore file indexing for native data loader.

The sample building and native data loader case is needed only for DLIO created data loaders. For native data loaders which provide their own API;s they provide their own indexing and there this sampling can be ignored.

* Remove shuffling from tfreader as it is already correctly done during reconfigure.
  • Loading branch information
hariharan-devarajan authored Aug 13, 2024
1 parent 5aec234 commit 866828c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
2 changes: 1 addition & 1 deletion dlio_benchmark/data_loader/tf_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def read(self):
self._dataset = ReaderFactory.get_reader(type=self.format_type,
dataset_type=self.dataset_type,
thread_index=-1,
epoch_number=0).next()
epoch_number=self.epoch_number).next()

@dlp.log
def next(self):
Expand Down
7 changes: 4 additions & 3 deletions dlio_benchmark/reader/tf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ def _parse_image(self, serialized):

@dlp.log
def next(self):
logging.debug(f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}")
filenames = tf.data.Dataset.list_files(self._file_list, shuffle=True)
logging.debug(f"{utcnow()} Reading {self._file_list} files thread {self.thread_index} rank {self._args.my_rank}")
filenames = tf.data.Dataset.list_files(self._file_list, shuffle=False)
# sharding in the file list if we have enought files.
if (len(self._file_list) >= self._args.comm_size):
filenames = filenames.shard(num_shards=self._args.comm_size, index=self._args.my_rank)

logging.debug(f"{utcnow()} shard {filenames} files index {self._args.my_rank} number {self._args.comm_size}")

self._dataset = tf.data.TFRecordDataset(filenames=filenames, buffer_size=self._args.transfer_size,
num_parallel_reads=self._args.read_threads)

Expand Down
43 changes: 27 additions & 16 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class ConfigArguments:
data_loader_class = None
reader_class = None
checkpoint_mechanism_class = None
native_data_loader = False

def __init__(self):
""" Virtually private constructor. """
Expand Down Expand Up @@ -300,6 +301,13 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None):
logging.info(f"Discovered custom data reader {class_name}")
self.reader_class = obj
break
self.native_data_loader = False
if self.data_loader == DataLoaderType.TENSORFLOW:
if self.format == FormatType.TFRECORD:
self.native_data_loader = True
elif self.data_loader == DataLoaderType.NATIVE_DALI:
if self.format in [FormatType.JPEG, FormatType.PNG, FormatType.NPY, FormatType.TFRECORD]:
self.native_data_loader = True

@dlp.log
def build_sample_map_iter(self, file_list, total_samples, epoch_number):
Expand Down Expand Up @@ -363,24 +371,27 @@ def get_global_map_index(self, file_list, total_samples):
def reconfigure(self, epoch_number, dataset_type):
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if self.file_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
if self.file_shuffle is Shuffle.SEED:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle(
self.file_list_eval)
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if dataset_type is DatasetType.TRAIN:
global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train,
epoch_number)
else:
global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number)
self.file_map = global_file_map[self.my_rank]
elif self.data_loader_sampler == DataLoaderSampler.INDEX:
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train)
else:
self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval)
# the code assumes that file and sample shuffling is handled by the native data loader code.
if not self.native_data_loader:
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if dataset_type is DatasetType.TRAIN:
global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train,
epoch_number)
else:
global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number)
self.file_map = global_file_map[self.my_rank]
elif self.data_loader_sampler == DataLoaderSampler.INDEX:
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train)
else:
self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval)


def LoadConfig(args, config):
Expand Down

0 comments on commit 866828c

Please sign in to comment.