diff --git a/README.md b/README.md index f593cfb..998cde3 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,9 @@ pip3 install autodistill-clip ```python from autodistill_clip import CLIP +from autodistill.detection import CaptionOntology -# define an ontology to map class names to our GroundingDINO prompt +# define an ontology to map class names to our CLIP prompt # the ontology dictionary has the format {caption: class} # where caption is the prompt sent to the base model, and class is the label that will # be saved for that caption in the generated annotations @@ -46,10 +47,14 @@ base_model = CLIP( } ) ) + +results = base_model.predict("./context_images/test.jpg") + +print(results) + base_model.label("./context_images", extension=".jpeg") ``` - ## License The code in this repository is licensed under an [MIT license](LICENSE.md). diff --git a/autodistill_clip/__init__.py b/autodistill_clip/__init__.py index 8e531db..cdf81ee 100644 --- a/autodistill_clip/__init__.py +++ b/autodistill_clip/__init__.py @@ -1,3 +1,3 @@ from autodistill_clip.clip_model import CLIP -__version__ = "0.1.4" +__version__ = "0.1.5" diff --git a/autodistill_clip/clip_model.py b/autodistill_clip/clip_model.py index 365a61a..82522d1 100644 --- a/autodistill_clip/clip_model.py +++ b/autodistill_clip/clip_model.py @@ -1,20 +1,16 @@ import os +import sys from dataclasses import dataclass +from typing import Union -import sys -from PIL import Image import numpy as np import supervision as sv import torch -from autodistill.core.embedding_ontology import ( - EmbeddingOntology, - ONTOLOGY_WITH_EMBEDDINGS, - compare_embeddings, -) +from autodistill.classification import ClassificationBaseModel from autodistill.core.embedding_model import EmbeddingModel +from autodistill.core.embedding_ontology import EmbeddingOntology, compare_embeddings from autodistill.detection import CaptionOntology -from autodistill.classification import ClassificationBaseModel -from typing import Union +from autodistill.helpers import load_image HOME = os.path.expanduser("~") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -54,7 +50,8 @@ def __init__(self, ontology: Union[EmbeddingOntology, CaptionOntology]): self.ontology_type = self.ontology.__class__.__name__ def embed_image(self, input: str) -> np.ndarray: - image = self.clip_preprocess(Image.open(input)).unsqueeze(0).to(DEVICE) + image = load_image(input, return_format="PIL") + image = self.clip_preprocess(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): image_features = self.clip_model.encode_image(image) @@ -67,7 +64,8 @@ def embed_text(self, input: str) -> np.ndarray: ) def predict(self, input: str) -> sv.Classifications: - image = self.clip_preprocess(Image.open(input)).unsqueeze(0).to(DEVICE) + image = load_image(input, return_format="PIL") + image = self.clip_preprocess(image).unsqueeze(0).to(DEVICE) if isinstance(self.ontology, EmbeddingOntology): with torch.no_grad():