diff --git a/.gitignore b/.gitignore index 91110be..fa9bb14 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,7 @@ wheels/ .venv venv +generated +output.png + draw.py -boxes.txt diff --git a/README.md b/README.md index e8a0c3f..bd3397b 100644 --- a/README.md +++ b/README.md @@ -25,12 +25,15 @@ from PIL import Image seg_model = SegModel.from_path("./models/seg_model.onnx") -img = Image.open("./input.png").convert("RGB") +img = Image.open("./assets/data/us_card.png").convert("RGB") cards = seg_model.apply(img) for card in cards: print(f"score={card.score}, box={card.box}") - img = draw_mask(img, card.mask > 0.5) img = draw_box(img, card.box) + img = draw_mask(img, card.mask > 0.5) img.save("./output.png") ``` + +## License +Apache-2.0 diff --git a/generate_tests.py b/generate_tests.py new file mode 100644 index 0000000..d6b2bd0 --- /dev/null +++ b/generate_tests.py @@ -0,0 +1,30 @@ +from microwink import SegModel +from microwink.common import draw_mask, draw_box + +from PIL import Image +from pathlib import Path +from tests.utils import round_box + +MODEL_PATH = "./models/seg_model.onnx" +SAVE_TO = Path("./generated") + + +def main() -> None: + SAVE_TO.mkdir(exist_ok=True) + seg_model = SegModel.from_path(MODEL_PATH) + + for img_path in Path("./assets/data/").iterdir(): + img = Image.open(img_path).convert("RGB") + cards = seg_model.apply(img) + + print(img_path.name) + for card in cards: + img = draw_box(img, card.box) + img = draw_mask(img, card.mask > 0.5) + print(round_box(card.box)) + print() + img.save(SAVE_TO / img_path.name) + + +if __name__ == "__main__": + main() diff --git a/src/microwink/seg.py b/src/microwink/seg.py index d7e3468..fd8b5e7 100644 --- a/src/microwink/seg.py +++ b/src/microwink/seg.py @@ -179,12 +179,12 @@ def postprocess( likely = x[:, 4 : 4 + NUM_CLASSES].max(axis=1) > conf_threshold x = x[likely] - boxes = x[:, :4] scores = x[:, 4 : 4 + NUM_CLASSES].max(axis=1) + boxes = x[:, :4] + boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h) keep = self.nms( boxes, scores, - conf_threshold=conf_threshold, iou_threshold=iou_threshold, ) N = len(keep) @@ -199,7 +199,6 @@ def postprocess( masks_in = masks_in[keep] ih, iw = img_size - boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h) masks = self.postprocess_masks(protos, masks_in, boxes, (ih, iw)) assert masks.shape == (N, ih, iw) @@ -293,21 +292,41 @@ def nms( boxes: np.ndarray, scores: np.ndarray, *, - conf_threshold: float, iou_threshold: float, ) -> list[int]: - from cv2.dnn import NMSBoxes - + sorted_indices = np.argsort(scores)[::-1] N = len(boxes) assert boxes.shape == (N, 4) assert scores.shape == (N,) - keep = NMSBoxes( - boxes, # type: ignore - scores, # type: ignore - conf_threshold, - iou_threshold, - ) - return list(keep) + assert sorted_indices.shape == (N,) + + keep_boxes = [] + while sorted_indices.size > 0: + box_id = int(sorted_indices[0]) + ious = SegModel._compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) + keep_indices = np.where(ious < iou_threshold)[0] + sorted_indices = sorted_indices[keep_indices + 1] + + keep_boxes.append(box_id) + return keep_boxes + + @staticmethod + def _compute_iou(box: np.ndarray, boxes: np.ndarray) -> np.ndarray: + assert box.shape == (4,) + assert boxes.shape == (len(boxes), 4) + xmin = np.maximum(box[0], boxes[:, 0]) + ymin = np.maximum(box[1], boxes[:, 1]) + xmax = np.minimum(box[2], boxes[:, 2]) + ymax = np.minimum(box[3], boxes[:, 3]) + + intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) + + box_area = (box[2] - box[0]) * (box[3] - box[1]) + boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + union_area = box_area + boxes_area - intersection_area + + iou = intersection_area / union_area + return iou @staticmethod def with_border( @@ -318,11 +337,11 @@ def with_border( right: int, color: tuple[int, int, int], ) -> np.ndarray: - from cv2 import BORDER_CONSTANT, copyMakeBorder + import cv2 assert img.ndim == 3 - return copyMakeBorder( - img, top, bottom, left, right, BORDER_CONSTANT, value=color + return cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color ) diff --git a/tests/test_readme.py b/tests/test_readme.py new file mode 100644 index 0000000..2635e40 --- /dev/null +++ b/tests/test_readme.py @@ -0,0 +1,27 @@ +from pathlib import Path + + +Script = str + + +def test_readme() -> None: + readme_path = Path("./README.md") + start_tag = "```python" + end_tag = "```" + scripts = parse_readme_code(readme_path, start_tag, end_tag) + for script in scripts: + print("\n# executing the following script") + print(script) + print("\n# stdout...") + exec(script) + + +def parse_readme_code(path: Path, start_tag: str, end_tag) -> list[Script]: + assert path.exists() + text = path.read_text() + _, *sections = text.split(start_tag) + scripts = [] + for section in sections: + script, *_ = section.split(end_tag) + scripts.append(script.strip()) + return scripts