Skip to content

Commit

Permalink
Merge pull request #7 from cospectrum/dev
Browse files Browse the repository at this point in the history
add typehints and asserts
  • Loading branch information
cospectrum authored Jan 4, 2025
2 parents 6bfd149 + 985fd26 commit bebd8d0
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 65 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "microwink"
version = "0.0.3"
version = "0.0.4"
description = "Lightweight instance segmentation of card IDs"
readme = "README.md"
license = { text = "Apache-2.0" }
Expand Down
171 changes: 109 additions & 62 deletions src/microwink/seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

H = NewType("H", int)
W = NewType("W", int)
BgrBuf = NewType("BgrBuf", np.ndarray)
RgbBuf = NewType("RgbBuf", np.ndarray)


Expand All @@ -44,11 +43,6 @@ class RawResult:
masks: np.ndarray


def rgb_to_bgr(img: RgbBuf) -> BgrBuf:
bgr = img[..., ::-1]
return BgrBuf(bgr)


class SegModel:
session: ort.InferenceSession
dtype: Dtype
Expand Down Expand Up @@ -79,18 +73,20 @@ def __init__(self, session: ort.InferenceSession) -> None:
self.dtype = np.float16
else:
self.dtype = np.float32
self.model_height, self.model_width = self.input_.shape[-2:]
B, _, H, W = self.input_.shape
assert B == 1, "batching is not supported"
self.model_height = H
self.model_width = W

def apply(
self, image: PILImage, threshold: Threshold = Threshold()
) -> list[SegResult]:
CLASS_ID = 0.0
CLASS_ID = 0

assert image.mode == "RGB"
buf = np.array(image)
img_buf = rgb_to_bgr(RgbBuf(buf))
buf = RgbBuf(np.array(image))

raw = self._forward(img_buf, threshold.confidence, threshold.iou)
raw = self._run(buf, threshold.confidence, threshold.iou)
if raw is None:
return []
assert len(raw.boxes) == len(raw.masks)
Expand All @@ -111,17 +107,18 @@ def apply(
)
return results

def _forward(
self, im0: BgrBuf, conf_threshold: float, iou_threshold: float
def _run(
self, img: RgbBuf, conf_threshold: float, iou_threshold: float
) -> RawResult | None:
NM = 32
assert im0.ndim == 3
blob, ratio, (pad_w, pad_h) = self.preprocess(im0)
ih, iw, _ = img.shape

blob, ratio, (pad_w, pad_h) = self.preprocess(img)
assert blob.ndim == 4
preds = self.session.run(None, {self.input_.name: blob})
out = self.postprocess(
preds,
im0=im0,
img_size=(ih, iw),
ratio=ratio,
pad_w=pad_w,
pad_h=pad_h,
Expand All @@ -132,13 +129,14 @@ def _forward(
if out is None:
return None
boxes, masks = out
assert isinstance(boxes, np.ndarray)
assert isinstance(masks, np.ndarray)
N = len(boxes)
assert boxes.shape == (N, 6)
assert masks.shape == (N, ih, iw), masks.shape
masks = common.sigmoid(masks)
return RawResult(boxes=boxes, masks=masks)

def preprocess(
self, img_buf: BgrBuf
self, img_buf: RgbBuf
) -> tuple[np.ndarray, float, tuple[float, float]]:
BORDER_COLOR = (114, 114, 114)
EPS = 0.1
Expand All @@ -161,79 +159,115 @@ def preprocess(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=BORDER_COLOR
)

img = (1 / 255.0) * np.ascontiguousarray(
np.einsum("HWC->CHW", img)[::-1], # type: ignore
blob = (1 / 255.0) * np.ascontiguousarray(
np.einsum("HWC->CHW", img), # type: ignore
dtype=self.dtype,
)
assert img.ndim == 3
img = img[None]
return img, r, (pad_w, pad_h)
assert blob.ndim == 3
blob = blob[None]
return blob, r, (pad_w, pad_h)

def postprocess(
self,
preds,
im0,
preds: list[np.ndarray],
img_size: tuple[H, W],
ratio: float,
pad_w,
pad_h,
conf_threshold,
iou_threshold,
pad_w: float,
pad_h: float,
conf_threshold: float,
iou_threshold: float,
nm: int,
) -> tuple[np.ndarray, np.ndarray] | None:
x, protos = preds[0], preds[1]
x = np.einsum("bcn->bnc", x)
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
B = 1
NM, MH, MW = (nm, 160, 160)
NUM_CLASSES = 1
C = 4 + NUM_CLASSES + NM

x, protos = preds
assert len(x) == len(protos) == B
protos = protos[0]
x = x[0].T
assert protos.shape == (NM, MH, MW), protos.shape
assert x.shape == (len(x), C)

x = x[x[..., 4 : 4 + NUM_CLASSES].max(axis=1) > conf_threshold]
L = len(x)
assert x.shape == (L, C)
x = np.c_[
x[..., :4],
np.amax(x[..., 4:-nm], axis=-1),
np.argmax(x[..., 4:-nm], axis=-1),
x[..., -nm:],
x[:, :4], # boxes
x[:, 4 : 4 + NUM_CLASSES].max(axis=1), # scores
x[:, 4 : 4 + NUM_CLASSES].argmax(axis=1), # class_ids
x[:, 4 + NUM_CLASSES :], # masks
]
x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)]

if len(x) == 0:
assert x.shape == (L, 1 + C)
keep = self.nms(
x[:, :4],
x[:, 4],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
)
N = len(keep)
x = x[keep]
assert x.shape == (N, 1 + C)
if N == 0:
return None

x[..., [0, 1]] -= x[..., [2, 3]] / 2
x[..., [2, 3]] += x[..., [0, 1]]
x[:, [0, 1]] -= x[:, [2, 3]] / 2
x[:, [2, 3]] += x[:, [0, 1]]

x[..., :4] -= [pad_w, pad_h, pad_w, pad_h]
x[..., :4] /= ratio
x[:, :4] -= [pad_w, pad_h, pad_w, pad_h]
x[:, :4] /= ratio

x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1])
x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0])
ih, iw = img_size
x[:, [0, 2]] = x[:, [0, 2]].clip(0, iw)
x[:, [1, 3]] = x[:, [1, 3]].clip(0, ih)

