Skip to content

Commit 685694e

Browse files
author
TensorFlow Recommenders Authors
committed
Fix a bug in loss metrics calculation, and allow specifying score_mask in tfrs.tasks.Retrieval.call
PiperOrigin-RevId: 663811507
1 parent 5e0629c commit 685694e

File tree

2 files changed

+70
-22
lines changed

2 files changed

+70
-22
lines changed

tensorflow_recommenders/tasks/retrieval.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
# Lint-as: python3
1616
"""A factorized retrieval task."""
1717

18-
from typing import Optional, Sequence, Union, Text, List
18+
from typing import List, Optional, Sequence, Text, Union
1919

20+
import numpy as np
2021
import tensorflow as tf
21-
2222
from tensorflow_recommenders import layers
2323
from tensorflow_recommenders import metrics as tfrs_metrics
2424
from tensorflow_recommenders.tasks import base
2525

26+
MIN_FLOAT = np.finfo(np.float32).min / 100.0
27+
2628

2729
class Retrieval(tf.keras.layers.Layer, base.Task):
2830
"""A factorized retrieval task.
@@ -116,14 +118,17 @@ def factorized_metrics(self,
116118

117119
self._factorized_metrics = value
118120

119-
def call(self,
120-
query_embeddings: tf.Tensor,
121-
candidate_embeddings: tf.Tensor,
122-
sample_weight: Optional[tf.Tensor] = None,
123-
candidate_sampling_probability: Optional[tf.Tensor] = None,
124-
candidate_ids: Optional[tf.Tensor] = None,
125-
compute_metrics: bool = True,
126-
compute_batch_metrics: bool = True) -> tf.Tensor:
121+
def call(
122+
self,
123+
query_embeddings: tf.Tensor,
124+
candidate_embeddings: tf.Tensor,
125+
sample_weight: Optional[tf.Tensor] = None,
126+
candidate_sampling_probability: Optional[tf.Tensor] = None,
127+
candidate_ids: Optional[tf.Tensor] = None,
128+
compute_metrics: bool = True,
129+
compute_batch_metrics: bool = True,
130+
score_mask: Optional[tf.Tensor] = None,
131+
) -> tf.Tensor:
127132
"""Computes the task loss and metrics.
128133
129134
The main argument are pairs of query and candidate embeddings: the first row
@@ -149,10 +154,14 @@ def call(self,
149154
reflect the sampling probability of negative candidates.
150155
candidate_ids: Optional tensor containing candidate ids. When given,
151156
factorized top-K evaluation will be id-based rather than score-based.
152-
compute_metrics: Whether to compute metrics. Set this to False
153-
during training for faster training.
154-
compute_batch_metrics: Whether to compute batch level metrics.
155-
In-batch loss_metrics will still be computed.
157+
compute_metrics: Whether to compute metrics. Set this to False during
158+
training for faster training.
159+
compute_batch_metrics: Whether to compute batch level metrics. In-batch
160+
loss_metrics will still be computed.
161+
score_mask: [num_queries, num_candidates] boolean tensor indicating for
162+
each query, which candidates should be considered for loss and
163+
metrics computation (false means the candidate is not considered).
164+
156165
Returns:
157166
loss: Tensor of loss values.
158167
"""
@@ -180,6 +189,9 @@ def call(self,
180189
)
181190
scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)
182191

192+
if score_mask is not None:
193+
scores = tf.where(score_mask, scores, MIN_FLOAT)
194+
183195
if self._num_hard_negatives is not None:
184196
scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(
185197
scores,
@@ -189,8 +201,7 @@ def call(self,
189201

190202
update_ops = []
191203
for metric in self._loss_metrics:
192-
update_ops.append(
193-
metric.update_state(loss, sample_weight=sample_weight))
204+
update_ops.append(metric.update_state(loss))
194205

195206
if compute_metrics:
196207
for metric in self._factorized_metrics:

tensorflow_recommenders/tasks/retrieval_test.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,20 @@ def test_task(self):
3737

3838
task = retrieval.Retrieval(
3939
metrics=metrics.FactorizedTopK(
40-
candidates=candidate_dataset.batch(16),
41-
ks=[5]
40+
candidates=candidate_dataset.batch(16), ks=[5]
4241
),
4342
batch_metrics=[
4443
tf.keras.metrics.TopKCategoricalAccuracy(
45-
k=1, name="batch_categorical_accuracy_at_1")
46-
])
44+
k=1, name="batch_categorical_accuracy_at_1"
45+
)
46+
],
47+
loss_metrics=[
48+
tf.keras.metrics.Mean(
49+
name="batch_loss",
50+
dtype=tf.float32,
51+
)
52+
],
53+
)
4754

4855
# All_pair_scores: [[6, 3], [9, 5]].
4956
# Normalized logits: [[3, 0], [4, 0]].
@@ -52,6 +59,7 @@ def test_task(self):
5259
expected_metrics = {
5360
"factorized_top_k/top_5_categorical_accuracy": 1.0,
5461
"batch_categorical_accuracy_at_1": 0.5,
62+
"batch_loss": expected_loss,
5563
}
5664
loss = task(query_embeddings=query, candidate_embeddings=candidate)
5765
metrics_ = {
@@ -70,7 +78,8 @@ def test_task(self):
7078
compute_metrics=False)
7179
expected_metrics1 = {
7280
"factorized_top_k/top_5_categorical_accuracy": 0.0,
73-
"batch_categorical_accuracy_at_1": 0.5
81+
"batch_categorical_accuracy_at_1": 0.5,
82+
"batch_loss": loss,
7483
}
7584
metrics1_ = {
7685
metric.name: metric.result().numpy() for metric in task.metrics
@@ -89,7 +98,8 @@ def test_task(self):
8998
compute_batch_metrics=False)
9099
expected_metrics2 = {
91100
"factorized_top_k/top_5_categorical_accuracy": 1.0,
92-
"batch_categorical_accuracy_at_1": 0.0
101+
"batch_categorical_accuracy_at_1": 0.0,
102+
"batch_loss": loss,
93103
}
94104
metrics2_ = {
95105
metric.name: metric.result().numpy() for metric in task.metrics
@@ -99,6 +109,33 @@ def test_task(self):
99109
self.assertAllClose(expected_loss, loss)
100110
self.assertAllClose(expected_metrics2, metrics2_)
101111

112+
# Test computation of metrics with sample_weight
113+
for metric in task.metrics:
114+
metric.reset_states()
115+
loss = task(
116+
query_embeddings=query,
117+
candidate_embeddings=candidate,
118+
sample_weight=tf.constant([0.7, 0.3], dtype=tf.float32),
119+
)
120+
121+
# All_pair_scores: [[6, 3], [9, 5]].
122+
# Normalized logits: [[3, 0], [4, 0]].
123+
expected_loss3 = -0.7 * np.log(_sigmoid(3.0)) - 0.3 * np.log(
124+
1 - _sigmoid(4.0)
125+
)
126+
127+
expected_metrics3 = {
128+
"factorized_top_k/top_5_categorical_accuracy": 1.0,
129+
"batch_categorical_accuracy_at_1": 0.7,
130+
"batch_loss": expected_loss3,
131+
}
132+
metrics3_ = {
133+
metric.name: metric.result().numpy() for metric in task.metrics
134+
}
135+
self.assertIsNotNone(loss)
136+
self.assertAllClose(expected_loss3, loss)
137+
self.assertAllClose(expected_metrics3, metrics3_)
138+
102139
def test_task_graph(self):
103140

104141
with tf.Graph().as_default():

0 commit comments

Comments
 (0)