diff --git a/visual_chatgpt.py b/visual_chatgpt.py index 0ad8cf09..a7977397 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -238,78 +238,6 @@ def get_new_image_name(org_img_name, func_name="update"): new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png' return os.path.join(head, new_file_name) -class MaskFormer: - def __init__(self, device): - print(f"Initializing MaskFormer to {device}") - self.device = device - self.revision = 'fp16' if 'cuda' in self.device else None - self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 - self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") - self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) - - - def get_mask(self, image_path, text): - threshold = 0.5 - min_area = 0.02 - padding = 20 - original_image = Image.open(image_path) - image = original_image.resize((512, 512)) - inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device) - with torch.no_grad(): - outputs = self.model(**inputs) - mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold - area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1]) - if area_ratio < min_area: - return None - true_indices = np.argwhere(mask) - mask_array = np.zeros_like(mask, dtype=bool) - for idx in true_indices: - padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx) - mask_array[padded_slice] = True - visual_mask = (mask_array * 255).astype(np.uint8) - image_mask = Image.fromarray(visual_mask) - return image_mask.resize(original_image.size) - -class ImageEditing: - def __init__(self, device): - print(f"Initializing ImageEditing to {device}") - self.device = device - self.mask_former = MaskFormer(device=self.device) - self.revision = 'fp16' if 'cuda' in device else None - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 - self.inpaint = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device) - - @prompts(name="Remove Something From The Photo", - description="useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. ") - def inference_remove(self, inputs): - image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) - return self.inference_replace(f"{image_path},{to_be_removed_txt},background") - - @prompts(name="Replace Something From The Photo", - description="useful when you want to replace an object from the object description or " - "location with another object from its description. " - "The input to this tool should be a comma separated string of three, " - "representing the image_path, the object to be replaced, the object to be replaced with ") - def inference_replace(self, inputs): - image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") - original_image = Image.open(image_path) - original_size = original_image.size - mask_image = self.mask_former.inference(image_path, to_be_replaced_txt) - updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)), - mask_image=mask_image.resize((512, 512))).images[0] - updated_image_path = get_new_image_name(image_path, func_name="replace-something") - updated_image = updated_image.resize(original_size) - updated_image.save(updated_image_path) - print( - f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, " - f"Output Image: {updated_image_path}") - return updated_image_path - - class InstructPix2Pix: def __init__(self, device): print(f"Initializing InstructPix2Pix to {device}")