From 22f788b9727eee94d8a707b0c9b549cbc6973b83 Mon Sep 17 00:00:00 2001 From: roman-bushuiev Date: Wed, 14 Aug 2024 22:40:33 +0200 Subject: [PATCH] Refactor `decode_smiles` and `stage` connection --- massspecgym/models/de_novo/smiles_tranformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/massspecgym/models/de_novo/smiles_tranformer.py b/massspecgym/models/de_novo/smiles_tranformer.py index 3d1836c..0c334c9 100644 --- a/massspecgym/models/de_novo/smiles_tranformer.py +++ b/massspecgym/models/de_novo/smiles_tranformer.py @@ -107,9 +107,9 @@ def step(self, batch: dict, stage: Stage = Stage.NONE) -> dict: # Generate SMILES strings if stage in self.log_only_loss_at_stages: - mols_pred = self.decode_smiles(batch["spec"]) - else: mols_pred = None + else: + mols_pred = self.decode_smiles(batch["spec"]) return dict(loss=loss, mols_pred=mols_pred)