Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration with hugging face ClipModel #26 #266

Merged
merged 21 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ clip_inference turn a set of text+image into clip embeddings
* **write_batch_size** Write batch size (default *10**6*)
* **wds_image_key** Key to use for images in webdataset. (default *jpg*)
* **wds_caption_key** Key to use for captions in webdataset. (default *txt*)
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip).
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model.
* **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*)
* **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*)
* **use_jit** uses jit for the clip model (default *True*)
Expand Down Expand Up @@ -223,7 +223,7 @@ The API is very similar to `clip-retrieval inference` with some minor changes:
* **enable_metadata** Enable metadata processing (default *False*)
* **wds_image_key** Key to use for images in webdataset. (default *jpg*)
* **wds_caption_key** Key to use for captions in webdataset. (default *txt*)
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip).
* **clip_model** CLIP model to load (default *ViT-B/32*). Specify it as `"open_clip:ViT-B-32-quickgelu"` to use the [open_clip](https://github.com/mlfoundations/open_clip) or `"hf_clip:patrickjohncyh/fashion-clip"` to use the [hugging face](https://huggingface.co/docs/transformers/model_doc/clip) clip model.
* **mclip_model** MCLIP model to load (default *sentence-transformers/clip-ViT-B-32-multilingual-v1*)
* **use_mclip** If False it performs the inference using CLIP; MCLIP otherwise (default *False*)
* **use_jit** uses jit for the clip model (default *True*)
Expand Down
2 changes: 1 addition & 1 deletion clip_retrieval/clip_inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(
input_format="files",
cache_path=None,
batch_size=256,
num_prepro_workers=8,
num_prepro_workers=4,
enable_text=True,
enable_image=True,
enable_metadata=False,
Expand Down
2 changes: 1 addition & 1 deletion clip_retrieval/clip_inference/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def worker(
input_format="files",
cache_path=None,
batch_size=256,
num_prepro_workers=8,
num_prepro_workers=4,
enable_text=True,
enable_image=True,
enable_metadata=False,
Expand Down
41 changes: 41 additions & 0 deletions clip_retrieval/load_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@
import time


class HFClipWrapper(nn.Module):
"""
Wrap Huggingface ClipModel
"""

def __init__(self, inner_model, device):
super().__init__()
self.inner_model = inner_model
self.device = torch.device(device=device)
if self.device.type == "cpu":
self.dtype = torch.float32
else:
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

def encode_image(self, image):
if self.device.type == "cpu":
return self.inner_model.get_image_features(image.squeeze(1))
with autocast(device_type=self.device.type, dtype=self.dtype):
return self.inner_model.get_image_features(image.squeeze(1))

def encode_text(self, text):
if self.device.type == "cpu":
return self.inner_model.get_text_features(text)
with autocast(device_type=self.device.type, dtype=self.dtype):
return self.inner_model.get_text_features(text)


class OpenClipWrapper(nn.Module):
"""
Wrap OpenClip for managing input types
Expand Down Expand Up @@ -37,6 +64,17 @@ def forward(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)


def load_hf_clip(clip_model, device="cuda"):
"""load hf clip"""
from transformers import CLIPProcessor, CLIPModel # pylint: disable=import-outside-toplevel

model = CLIPModel.from_pretrained(clip_model)
preprocess = CLIPProcessor.from_pretrained(clip_model).image_processor
model = HFClipWrapper(inner_model=model, device=device)
model.to(device=device)
return model, lambda x: preprocess(x, return_tensors="pt").pixel_values


def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None):
"""load open clip"""

Expand Down Expand Up @@ -72,6 +110,9 @@ def load_clip_without_warmup(clip_model, use_jit, device, clip_cache_path):
if clip_model.startswith("open_clip:"):
clip_model = clip_model[len("open_clip:") :]
model, preprocess = load_open_clip(clip_model, use_jit, device, clip_cache_path)
elif clip_model.startswith("hf_clip:"):
clip_model = clip_model[len("hf_clip:") :]
model, preprocess = load_hf_clip(clip_model, device)
else:
model, preprocess = clip.load(clip_model, device=device, jit=use_jit, download_root=clip_cache_path)
return model, preprocess
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ open-clip-torch>=2.0.0,<3.0.0
requests>=2.27.1,<3
aiohttp>=3.8.1,<4
multilingual-clip>=1.0.10,<2
transformers
urllib3<2
scipy<1.9.2
scipy<1.9.2
2 changes: 1 addition & 1 deletion tests/test_clip_inference/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from clip_retrieval.clip_inference.mapper import ClipMapper


@pytest.mark.parametrize("model", ["ViT-B/32", "open_clip:ViT-B-32-quickgelu"])
@pytest.mark.parametrize("model", ["ViT-B/32", "open_clip:ViT-B-32-quickgelu", "hf_clip:patrickjohncyh/fashion-clip"])
def test_mapper(model):
os.environ["CUDA_VISIBLE_DEVICES"] = ""

Expand Down
Loading