From 59d25b4d64e82c837aa54666b152193892559ec7 Mon Sep 17 00:00:00 2001 From: Manuel Lopez Antequera Date: Fri, 8 Jan 2021 15:17:52 +0100 Subject: [PATCH 1/5] Add new test for accuracy calculator This test covers one-dimensional labels with a custom comparison fn --- tests/utils/test_calculate_accuracies.py | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index 1bd222a0..f220142e 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -142,6 +142,38 @@ def correct_mean_average_precision( else: return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4]) + def test_get_lone_query_labels_custom(self): + def comparison_fn(x, y): + return abs(x - y) < 2 + + query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) + + label_counts, num_k = accuracy_calculator.get_label_match_counts( + query_labels, + query_labels, + comparison_fn, + ) + unique_labels, counts = label_counts + + correct_unique_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) + correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1]) + self.assertTrue(np.all(unique_labels == correct_unique_labels)) + self.assertTrue(np.all(counts == correct_counts)) + + ( + lone_query_labels, + not_lone_query_mask, + ) = accuracy_calculator.get_lone_query_labels( + query_labels, label_counts, True, comparison_fn + ) + + correct_lone_query_labels = np.array([100]) + correct_not_lone_query_mask = np.array( + [True, True, True, True, True, True, True, True, True, True, True, False] + ) + self.assertTrue(np.all(lone_query_labels == correct_lone_query_labels)) + self.assertTrue(np.all(not_lone_query_mask == correct_not_lone_query_mask)) + def test_get_lone_query_labels_multi_dim(self): def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) From b003fa0a4d0def9e3b2432b2cb8d771cd02d090a Mon Sep 17 00:00:00 2001 From: Manuel Lopez Antequera Date: Fri, 8 Jan 2021 15:21:38 +0100 Subject: [PATCH 2/5] Keep numpy indices while iterating over 1D labels --- src/pytorch_metric_learning/utils/accuracy_calculator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index 80bb455d..d428bbda 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -133,9 +133,11 @@ def get_label_match_counts(query_labels, reference_labels, label_comparison_fn): # Labels are compared with a custom function. # They might be non-categorical or multidimensional labels. match_counts = np.array([0 for _ in unique_query_labels]) - for ix_a, label_a in enumerate(unique_query_labels): - for label_b in reference_labels: - if label_comparison_fn(label_a[None, :], label_b[None, :]): + for ix_a in range(len(unique_query_labels)): + label_a = unique_query_labels[ix_a : ix_a + 1] + for ix_b in range(len(reference_labels)): + label_b = reference_labels[ix_b : ix_b + 1] + if label_comparison_fn(label_a, label_b): match_counts[ix_a] += 1 # faiss can only do a max of k=1024, and we have to do k+1 From 46325698465da8cb45eb0bcca3d7dcef3061cf62 Mon Sep 17 00:00:00 2001 From: Manuel Lopez Antequera Date: Fri, 8 Jan 2021 19:40:37 +0100 Subject: [PATCH 3/5] Use custom label comparison in get_lone_query_labels If the embeddings come from the same source, the lone condition is also dependant on whether the labels match to themselves or not. I can't imagine a metric learning problem where this happens, but the test case that we have is such a case. It is easy to support --- .../utils/accuracy_calculator.py | 5 +-- tests/utils/test_calculate_accuracies.py | 32 +++++++++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index d428bbda..85467dba 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -153,8 +153,9 @@ def get_lone_query_labels( label_comparison_fn, ): unique_labels, match_counts = label_counts - if label_comparison_fn is EQUALITY and embeddings_come_from_same_source: - lone_condition = match_counts <= 1 + label_matches_itself = label_comparison_fn(unique_labels, unique_labels) + if embeddings_come_from_same_source: + lone_condition = match_counts - ~label_matches_itself <= 1 else: lone_condition = match_counts == 0 lone_query_labels = unique_labels[lone_condition] diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index f220142e..8b2edf7d 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -175,6 +175,9 @@ def comparison_fn(x, y): self.assertTrue(np.all(not_lone_query_mask == correct_not_lone_query_mask)) def test_get_lone_query_labels_multi_dim(self): + def equality2D(x, y): + return (x[..., 0] == y[..., 0]) & (x[..., 1] == y[..., 1]) + def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) @@ -184,20 +187,29 @@ def custom_label_comparison_fn(x, y): (0, 3), (0, 3), (0, 3), - (0, 2), (1, 2), (4, 5), ] ) - for comparison_fn in [accuracy_calculator.EQUALITY, custom_label_comparison_fn]: + for comparison_fn in [custom_label_comparison_fn]: label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) - if comparison_fn is accuracy_calculator.EQUALITY: + unique_labels, counts = label_counts + correct_unique_labels = np.array([[0, 3], [1, 2], [1, 3], [4, 5]]) + if comparison_fn is equality2D: + correct_counts = np.array([3, 1, 1, 1]) + else: + correct_counts = np.array([0, 1, 1, 0]) + + self.assertTrue(np.all(correct_counts == counts)) + self.assertTrue(np.all(correct_unique_labels == unique_labels)) + + if comparison_fn is equality2D: correct = [ ( True, @@ -211,11 +223,17 @@ def custom_label_comparison_fn(x, y): ), ] else: - correct_lone = np.array([[4, 5]]) - correct_mask = np.array([True, True, True, True, True, True, False]) correct = [ - (True, correct_lone, correct_mask), - (False, correct_lone, correct_mask), + ( + True, + np.array([[0, 3], [1, 2], [1, 3], [4, 5]]), + np.array([False, False, False, False, False, False]), + ), + ( + False, + np.array([[0, 3], [4, 5]]), + np.array([True, False, False, False, True, False]), + ), ] for same_source, correct_lone, correct_mask in correct: From 94815e69dd3de5a9b055269339f6dd4ae249826c Mon Sep 17 00:00:00 2001 From: Manuel Lopez Antequera Date: Fri, 8 Jan 2021 20:08:06 +0100 Subject: [PATCH 4/5] fix test: adapted expected output to new input --- tests/utils/test_calculate_accuracies.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index 8b2edf7d..39ad776a 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -192,7 +192,7 @@ def custom_label_comparison_fn(x, y): ] ) - for comparison_fn in [custom_label_comparison_fn]: + for comparison_fn in [equality2D, custom_label_comparison_fn]: label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, @@ -213,13 +213,13 @@ def custom_label_comparison_fn(x, y): correct = [ ( True, - np.array([[0, 2], [1, 2], [1, 3], [4, 5]]), - np.array([False, True, True, True, False, False, False]), + np.array([[1, 2], [1, 3], [4, 5]]), + np.array([False, True, True, True, False, False]), ), ( False, np.array([[]]), - np.array([True, True, True, True, True, True, True]), + np.array([True, True, True, True, True, True]), ), ] else: From 76003103546e77c2d61d7019767df7689725503c Mon Sep 17 00:00:00 2001 From: TakeshiMusgrave Date: Sat, 9 Jan 2021 17:47:33 -0500 Subject: [PATCH 5/5] Fixed the lone_condition for when embeddings_come_from_same_source is True. Added a test case to test_get_lone_query_labels_custom and fixed one of the correct arrays in test_get_lone_query_labels_multi_dim --- setup.py | 2 +- src/pytorch_metric_learning/__init__.py | 2 +- .../utils/accuracy_calculator.py | 2 +- tests/utils/test_calculate_accuracies.py | 93 +++++++++++++------ 4 files changed, 69 insertions(+), 30 deletions(-) diff --git a/setup.py b/setup.py index b7ee53e6..cf96fee2 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="pytorch-metric-learning", - version="0.9.96.dev2", + version="0.9.96", author="Kevin Musgrave", author_email="tkm45@cornell.edu", description="The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.", diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 789ea798..02784fbc 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "0.9.96.dev2" +__version__ = "0.9.96" diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index 85467dba..920cb278 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -155,7 +155,7 @@ def get_lone_query_labels( unique_labels, match_counts = label_counts label_matches_itself = label_comparison_fn(unique_labels, unique_labels) if embeddings_come_from_same_source: - lone_condition = match_counts - ~label_matches_itself <= 1 + lone_condition = match_counts - label_matches_itself <= 0 else: lone_condition = match_counts == 0 lone_query_labels = unique_labels[lone_condition] diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index 39ad776a..d4a4e200 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -143,36 +143,75 @@ def correct_mean_average_precision( return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4]) def test_get_lone_query_labels_custom(self): - def comparison_fn(x, y): + def fn1(x, y): return abs(x - y) < 2 + def fn2(x, y): + return abs(x - y) > 99 + query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) - label_counts, num_k = accuracy_calculator.get_label_match_counts( - query_labels, - query_labels, - comparison_fn, - ) - unique_labels, counts = label_counts - - correct_unique_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) - correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1]) - self.assertTrue(np.all(unique_labels == correct_unique_labels)) - self.assertTrue(np.all(counts == correct_counts)) - - ( - lone_query_labels, - not_lone_query_mask, - ) = accuracy_calculator.get_lone_query_labels( - query_labels, label_counts, True, comparison_fn - ) + for comparison_fn in [fn1, fn2]: + correct_unique_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) - correct_lone_query_labels = np.array([100]) - correct_not_lone_query_mask = np.array( - [True, True, True, True, True, True, True, True, True, True, True, False] - ) - self.assertTrue(np.all(lone_query_labels == correct_lone_query_labels)) - self.assertTrue(np.all(not_lone_query_mask == correct_not_lone_query_mask)) + if comparison_fn is fn1: + correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1]) + correct_lone_query_labels = np.array([100]) + correct_not_lone_query_mask = np.array( + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + ] + ) + elif comparison_fn is fn2: + correct_counts = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]) + correct_lone_query_labels = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + correct_not_lone_query_mask = np.array( + [ + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + ] + ) + + label_counts, num_k = accuracy_calculator.get_label_match_counts( + query_labels, + query_labels, + comparison_fn, + ) + unique_labels, counts = label_counts + + self.assertTrue(np.all(unique_labels == correct_unique_labels)) + self.assertTrue(np.all(counts == correct_counts)) + + ( + lone_query_labels, + not_lone_query_mask, + ) = accuracy_calculator.get_lone_query_labels( + query_labels, label_counts, True, comparison_fn + ) + + self.assertTrue(np.all(lone_query_labels == correct_lone_query_labels)) + self.assertTrue(np.all(not_lone_query_mask == correct_not_lone_query_mask)) def test_get_lone_query_labels_multi_dim(self): def equality2D(x, y): @@ -226,8 +265,8 @@ def custom_label_comparison_fn(x, y): correct = [ ( True, - np.array([[0, 3], [1, 2], [1, 3], [4, 5]]), - np.array([False, False, False, False, False, False]), + np.array([[0, 3], [4, 5]]), + np.array([True, False, False, False, True, False]), ), ( False,