Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support to vitmatte (matte anything) #42

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Empty file added .gitmodules
Empty file.
8 changes: 8 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
'SAMModelLoader (segment anything)': SAMModelLoader,
'GroundingDinoModelLoader (segment anything)': GroundingDinoModelLoader,
'GroundingDinoSAMSegment (segment anything)': GroundingDinoSAMSegment,
'GenerateVITMatte (segment anything)': GenerateVITMatte,
'VITMatteTransformersModelLoader (segment anything)': VITMatteTransformersModelLoader,
'MaskToTrimap (segment anything)': MaskToTrimap,
'TrimapToMask (segment anything)': TrimapToMask,
'InvertMask (segment anything)': InvertMask,
"IsMaskEmpty": IsMaskEmptyNode,
"BoundingBox (segment anything)":BoundingBox,
"MaskToBoundingBox (segment anything)":MaskToBoundingBox,
"BoundingBoxSAMSegment (segment anything)":BoundingBoxSAMSegment,
"GroundingDinoBoundingBoxes (segment anything)":GroundingDinoBoundingBoxes
}

__all__ = ['NODE_CLASS_MAPPINGS']
Expand Down
9 changes: 5 additions & 4 deletions install.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import os.path
import subprocess

import os
custom_nodes_path = os.path.dirname(os.path.abspath(__file__))

def build_pip_install_cmds(args):
Expand All @@ -11,7 +11,8 @@ def build_pip_install_cmds(args):
return [sys.executable, '-m', 'pip', 'install'] + args

def ensure_package():
cmds = build_pip_install_cmds(['-r', 'requirements.txt'])
subprocess.run(cmds, cwd=custom_nodes_path)
if os.environ.get('COMFY_SAM_ENSURE_PACKAGES',None):
cmds = build_pip_install_cmds(['-r', 'requirements.txt'])
subprocess.run(cmds, cwd=custom_nodes_path)

ensure_package()
ensure_package()
272 changes: 270 additions & 2 deletions node.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import sys

import cv2


sys.path.append(
os.path.dirname(os.path.abspath(__file__))
)

import copy
import torch
import numpy as np
from PIL import Image
from PIL import Image,ImageDraw
import logging
from torch.hub import download_url_to_file
from urllib.parse import urlparse
Expand All @@ -21,6 +25,7 @@
from local_groundingdino.models import build_model as local_groundingdino_build_model
import glob
import folder_paths
from transformers import VitMatteImageProcessor, VitMatteForImageMatting

logger = logging.getLogger('comfyui_segment_anything')

Expand Down Expand Up @@ -367,4 +372,267 @@ def INPUT_TYPES(s):
CATEGORY = "segment_anything"

def main(self, mask):
return (torch.all(mask == 0).int().item(), )
return (torch.all(mask == 0).int().item(), )

def tensor2pil(image: torch.Tensor) -> Image.Image:
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))

def pil2tensor(image: Image.Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)


class VITMatteModel:
def __init__(self,model,processor):
self.model = model
self.processor = processor


class VITMatteTransformersModelLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": ("STRING",{"default":"hustvl/vitmatte-small-composition-1k"}),

}
}

RETURN_TYPES = ("VIT_MATTE_MODEL",)
FUNCTION = "load_model"

CATEGORY = "segment_anything"

def load_model(self, model_name):
model = VitMatteForImageMatting.from_pretrained(model_name)
processor = VitMatteImageProcessor.from_pretrained(model_name)
vitmatte = VITMatteModel(
model,
processor,
)
return (vitmatte,)

class GenerateVITMatte:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE", {}),
"trimap": ("TRIMAP", {}),
"vit_matte_model": ("VIT_MATTE_MODEL", {}),
}
}

RETURN_TYPES = ("IMAGE","MASK")
FUNCTION = "generate_matte"

CATEGORY = "Matte Anything"

def generate_matte(self, image, trimap, vit_matte_model):
image = tensor2pil(image)
trimap = tensor2pil(trimap).convert("L")

# prepare image + trimap for the model
inputs = vit_matte_model.processor(images=image, trimaps=trimap, return_tensors="pt")

with torch.no_grad():
predictions = vit_matte_model.model(**inputs).alphas


mask = tensor2pil(predictions).convert('L')
mask = mask.crop((0,0,image.width,image.height)) # remove padding that the prediction appends (works in 32px tiles)
image.putalpha(mask)
image = pil2tensor(image)
mask = pil2tensor(mask)
return (image,mask,)

class MaskToTrimap:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK", {}),
"erode_kernel_size": ("INT", {"default":10,"min": 1, "step": 1}),
"dilate_kernel_size": ("INT", {"default":10,"min": 1, "step": 1}),
}
}