h, w, _ = im0.shape
masks = self.process_mask(protos[0], x[:, 6:], x[:, :4], (h, w))
return x[..., :6], masks
boxes = x[:, :4]
masks = self.process_mask(protos, x[:, 6:], boxes, (ih, iw))
dets = x[:, :6]
assert masks.shape == (N, ih, iw)
assert dets.shape == (N, 6)
return dets, masks

@staticmethod
def crop_mask(masks: np.ndarray, boxes: np.ndarray) -> np.ndarray:
N, h, w = masks.shape
def nms(
boxes: np.ndarray,
scores: np.ndarray,
*,
conf_threshold: float,
iou_threshold: float,
) -> list[int]:
N = len(boxes)
assert boxes.shape == (N, 4)
x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)
r = np.arange(w, dtype=x1.dtype)[None, None, :]
c = np.arange(h, dtype=x1.dtype)[None, :, None]
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
assert scores.shape == (N,)
keep = cv2.dnn.NMSBoxes(
boxes, # type: ignore
scores, # type: ignore
conf_threshold,
iou_threshold,
)
return list(keep)

def process_mask(
self,
protos: np.ndarray,
masks_in: np.ndarray,
bboxes: np.ndarray,
boxes: np.ndarray,
img_size: tuple[H, W],
) -> np.ndarray:
N = len(masks_in)
nm, mh, mw = protos.shape
assert boxes.shape == (N, 4)
assert masks_in.shape == (N, nm)

masks = np.matmul(masks_in, protos.reshape((nm, -1))).reshape((N, mh, mw))
ih, iw = img_size
c, mh, mw = protos.shape
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw))
assert masks.shape == (N, mh, mw)
masks = self.scale_mask(np.ascontiguousarray(masks), (ih, iw))
masks = self.scale_masks(np.ascontiguousarray(masks), (ih, iw))
assert masks.shape == (N, ih, iw)
return self.crop_mask(masks, bboxes)
return self.crop_masks(masks, boxes)

@staticmethod
def scale_mask(masks: np.ndarray, img_size: tuple[H, W]) -> np.ndarray:
def scale_masks(masks: np.ndarray, img_size: tuple[H, W]) -> np.ndarray:
EPS = 0.1
ih, iw = img_size
N, mh, mw = masks.shape
Expand All @@ -256,6 +290,19 @@ def scale_mask(masks: np.ndarray, img_size: tuple[H, W]) -> np.ndarray:
masks_out[i] = resized_mask
return masks_out

@staticmethod
def crop_masks(masks: np.ndarray, boxes: np.ndarray) -> np.ndarray:
N, mh, mw = masks.shape
assert boxes.shape == (N, 4)
x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)
r = np.arange(mw, dtype=x1.dtype)[None, None, :]
c = np.arange(mh, dtype=x1.dtype)[None, :, None]
assert r.shape == (1, 1, mw)
assert c.shape == (1, mh, 1)
masks_out = masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
assert masks_out.shape == (N, mh, mw)
return masks_out


def resize(buf: np.ndarray, size: tuple[W, H]) -> np.ndarray:
w, h = size
Expand Down
2 changes: 1 addition & 1 deletion tests/test_props.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@settings(
deadline=15 * 1000,
max_examples=40,
max_examples=20,
)
@given(
img=arb_img((1, 1000), (1, 1000)),
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit bebd8d0

Please sign in to comment.