diff --git a/src/microwink/seg.py b/src/microwink/seg.py index d7e3468..fd8b5e7 100644 --- a/src/microwink/seg.py +++ b/src/microwink/seg.py @@ -179,12 +179,12 @@ def postprocess( likely = x[:, 4 : 4 + NUM_CLASSES].max(axis=1) > conf_threshold x = x[likely] - boxes = x[:, :4] scores = x[:, 4 : 4 + NUM_CLASSES].max(axis=1) + boxes = x[:, :4] + boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h) keep = self.nms( boxes, scores, - conf_threshold=conf_threshold, iou_threshold=iou_threshold, ) N = len(keep) @@ -199,7 +199,6 @@ def postprocess( masks_in = masks_in[keep] ih, iw = img_size - boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h) masks = self.postprocess_masks(protos, masks_in, boxes, (ih, iw)) assert masks.shape == (N, ih, iw) @@ -293,21 +292,41 @@ def nms( boxes: np.ndarray, scores: np.ndarray, *, - conf_threshold: float, iou_threshold: float, ) -> list[int]: - from cv2.dnn import NMSBoxes - + sorted_indices = np.argsort(scores)[::-1] N = len(boxes) assert boxes.shape == (N, 4) assert scores.shape == (N,) - keep = NMSBoxes( - boxes, # type: ignore - scores, # type: ignore - conf_threshold, - iou_threshold, - ) - return list(keep) + assert sorted_indices.shape == (N,) + + keep_boxes = [] + while sorted_indices.size > 0: + box_id = int(sorted_indices[0]) + ious = SegModel._compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) + keep_indices = np.where(ious < iou_threshold)[0] + sorted_indices = sorted_indices[keep_indices + 1] + + keep_boxes.append(box_id) + return keep_boxes + + @staticmethod + def _compute_iou(box: np.ndarray, boxes: np.ndarray) -> np.ndarray: + assert box.shape == (4,) + assert boxes.shape == (len(boxes), 4) + xmin = np.maximum(box[0], boxes[:, 0]) + ymin = np.maximum(box[1], boxes[:, 1]) + xmax = np.minimum(box[2], boxes[:, 2]) + ymax = np.minimum(box[3], boxes[:, 3]) + + intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) + + box_area = (box[2] - box[0]) * (box[3] - box[1]) + boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + union_area = box_area + boxes_area - intersection_area + + iou = intersection_area / union_area + return iou @staticmethod def with_border( @@ -318,11 +337,11 @@ def with_border( right: int, color: tuple[int, int, int], ) -> np.ndarray: - from cv2 import BORDER_CONSTANT, copyMakeBorder + import cv2 assert img.ndim == 3 - return copyMakeBorder( - img, top, bottom, left, right, BORDER_CONSTANT, value=color + return cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color )