diff --git a/README.md b/README.md index b25ee6a9..582ea15b 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,12 @@ python visual_chatgpt.py --load ImageCaptioning_cpu,Text2Image_cpu python visual_chatgpt.py --load "ImageCaptioning_cuda:0,Text2Image_cuda:0" # Advice for 4 Tesla V100 32GB -python visual_chatgpt.py --load "ImageCaptioning_cuda:0,ImageEditing_cuda:0, +python visual_chatgpt.py --load "Text2Box_cuda:0,Segmenting_cuda:0, + MaskFormer_cuda:0,Inpainting_cuda:0,ImageCaptioning_cuda:0, Text2Image_cuda:1,Image2Canny_cpu,CannyText2Image_cuda:1, Image2Depth_cpu,DepthText2Image_cuda:1,VisualQuestionAnswering_cuda:2, InstructPix2Pix_cuda:2,Image2Scribble_cpu,ScribbleText2Image_cuda:2, - Image2Seg_cpu,SegText2Image_cuda:2,Image2Pose_cpu,PoseText2Image_cuda:2, + SegText2Image_cuda:2,Image2Pose_cpu,PoseText2Image_cuda:2, Image2Hed_cpu,HedText2Image_cuda:3,Image2Normal_cpu, NormalText2Image_cuda:3,Image2Line_cpu,LineText2Image_cuda:3" diff --git a/visual_chatgpt.py b/visual_chatgpt.py index a58e4f93..0ad8cf09 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -6,7 +6,7 @@ import cv2 import re import uuid -from PIL import Image, ImageDraw, ImageOps +from PIL import Image, ImageDraw, ImageOps, ImageFont import math import numpy as np import argparse @@ -26,6 +26,20 @@ from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.llms.openai import OpenAI +# Grounding DINO +import extensions.GroundingDINO.groundingdino.datasets.transforms as T +from extensions.GroundingDINO.groundingdino.models import build_model +from extensions.GroundingDINO.groundingdino.util import box_ops +from extensions.GroundingDINO.groundingdino.util.slconfig import SLConfig +from extensions.GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap + +# segment anything +from extensions.segment_anything.segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator +import cv2 +import numpy as np +import matplotlib.pyplot as plt +import wget + VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated. @@ -224,16 +238,17 @@ def get_new_image_name(org_img_name, func_name="update"): new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png' return os.path.join(head, new_file_name) - - class MaskFormer: def __init__(self, device): print(f"Initializing MaskFormer to {device}") self.device = device + self.revision = 'fp16' if 'cuda' in self.device else None + self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) - def inference(self, image_path, text): + + def get_mask(self, image_path, text): threshold = 0.5 min_area = 0.02 padding = 20 @@ -255,7 +270,6 @@ def inference(self, image_path, text): image_mask = Image.fromarray(visual_mask) return image_mask.resize(original_image.size) - class ImageEditing: def __init__(self, device): print(f"Initializing ImageEditing to {device}") @@ -660,74 +674,6 @@ def inference(self, inputs): f"Output Image: {updated_image_path}") return updated_image_path - -class Image2Seg: - def __init__(self, device): - print("Initializing Image2Seg") - self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") - self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") - self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]] - - @prompts(name="Segmentation On Image", - description="useful when you want to detect segmentations of the image. " - "like: segment this image, or generate segmentations on this image, " - "or perform segmentation on this image. " - "The input to this tool should be a string, representing the image_path") - def inference(self, inputs): - image = Image.open(inputs) - pixel_values = self.image_processor(image, return_tensors="pt").pixel_values - with torch.no_grad(): - outputs = self.image_segmentor(pixel_values) - seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] - color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 - palette = np.array(self.ade_palette) - for label, color in enumerate(palette): - color_seg[seg == label, :] = color - color_seg = color_seg.astype(np.uint8) - segmentation = Image.fromarray(color_seg) - updated_image_path = get_new_image_name(inputs, func_name="segmentation") - segmentation.save(updated_image_path) - print(f"\nProcessed Image2Seg, Input Image: {inputs}, Output Pose: {updated_image_path}") - return updated_image_path - - class SegText2Image: def __init__(self, device): print(f"Initializing SegText2Image to {device}") @@ -919,6 +865,273 @@ def inference(self, inputs): return answer +class Segmenting: + def __init__(self, device): + print(f"Inintializing Segmentation to {device}") + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.model_checkpoint_path = os.path.join("checkpoints","sam") + + self.download_parameters() + self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device) + self.sam_predictor = SamPredictor(self.sam) + self.mask_generator = SamAutomaticMaskGenerator(self.sam) + + def download_parameters(self): + url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + if not os.path.exists(self.model_checkpoint_path): + wget.download(url,out=self.model_checkpoint_path) + + def show_mask(self, mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([1])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 1]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + def show_box(self, box, ax, label): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + ax.text(x0, y0, label) + + + def get_mask_with_boxes(self, image_pil, image, boxes_filt): + + size = image_pil.size + H, W = size[1], size[0] + for i in range(boxes_filt.size(0)): + boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) + boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 + boxes_filt[i][2:] += boxes_filt[i][:2] + + boxes_filt = boxes_filt.cpu() + transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(self.device) + + masks, _, _ = self.sam_predictor.predict_torch( + point_coords = None, + point_labels = None, + boxes = transformed_boxes.to(self.device), + multimask_output = False, + ) + return masks + + def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, pred_phrases): + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + self.sam_predictor.set_image(image) + + masks = self.get_mask_with_boxes(image_pil, image, boxes_filt) + + # draw output image + plt.figure(figsize=(10, 10)) + plt.imshow(image) + for mask in masks: + self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) + + updated_image_path = get_new_image_name(image_path, func_name="segmentation") + plt.axis('off') + plt.savefig( + updated_image_path, + bbox_inches="tight", dpi=300, pad_inches=0.0 + ) + return updated_image_path + + + @prompts(name="Segment the Image", + description="useful when you want to detect segmentations of the image. " + "like: segment this image, or generate segmentations on this image, " + "or perform segmentation on this image, " + "or segment all the object in this image" + "The input to this tool should be a string, representing the image_path") + def inference_all(self,image_path): + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + masks = self.mask_generator.generate(image) + plt.figure(figsize=(20,20)) + plt.imshow(image) + if len(masks) == 0: + return + sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(False) + polygons = [] + color = [] + for ann in sorted_anns: + m = ann['segmentation'] + img = np.ones((m.shape[0], m.shape[1], 3)) + color_mask = np.random.random((1, 3)).tolist()[0] + for i in range(3): + img[:,:,i] = color_mask[i] + ax.imshow(np.dstack((img, m))) + + updated_image_path = get_new_image_name(image_path, func_name="segment-image") + plt.axis('off') + plt.savefig( + updated_image_path, + bbox_inches="tight", dpi=300, pad_inches=0.0 + ) + return updated_image_path + +class Text2Box: + def __init__(self, device): + print(f"Initializing ObjectDetection to {device}") + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.model_checkpoint_path = os.path.join("checkpoints","groundingdino") + self.download_parameters() + self.box_threshold = 0.3 + self.text_threshold = 0.25 + self.model_config_path = "extensions/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" + self.grounding = (self.load_model()).to(self.device) + + def download_parameters(self): + url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" + if not os.path.exists(self.model_checkpoint_path): + wget.download(url,out=self.model_checkpoint_path) + + def load_image(self,image_path): + # load image + image_pil = Image.open(image_path).convert("RGB") # load image + + transform = T.Compose( + [ + T.RandomResize([512], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image, _ = transform(image_pil, None) # 3, h, w + return image_pil, image + + def load_model(self): + args = SLConfig.fromfile(self.model_config_path) + args.device = self.device + model = build_model(args) + checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu") + load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) + print(load_res) + _ = model.eval() + return model + + def get_grounding_boxes(self, image, caption, with_logits=True): + caption = caption.lower() + caption = caption.strip() + if not caption.endswith("."): + caption = caption + "." + image = image.to(self.device) + with torch.no_grad(): + outputs = self.grounding(image[None], captions=[caption]) + logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) + boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) + logits.shape[0] + + # filter output + logits_filt = logits.clone() + boxes_filt = boxes.clone() + filt_mask = logits_filt.max(dim=1)[0] > self.box_threshold + logits_filt = logits_filt[filt_mask] # num_filt, 256 + boxes_filt = boxes_filt[filt_mask] # num_filt, 4 + logits_filt.shape[0] + + # get phrase + tokenlizer = self.grounding.tokenizer + tokenized = tokenlizer(caption) + # build pred + pred_phrases = [] + for logit, box in zip(logits_filt, boxes_filt): + pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, tokenized, tokenlizer) + if with_logits: + pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") + else: + pred_phrases.append(pred_phrase) + + return boxes_filt, pred_phrases + + def plot_boxes_to_image(self, image_pil, tgt): + H, W = tgt["size"] + boxes = tgt["boxes"] + labels = tgt["labels"] + assert len(boxes) == len(labels), "boxes and labels must have same length" + + draw = ImageDraw.Draw(image_pil) + mask = Image.new("L", image_pil.size, 0) + mask_draw = ImageDraw.Draw(mask) + + # draw boxes and masks + for box, label in zip(boxes, labels): + # from 0..1 to 0..W, 0..H + box = box * torch.Tensor([W, H, W, H]) + # from xywh to xyxy + box[:2] -= box[2:] / 2 + box[2:] += box[:2] + # random color + color = tuple(np.random.randint(0, 255, size=3).tolist()) + # draw + x0, y0, x1, y1 = box + x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) + + draw.rectangle([x0, y0, x1, y1], outline=color, width=6) + # draw.text((x0, y0), str(label), fill=color) + + font = ImageFont.load_default() + if hasattr(font, "getbbox"): + bbox = draw.textbbox((x0, y0), str(label), font) + else: + w, h = draw.textsize(str(label), font) + bbox = (x0, y0, w + x0, y0 + h) + # bbox = draw.textbbox((x0, y0), str(label)) + draw.rectangle(bbox, fill=color) + draw.text((x0, y0), str(label), fill="white") + + mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=2) + + return image_pil, mask + + @prompts(name="Detect the Give Object", + description="useful when you only want to detect or find out given objects in the picture" + "The input to this tool should be a comma separated string of two, " + "representing the image_path, the text description of the object to be found") + def inference(self, inputs): + image_path, det_prompt = inputs.split(",") + print(f"image_path={image_path}, text_prompt={det_prompt}") + image_pil, image = self.load_image(image_path) + + boxes_filt, pred_phrases = self.get_grounding_boxes(image, det_prompt) + + size = image_pil.size + pred_dict = { + "boxes": boxes_filt, + "size": [size[1], size[0]], # H,W + "labels": pred_phrases,} + + image_with_box = self.plot_boxes_to_image(image_pil, pred_dict)[0] + + updated_image_path = get_new_image_name(image_path, func_name="detect-something") + updated_image = image_with_box.resize(size) + updated_image.save(updated_image_path) + print( + f"\nProcessed ObejectDetecting, Input Image: {image_path}, Object to be Detect {det_prompt}, " + f"Output Image: {updated_image_path}") + return updated_image_path + + +class Inpainting: + def __init__(self, device): + self.device = device + self.revision = 'fp16' if 'cuda' in self.device else None + self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 + + self.inpaint = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device) + def __call__(self, prompt, original_image, mask_image): + update_image = self.inpaint(prompt=prompt, image=original_image.resize((512, 512)), + mask_image=mask_image.resize((512, 512))).images[0] + return update_image + class InfinityOutPainting: template_model = True # Add this line to show this is a template model. def __init__(self, ImageCaptioning, ImageEditing, VisualQuestionAnswering): @@ -1017,6 +1230,94 @@ def inference(self, inputs): return updated_image_path +class ObjectSegmenting: + template_model = True # Add this line to show this is a template model. + def __init__(self, Text2Box:Text2Box, Segmenting:Segmenting): + # self.llm = OpenAI(temperature=0) + self.grounding = Text2Box + self.sam = Segmenting + + @prompts(name="Segment the given object", + description="useful when you only want to segment the certain objects in the picture" + "according to the given text" + "like: segment the cat," + "or can you segment an obeject for me" + "The input to this tool should be a comma separated string of two, " + "representing the image_path, the text description of the object to be found") + def inference(self, inputs): + image_path, det_prompt = inputs.split(",") + print(f"image_path={image_path}, text_prompt={det_prompt}") + image_pil, image = self.grounding.load_image(image_path) + boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, det_prompt) + updated_image_path = self.sam.segment_image_with_boxes(image_pil,image_path,boxes_filt,pred_phrases) + print( + f"\nProcessed ObejectSegmenting, Input Image: {image_path}, Object to be Segment {det_prompt}, " + f"Output Image: {updated_image_path}") + return updated_image_path + + +class ImageEditing: + template_model = True + def __init__(self, Text2Box:Text2Box, Segmenting:Segmenting, Inpainting:Inpainting): + print(f"Initializing ImageEditing") + self.sam = Segmenting + self.grounding = Text2Box + self.inpaint = Inpainting + + def pad_edge(self,mask,padding): + #mask Tensor [H,W] + mask = mask.numpy() + true_indices = np.argwhere(mask) + mask_array = np.zeros_like(mask, dtype=bool) + for idx in true_indices: + padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx) + mask_array[padded_slice] = True + new_mask = (mask_array * 255).astype(np.uint8) + #new_mask + return new_mask + + @prompts(name="Remove Something From The Photo", + description="useful when you want to remove and object or something from the photo " + "from its description or location. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the object need to be removed. ") + def inference_remove(self, inputs): + image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:]) + return self.inference_replace_sam(f"{image_path},{to_be_removed_txt},background") + + @prompts(name="Replace Something From The Photo", + description="useful when you want to replace an object from the object description or " + "location with another object from its description. " + "The input to this tool should be a comma separated string of three, " + "representing the image_path, the object to be replaced, the object to be replaced with ") + def inference_replace_sam(self,inputs): + image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") + + print(f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}") + image_pil, image = self.grounding.load_image(image_path) + boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, to_be_replaced_txt) + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + self.sam.sam_predictor.set_image(image) + masks = self.sam.get_mask_with_boxes(image_pil, image, boxes_filt) + mask = torch.sum(masks, dim=0).unsqueeze(0) + mask = torch.where(mask > 0, True, False) + mask = mask.squeeze(0).squeeze(0).cpu() #tensor + + mask = self.pad_edge(mask,padding=20) #numpy + mask_image = Image.fromarray(mask) + + updated_image = self.inpaint(prompt=replace_with_txt, original_image=image_pil, + mask_image=mask_image) + updated_image_path = get_new_image_name(image_path, func_name="replace-something") + updated_image = updated_image.resize(image_pil.size) + updated_image.save(updated_image_path) + print( + f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, " + f"Output Image: {updated_image_path}") + return updated_image_path + + class ConversationBot: def __init__(self, load_dict): # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...} @@ -1037,6 +1338,9 @@ def __init__(self, load_dict): if template_required_names.issubset(loaded_names): self.models[class_name] = globals()[class_name]( **{name: self.models[name] for name in template_required_names}) + + print(f"All the Available Functions: {self.models}") + self.tools = [] for instance in self.models.values(): for e in dir(instance): @@ -1045,6 +1349,7 @@ def __init__(self, load_dict): self.tools.append(Tool(name=func.name, description=func.description, func=func)) self.llm = OpenAI(temperature=0) self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output') + def init_agent(self, lang): self.memory.clear() #clear previous history if lang=='English': @@ -1104,6 +1409,8 @@ def run_image(self, image, state, txt, lang): if __name__ == '__main__': + if not os.path.exists("checkpoints"): + os.mkdir("checkpoints") parser = argparse.ArgumentParser() parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0") args = parser.parse_args()