Skip to content

Commit

Permalink
Merge pull request #36 from anton-bushuiev/main
Browse files Browse the repository at this point in the history
Enhance evaluation
  • Loading branch information
anton-bushuiev authored Aug 13, 2024
2 parents 79cf5da + 1d3f502 commit d67eeb3
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 88 deletions.
8 changes: 8 additions & 0 deletions massspecgym/definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Global variables used across the package."""
import pathlib

# Dirs
MASSSPECGYM_ROOT_DIR = pathlib.Path(__file__).parent.absolute()
MASSSPECGYM_REPO_DIR = MASSSPECGYM_ROOT_DIR.parent
MASSSPECGYM_DATA_DIR = MASSSPECGYM_REPO_DIR / 'data'
MASSSPECGYM_TEST_RESULTS_DIR = MASSSPECGYM_DATA_DIR / 'test_results'
50 changes: 44 additions & 6 deletions massspecgym/models/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import typing as T
import collections
from enum import Enum
from abc import ABC, abstractmethod
from pathlib import Path

import torch
import pytorch_lightning as pl
from torchmetrics import Metric, SumMetric
from massspecgym.utils import ReturnScalarBootStrapper


class Stage(Enum):
Expand All @@ -25,15 +28,24 @@ def __init__(
lr: float = 1e-4,
weight_decay: float = 0.0,
log_only_loss_at_stages: T.Sequence[Stage | str] = (),
bootstrap_metrics: bool = True,
df_test_path: T.Optional[str | Path] = None,
*args,
**kwargs
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.save_hyperparameters()

# Setup metring logging
self.log_only_loss_at_stages = [
Stage(s) if isinstance(s, str) else s for s in log_only_loss_at_stages
]
self.bootstrap_metrics = bootstrap_metrics

# Init dictionary to store dataframe columns where rows correspond to samples
# (for constructing test dataframe with predictions and metrics for each sample)
self.df_test_path = Path(df_test_path) if df_test_path is not None else None
self.df_test = collections.defaultdict(list)

@abstractmethod
def step(
Expand Down Expand Up @@ -81,7 +93,7 @@ def on_test_batch_end(self, *args, **kwargs):

def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
)

def get_checkpoint_monitors(self) -> list[dict]:
Expand All @@ -100,14 +112,19 @@ def _update_metric(
metric_kwargs: T.Optional[dict] = None,
log: bool = True,
log_n_samples: bool = False,
bootstrap: bool = False,
num_bootstraps: int = 100
) -> None:
"""
This method enables updating and logging metrics without instantiating them in advance in
the __init__ method. The metrics are aggreated over batches and logged at the end of the
epoch. If the metric does not exist yet, it is instantiated and added as an attribute to the
model.
"""
# Log total number of samples for debugging
# Process arguments
bootstrap = bootstrap and self.bootstrap_metrics

# Log total number of samples (useful for debugging)
if log_n_samples:
self._update_metric(
name=name + "_n_samples",
Expand All @@ -122,7 +139,8 @@ def _update_metric(
else:
if metric_kwargs is None:
metric_kwargs = dict()
metric = metric_class(**metric_kwargs).to(self.device)
metric = metric_class(**metric_kwargs)
metric = metric.to(self.device)
setattr(self, name, metric)

# Update
Expand All @@ -138,5 +156,25 @@ def _update_metric(
on_step=False,
on_epoch=True,
add_dataloader_idx=False,
metric_attribute=name, # Suggested by a torchmetrics error
metric_attribute=name # Suggested by a torchmetrics error
)

# Bootstrap
if bootstrap:
def _bootsrapped_metric_class(**metric_kwargs):
metric = metric_class(**metric_kwargs)
return ReturnScalarBootStrapper(metric, std=True, num_bootstraps=num_bootstraps)

self._update_metric(
name=name + "_std",
metric_class=_bootsrapped_metric_class,
update_args=update_args,
batch_size=batch_size,
metric_kwargs=metric_kwargs,
)

def _update_df_test(self, dct: dict) -> None:
for col, vals in dct.items():
if isinstance(vals, torch.Tensor):
vals = vals.tolist()
self.df_test[col].extend(vals)
69 changes: 59 additions & 10 deletions massspecgym/models/de_novo/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import typing as T
from abc import ABC

import pulp
import torch
import pandas as pd
from rdkit import Chem
from rdkit.DataStructs import TanimotoSimilarity
from torchmetrics.aggregation import MeanMetric
Expand Down Expand Up @@ -46,18 +47,21 @@ def on_batch_end(
if stage in self.log_only_loss_at_stages:
return

self.evaluate_de_novo_step(
metric_vals = self.evaluate_de_novo_step(
outputs["mols_pred"], # (bs, k) list of generated rdkit molecules or SMILES strings
batch["mol"], # (bs) list of ground truth SMILES strings
stage=stage
)

if stage == Stage.TEST and self.df_test_path is not None:
self._update_df_test(metric_vals)

def evaluate_de_novo_step(
self,
mols_pred: list[list[T.Optional[Chem.Mol | str]]],
mol_true: list[str],
stage: Stage,
) -> None:
) -> dict[str, torch.Tensor]:
"""
# TODO: refactor to compute only for max(k) and then use the result to obtain the rest by
subsetting.
Expand All @@ -69,6 +73,9 @@ def evaluate_de_novo_step(
strings with possible Nones if no molecule was generated
mol_true (list[str]): (bs) list of ground-truth SMILES strings
"""
# Initialize return dictionary to store metric values per sample
metric_vals = {}

# Get SMILES and RDKit molecule objects for all predictions
if self.mol_pred_kind == "smiles":
smiles_pred_valid, mols_pred_valid = [], []
Expand Down Expand Up @@ -112,15 +119,13 @@ def _get_morgan_fp_with_cache(mol):
self.mol_2_morgan_fp[mol] = morgan_fp(mol, to_np=False)
return self.mol_2_morgan_fp[mol]



# Evaluate top-k metrics
for top_k in self.top_ks:
# Get top-k predicted molecules for each ground-truth sample
smiles_pred_top_k = [smiles_pred_sample[:top_k] for smiles_pred_sample in smiles_pred]
mols_pred_top_k = [mols_pred_sample[:top_k] for mols_pred_sample in mols_pred]

# 1. Evaluate minimum common edge subgraph:
# 1. Evaluate minimum common edge subgraph:
# Calculate MCES distance between top-k predicted molecules and ground truth and
# report the minimum distance. The minimum distances for each sample in the batch are
# averaged across the epoch.
Expand All @@ -139,12 +144,18 @@ def _get_morgan_fp_with_cache(mol):
self.mces_cache[(true, pred)] = mce_val
dists.append(self.mces_cache[(true, pred)])
min_mces_dists.append(min(min(dists), mces_thld))
min_mces_dists = torch.tensor(min_mces_dists, device=self.device)

# Log
metric_name = stage.to_pref() + f"top_{top_k}_mces_dist"
self._update_metric(
stage.to_pref() + f"top_{top_k}_min_mces_dist",
metric_name,
MeanMetric,
(min_mces_dists,),
batch_size=len(min_mces_dists),
bootstrap=stage == Stage.TEST
)
metric_vals[metric_name] = min_mces_dists

# 2. Evaluate Tanimoto similarity:
# Calculate Tanimoto similarity between top-k predicted molecules and ground truth and
Expand All @@ -166,12 +177,18 @@ def _get_morgan_fp_with_cache(mol):
for pred in preds
]
max_tanimoto_sims.append(max(sims))
max_tanimoto_sims = torch.tensor(max_tanimoto_sims, device=self.device)

# Log
metric_name = stage.to_pref() + f"top_{top_k}_max_tanimoto_sim"
self._update_metric(
stage.to_pref() + f"top_{top_k}_max_tanimoto_sim",
metric_name,
MeanMetric,
(max_tanimoto_sims,),
batch_size=len(max_tanimoto_sims),
bootstrap=stage == Stage.TEST
)
metric_vals[metric_name] = max_tanimoto_sims

# 3. Evaluate exact match (accuracy):
# Calculate if the ground truth molecule is in the top-k predicted molecules and report
Expand All @@ -184,9 +201,41 @@ def _get_morgan_fp_with_cache(mol):
]
for true, preds in zip(mol_true, mols_pred_top_k)
]
in_top_k = torch.tensor(in_top_k, device=self.device)

# Log
metric_name = stage.to_pref() + f"top_{top_k}_accuracy"
self._update_metric(
stage.to_pref() + f"top_{top_k}_accuracy",
metric_name,
MeanMetric,
(in_top_k,),
batch_size=len(in_top_k)
batch_size=len(in_top_k),
bootstrap=stage == Stage.TEST
)
metric_vals[metric_name] = in_top_k

return metric_vals


def test_step(
self,
batch: dict,
batch_idx: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
outputs = super().test_step(batch, batch_idx)

# Get generated (i.e., predicted) SMILES
if self.df_test_path is not None:
self._update_df_test({
'identifier': batch['identifier'],
'mols_pred': outputs['mols_pred']
})

return outputs

def on_test_epoch_end(self):
# Save test data frame to disk
if self.df_test_path is not None:
df_test = pd.DataFrame(self.df_test)
self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
df_test.to_pickle(self.df_test_path)
Loading

0 comments on commit d67eeb3

Please sign in to comment.