RETURN_TYPES = ("TRIMAP",)
FUNCTION = "get_trimap"

CATEGORY = "segment_anything"

def get_trimap(self, mask: torch.Tensor, erode_kernel_size: int, dilate_kernel_size: int):
mask = mask.squeeze(0).cpu().detach().numpy().astype(np.uint8)*255
trimap = self.generate_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32)
trimap[trimap==128] = 0.5
trimap[trimap==255] = 1
trimap = torch.from_numpy(trimap).unsqueeze(0)

return (trimap,)

def generate_trimap(self,mask, erode_kernel_size=10, dilate_kernel_size=10):
erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8)
dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
eroded = cv2.erode(mask, erode_kernel, iterations=5)
dilated = cv2.dilate(mask, dilate_kernel, iterations=5)
trimap = np.zeros_like(mask)
trimap[dilated==255] = 128
trimap[eroded==255] = 255
return trimap


class TrimapToMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"trimap": ("TRIMAP", {}),
}
}

RETURN_TYPES = ("MASK",)
FUNCTION = "to_mask"

CATEGORY = "segment_anything"

def to_mask(self, trimap: torch.Tensor):
return (trimap,)

class MaskToBoundingBox:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK", {}),
}
}
CATEGORY = "mask"
FUNCTION = "main"
RETURN_TYPES = ("BOUNDING_BOX",)


def main(self, mask):
mask_np = np.clip(255. * mask.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
axis = np.where(mask_np != 0)
rmin = np.min(axis[0])
rmax = np.max(axis[0])
cmin = np.min(axis[1])
cmax = np.max(axis[1])
return (torch.FloatTensor([cmin,rmin,cmax,rmax]),)

class BoundingBox:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"left": ("INT", {"default":0,"step":1}),
"top": ("INT", {"default":0,"step":1}),
"right": ("INT", {"default":0,"step":1}),
"bottom": ("INT", {"default":0,"step":1}),

}
}
CATEGORY = "mask"
FUNCTION = "util"
RETURN_TYPES = ("BOUNDING_BOX",)


def main(self, left,top,right,bottom,):
return (torch.FloatTensor([left,top,right,bottom]),)


class BoundingBoxSAMSegment:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"sam_model": ('SAM_MODEL', {}),
"image": ('IMAGE', {}),
"bounding_box": ("BOUNDING_BOX", {}),

}
}
CATEGORY = "segment_anything"
FUNCTION = "main"
RETURN_TYPES = ("IMAGE", "MASK",)

def main(self, sam_model, image, bounding_box):
res_images = []
res_masks = []

for item in image:
item = Image.fromarray(
np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA')

(images, masks) = sam_segment(
sam_model,
item,
bounding_box
)
res_images.extend(images)
res_masks.extend(masks)

if len(res_images) == 0:
_, height, width, _ = image.size()
empty_mask = torch.zeros((1, height, width), dtype=torch.uint8, device="cpu")
return (empty_mask, empty_mask)
return (torch.cat(res_images, dim=0), torch.cat(res_masks, dim=0))


class GroundingDinoBoundingBoxes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"grounding_dino_model": ('GROUNDING_DINO_MODEL', {}),
"image": ('IMAGE', {}),
"prompt": ("STRING", {}),
"threshold": ("FLOAT", {
"default": 0.3,
"min": 0,
"max": 1.0,
"step": 0.01
}),
}
}
CATEGORY = "segment_anything"
FUNCTION = "main"
MULTIPE_OUTPUTS=True
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("IMAGE", "MASK")

def main(self, grounding_dino_model, image, prompt, threshold):
res_images = []
res_masks = []
for item in image:
item = Image.fromarray(
np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA')
boxes = groundingdino_predict(
grounding_dino_model,
item,
prompt,
threshold
)
if boxes.shape[0] == 0:
break
# (images, masks) = sam_segment(
# sam_model,
# item,
# boxes
# )
for box in boxes:
shape = ((box[0],box[1]),(box[2],box[3]))
mask = Image.new("RGB",item.size,"#000000")
drawer = ImageDraw.Draw(mask)
drawer.rectangle(shape, fill ="#ffffff")
res_images.append(pil2tensor(mask))
res_masks.append(pil2tensor(mask))

# res_images.extend(boxes)
# res_masks.extend(boxes)
print(res_images)
if len(res_images) == 0:
_, height, width, _ = image.size()
empty_mask = torch.zeros((1, height, width), dtype=torch.uint8, device="cpu")
return (empty_mask, empty_mask)
return (torch.cat(res_images), torch.cat(res_masks))
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
segment_anything
timm
addict
yapf
yapf
transformers>=4.36.2