Skip to content

Commit

Permalink
only copare top 100 classes in image classification
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 29, 2025
1 parent 5ede25a commit eedff54
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2867,9 +2867,6 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin):
ORTMODEL_CLASS = ORTModelForImageClassification
TASK = "image-classification"

ATOL = 2e-3 # 0.02 difference in logits
RTOL = 1e-2 # 1% difference in logits

def _get_model_ids(self, model_arch):
model_ids = MODEL_NAMES[model_arch]
if isinstance(model_ids, dict):
Expand Down Expand Up @@ -3040,8 +3037,16 @@ def test_compare_to_io_binding(self, model_arch):
onnx_outputs = onnx_model(**inputs)
io_outputs = io_model(**inputs)

print("shape of logits", io_outputs.logits.shape)

self.assertTrue("logits" in io_outputs)
self.assertIsInstance(io_outputs.logits, torch.Tensor)
self.assertEqual(io_outputs.logits.shape, onnx_outputs.logits.shape)

if io_outputs.logits.shape[1] > 100:
# we compare only the top 100 classes (biggest 100 values in order)
io_outputs.logits = torch.topk(io_outputs.logits, 100, dim=1).values
onnx_outputs.logits = torch.topk(onnx_outputs.logits, 100, dim=1).values

# compare tensor outputs
torch.testing.assert_close(onnx_outputs.logits, io_outputs.logits, atol=self.ATOL, rtol=self.RTOL)
Expand Down

0 comments on commit eedff54

Please sign in to comment.