diff --git a/dlio_benchmark/data_loader/tf_data_loader.py b/dlio_benchmark/data_loader/tf_data_loader.py index 162e6c1f..695c4523 100644 --- a/dlio_benchmark/data_loader/tf_data_loader.py +++ b/dlio_benchmark/data_loader/tf_data_loader.py @@ -95,7 +95,7 @@ def read(self): self._dataset = ReaderFactory.get_reader(type=self.format_type, dataset_type=self.dataset_type, thread_index=-1, - epoch_number=0).next() + epoch_number=self.epoch_number).next() @dlp.log def next(self): diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index ce37b925..dc496039 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -81,12 +81,13 @@ def _parse_image(self, serialized): @dlp.log def next(self): - logging.debug(f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}") - filenames = tf.data.Dataset.list_files(self._file_list, shuffle=True) + logging.debug(f"{utcnow()} Reading {self._file_list} files thread {self.thread_index} rank {self._args.my_rank}") + filenames = tf.data.Dataset.list_files(self._file_list, shuffle=False) # sharding in the file list if we have enought files. if (len(self._file_list) >= self._args.comm_size): filenames = filenames.shard(num_shards=self._args.comm_size, index=self._args.my_rank) - + logging.debug(f"{utcnow()} shard {filenames} files index {self._args.my_rank} number {self._args.comm_size}") + self._dataset = tf.data.TFRecordDataset(filenames=filenames, buffer_size=self._args.transfer_size, num_parallel_reads=self._args.read_threads) diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 8e25b148..a0264553 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -138,6 +138,7 @@ class ConfigArguments: data_loader_class = None reader_class = None checkpoint_mechanism_class = None + native_data_loader = False def __init__(self): """ Virtually private constructor. """ @@ -300,6 +301,13 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): logging.info(f"Discovered custom data reader {class_name}") self.reader_class = obj break + self.native_data_loader = False + if self.data_loader == DataLoaderType.TENSORFLOW: + if self.format == FormatType.TFRECORD: + self.native_data_loader = True + elif self.data_loader == DataLoaderType.NATIVE_DALI: + if self.format in [FormatType.JPEG, FormatType.PNG, FormatType.NPY, FormatType.TFRECORD]: + self.native_data_loader = True @dlp.log def build_sample_map_iter(self, file_list, total_samples, epoch_number): @@ -363,24 +371,27 @@ def get_global_map_index(self, file_list, total_samples): def reconfigure(self, epoch_number, dataset_type): if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: if self.file_shuffle is not Shuffle.OFF: - if self.seed_change_epoch: - np.random.seed(self.seed + epoch_number) - else: - np.random.seed(self.seed) + if self.file_shuffle is Shuffle.SEED: + if self.seed_change_epoch: + np.random.seed(self.seed + epoch_number) + else: + np.random.seed(self.seed) np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle( self.file_list_eval) - if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if dataset_type is DatasetType.TRAIN: - global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, - epoch_number) - else: - global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) - self.file_map = global_file_map[self.my_rank] - elif self.data_loader_sampler == DataLoaderSampler.INDEX: - if dataset_type is DatasetType.TRAIN: - self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) - else: - self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) + # the code assumes that file and sample shuffling is handled by the native data loader code. + if not self.native_data_loader: + if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: + if dataset_type is DatasetType.TRAIN: + global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, + epoch_number) + else: + global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) + self.file_map = global_file_map[self.my_rank] + elif self.data_loader_sampler == DataLoaderSampler.INDEX: + if dataset_type is DatasetType.TRAIN: + self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) + else: + self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) def LoadConfig(args, config):