Skip to content

Commit

Permalink
use load_image primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Dec 5, 2023
1 parent 52d8b8a commit 23c7052
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ pip3 install autodistill-clip

```python
from autodistill_clip import CLIP
from autodistill.detection import CaptionOntology

# define an ontology to map class names to our GroundingDINO prompt
# define an ontology to map class names to our CLIP prompt
# the ontology dictionary has the format {caption: class}
# where caption is the prompt sent to the base model, and class is the label that will
# be saved for that caption in the generated annotations
Expand All @@ -46,10 +47,14 @@ base_model = CLIP(
}
)
)

results = base_model.predict("./context_images/test.jpg")

print(results)

base_model.label("./context_images", extension=".jpeg")
```


## License

The code in this repository is licensed under an [MIT license](LICENSE.md).
Expand Down
2 changes: 1 addition & 1 deletion autodistill_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autodistill_clip.clip_model import CLIP

__version__ = "0.1.4"
__version__ = "0.1.5"
20 changes: 9 additions & 11 deletions autodistill_clip/clip_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import os
import sys
from dataclasses import dataclass
from typing import Union

import sys
from PIL import Image
import numpy as np
import supervision as sv
import torch
from autodistill.core.embedding_ontology import (
EmbeddingOntology,
ONTOLOGY_WITH_EMBEDDINGS,
compare_embeddings,
)
from autodistill.classification import ClassificationBaseModel
from autodistill.core.embedding_model import EmbeddingModel
from autodistill.core.embedding_ontology import EmbeddingOntology, compare_embeddings
from autodistill.detection import CaptionOntology
from autodistill.classification import ClassificationBaseModel
from typing import Union
from autodistill.helpers import load_image

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -54,7 +50,8 @@ def __init__(self, ontology: Union[EmbeddingOntology, CaptionOntology]):
self.ontology_type = self.ontology.__class__.__name__

def embed_image(self, input: str) -> np.ndarray:
image = self.clip_preprocess(Image.open(input)).unsqueeze(0).to(DEVICE)
image = load_image(input, return_format="PIL")
image = self.clip_preprocess(image).unsqueeze(0).to(DEVICE)

with torch.no_grad():
image_features = self.clip_model.encode_image(image)
Expand All @@ -67,7 +64,8 @@ def embed_text(self, input: str) -> np.ndarray:
)

def predict(self, input: str) -> sv.Classifications:
image = self.clip_preprocess(Image.open(input)).unsqueeze(0).to(DEVICE)
image = load_image(input, return_format="PIL")
image = self.clip_preprocess(image).unsqueeze(0).to(DEVICE)

if isinstance(self.ontology, EmbeddingOntology):
with torch.no_grad():
Expand Down

0 comments on commit 23c7052

Please sign in to comment.