Skip to content

Commit

Permalink
improve: support new open clip models
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Sep 15, 2022
1 parent 5599547 commit 0ca3146
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 26 deletions.
30 changes: 19 additions & 11 deletions perceptor/losses/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,38 @@


class OpenCLIP(LossInterface):
def __init__(self, architecture="ViT-B-32", weights="laion2b_e16"):
def __init__(
self,
architecture="ViT-H-14",
weights="laion2b_s32b_b79k",
):
"""
Args:
architecture (str): name of the clip model
weights (str): name of the weights
Available weight/model combinations are (in order of relevance):
- ("ViT-B-32", "laion2b_e16") (65.62%)
- ("ViT-B-16-plus-240", "laion400m_e32") (69.21%)
- ("ViT-B-16", "laion400m_e32") (67.07%)
- ("ViT-B-32", "laion400m_e32") (62.96%)
- ("ViT-L-14", "laion400m_e32") (72.77%)
- ("ViT-H-14", "laion2b_s32b_b79k") (78.0%)
- ("ViT-g-14", "laion2b_s12b_b42k") (76.6%)
- ("ViT-L-14", "laion2b_s32b_b82k") (75.3%)
- ("ViT-B-32", "laion2b_s34b_b79k") (66.6%)
- ("ViT-B-16-plus-240", "laion400m_e32") (69.2%)
- ("ViT-B-32", "laion2b_e16") (65.7%)
- ("ViT-B-16", "laion400m_e32") (67.0%)
- ("ViT-B-32", "laion400m_e32") (62.9%)
- ("ViT-L-14", "laion400m_e32") (72.8%)
- ("RN101", "yfcc15m") (34.8%)
- ("RN50", "yfcc15m") (32.7%)
- ("RN50", "cc12m") (36.45%)
- ("RN50-quickgelu", "openai")
- ("RN50-quickgelu", "openai") (59.6%)
- ("RN101-quickgelu", "openai")
- ("RN50x4", "openai")
- ("RN50x16", "openai")
- ("RN50x64", "openai")
- ("ViT-B-32-quickgelu", "openai")
- ("ViT-B-16", "openai")
- ("ViT-L-14", "openai")
- ("ViT-L-14-336", "openai")
- ("ViT-B-32-quickgelu", "openai") (63.3%)
- ("ViT-B-16", "openai") (68.3%)
- ("ViT-L-14", "openai") (75.6%)
- ("ViT-L-14-336", "openai") (76.6%)
"""
super().__init__()
self.architecture = architecture
Expand Down
56 changes: 45 additions & 11 deletions perceptor/models/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@
@utils.cache
class OpenCLIP(torch.nn.Module):
def __init__(
self, architecture="ViT-B-32", weights="laion2b_e16", precision=None, jit=False
self,
architecture="ViT-L-14",
weights="laion2b_s32b_b82k",
precision=None,
jit=False,
):
"""
Args:
architecture (str): name of the clip model
weights (str): name of the weights
Available weight/model combinations are (in order of relevance):
- ("ViT-B-32", "laion2b_e16") (65.7%)
- ("ViT-H-14", "laion2b_s32b_b79k") (78.0%)
- ("ViT-g-14", "laion2b_s12b_b42k") (76.6%)
- ("ViT-L-14", "laion2b_s32b_b82k") (75.3%)
- ("ViT-B-32", "laion2b_s34b_b79k") (66.6%)
- ("ViT-B-16-plus-240", "laion400m_e32") (69.2%)
- ("ViT-B-32", "laion2b_e16") (65.7%)
- ("ViT-B-16", "laion400m_e32") (67.0%)
- ("ViT-B-32", "laion400m_e32") (62.9%)
- ("ViT-L-14", "laion400m_e32") (72.8%)
Expand All @@ -43,9 +51,10 @@ def __init__(
if (architecture, weights) not in open_clip.list_pretrained():
raise ValueError(f"Invalid architecture/weights: {architecture}/{weights}")

pretrained_cfg = open_clip.pretrained.get_pretrained_cfg(architecture, weights)
weights_path = open_clip.pretrained.download_pretrained(
open_clip.pretrained.get_pretrained_url(architecture, weights),
root="models",
pretrained_cfg,
cache_dir="models",
)

# softmax on cpu does not support half precision
Expand All @@ -58,6 +67,7 @@ def __init__(
else:
precision = "fp32"

# hack: needed to specify path to weights
if weights == "openai":
self.model = open_clip.load_openai_model(
weights_path, start_device, jit=jit
Expand All @@ -73,12 +83,32 @@ def __init__(
jit=jit,
).eval()

# hack: since we specified the weights path instead of the model name the config isn't loaded right
setattr(
self.model.visual,
"image_mean",
pretrained_cfg.get(
"mean",
getattr(self.model.visual, "image_mean", None),
)
or (0.48145466, 0.4578275, 0.40821073),
)
setattr(
self.model.visual,
"image_std",
pretrained_cfg.get(
"std",
getattr(self.model.visual, "image_std", None),
)
or (0.26862954, 0.26130258, 0.27577711),
)

if jit is False:
self.model = self.model.requires_grad_(False)

self.normalize = transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
self.model.visual.image_mean,
self.model.visual.image_std,
)

def to(self, device):
Expand All @@ -90,6 +120,13 @@ def to(self, device):
def device(self):
return next(iter(self.parameters())).device

@property
def image_size(self):
if isinstance(self.model.visual.image_size, tuple):
return self.model.visual.image_size
else:
return (self.model.visual.image_size, self.model.visual.image_size)

@torch.cuda.amp.autocast()
def encode_texts(self, text_prompts, normalize=True):
encodings = self.model.encode_text(
Expand All @@ -106,10 +143,7 @@ def encode_images(self, images, normalize=True):
self.normalize(
resize(
images.to(self.device),
out_shape=(
self.model.visual.image_size,
self.model.visual.image_size,
),
out_shape=self.image_size,
)
)
)
Expand All @@ -126,7 +160,7 @@ def forward(self, _):
def test_open_clip():
import torch

model = OpenCLIP("ViT-B-32", "laion2b_e16")
model = OpenCLIP()

image = torch.randn((1, 3, 256, 256)).requires_grad_()
with torch.enable_grad():
Expand Down
5 changes: 3 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "perceptor"
version = "0.5.11"
version = "0.5.12"
description = "Modular image generation library"
authors = ["Richard Löwenström <samedii@gmail.com>", "dribnet"]
readme = "README.md"
Expand Down Expand Up @@ -28,10 +28,10 @@ more-itertools = "^8.12.0"
dill = "^0.3.4"
ninja = "^1.10.2"
lpips = "^0.1.4"
open-clip-torch = "^1.3.0"
pytorch-lantern = "^0.12.0"
taming-transformers-rom1504 = "^0.0.6"
diffusers = "^0.2.4"
open-clip-torch = "^2.0.0"

[tool.poetry.dev-dependencies]
ipykernel = "^6.8.0"
Expand Down

0 comments on commit 0ca3146

Please sign in to comment.