Skip to content

Commit

Permalink
support GroundingDINO and segment-anything
Browse files Browse the repository at this point in the history
  • Loading branch information
jordddan committed Apr 17, 2023
1 parent ba55d64 commit 23fbef9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 46 deletions.
43 changes: 0 additions & 43 deletions extensions/grounding_config.py

This file was deleted.

8 changes: 5 additions & 3 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import argparse
import inspect

import tempfile
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
Expand Down Expand Up @@ -911,17 +911,19 @@ def __init__(self, device):
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.model_checkpoint_path = os.path.join("checkpoints","groundingdino")
self.model_config_path = os.path.join("checkpoints","grounding_config.py")
self.download_parameters()
self.box_threshold = 0.3
self.text_threshold = 0.25
self.model_config_path = "extensions/grounding_config.py"
self.grounding = (self.load_model()).to(self.device)

def download_parameters(self):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)

config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(self.model_config_path):
wget.download(config_url,out=self.model_config_path)
def load_image(self,image_path):
# load image
image_pil = Image.open(image_path).convert("RGB") # load image
Expand Down

0 comments on commit 23fbef9

Please sign in to comment.