From c3381dff51c75da9cea99ef3d6bf8c85e9f5e1e1 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Wed, 11 Sep 2024 09:57:51 -0400 Subject: [PATCH] fix: revert back to old style dataset sharding --- flaxdiff/data/online_loader.py | 16 ++++++++-------- setup.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/flaxdiff/data/online_loader.py b/flaxdiff/data/online_loader.py index cc51446..0f85f6d 100644 --- a/flaxdiff/data/online_loader.py +++ b/flaxdiff/data/online_loader.py @@ -86,7 +86,8 @@ def default_feature_extractor(sample): def map_sample( - sample, + url, + caption, image_shape=(256, 256), min_image_shape=(128, 128), timeout=15, @@ -94,11 +95,8 @@ def map_sample( upscale_interpolation=cv2.INTER_CUBIC, downscale_interpolation=cv2.INTER_AREA, image_processor=default_image_processor, - feature_extractor=default_feature_extractor, ): try: - features = feature_extractor(sample) - url, caption = features["url"], features["caption"] # Assuming fetch_single_image is defined elsewhere image = fetch_single_image(url, timeout=timeout, retries=retries) if image is None: @@ -147,8 +145,10 @@ def map_batch( downscale_interpolation=downscale_interpolation, feature_extractor=feature_extractor ) + features = feature_extractor(batch) + url, caption = features["url"], features["caption"] with ThreadPoolExecutor(max_workers=num_threads) as executor: - executor.map(map_sample_fn, batch) + executor.map(map_sample_fn, url, caption) except Exception as e: print(f"Error maping batch", e) traceback.print_exc() @@ -214,9 +214,9 @@ def parallel_image_loader( iteration = 0 while True: # Repeat forever - # shards = [dataset[i*shard_len:(i+1)*shard_len] - # for i in range(num_workers)] - shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)] + shards = [dataset[i*shard_len:(i+1)*shard_len] + for i in range(num_workers)] + # shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)] print(f"mapping {len(shards)} shards") pool.map(map_batch_fn, shards) iteration += 1 diff --git a/setup.py b/setup.py index bb67a6f..201a6a5 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.35', + version='0.1.35.1', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',