From 7fd6397c8083e9d4cbc1d15c17678223932bad72 Mon Sep 17 00:00:00 2001 From: MaxTeselkin Date: Mon, 16 Dec 2024 21:11:12 +0400 Subject: [PATCH] made raw mask generation mode to predict the same class --- src/main.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/main.py b/src/main.py index 1e9ce32..589d01c 100644 --- a/src/main.py +++ b/src/main.py @@ -116,7 +116,7 @@ def load_on_device( # build predictor self.predictor = SamPredictor(self.sam) # define class names - self.class_names = ["target_mask"] + self.class_names = ["object_mask"] # list for storing mask colors self.mask_colors = [[255, 0, 0]] # variable for storing image ids from previous inference iterations @@ -191,18 +191,10 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred output_mode=settings["output_mode"], ) masks = mask_generator.generate(input_image) - for i, mask in enumerate(masks): - class_name = "object_" + str(i) - # add new class to model meta if necessary - if not self._model_meta.get_obj_class(class_name): - color = generate_rgb(self.mask_colors) - self.mask_colors.append(color) - self.class_names.append(class_name) - new_class = sly.ObjClass(class_name, sly.Bitmap, color) - self._model_meta = self._model_meta.add_obj_class(new_class) + for mask in masks: # get predicted mask mask = mask["segmentation"] - predictions.append(sly.nn.PredictionMask(class_name=class_name, mask=mask)) + predictions.append(sly.nn.PredictionMask(class_name="object_mask", mask=mask)) elif settings["mode"] == "bbox": # get bbox coordinates if "rectangle" not in settings: @@ -449,7 +441,9 @@ def smart_segmentation(response: Response, request: Request): init_mask = None if init_mask is not None: image_info = api.image.get_info_by_id(image_id) - init_mask = functional.bitmap_to_mask(init_mask, image_info.height, image_info.width) + init_mask = functional.bitmap_to_mask( + init_mask, image_info.height, image_info.width + ) # init_mask = functional.crop_image(crop, init_mask) assert init_mask.shape[:2] == image_np.shape[:2] settings["init_mask"] = init_mask