diff --git a/src/microwink/seg.py b/src/microwink/seg.py index b7c30e5..fbb405b 100644 --- a/src/microwink/seg.py +++ b/src/microwink/seg.py @@ -194,28 +194,33 @@ def postprocess( L = len(x) assert x.shape == (L, C) x = np.c_[ - x[..., :4], # boxes - x[..., 4 : 4 + NUM_CLASSES].max(axis=1), # scores - x[..., 4 : 4 + NUM_CLASSES].argmax(axis=1), # class_ids - x[..., 4 + NUM_CLASSES :], # masks + x[:, :4], # boxes + x[:, 4 : 4 + NUM_CLASSES].max(axis=1), # scores + x[:, 4 : 4 + NUM_CLASSES].argmax(axis=1), # class_ids + x[:, 4 + NUM_CLASSES :], # masks ] assert x.shape == (L, 1 + C) - keep = cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold) + keep = self.nms( + x[:, :4], + x[:, 4], + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + ) N = len(keep) x = x[keep] assert x.shape == (N, 1 + C) if N == 0: return None - x[..., [0, 1]] -= x[..., [2, 3]] / 2 - x[..., [2, 3]] += x[..., [0, 1]] + x[:, [0, 1]] -= x[:, [2, 3]] / 2 + x[:, [2, 3]] += x[:, [0, 1]] - x[..., :4] -= [pad_w, pad_h, pad_w, pad_h] - x[..., :4] /= ratio + x[:, :4] -= [pad_w, pad_h, pad_w, pad_h] + x[:, :4] /= ratio ih, iw = img_size - x[..., [0, 2]] = x[:, [0, 2]].clip(0, iw) - x[..., [1, 3]] = x[:, [1, 3]].clip(0, ih) + x[:, [0, 2]] = x[:, [0, 2]].clip(0, iw) + x[:, [1, 3]] = x[:, [1, 3]].clip(0, ih) boxes = x[:, :4] masks = self.process_mask(protos, x[:, 6:], boxes, (ih, iw)) @@ -224,6 +229,25 @@ def postprocess( assert dets.shape == (N, 6) return dets, masks + @staticmethod + def nms( + boxes: np.ndarray, + scores: np.ndarray, + *, + conf_threshold: float, + iou_threshold: float, + ) -> list[int]: + N = len(boxes) + assert boxes.shape == (N, 4) + assert scores.shape == (N,) + keep = cv2.dnn.NMSBoxes( + boxes, # type: ignore + scores, # type: ignore + conf_threshold, + iou_threshold, + ) + return list(keep) + def process_mask( self, protos: np.ndarray,