Skip to content

Commit

Permalink
Adding Native Dali Data Loader support for TFRecord, Images, and NPZ …
Browse files Browse the repository at this point in the history
…files (#118)

* fixed readthedoc build issue

* partial merged the following PR: #81

* added back npz_reader

* fixed bugs

* fixed bugs

* fixed image reader issue

* fixed Profile, PerfTrace

* removed unnecessary logs

* fixed dali_image_reader

* fixed dali_image_reader

* added support for npy format

* added support for npy format

* changed enumerations

* added removed dali base reader

* fixed a bug

* added native-dali-loader tests in github action

* corrected github action formats

* fixed read return

* removed abstractmethod

* fixed bugs

* added dont_use_mmap

* fixed indent

* fixed csvreader

* native_dali test with npy format instead of npz

* fixed issue of enum

* modify action so that dlio will always be installed

* [skip ci] added documentation for dali

* removed read; and define it as pipeline

* added exceptions for unimplemented methods

* added preprocessing

* conditional cache for DLIO installation

* fixed bugs

* fixed bugs

* fixed bugs

* fixing again

* tests again
  • Loading branch information
zhenghh04 authored Dec 15, 2023
1 parent 5dd23af commit 657d4b9
Show file tree
Hide file tree
Showing 20 changed files with 515 additions and 47 deletions.
29 changes: 24 additions & 5 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,18 @@ jobs:
- name: Install System Tools
run: |
sudo apt update
sudo apt-get install $CC $CXX libc6
sudo apt-get install $CC $CXX libc6 git
sudo apt-get install mpich libhwloc-dev
- name: Install DLIO code only
if: steps.cache-modules.outputs.cache-hit == 'true'
run: |
source ${VENV}/bin/activate
rm -rf *.egg*
rm -rf build
rm -rf dist
pip uninstall -y dlio_benchmark
python setup.py build
python setup.py install
- name: Install DLIO
if: steps.cache-modules.outputs.cache-hit != 'true'
run: |
Expand All @@ -57,8 +67,7 @@ jobs:
pip install virtualenv
python -m venv ${VENV}
source ${VENV}/bin/activate
pip install .[test]
rm -rf dlio_benchmark
pip install .[test]
- name: Install DLIO Profiler
run: |
echo "Profiler ${DLIO_PROFILER} gcc $CC"
Expand Down Expand Up @@ -152,8 +161,18 @@ jobs:
- name: test-tf-loader-npz
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=8 ++workload.dataset.num_files_eval=8 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=8 ++workload.dataset.num_files_eval=8 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
- name: test-torch-native-dali-loader-npy
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.reader.data_loader=native_dali ++workload.dataset.format=npy ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.reader.data_loader=native_dali ++workload.dataset.format=npy ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
- name: test-tf-native-dali-loader-npy
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.dataset.format=npy ++workload.reader.data_loader=native_dali ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.dataset.format=npy ++workload.reader.data_loader=native_dali ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
- name: test_subset
run: |
source ${VENV}/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dlio_benchmark ++workload.workflow.generate_data=True
git clone https://github.com/argonne-lcf/dlio_benchmark
cd dlio_benchmark/
pip install .[dlio_profiler]

```
## Container

```bash
Expand Down
6 changes: 5 additions & 1 deletion dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ class FormatType(Enum):
HDF5 = 'hdf5'
CSV = 'csv'
NPZ = 'npz'
NPY = 'npy'
HDF5_OPT = 'hdf5_opt'
JPEG = 'jpeg'
PNG = 'png'

def __str__(self):
return self.value

@ staticmethod
@staticmethod
def get_enum(value):
if FormatType.TFRECORD.value == value:
return FormatType.TFRECORD
Expand All @@ -110,6 +111,8 @@ def get_enum(value):
return FormatType.CSV
elif FormatType.NPZ.value == value:
return FormatType.NPZ
elif FormatType.NPY.value == value:
return FormatType.NPY
elif FormatType.HDF5_OPT.value == value:
return FormatType.HDF5_OPT
elif FormatType.JPEG.value == value:
Expand All @@ -124,6 +127,7 @@ class DataLoaderType(Enum):
TENSORFLOW='tensorflow'
PYTORCH='pytorch'
DALI='dali'
NATIVE_DALI='native_dali'
CUSTOM='custom'
NONE='none'

Expand Down
3 changes: 3 additions & 0 deletions dlio_benchmark/data_generator/generator_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def get_generator(type):
elif type == FormatType.NPZ:
from dlio_benchmark.data_generator.npz_generator import NPZGenerator
return NPZGenerator()
elif type == FormatType.NPY:
from dlio_benchmark.data_generator.npy_generator import NPYGenerator
return NPYGenerator()
elif type == FormatType.JPEG:
from dlio_benchmark.data_generator.jpeg_generator import JPEGGenerator
return JPEGGenerator()
Expand Down
53 changes: 53 additions & 0 deletions dlio_benchmark/data_generator/npy_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from dlio_benchmark.common.enumerations import Compression
from dlio_benchmark.data_generator.data_generator import DataGenerator

import logging
import numpy as np

from dlio_benchmark.utils.utility import progress, utcnow
from dlio_profiler.logger import fn_interceptor as Profile
from shutil import copyfile
from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR

dlp = Profile(MODULE_DATA_GENERATOR)

"""
Generator for creating data in NPZ format.
"""
class NPYGenerator(DataGenerator):
def __init__(self):
super().__init__()

@dlp.log
def generate(self):
"""
Generator for creating data in NPY format of 3d dataset.
"""
super().generate()
np.random.seed(10)
record_labels = [0] * self.num_samples
for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)):
dim1, dim2 = self.get_dimension()
records = np.random.randint(255, size=(dim1, dim2, self.num_samples), dtype=np.uint8)
out_path_spec = self.storage.get_uri(self._file_list[i])
progress(i+1, self.total_files_to_generate, "Generating NPY Data")
prev_out_spec = out_path_spec
np.save(out_path_spec, records)
np.random.seed()
15 changes: 14 additions & 1 deletion dlio_benchmark/data_generator/tf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from subprocess import call

