Skip to content

Commit

Permalink
improve: better error if clip model does not exist
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Sep 12, 2022
1 parent eefd975 commit 0edc917
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion perceptor/losses/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class OpenCLIP(LossInterface):
def __init__(self, architecture="ViT-B-32", weights="laion2b_e16"):
"""
Args:
archicture (str): name of the clip model
architecture (str): name of the clip model
weights (str): name of the weights
Available weight/model combinations are (in order of relevance):
Expand Down
14 changes: 9 additions & 5 deletions perceptor/models/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
@utils.cache
class OpenCLIP(torch.nn.Module):
def __init__(
self, archicture="ViT-B-32", weights="laion2b_e16", precision=None, jit=False
self, architecture="ViT-B-32", weights="laion2b_e16", precision=None, jit=False
):
"""
Args:
archicture (str): name of the clip model
architecture (str): name of the clip model
weights (str): name of the weights
Available weight/model combinations are (in order of relevance):
Expand All @@ -37,11 +37,15 @@ def __init__(
- ("ViT-L-14-336", "openai") (76.6%)
"""
super().__init__()
self.archicture = archicture
self.architecture = architecture
self.weights = weights

if (architecture, weights) not in open_clip.list_pretrained():
raise ValueError(f"Invalid architecture/weights: {architecture}/{weights}")

weights_path = open_clip.pretrained.download_pretrained(
open_clip.pretrained.get_pretrained_url(archicture, weights), root="models"
open_clip.pretrained.get_pretrained_url(architecture, weights),
root="models",
)

# softmax on cpu does not support half precision
Expand All @@ -62,7 +66,7 @@ def __init__(
model = model.float()
else:
self.model = open_clip.create_model(
archicture,
architecture,
weights_path,
device=start_device,
precision=precision,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "perceptor"
version = "0.5.9"
version = "0.5.10"
description = "Modular image generation library"
authors = ["Richard Löwenström <samedii@gmail.com>", "dribnet"]
readme = "README.md"
Expand Down

0 comments on commit 0edc917

Please sign in to comment.