Skip to content

Commit

Permalink
fix: scores dimensionnality
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirkey authored Dec 27, 2023
1 parent eefa054 commit cbace0f
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions bindings/tests/test_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

np.random.seed(42)

SCORES = np.random.random((100, 1))

@pytest.fixture
def scores():
return np.random.random((100,))

@pytest.mark.benchmark(group="giou_distance")
@pytest.mark.parametrize("dtype", supported_dtypes)
Expand Down Expand Up @@ -115,18 +116,18 @@ def test_masks_to_boxes(benchmark):

@pytest.mark.benchmark(group="nms")
@pytest.mark.parametrize("dtype", supported_dtypes)
def test_nms(benchmark, dtype, generate_boxes):
def test_nms(benchmark, dtype, generate_boxes, scores):
boxes = generate_boxes
boxes = boxes.astype(dtype)
benchmark(nms, boxes, SCORES, 0.5, 0.5)
benchmark(nms, boxes, scores, 0.5, 0.5)


@pytest.mark.benchmark(group="nms")
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "int16"])
def test_rtree_nms(benchmark, dtype, generate_boxes):
def test_rtree_nms(benchmark, dtype, generate_boxes, scores):
boxes = generate_boxes
boxes = boxes.astype(dtype)
benchmark(rtree_nms, boxes, SCORES, 0.5, 0.5)
benchmark(rtree_nms, boxes, scores, 0.5, 0.5)


@pytest.mark.benchmark(group="nms_many_boxes")
Expand Down

0 comments on commit cbace0f

Please sign in to comment.