Skip to content

Commit

Permalink
refactor: Update README to include examples
Browse files Browse the repository at this point in the history
refactor: Modify config default values
refactor: increased CI tests timeout from 30 min to 90 min.
refactor: pin dataloader memory if cuda device is available
chore: update version 0.3.6 to 0.3.7
  • Loading branch information
emapco committed Feb 4, 2025
1 parent 83c8bb1 commit ecdc7a8
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
tests:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 90
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down
152 changes: 150 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,162 @@ Hyperparameter tuning indicates that a custom Tanimoto similarity loss function,
pip install chem-mrl
```

## Usage

### Basic Training Workflow

To train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training.

```python
from chem_mrl.configs import ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer

# Define training configuration
config = ChemMRLConfig(
model_name=BASE_MODEL_NAME, # Predefined model name - Can be a any transformer model name or path that is compatible with sentence-transformers
train_dataset_path="train.parquet", # Path to training data
val_dataset_path="val.parquet", # Path to validation data
test_dataset_path="test.parquet", # Optional test dataset
smiles_a_column_name="smiles_a", # Column with first molecule SMILES representation
smiles_b_column_name="smiles_b", # Column with second molecule SMILES representation
label_column_name="similarity", # Similarity score between molecules
return_eval_metric=True, # Compute and return test metric if test dataset is provided
n_dims_per_step=3, # Model-specific hyperparameter
)

# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = trainer.train() # Returns evaluation metric (if test dataset exists) otherwise returns the final validation eval metric
```

### Custom Evaluation Callbacks

You can provide a callback function that is executed every `evaluation_steps` steps, allowing custom logic such as logging, early stopping, or model checkpointing.

```python
from chem_mrl.configs import Chem2dMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer

# Define a callback function for logging evaluation metrics
def eval_callback(score: float, epoch: int, steps: int):
print(f"Step {steps}, Epoch {epoch}: Evaluation Score = {score}")

# Define configuration for a 2D MRL model with additional hyperparameters
config = Chem2dMRLConfig(
model_name=BASE_MODEL_NAME,
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
evaluation_steps=1000, # Callback execution frequency
return_eval_metric=True, # Returns validation metric instead of test metric
n_dims_per_step=3, # Model-specific hyperparameter
n_layers_per_step=2, # Additional parameter specific to 2D MRL models
kl_div_weight=0.7, # Weight for KL divergence regularization
kl_temperature=0.5, # Temperature parameter for KL loss
)

# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(eval_callback=eval_callback) # Callback executed every `evaluation_steps`
```

### W&B Integration

This library includes a `WandBTrainerExecutor` class for seamless Weights & Biases (W&B) integration. It handles authentication, initialization, and logging at the frequency specified by `evaluation_steps`. This setup ensures seamless logging and experiment tracking, allowing for better visualization and monitoring of model performance.


```python
from chem_mrl.configs import Chem2dMRLConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer, WandBTrainerExecutor

# Define W&B configuration for experiment tracking
wandb_config = WandbConfig(
project_name="chem_mrl_test", # W&B project name
run_name="test", # Name for the experiment run
use_watch=True, # Enables model watching for tracking gradients
watch_log="all", # Logs all model parameters and gradients
watch_log_freq=1000, # Logging frequency
watch_log_graph=True, # Logs model computation graph
)

# Configure training with W&B integration
config = ChemMRLConfig(
model_name=BASE_MODEL_NAME,
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
evaluation_steps=1000,
use_wandb=True, # Enables W&B logging
wandb_config=wandb_config,
)

# Initialize trainer and W&B executor
trainer = ChemMRLTrainer(config)
executor = WandBTrainerExecutor(trainer)
executor.execute() # Handles training and W&B logging
```

## Classifier

This repository includes code for training a linear SBERT classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features. While demonstrated on the Isomer Design dataset, it is generalizable to any dataset containing `smiles` and `label` columns. The training scripts (see below) allow users to specify these column names.
This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features. While demonstrated on the Isomer Design dataset, it is generalizable to any dataset containing `smiles` and `label` columns. The training scripts (see below) allow users to specify these column names.

Currently, the dataset must be in Parquet format.

Hyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.

## Usage

### Basic Classification Training

To train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training.

```python
from chem_mrl.configs import ClassifierConfig
from chem_mrl.trainers import ClassifierTrainer

# Define classification training configuration
config = ClassifierConfig(
model_name="path/to/trained_mrl_model", # Pretrained MRL model path
train_dataset_path="train_classification.parquet", # Path to training dataset
val_dataset_path="val_classification.parquet", # Path to validation dataset
smiles_column_name="smiles", # Column containing SMILES representations of molecules
label_column_name="label", # Column containing classification labels
)

# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()
```

### Training with Dice Loss

For imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters.

```python
from chem_mrl.configs import DiceLossClassifierConfig
from chem_mrl.trainers import ClassifierTrainer
from chem_mrl.constants import BASE_MODEL_NAME

# Define classification training configuration with Dice Loss
config = DiceLossClassifierConfig(
model_name=BASE_MODEL_NAME, # Predefined base model
train_dataset_path="train_classification.parquet",
val_dataset_path="val_classification.parquet",
smiles_column_name="smiles",
label_column_name="label",
dice_reduction="sum", # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
dice_gamma=1.0, # Hyperparameter controlling the impact of Dice Loss
)

# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()
```

## Scripts

The `scripts` directory contains two training scripts:
Expand Down Expand Up @@ -112,7 +260,7 @@ options:
--smiles_b_column_name SMILES_B_COLUMN_NAME
SMILES B column name (default: smiles_b)
--label_column_name LABEL_COLUMN_NAME
Label column name (default: fingerprint_similarity)
Label column name (default: similarity)
--embedding_pooling {mean,mean_sqrt_len_tokens,weightedmean,lasttoken}
Pooling layer method applied to the embeddings.Pooling layer is required to generate a fixed sized SMILES embedding from a variable sized SMILES.For details
visit: https://sbert.net/docs/package_reference/sentence_transformer/models.html#sentence_transformers.models.Pooling (default: mean)
Expand Down
2 changes: 1 addition & 1 deletion chem_mrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sentence_transformers import LoggingHandler

__version__ = "0.3.6"
__version__ = "0.3.7"

from . import (
benchmark,
Expand Down
2 changes: 1 addition & 1 deletion chem_mrl/configs/ArgParseHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def add_chem_mrl_config_args(
)
parser.add_argument(
"--label_column_name",
default="fingerprint_similarity",
default="similarity",
help="Label column name",
)
parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions chem_mrl/configs/ClassifierConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __post_init__(self):

@dataclass(frozen=True)
class DiceLossClassifierConfig(ClassifierConfig):
loss_func = "selfadjdice"
dice_reduction: DiceReductionOptionType = "mean" # type: ignore
dice_gamma: float = 1.0

Expand Down
2 changes: 1 addition & 1 deletion chem_mrl/configs/MrlConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class ChemMRLConfig(_BaseConfig):
smiles_a_column_name: str = "smiles_a"
smiles_b_column_name: str = "smiles_b"
label_column_name: str = "fingerprint_similarity"
label_column_name: str = "similarity"
embedding_pooling: ChemMrlPoolingOptionType = "mean" # type: ignore
loss_func: ChemMrlLossFctOptionType = "tanimotosentloss" # type: ignore
tanimoto_similarity_loss_func: TanimotoSimilarityBaseLossFctOptionType | None = None # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions chem_mrl/trainers/BaseTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def _initialize_output_path(self):
# concrete methods
############################################################################

def _device(self) -> str:
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "-1")
use_cuda = torch.cuda.is_available() and cuda_visible_devices != "-1"
device = "cuda" if use_cuda else "cpu"
return device

def __calculate_training_params(self) -> tuple[float, float, int]:
total_training_points = self.steps_per_epoch * self.config.train_batch_size
# Normalized weight decay for adamw optimizer - https://arxiv.org/pdf/1711.05101.pdf
Expand Down
16 changes: 10 additions & 6 deletions chem_mrl/trainers/ClassifierTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def _initialize_data(
test_file: str | None = None,
):
logging.info(f"Loading {train_file} dataset")

pin_device = self._device()
pin_memory = True if pin_device != "cpu" else False

train_df = pd.read_parquet(
train_file,
columns=[
Expand All @@ -172,8 +176,8 @@ def _initialize_data(
),
batch_size=self._config.train_batch_size,
shuffle=True,
pin_memory=True,
pin_memory_device="cuda",
pin_memory=pin_memory,
pin_memory_device=pin_device,
num_workers=self._config.n_dataloader_workers,
)

Expand Down Expand Up @@ -203,8 +207,8 @@ def _initialize_data(
),
batch_size=self._config.train_batch_size,
shuffle=False,
pin_memory=True,
pin_memory_device="cuda",
pin_memory=pin_memory,
pin_memory_device=pin_device,
num_workers=self._config.n_dataloader_workers,
)

Expand Down Expand Up @@ -236,8 +240,8 @@ def _initialize_data(
),
batch_size=self._config.train_batch_size,
shuffle=False,
pin_memory=True,
pin_memory_device="cuda",
pin_memory=pin_memory,
pin_memory_device=pin_device,
num_workers=self._config.n_dataloader_workers,
)

Expand Down
4 changes: 2 additions & 2 deletions chem_mrl/trainers/TrainerExecutor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Callable, Generic, TypeVar
from typing import Generic, TypeVar

import optuna
import wandb

import wandb
from chem_mrl.configs import BoundConfigType

from .BaseTrainer import BoundTrainerType
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "chem-mrl"
version = "0.3.6"
version = "0.3.7"
description = "SMILES-based Matryoshka Representation Learning Embedding Model"
license = { text = "Apache 2.0" }
readme = "README.md"
Expand All @@ -14,7 +14,7 @@ requires-python = ">=3.10"

dependencies = [
"sentence-transformers==3.4.1",
"transformers>=4.34.0",
"transformers[torch]>=4.34.0",
"optuna==4.2.0",
"wandb==0.19.4",
"torch>=2.0.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/configs/test_mrl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_chem_mrl_config_default_values():

assert config.smiles_a_column_name == "smiles_a"
assert config.smiles_b_column_name == "smiles_b"
assert config.label_column_name == "fingerprint_similarity"
assert config.label_column_name == "similarity"
assert config.model_name == BASE_MODEL_NAME
assert config.embedding_pooling == "mean"
assert config.loss_func == "tanimotosentloss"
Expand Down
Binary file modified tests/data/test_chem_mrl.parquet
Binary file not shown.
50 changes: 0 additions & 50 deletions tests/trainers/test_chem_mrl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,6 @@ def test_chem_mrl_pooling_options(pooling):
assert isinstance(result, float)
assert result != -math.inf

config = Chem2dMRLConfig(
model_name=BASE_MODEL_NAME,
train_dataset_path=TEST_CHEM_MRL_PATH,
val_dataset_path=TEST_CHEM_MRL_PATH,
embedding_pooling=pooling,
return_eval_metric=True,
)
trainer = ChemMRLTrainer(config)
result = trainer.train()
assert trainer.config.embedding_pooling == pooling
assert isinstance(result, float)
assert result != -math.inf


@pytest.mark.parametrize(
"loss_func",
Expand All @@ -147,19 +134,6 @@ def test_chem_mrl_loss_functions(loss_func):
assert isinstance(result, float)
assert result != -math.inf

# can't test tanimotosimilarityloss since it requires an additional parameter
config = Chem2dMRLConfig(
model_name=BASE_MODEL_NAME,
train_dataset_path=TEST_CHEM_MRL_PATH,
val_dataset_path=TEST_CHEM_MRL_PATH,
loss_func=loss_func,
return_eval_metric=True,
)
trainer = ChemMRLTrainer(config)
result = trainer.train()
assert isinstance(result, float)
assert result != -math.inf


@pytest.mark.parametrize("base_loss", TANIMOTO_SIMILARITY_BASE_LOSS_FCT_OPTIONS)
def test_chem_mrl_tanimoto_similarity_loss(base_loss):
Expand Down Expand Up @@ -204,18 +178,6 @@ def test_chem_mrl_eval_similarity(eval_similarity):
assert isinstance(result, float)
assert result != -math.inf

config = Chem2dMRLConfig(
model_name=BASE_MODEL_NAME,
train_dataset_path=TEST_CHEM_MRL_PATH,
val_dataset_path=TEST_CHEM_MRL_PATH,
eval_similarity_fct=eval_similarity,
return_eval_metric=True,
)
trainer = ChemMRLTrainer(config)
result = trainer.train()
assert isinstance(result, float)
assert result != -math.inf


@pytest.mark.parametrize("eval_metric", CHEM_MRL_EVAL_METRIC_OPTIONS)
def test_chem_mrl_eval_metrics(eval_metric):
Expand All @@ -231,18 +193,6 @@ def test_chem_mrl_eval_metrics(eval_metric):
assert isinstance(result, float)
assert result != -math.inf

config = Chem2dMRLConfig(
model_name=BASE_MODEL_NAME,
train_dataset_path=TEST_CHEM_MRL_PATH,
val_dataset_path=TEST_CHEM_MRL_PATH,
eval_metric=eval_metric,
return_eval_metric=True,
)
trainer = ChemMRLTrainer(config)
result = trainer.train()
assert isinstance(result, float)
assert result != -math.inf


def test_chem_2d_mrl_trainer_instantiation():
config = Chem2dMRLConfig(
Expand Down
Loading

0 comments on commit ecdc7a8

Please sign in to comment.