from dlio_benchmark.data_generator.data_generator import DataGenerator
import numpy as np
import tensorflow as tf
from dlio_benchmark.utils.utility import progress, utcnow
from dlio_profiler.logger import fn_interceptor as Profile

from dlio_benchmark.utils.utility import progress, utcnow
from shutil import copyfile
from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR

Expand Down Expand Up @@ -64,4 +67,14 @@ def generate(self):
serialized = example.SerializeToString()
# Write the serialized data to the TFRecords file.
writer.write(serialized)
tfrecord2idx_script = "tfrecord2idx"
folder = "train"
if "valid" in out_path_spec:
folder = "valid"
index_folder = f"{self._args.data_folder}/index/{folder}"
filename = os.path.basename(out_path_spec)
self.storage.create_node(index_folder, exist_ok=True)
tfrecord_idx = f"{index_folder}/{filename}.idx"
if not os.path.isfile(tfrecord_idx):
call([tfrecord2idx_script, out_path_spec, tfrecord_idx])
np.random.seed()
3 changes: 3 additions & 0 deletions dlio_benchmark/data_loader/data_loader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def get_loader(type, format_type, dataset_type, epoch):
elif type == DataLoaderType.DALI:
from dlio_benchmark.data_loader.dali_data_loader import DaliDataLoader
return DaliDataLoader(format_type, dataset_type, epoch)
elif type == DataLoaderType.NATIVE_DALI:
from dlio_benchmark.data_loader.native_dali_data_loader import NativeDaliDataLoader
return NativeDaliDataLoader(format_type, dataset_type, epoch)
else:
print("Data Loader %s not supported or plugins not found" % type)
raise Exception(str(ErrorCodes.EC1004))
60 changes: 60 additions & 0 deletions dlio_benchmark/data_loader/native_dali_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from time import time
import logging
import math
import numpy as np
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali as dali
from nvidia.dali.plugin.pytorch import DALIGenericIterator

