Skip to content

Commit

Permalink
Support get_tokenizer in clip back inf (#236)
Browse files Browse the repository at this point in the history
* Support get_tokenizer in clip back inf

* remove truncate

* fix

* test fixes
  • Loading branch information
rom1504 authored Jan 13, 2024
1 parent 7a4959d commit 2c56c07
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
19 changes: 12 additions & 7 deletions clip_retrieval/clip_inference/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ class ImageDataset(Dataset):
def __init__(
self,
preprocess,
tokenizer,
folder,
enable_text=True,
enable_image=True,
enable_metadata=False,
input_sampler=lambda a: a,
):
super().__init__()
import clip # pylint: disable=import-outside-toplevel

self.keys, text_files, image_files, metadata_files = folder_to_keys(
folder, enable_text, enable_image, enable_metadata
Expand All @@ -80,7 +80,7 @@ def __init__(
self.enable_metadata = enable_metadata
keys_set = set(self.keys)
if self.enable_text:
self.tokenizer = lambda text: clip.tokenize([text], truncate=True)[0]
self.tokenizer = lambda text: tokenizer([text])[0]
self.text_files = {k: v for k, v in text_files.items() if k in keys_set}
if self.enable_image:
self.image_files = {k: v for k, v in image_files.items() if k in keys_set}
Expand Down Expand Up @@ -125,6 +125,7 @@ def __getitem__(self, ind):
def create_webdataset(
urls,
image_transform,
tokenizer,
enable_text=True,
enable_image=True,
image_key="jpg",
Expand All @@ -134,15 +135,14 @@ def create_webdataset(
input_sampler=lambda a: a,
):
"""Create a WebDataset reader, it can read a webdataset of image, text and json"""
import clip # pylint: disable=import-outside-toplevel
import webdataset as wds # pylint: disable=import-outside-toplevel

urls = input_sampler(urls)

dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue)

def tokenizer(text):
return clip.tokenize([text], truncate=True)[0]
def _tokenizer(text):
return tokenizer([text])[0]

def filter_dataset(item):
if enable_text and caption_key not in item:
Expand All @@ -167,7 +167,7 @@ def preprocess_dataset(item):
if enable_text:
text = item[caption_key]
caption = text.decode("utf-8")
tokenized_text = tokenizer(caption)
tokenized_text = _tokenizer(caption)
output["text_tokens"] = tokenized_text
output["text"] = caption

Expand Down Expand Up @@ -207,6 +207,7 @@ def __init__(
self,
sampler,
preprocess,
tokenizer,
input_dataset,
batch_size,
num_prepro_workers,
Expand All @@ -215,7 +216,9 @@ def __init__(
enable_metadata=False,
) -> None:
super().__init__()
dataset = get_image_dataset()(preprocess, input_dataset, enable_text, enable_image, enable_metadata, sampler)
dataset = get_image_dataset()(
preprocess, tokenizer, input_dataset, enable_text, enable_image, enable_metadata, sampler
)
self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "files")

def __iter__(self):
Expand All @@ -230,6 +233,7 @@ def __init__(
self,
sampler,
preprocess,
tokenizer,
input_dataset,
batch_size,
num_prepro_workers,
Expand All @@ -244,6 +248,7 @@ def __init__(
dataset = create_webdataset(
input_dataset,
preprocess,
tokenizer,
enable_text=enable_text,
enable_image=enable_image,
image_key=wds_image_key,
Expand Down
4 changes: 3 additions & 1 deletion clip_retrieval/clip_inference/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def worker(
print(f"dataset is {len(input_dataset)}", flush=True)

def reader_builder(sampler):
_, preprocess, _ = load_clip(
_, preprocess, tokenizer = load_clip(
clip_model=clip_model,
use_jit=use_jit,
warmup_batch_size=batch_size,
Expand All @@ -59,6 +59,7 @@ def reader_builder(sampler):
return FilesReader(
sampler,
preprocess,
tokenizer,
input_dataset,
batch_size,
num_prepro_workers,
Expand All @@ -70,6 +71,7 @@ def reader_builder(sampler):
return WebdatasetReader(
sampler,
preprocess,
tokenizer,
input_dataset,
batch_size,
num_prepro_workers,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_clip_inference/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_reader(file_format):
input_dataset = [tar_folder + "/image1.tar", tar_folder + "/image2.tar"]
batch_size = 2
num_prepro_workers = 2
_, preprocess, _ = load_clip(warmup_batch_size=batch_size)
_, preprocess, tokenizer = load_clip(warmup_batch_size=batch_size)

output_partition_count = 2
actual_values = []
Expand All @@ -27,6 +27,7 @@ def test_reader(file_format):
reader = FilesReader(
sampler,
preprocess,
tokenizer,
input_dataset,
batch_size,
num_prepro_workers,
Expand All @@ -38,6 +39,7 @@ def test_reader(file_format):
reader = WebdatasetReader(
sampler,
preprocess,
tokenizer,
input_dataset,
batch_size,
num_prepro_workers,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_clip_inference/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def test_runner():
with tempfile.TemporaryDirectory() as tmpdir:

def reader_builder(sampler):
_, preprocess, _ = load_clip(warmup_batch_size=batch_size)
_, preprocess, tokenizer = load_clip(warmup_batch_size=batch_size)
return FilesReader(
sampler,
preprocess,
tokenizer,
folder,
batch_size,
num_prepro_workers,
Expand Down

0 comments on commit 2c56c07

Please sign in to comment.