Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
GrogusBall committed Feb 25, 2021
1 parent 82ee504 commit 142fc0d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 3 additions & 1 deletion deep_recommenders/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
@Author: Wang Yao
@Date: 2021-02-22 21:50:08
@LastEditors: Wang Yao
@LastEditTime: 2021-02-22 21:50:09
@LastEditTime: 2021-02-24 17:02:36
"""
from deep_recommenders.tasks.ranking import Ranking
from deep_recommenders.tasks.retrieval import Retrieval
9 changes: 6 additions & 3 deletions deep_recommenders/tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@Author: Wang Yao
@Date: 2021-02-24 11:23:29
@LastEditors: Wang Yao
@LastEditTime: 2021-02-24 14:42:08
@LastEditTime: 2021-02-24 19:36:34
"""
from typing import Optional

Expand Down Expand Up @@ -58,7 +58,10 @@ def call(self,

scores = tf.matmul(query_embeddings, candidate_embeddings, transpose_b=True)

labels = tf.eye(*tf.shape(scores))
num_queries = tf.shape(scores)[0]
num_candidates = tf.shape(scores)[1]

labels = tf.eye(num_queries, num_candidates)

if candidate_sampling_probability is not None:
scores = layers.loss.SamplingProbablityCorrection()(
Expand All @@ -75,7 +78,7 @@ def call(self,
if self._temperature is not None:
scores = scores / self._temperature

loss = self._loss(y_ture=labels, y_pred=scores, sample_weight=sample_weight)
loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)

if compute_metrics is False:
return loss
Expand Down

0 comments on commit 142fc0d

Please sign in to comment.