Skip to content

Commit

Permalink
Merge pull request #20 from supervisely-ecosystem/raw-mode
Browse files Browse the repository at this point in the history
made raw mask generation mode to predict the same class
  • Loading branch information
MaxTeselkin authored Dec 16, 2024
2 parents 4250ac8 + 7fd6397 commit 8738662
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def load_on_device(
# build predictor
self.predictor = SamPredictor(self.sam)
# define class names
self.class_names = ["target_mask"]
self.class_names = ["object_mask"]
# list for storing mask colors
self.mask_colors = [[255, 0, 0]]
# variable for storing image ids from previous inference iterations
Expand Down Expand Up @@ -191,18 +191,10 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
output_mode=settings["output_mode"],
)
masks = mask_generator.generate(input_image)
for i, mask in enumerate(masks):
class_name = "object_" + str(i)
# add new class to model meta if necessary
if not self._model_meta.get_obj_class(class_name):
color = generate_rgb(self.mask_colors)
self.mask_colors.append(color)
self.class_names.append(class_name)
new_class = sly.ObjClass(class_name, sly.Bitmap, color)
self._model_meta = self._model_meta.add_obj_class(new_class)
for mask in masks:
# get predicted mask
mask = mask["segmentation"]
predictions.append(sly.nn.PredictionMask(class_name=class_name, mask=mask))
predictions.append(sly.nn.PredictionMask(class_name="object_mask", mask=mask))
elif settings["mode"] == "bbox":
# get bbox coordinates
if "rectangle" not in settings:
Expand Down Expand Up @@ -449,7 +441,9 @@ def smart_segmentation(response: Response, request: Request):
init_mask = None
if init_mask is not None:
image_info = api.image.get_info_by_id(image_id)
init_mask = functional.bitmap_to_mask(init_mask, image_info.height, image_info.width)
init_mask = functional.bitmap_to_mask(
init_mask, image_info.height, image_info.width
)
# init_mask = functional.crop_image(crop, init_mask)
assert init_mask.shape[:2] == image_np.shape[:2]
settings["init_mask"] = init_mask
Expand Down

0 comments on commit 8738662

Please sign in to comment.