from dlio_benchmark.common.constants import MODULE_DATA_LOADER
from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType
from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader
from dlio_benchmark.reader.reader_factory import ReaderFactory
from dlio_benchmark.utils.utility import utcnow, get_rank, timeit
from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile

dlp = Profile(MODULE_DATA_LOADER)


class NativeDaliDataLoader(BaseDataLoader):
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch):
super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI)
self.pipelines = []

@dlp.log
def read(self):
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
parallel = True if self._args.read_threads > 0 else False
self.pipelines = []
num_threads = 1
if self._args.read_threads > 0:
num_threads = self._args.read_threads
# None executes pipeline on CPU and the reader does the batching
pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads,
exec_async=False, exec_pipelined=False)
with pipeline:
images = ReaderFactory.get_reader(type=self.format_type,
dataset_type=self.dataset_type,
thread_index=-1,
epoch_number=self.epoch_number).pipeline()
pipeline.set_outputs(images)
self.pipelines.append(pipeline)

@dlp.log
def next(self):
super().next()
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
for step in range(num_samples // batch_size):
_dataset = DALIGenericIterator(self.pipelines, ['data'])
for batch in _dataset:
logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ")
yield batch

@dlp.log
def finalize(self):
pass
2 changes: 1 addition & 1 deletion dlio_benchmark/reader/csv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def read_index(self, image_idx, step):

@dlp.log
def finalize(self):
return super().finalize()
return super().finalize()
96 changes: 96 additions & 0 deletions dlio_benchmark/reader/dali_image_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import logging
from time import time, sleep
import numpy as np

import nvidia.dali.fn as fn
from dlio_benchmark.common.constants import MODULE_DATA_READER
from dlio_benchmark.dlio_benchmark.reader.reader_handler import FormatReader
from dlio_benchmark.utils.utility import utcnow
from dlio_benchmark.common.enumerations import DatasetType, Shuffle
import nvidia.dali.tfrecord as tfrec
from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile

dlp = Profile(MODULE_DATA_READER)


class DaliImageReader(FormatReader):
@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)

@dlp.log
def open(self, filename):
super().open(filename)

def close(self):
super().close()

def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
raise Exception("get sample method is not implemented in dali readers")

def next(self):
super().next()
raise Exception("next method is not implemented in dali readers")

def read_index(self):
super().read_index()
raise Exception("read_index method is not implemented in dali readers")

@dlp.log
def pipeline(self):
logging.debug(
f"{utcnow()} Reading {len(self._file_list)} files rank {self._args.my_rank}")
random_shuffle = False
seed = -1
seed_change_epoch = False
if self._args.sample_shuffle is not Shuffle.OFF:
if self._args.sample_shuffle is not Shuffle.SEED:
seed = self._args.seed
random_shuffle = True
seed_change_epoch = True
initial_fill = 1024
if self._args.shuffle_size > 0:
initial_fill = self._args.shuffle_size
prefetch_size = 1
if self._args.prefetch_size > 0:
prefetch_size = self._args.prefetch_size

stick_to_shard = True
if seed_change_epoch:
stick_to_shard = False
images, labels = fn.readers.file(files=self._file_list, num_shards=self._args.comm_size,
prefetch_queue_depth=prefetch_size,
initial_fill=initial_fill, random_shuffle=random_shuffle,
shuffle_after_epoch=seed_change_epoch,
stick_to_shard=stick_to_shard, pad_last_batch=True,
dont_use_mmap=self._args.dont_use_mmap)
images = fn.decoders.image(images, device='cpu')
fn.python_function(dataset, function=self.preprocess, num_outputs=0)
dataset = self._resize(images)
return dataset

@dlp.log
def _resize(self, dataset):
return fn.resize(dataset, size=[self._args.max_dimension, self._args.max_dimension])

@dlp.log
def finalize(self):
pass
Loading

0 comments on commit 657d4b9

Please sign in to comment.