Skip to content

Commit

Permalink
Refactor UCE fine-tuning to be standalone
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwoodx committed Oct 7, 2024
1 parent 044f10f commit ab53dbe
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
20 changes: 10 additions & 10 deletions docs/model_cards/uce.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,31 @@ print(embeddings.shape)
## How To Fine-Tune

```python
from helical.models.uce.model import UCE, UCEConfig
from helical.models.uce.fine_tuning_model import UCEFineTuningModel
from helical import UCEConfig, UCEFineTuningModel
import anndata as ad

configurer=UCEConfig(batch_size=10)
uce = UCE(configurer=configurer)

# Load the data
ann_data = ad.read_h5ad("dataset.h5ad")

# Get unique output labels
label_set = set(cell_types)

# Create the fine-tuning model with the desired configs
configurer=UCEConfig(batch_size=10)
uce_fine_tune = UCEFineTuningModel(uce_config=configurer, fine_tuning_head="classification", output_size=len(label_set))

# Process the data for training
dataset = uce.process_data(ann_data)
dataset = uce_fine_tune.process_data(ann_data)

# Get the desired label class
cell_types = list(ann_data.obs.cell_type)

# Create a dictionary mapping the classes to unique integers for training
label_set = set(cell_types)
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
cell_types[i] = class_id_dict[cell_types[i]]

# Create the fine-tuning model
uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))

# Fine-tune
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types)

Expand Down
38 changes: 18 additions & 20 deletions helical/models/uce/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
from torch.utils.data import DataLoader
from helical.models.base_models import HelicalBaseFineTuningHead
from helical.models.base_models import HelicalBaseFineTuningModel
from helical.models.uce import UCE
from helical.models.uce import UCE, UCEConfig
from typing import Literal, Optional
from tqdm import tqdm
from transformers import get_scheduler
import logging

logger = logging.getLogger(__name__)

class UCEFineTuningModel(HelicalBaseFineTuningModel):
class UCEFineTuningModel(HelicalBaseFineTuningModel, UCE):
"""
Fine-tuning model for the UCE model.
Parameters
----------
uce_model : UCE
The initialised UCE model to fine-tune.
uce_config : UCE
The UCE configs for fine-tuning model, the same configs that would be used to instantiate the standard UCE model.
fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead
The fine-tuning head that is appended to the model. This can either be a string (options available: "classification", "regression") specifying the task or a custom fine-tuning head inheriting from HelicalBaseFineTuningHead.
output_size : Optional[int]
Expand All @@ -38,15 +38,13 @@ class UCEFineTuningModel(HelicalBaseFineTuningModel):
"""
def __init__(self,
uce_model: UCE,
uce_config: UCEConfig,
fine_tuning_head: Literal["classification"] | HelicalBaseFineTuningHead,
output_size: Optional[int]=None):

super().__init__(fine_tuning_head, output_size)
self.config = uce_model.config
self.uce_model = uce_model.model
self.device = uce_model.device
self.accelerator = uce_model.accelerator
HelicalBaseFineTuningModel.__init__(self, fine_tuning_head, output_size)
UCE.__init__(self, uce_config)

self.fine_tuning_head.set_dim_size(self.config["embsize"])

def _forward(self, batch_sentences: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
Expand All @@ -65,7 +63,7 @@ def _forward(self, batch_sentences: torch.Tensor, mask: torch.Tensor) -> torch.T
torch.Tensor
The output tensor of the fine-tuning model.
"""
_, embeddings = self.uce_model.forward(batch_sentences, mask=mask)
_, embeddings = self.model.forward(batch_sentences, mask=mask)
if self.accelerator is not None:
self.accelerator.wait_for_everyone()
embeddings = self.accelerator.gather_for_metrics((embeddings))
Expand Down Expand Up @@ -138,7 +136,7 @@ def train(
if validation_input_data is not None:
validation_dataloader = self.accelerator.prepare(validation_dataloader)

self.uce_model.train()
self.model.train()
self.fine_tuning_head.train()

# disable progress bar if not the main process
Expand All @@ -165,9 +163,9 @@ def train(
batch_sentences, mask, idxs = batch[0], batch[1], batch[2]
batch_sentences = batch_sentences.permute(1, 0)
if self.config["multi_gpu"]:
batch_sentences = self.uce_model.module.pe_embedding(batch_sentences.long())
batch_sentences = self.model.module.pe_embedding(batch_sentences.long())
else:
batch_sentences = self.uce_model.pe_embedding(batch_sentences.long())
batch_sentences = self.model.pe_embedding(batch_sentences.long())
batch_sentences = torch.nn.functional.normalize(batch_sentences, dim=2) # normalize token outputs
output = self._forward(batch_sentences, mask=mask)
labels = torch.tensor(train_labels[batch_count: batch_count + self.config["batch_size"]], device=self.device)
Expand Down Expand Up @@ -195,9 +193,9 @@ def train(
batch_sentences, mask, idxs = validation_data[0], validation_data[1], validation_data[2]
batch_sentences = batch_sentences.permute(1, 0)
if self.config["multi_gpu"]:
batch_sentences = self.uce_model.module.pe_embedding(batch_sentences.long())
batch_sentences = self.model.module.pe_embedding(batch_sentences.long())
else:
batch_sentences = self.uce_model.pe_embedding(batch_sentences.long())
batch_sentences = self.model.pe_embedding(batch_sentences.long())
batch_sentences = torch.nn.functional.normalize(batch_sentences, dim=2) # normalize token outputs
output = self._forward(batch_sentences, mask=mask)
val_labels = torch.tensor(validation_labels[validation_batch_count: validation_batch_count + self.config["batch_size"]], device=self.device)
Expand All @@ -206,7 +204,7 @@ def train(
count += 1.0
testing_loop.set_postfix({"val_loss": val_loss/count})
logger.info(f"Fine-Tuning Complete. Epochs: {epochs}")
self.uce_model.eval()
self.model.eval()
self.fine_tuning_head.eval()

def get_outputs(
Expand Down Expand Up @@ -239,7 +237,7 @@ def get_outputs(
if self.accelerator is not None:
dataloader = self.accelerator.prepare(dataloader)

self.uce_model.eval()
self.model.eval()
self.fine_tuning_head.eval()

testing_loop = tqdm(dataloader, desc="Fine-Tuning Validation")
Expand All @@ -248,9 +246,9 @@ def get_outputs(
batch_sentences, mask, idxs = validation_data[0], validation_data[1], validation_data[2]
batch_sentences = batch_sentences.permute(1, 0)
if self.config["multi_gpu"]:
batch_sentences = self.uce_model.module.pe_embedding(batch_sentences.long())
batch_sentences = self.model.module.pe_embedding(batch_sentences.long())
else:
batch_sentences = self.uce_model.pe_embedding(batch_sentences.long())
batch_sentences = self.model.pe_embedding(batch_sentences.long())
batch_sentences = torch.nn.functional.normalize(batch_sentences, dim=2) # normalize token outputs
output = self._forward(batch_sentences, mask=mask)
outputs.append(output.detach().cpu().numpy())
Expand Down

0 comments on commit ab53dbe

Please sign in to comment.