Skip to content

Commit

Permalink
Merge pull request #85 from bio-ontology-research-group/develop
Browse files Browse the repository at this point in the history
🐛 fix device discrepance in evaluator
  • Loading branch information
ferzcam authored Oct 16, 2024
2 parents 8b9dfb0 + 3bb953b commit b7e1302
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mowl/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self, dataset, device="cpu", batch_size=16):
eval_heads, eval_tails = self.dataset.evaluation_classes

print(f"Number of evaluation classes: {len(eval_heads)}")
self.evaluation_heads = th.tensor([self.class_to_id[c] for c in eval_heads.as_str], dtype=th.long)
self.evaluation_tails = th.tensor([self.class_to_id[c] for c in eval_tails.as_str], dtype=th.long)
self.evaluation_heads = th.tensor([self.class_to_id[c] for c in eval_heads.as_str], dtype=th.long).to(self.device)
self.evaluation_tails = th.tensor([self.class_to_id[c] for c in eval_tails.as_str], dtype=th.long).to(self.device)


@property
Expand All @@ -71,6 +71,7 @@ def evaluate_base(self, model, eval_tuples, mode="test",
filter_deductive_closure=False,
**kwargs):

model = model.to(self.device)
num_heads, num_tails = len(self.evaluation_heads), len(self.evaluation_tails)
model.eval()
if not mode in ["valid", "test"]:
Expand Down

0 comments on commit b7e1302

Please sign in to comment.