@@ -250,5 +250,53 @@ def test_task(self):
250
250
self .assertAllClose (expected_metrics2 , metrics2_ )
251
251
252
252
253
+ class RetrievalTestWithMultipointQueries (tf .test .TestCase ):
254
+
255
+ def test_task (self ):
256
+
257
+ query = tf .constant (
258
+ [[[3 , 2 , 1 ], [1 , 2 , 3 ]], [[2 , 3 , 4 ], [4 , 3 , 2 ]]], dtype = tf .float32
259
+ )
260
+ candidate = tf .constant ([[0 , 1 , 0 ], [0 , 1 , 1 ], [1 , 1 , 0 ]], dtype = tf .float32 )
261
+ candidate_dataset = tf .data .Dataset .from_tensor_slices (
262
+ np .array ([[0 , 0 , 0 ]] * 20 , dtype = np .float32 )
263
+ )
264
+
265
+ task = retrieval .Retrieval (
266
+ metrics = metrics .FactorizedTopK (
267
+ candidates = candidate_dataset .batch (16 ), ks = [5 ]
268
+ ),
269
+ batch_metrics = [
270
+ tf .keras .metrics .TopKCategoricalAccuracy (
271
+ k = 1 , name = "batch_categorical_accuracy_at_1"
272
+ )
273
+ ],
274
+ )
275
+
276
+ # Scores will have shape [num_queries, num_candidates]
277
+ # All_pair_scores: [[[2,2], [3,5], [5,3]], [[3, 3], [7,5], [5,7]]].
278
+ # Max-sim scores: [[2, 5, 5], [3, 7, 7]].
279
+ # Normalized logits: [[0, 3, 3], [1, 5, 5]].
280
+ expected_loss = - np .log (1 / (1 + np .exp (3 ) + np .exp (3 ))) - np .log (
281
+ np .exp (5 ) / (np .exp (1 ) + np .exp (5 ) + np .exp (5 ))
282
+ )
283
+
284
+ expected_metrics = {
285
+ "factorized_top_k/top_5_categorical_accuracy" : (
286
+ 0.0
287
+ ), # not computed for multipoint queries
288
+ "batch_categorical_accuracy_at_1" : 0.5 ,
289
+ }
290
+ loss = task (
291
+ query_embeddings = query ,
292
+ candidate_embeddings = candidate ,
293
+ )
294
+ metrics_ = {metric .name : metric .result ().numpy () for metric in task .metrics }
295
+
296
+ self .assertIsNotNone (loss )
297
+ self .assertAllClose (expected_loss , loss )
298
+ self .assertAllClose (expected_metrics , metrics_ )
299
+
300
+
253
301
if __name__ == "__main__" :
254
302
tf .test .main ()
0 commit comments