15
15
# Lint-as: python3
16
16
"""A factorized retrieval task."""
17
17
18
- from typing import Optional , Sequence , Union , Text , List
18
+ from typing import List , Optional , Sequence , Text , Union
19
19
20
+ import numpy as np
20
21
import tensorflow as tf
21
-
22
22
from tensorflow_recommenders import layers
23
23
from tensorflow_recommenders import metrics as tfrs_metrics
24
24
from tensorflow_recommenders .tasks import base
25
25
26
+ MIN_FLOAT = np .finfo (np .float32 ).min / 100.0
27
+
26
28
27
29
class Retrieval (tf .keras .layers .Layer , base .Task ):
28
30
"""A factorized retrieval task.
@@ -116,14 +118,17 @@ def factorized_metrics(self,
116
118
117
119
self ._factorized_metrics = value
118
120
119
- def call (self ,
120
- query_embeddings : tf .Tensor ,
121
- candidate_embeddings : tf .Tensor ,
122
- sample_weight : Optional [tf .Tensor ] = None ,
123
- candidate_sampling_probability : Optional [tf .Tensor ] = None ,
124
- candidate_ids : Optional [tf .Tensor ] = None ,
125
- compute_metrics : bool = True ,
126
- compute_batch_metrics : bool = True ) -> tf .Tensor :
121
+ def call (
122
+ self ,
123
+ query_embeddings : tf .Tensor ,
124
+ candidate_embeddings : tf .Tensor ,
125
+ sample_weight : Optional [tf .Tensor ] = None ,
126
+ candidate_sampling_probability : Optional [tf .Tensor ] = None ,
127
+ candidate_ids : Optional [tf .Tensor ] = None ,
128
+ compute_metrics : bool = True ,
129
+ compute_batch_metrics : bool = True ,
130
+ score_mask : Optional [tf .Tensor ] = None ,
131
+ ) -> tf .Tensor :
127
132
"""Computes the task loss and metrics.
128
133
129
134
The main argument are pairs of query and candidate embeddings: the first row
@@ -149,10 +154,14 @@ def call(self,
149
154
reflect the sampling probability of negative candidates.
150
155
candidate_ids: Optional tensor containing candidate ids. When given,
151
156
factorized top-K evaluation will be id-based rather than score-based.
152
- compute_metrics: Whether to compute metrics. Set this to False
153
- during training for faster training.
154
- compute_batch_metrics: Whether to compute batch level metrics.
155
- In-batch loss_metrics will still be computed.
157
+ compute_metrics: Whether to compute metrics. Set this to False during
158
+ training for faster training.
159
+ compute_batch_metrics: Whether to compute batch level metrics. In-batch
160
+ loss_metrics will still be computed.
161
+ score_mask: [num_queries, num_candidates] boolean tensor indicating for
162
+ each query, which candidates should be considered for loss and
163
+ metrics computation (false means the candidate is not considered).
164
+
156
165
Returns:
157
166
loss: Tensor of loss values.
158
167
"""
@@ -180,6 +189,9 @@ def call(self,
180
189
)
181
190
scores = layers .loss .RemoveAccidentalHits ()(labels , scores , candidate_ids )
182
191
192
+ if score_mask is not None :
193
+ scores = tf .where (score_mask , scores , MIN_FLOAT )
194
+
183
195
if self ._num_hard_negatives is not None :
184
196
scores , labels = layers .loss .HardNegativeMining (self ._num_hard_negatives )(
185
197
scores ,
@@ -189,8 +201,7 @@ def call(self,
189
201
190
202
update_ops = []
191
203
for metric in self ._loss_metrics :
192
- update_ops .append (
193
- metric .update_state (loss , sample_weight = sample_weight ))
204
+ update_ops .append (metric .update_state (loss ))
194
205
195
206
if compute_metrics :
196
207
for metric in self ._factorized_metrics :
0 commit comments