Skip to content

Commit 36a1836

Browse files
author
TensorFlow Recommenders Authors
committed
Add support for multipoint query retrieval to TFRS Retrieval task.
PiperOrigin-RevId: 689532029
1 parent 685694e commit 36a1836

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

tensorflow_recommenders/tasks/retrieval.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ def call(
141141
142142
Args:
143143
query_embeddings: [num_queries, embedding_dim] tensor of query
144-
representations.
144+
representations, or [num_queries, num_heads, embedding_dim]. If latter,
145+
we do "maxsim" scoring over those multiple query heads. This applies to
146+
the loss computation and batch metrics. Factorized metrics won't be
147+
computed in this case.
145148
candidate_embeddings: [num_candidates, embedding_dim] tensor of candidate
146149
representations. Normally, `num_candidates` is the same as
147150
`num_queries`: there is a positive candidate corresponding for every
@@ -166,8 +169,15 @@ def call(
166169
loss: Tensor of loss values.
167170
"""
168171

169-
scores = tf.linalg.matmul(
170-
query_embeddings, candidate_embeddings, transpose_b=True)
172+
if len(tf.shape(query_embeddings)) == 3:
173+
scores = tf.einsum(
174+
"qne,ce->qnc", query_embeddings, candidate_embeddings
175+
)
176+
scores = tf.math.reduce_max(scores, axis=1)
177+
else:
178+
scores = tf.linalg.matmul(
179+
query_embeddings, candidate_embeddings, transpose_b=True
180+
)
171181

172182
num_queries = tf.shape(scores)[0]
173183
num_candidates = tf.shape(scores)[1]
@@ -203,7 +213,7 @@ def call(
203213
for metric in self._loss_metrics:
204214
update_ops.append(metric.update_state(loss))
205215

206-
if compute_metrics:
216+
if compute_metrics and len(tf.shape(query_embeddings)) == 2:
207217
for metric in self._factorized_metrics:
208218
update_ops.append(
209219
metric.update_state(

tensorflow_recommenders/tasks/retrieval_test.py

+48
Original file line numberDiff line numberDiff line change
@@ -250,5 +250,53 @@ def test_task(self):
250250
self.assertAllClose(expected_metrics2, metrics2_)
251251

252252

253+
class RetrievalTestWithMultipointQueries(tf.test.TestCase):
254+
255+
def test_task(self):
256+
257+
query = tf.constant(
258+
[[[3, 2, 1], [1, 2, 3]], [[2, 3, 4], [4, 3, 2]]], dtype=tf.float32
259+
)
260+
candidate = tf.constant([[0, 1, 0], [0, 1, 1], [1, 1, 0]], dtype=tf.float32)
261+
candidate_dataset = tf.data.Dataset.from_tensor_slices(
262+
np.array([[0, 0, 0]] * 20, dtype=np.float32)
263+
)
264+
265+
task = retrieval.Retrieval(
266+
metrics=metrics.FactorizedTopK(
267+
candidates=candidate_dataset.batch(16), ks=[5]
268+
),
269+
batch_metrics=[
270+
tf.keras.metrics.TopKCategoricalAccuracy(
271+
k=1, name="batch_categorical_accuracy_at_1"
272+
)
273+
],
274+
)
275+
276+
# Scores will have shape [num_queries, num_candidates]
277+
# All_pair_scores: [[[2,2], [3,5], [5,3]], [[3, 3], [7,5], [5,7]]].
278+
# Max-sim scores: [[2, 5, 5], [3, 7, 7]].
279+
# Normalized logits: [[0, 3, 3], [1, 5, 5]].
280+
expected_loss = -np.log(1 / (1 + np.exp(3) + np.exp(3))) - np.log(
281+
np.exp(5) / (np.exp(1) + np.exp(5) + np.exp(5))
282+
)
283+
284+
expected_metrics = {
285+
"factorized_top_k/top_5_categorical_accuracy": (
286+
0.0
287+
), # not computed for multipoint queries
288+
"batch_categorical_accuracy_at_1": 0.5,
289+
}
290+
loss = task(
291+
query_embeddings=query,
292+
candidate_embeddings=candidate,
293+
)
294+
metrics_ = {metric.name: metric.result().numpy() for metric in task.metrics}
295+
296+
self.assertIsNotNone(loss)
297+
self.assertAllClose(expected_loss, loss)
298+
self.assertAllClose(expected_metrics, metrics_)
299+
300+
253301
if __name__ == "__main__":
254302
tf.test.main()

0 commit comments

Comments
 (0)