From ecdc7a801a14b66f603655725c192237ceef678d Mon Sep 17 00:00:00 2001 From: Manny Cortes Date: Tue, 4 Feb 2025 15:20:57 -0800 Subject: [PATCH] refactor: Update README to include examples refactor: Modify config default values refactor: increased CI tests timeout from 30 min to 90 min. refactor: pin dataloader memory if cuda device is available chore: update version 0.3.6 to 0.3.7 --- .github/workflows/ci.yml | 2 +- README.md | 152 +++++++++++++++++++++- chem_mrl/__init__.py | 2 +- chem_mrl/configs/ArgParseHelper.py | 2 +- chem_mrl/configs/ClassifierConfig.py | 1 + chem_mrl/configs/MrlConfig.py | 2 +- chem_mrl/trainers/BaseTrainer.py | 6 + chem_mrl/trainers/ClassifierTrainer.py | 16 ++- chem_mrl/trainers/TrainerExecutor.py | 4 +- pyproject.toml | 4 +- tests/configs/test_mrl_config.py | 2 +- tests/data/test_chem_mrl.parquet | Bin 1522 -> 1474 bytes tests/trainers/test_chem_mrl_trainer.py | 50 ------- tests/trainers/test_classifier_trainer.py | 48 ------- 14 files changed, 176 insertions(+), 115 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bdc1938..6056db4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: jobs: tests: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 90 steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/README.md b/README.md index bc14807..e93c24d 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,162 @@ Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, pip install chem-mrl ``` +## Usage + +### Basic Training Workflow + +To train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training. + +```python +from chem_mrl.configs import ChemMRLConfig +from chem_mrl.constants import BASE_MODEL_NAME +from chem_mrl.trainers import ChemMRLTrainer + +# Define training configuration +config = ChemMRLConfig( + model_name=BASE_MODEL_NAME, # Predefined model name - Can be a any transformer model name or path that is compatible with sentence-transformers + train_dataset_path="train.parquet", # Path to training data + val_dataset_path="val.parquet", # Path to validation data + test_dataset_path="test.parquet", # Optional test dataset + smiles_a_column_name="smiles_a", # Column with first molecule SMILES representation + smiles_b_column_name="smiles_b", # Column with second molecule SMILES representation + label_column_name="similarity", # Similarity score between molecules + return_eval_metric=True, # Compute and return test metric if test dataset is provided + n_dims_per_step=3, # Model-specific hyperparameter +) + +# Initialize trainer and start training +trainer = ChemMRLTrainer(config) +test_eval_metric = trainer.train() # Returns evaluation metric (if test dataset exists) otherwise returns the final validation eval metric +``` + +### Custom Evaluation Callbacks + +You can provide a callback function that is executed every `evaluation_steps` steps, allowing custom logic such as logging, early stopping, or model checkpointing. + +```python +from chem_mrl.configs import Chem2dMRLConfig +from chem_mrl.constants import BASE_MODEL_NAME +from chem_mrl.trainers import ChemMRLTrainer + +# Define a callback function for logging evaluation metrics +def eval_callback(score: float, epoch: int, steps: int): + print(f"Step {steps}, Epoch {epoch}: Evaluation Score = {score}") + +# Define configuration for a 2D MRL model with additional hyperparameters +config = Chem2dMRLConfig( + model_name=BASE_MODEL_NAME, + train_dataset_path="train.parquet", + val_dataset_path="val.parquet", + smiles_a_column_name="smiles_a", + smiles_b_column_name="smiles_b", + label_column_name="similarity", + evaluation_steps=1000, # Callback execution frequency + return_eval_metric=True, # Returns validation metric instead of test metric + n_dims_per_step=3, # Model-specific hyperparameter + n_layers_per_step=2, # Additional parameter specific to 2D MRL models + kl_div_weight=0.7, # Weight for KL divergence regularization + kl_temperature=0.5, # Temperature parameter for KL loss +) + +# Train with callback +trainer = ChemMRLTrainer(config) +val_eval_metric = trainer.train(eval_callback=eval_callback) # Callback executed every `evaluation_steps` +``` + +### W&B Integration + +This library includes a `WandBTrainerExecutor` class for seamless Weights & Biases (W&B) integration. It handles authentication, initialization, and logging at the frequency specified by `evaluation_steps`. This setup ensures seamless logging and experiment tracking, allowing for better visualization and monitoring of model performance. + + +```python +from chem_mrl.configs import Chem2dMRLConfig, ChemMRLConfig +from chem_mrl.constants import BASE_MODEL_NAME +from chem_mrl.trainers import ChemMRLTrainer, WandBTrainerExecutor + +# Define W&B configuration for experiment tracking +wandb_config = WandbConfig( + project_name="chem_mrl_test", # W&B project name + run_name="test", # Name for the experiment run + use_watch=True, # Enables model watching for tracking gradients + watch_log="all", # Logs all model parameters and gradients + watch_log_freq=1000, # Logging frequency + watch_log_graph=True, # Logs model computation graph +) + +# Configure training with W&B integration +config = ChemMRLConfig( + model_name=BASE_MODEL_NAME, + train_dataset_path="train.parquet", + val_dataset_path="val.parquet", + evaluation_steps=1000, + use_wandb=True, # Enables W&B logging + wandb_config=wandb_config, +) + +# Initialize trainer and W&B executor +trainer = ChemMRLTrainer(config) +executor = WandBTrainerExecutor(trainer) +executor.execute() # Handles training and W&B logging +``` + ## Classifier -This repository includes code for training a linear SBERT classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features. While demonstrated on the Isomer Design dataset, it is generalizable to any dataset containing `smiles` and `label` columns. The training scripts (see below) allow users to specify these column names. +This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features. While demonstrated on the Isomer Design dataset, it is generalizable to any dataset containing `smiles` and `label` columns. The training scripts (see below) allow users to specify these column names. Currently, the dataset must be in Parquet format. Hyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification. +## Usage + +### Basic Classification Training + +To train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training. + +```python +from chem_mrl.configs import ClassifierConfig +from chem_mrl.trainers import ClassifierTrainer + +# Define classification training configuration +config = ClassifierConfig( + model_name="path/to/trained_mrl_model", # Pretrained MRL model path + train_dataset_path="train_classification.parquet", # Path to training dataset + val_dataset_path="val_classification.parquet", # Path to validation dataset + smiles_column_name="smiles", # Column containing SMILES representations of molecules + label_column_name="label", # Column containing classification labels +) + +# Initialize and train the classifier +trainer = ClassifierTrainer(config) +trainer.train() +``` + +### Training with Dice Loss + +For imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters. + +```python +from chem_mrl.configs import DiceLossClassifierConfig +from chem_mrl.trainers import ClassifierTrainer +from chem_mrl.constants import BASE_MODEL_NAME + +# Define classification training configuration with Dice Loss +config = DiceLossClassifierConfig( + model_name=BASE_MODEL_NAME, # Predefined base model + train_dataset_path="train_classification.parquet", + val_dataset_path="val_classification.parquet", + smiles_column_name="smiles", + label_column_name="label", + dice_reduction="sum", # Reduction method for Dice Loss (e.g., 'mean' or 'sum') + dice_gamma=1.0, # Hyperparameter controlling the impact of Dice Loss +) + +# Initialize and train the classifier with Dice Loss +trainer = ClassifierTrainer(config) +trainer.train() +``` + ## Scripts The `scripts` directory contains two training scripts: @@ -112,7 +260,7 @@ options: --smiles_b_column_name SMILES_B_COLUMN_NAME SMILES B column name (default: smiles_b) --label_column_name LABEL_COLUMN_NAME - Label column name (default: fingerprint_similarity) + Label column name (default: similarity) --embedding_pooling {mean,mean_sqrt_len_tokens,weightedmean,lasttoken} Pooling layer method applied to the embeddings.Pooling layer is required to generate a fixed sized SMILES embedding from a variable sized SMILES.For details visit: https://sbert.net/docs/package_reference/sentence_transformer/models.html#sentence_transformers.models.Pooling (default: mean) diff --git a/chem_mrl/__init__.py b/chem_mrl/__init__.py index 00b7669..c87dc08 100644 --- a/chem_mrl/__init__.py +++ b/chem_mrl/__init__.py @@ -4,7 +4,7 @@ from sentence_transformers import LoggingHandler -__version__ = "0.3.6" +__version__ = "0.3.7" from . import ( benchmark, diff --git a/chem_mrl/configs/ArgParseHelper.py b/chem_mrl/configs/ArgParseHelper.py index d3f7718..17cbed6 100644 --- a/chem_mrl/configs/ArgParseHelper.py +++ b/chem_mrl/configs/ArgParseHelper.py @@ -168,7 +168,7 @@ def add_chem_mrl_config_args( ) parser.add_argument( "--label_column_name", - default="fingerprint_similarity", + default="similarity", help="Label column name", ) parser.add_argument( diff --git a/chem_mrl/configs/ClassifierConfig.py b/chem_mrl/configs/ClassifierConfig.py index e3b6dd6..45bacfc 100644 --- a/chem_mrl/configs/ClassifierConfig.py +++ b/chem_mrl/configs/ClassifierConfig.py @@ -59,6 +59,7 @@ def __post_init__(self): @dataclass(frozen=True) class DiceLossClassifierConfig(ClassifierConfig): + loss_func = "selfadjdice" dice_reduction: DiceReductionOptionType = "mean" # type: ignore dice_gamma: float = 1.0 diff --git a/chem_mrl/configs/MrlConfig.py b/chem_mrl/configs/MrlConfig.py index 9a56159..a4912bd 100644 --- a/chem_mrl/configs/MrlConfig.py +++ b/chem_mrl/configs/MrlConfig.py @@ -21,7 +21,7 @@ class ChemMRLConfig(_BaseConfig): smiles_a_column_name: str = "smiles_a" smiles_b_column_name: str = "smiles_b" - label_column_name: str = "fingerprint_similarity" + label_column_name: str = "similarity" embedding_pooling: ChemMrlPoolingOptionType = "mean" # type: ignore loss_func: ChemMrlLossFctOptionType = "tanimotosentloss" # type: ignore tanimoto_similarity_loss_func: TanimotoSimilarityBaseLossFctOptionType | None = None # type: ignore diff --git a/chem_mrl/trainers/BaseTrainer.py b/chem_mrl/trainers/BaseTrainer.py index fe4bbdf..aadcd10 100644 --- a/chem_mrl/trainers/BaseTrainer.py +++ b/chem_mrl/trainers/BaseTrainer.py @@ -129,6 +129,12 @@ def _initialize_output_path(self): # concrete methods ############################################################################ + def _device(self) -> str: + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "-1") + use_cuda = torch.cuda.is_available() and cuda_visible_devices != "-1" + device = "cuda" if use_cuda else "cpu" + return device + def __calculate_training_params(self) -> tuple[float, float, int]: total_training_points = self.steps_per_epoch * self.config.train_batch_size # Normalized weight decay for adamw optimizer - https://arxiv.org/pdf/1711.05101.pdf diff --git a/chem_mrl/trainers/ClassifierTrainer.py b/chem_mrl/trainers/ClassifierTrainer.py index 36ab95f..c5466d4 100644 --- a/chem_mrl/trainers/ClassifierTrainer.py +++ b/chem_mrl/trainers/ClassifierTrainer.py @@ -147,6 +147,10 @@ def _initialize_data( test_file: str | None = None, ): logging.info(f"Loading {train_file} dataset") + + pin_device = self._device() + pin_memory = True if pin_device != "cpu" else False + train_df = pd.read_parquet( train_file, columns=[ @@ -172,8 +176,8 @@ def _initialize_data( ), batch_size=self._config.train_batch_size, shuffle=True, - pin_memory=True, - pin_memory_device="cuda", + pin_memory=pin_memory, + pin_memory_device=pin_device, num_workers=self._config.n_dataloader_workers, ) @@ -203,8 +207,8 @@ def _initialize_data( ), batch_size=self._config.train_batch_size, shuffle=False, - pin_memory=True, - pin_memory_device="cuda", + pin_memory=pin_memory, + pin_memory_device=pin_device, num_workers=self._config.n_dataloader_workers, ) @@ -236,8 +240,8 @@ def _initialize_data( ), batch_size=self._config.train_batch_size, shuffle=False, - pin_memory=True, - pin_memory_device="cuda", + pin_memory=pin_memory, + pin_memory_device=pin_device, num_workers=self._config.n_dataloader_workers, ) diff --git a/chem_mrl/trainers/TrainerExecutor.py b/chem_mrl/trainers/TrainerExecutor.py index 10cf1f1..dea56ad 100644 --- a/chem_mrl/trainers/TrainerExecutor.py +++ b/chem_mrl/trainers/TrainerExecutor.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from contextlib import nullcontext -from typing import Callable, Generic, TypeVar +from typing import Generic, TypeVar import optuna -import wandb +import wandb from chem_mrl.configs import BoundConfigType from .BaseTrainer import BoundTrainerType diff --git a/pyproject.toml b/pyproject.toml index e3d08f2..2224e31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chem-mrl" -version = "0.3.6" +version = "0.3.7" description = "SMILES-based Matryoshka Representation Learning Embedding Model" license = { text = "Apache 2.0" } readme = "README.md" @@ -14,7 +14,7 @@ requires-python = ">=3.10" dependencies = [ "sentence-transformers==3.4.1", - "transformers>=4.34.0", + "transformers[torch]>=4.34.0", "optuna==4.2.0", "wandb==0.19.4", "torch>=2.0.0", diff --git a/tests/configs/test_mrl_config.py b/tests/configs/test_mrl_config.py index 090062e..85b87e3 100644 --- a/tests/configs/test_mrl_config.py +++ b/tests/configs/test_mrl_config.py @@ -18,7 +18,7 @@ def test_chem_mrl_config_default_values(): assert config.smiles_a_column_name == "smiles_a" assert config.smiles_b_column_name == "smiles_b" - assert config.label_column_name == "fingerprint_similarity" + assert config.label_column_name == "similarity" assert config.model_name == BASE_MODEL_NAME assert config.embedding_pooling == "mean" assert config.loss_func == "tanimotosentloss" diff --git a/tests/data/test_chem_mrl.parquet b/tests/data/test_chem_mrl.parquet index b880df1c21d03c920287432abef4fd4f117ce44a..ebcadb8baf4bd0dca18b56c2cba112bb63600590 100644 GIT binary patch delta 51 zcmeyweTaL*WJX4=$qN{Jfut2v4buzO%{!Rd87F(QXl$-$k!55wo;;V;mm`6Nfg!*# G$PfUmvJR;L delta 101 zcmX@a{fT?SWJX@Gw9LHp)S`l-%)FBL$x|46(L|J(YM2_>HZNdmXXHUqs?MT;BCy$? WMV67#XmT#AFUJHH28IB~AVUCE&mt87 diff --git a/tests/trainers/test_chem_mrl_trainer.py b/tests/trainers/test_chem_mrl_trainer.py index faf7cc2..e9b0f4c 100644 --- a/tests/trainers/test_chem_mrl_trainer.py +++ b/tests/trainers/test_chem_mrl_trainer.py @@ -112,19 +112,6 @@ def test_chem_mrl_pooling_options(pooling): assert isinstance(result, float) assert result != -math.inf - config = Chem2dMRLConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CHEM_MRL_PATH, - val_dataset_path=TEST_CHEM_MRL_PATH, - embedding_pooling=pooling, - return_eval_metric=True, - ) - trainer = ChemMRLTrainer(config) - result = trainer.train() - assert trainer.config.embedding_pooling == pooling - assert isinstance(result, float) - assert result != -math.inf - @pytest.mark.parametrize( "loss_func", @@ -147,19 +134,6 @@ def test_chem_mrl_loss_functions(loss_func): assert isinstance(result, float) assert result != -math.inf - # can't test tanimotosimilarityloss since it requires an additional parameter - config = Chem2dMRLConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CHEM_MRL_PATH, - val_dataset_path=TEST_CHEM_MRL_PATH, - loss_func=loss_func, - return_eval_metric=True, - ) - trainer = ChemMRLTrainer(config) - result = trainer.train() - assert isinstance(result, float) - assert result != -math.inf - @pytest.mark.parametrize("base_loss", TANIMOTO_SIMILARITY_BASE_LOSS_FCT_OPTIONS) def test_chem_mrl_tanimoto_similarity_loss(base_loss): @@ -204,18 +178,6 @@ def test_chem_mrl_eval_similarity(eval_similarity): assert isinstance(result, float) assert result != -math.inf - config = Chem2dMRLConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CHEM_MRL_PATH, - val_dataset_path=TEST_CHEM_MRL_PATH, - eval_similarity_fct=eval_similarity, - return_eval_metric=True, - ) - trainer = ChemMRLTrainer(config) - result = trainer.train() - assert isinstance(result, float) - assert result != -math.inf - @pytest.mark.parametrize("eval_metric", CHEM_MRL_EVAL_METRIC_OPTIONS) def test_chem_mrl_eval_metrics(eval_metric): @@ -231,18 +193,6 @@ def test_chem_mrl_eval_metrics(eval_metric): assert isinstance(result, float) assert result != -math.inf - config = Chem2dMRLConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CHEM_MRL_PATH, - val_dataset_path=TEST_CHEM_MRL_PATH, - eval_metric=eval_metric, - return_eval_metric=True, - ) - trainer = ChemMRLTrainer(config) - result = trainer.train() - assert isinstance(result, float) - assert result != -math.inf - def test_chem_2d_mrl_trainer_instantiation(): config = Chem2dMRLConfig( diff --git a/tests/trainers/test_classifier_trainer.py b/tests/trainers/test_classifier_trainer.py index 4bad150..1a355bd 100644 --- a/tests/trainers/test_classifier_trainer.py +++ b/tests/trainers/test_classifier_trainer.py @@ -108,20 +108,6 @@ def test_classifier_classifier_hidden_dimensions( assert isinstance(result, float) assert result != -math.inf - config = DiceLossClassifierConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CLASSIFICATION_PATH, - val_dataset_path=TEST_CLASSIFICATION_PATH, - classifier_hidden_dimension=dimension, - return_eval_metric=True, - ) - trainer = ClassifierTrainer(config) - result = trainer.train() - assert trainer.model.truncate_dim == dimension - assert trainer.loss_fct.smiles_embedding_dimension == dimension - assert isinstance(result, float) - assert result != -math.inf - @pytest.mark.parametrize("eval_metric", CLASSIFIER_EVAL_METRIC_OPTIONS) def test_classifier_eval_metrics(eval_metric): @@ -137,18 +123,6 @@ def test_classifier_eval_metrics(eval_metric): assert isinstance(result, float) assert result != -math.inf - config = DiceLossClassifierConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CLASSIFICATION_PATH, - val_dataset_path=TEST_CLASSIFICATION_PATH, - eval_metric=eval_metric, - return_eval_metric=True, - ) - trainer = ClassifierTrainer(config) - result = trainer.train() - assert isinstance(result, float) - assert result != -math.inf - def test_classifier_freeze_internal_model(): config = ClassifierConfig( @@ -188,15 +162,6 @@ def test_classifier_num_labels(): trainer = ClassifierTrainer(config) assert trainer.loss_fct.num_labels == 2 # testing dataset only has two classes - config = DiceLossClassifierConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CLASSIFICATION_PATH, - val_dataset_path=TEST_CLASSIFICATION_PATH, - freeze_model=True, - ) - trainer = ClassifierTrainer(config) - assert trainer.loss_fct.num_labels == 2 # testing dataset only has two classes - @pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5, 1.0]) def test_classifier_dropout(dropout_p): @@ -213,19 +178,6 @@ def test_classifier_dropout(dropout_p): assert isinstance(result, float) assert result != -math.inf - config = DiceLossClassifierConfig( - model_name=BASE_MODEL_NAME, - train_dataset_path=TEST_CLASSIFICATION_PATH, - val_dataset_path=TEST_CLASSIFICATION_PATH, - dropout_p=dropout_p, - return_eval_metric=True, - ) - trainer = ClassifierTrainer(config) - result = trainer.train() - assert trainer.loss_fct.dropout_p == dropout_p - assert isinstance(result, float) - assert result != -math.inf - def test_dice_loss_classifier_trainer_instantiation(): config = DiceLossClassifierConfig(