Skip to content

Commit

Permalink
Update Notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwoodx committed Oct 7, 2024
1 parent 1532b0d commit 13b3f70
Show file tree
Hide file tree
Showing 2 changed files with 362 additions and 59 deletions.
383 changes: 343 additions & 40 deletions examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb

Large diffs are not rendered by default.

38 changes: 19 additions & 19 deletions helical/models/scgpt/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self,
output_size: Optional[int]=None):
HelicalBaseFineTuningModel.__init__(self, fine_tuning_head, output_size)
scGPT.__init__(self, scGPT_config)

self.fine_tuning_head.set_dim_size(self.config["embsize"])

def _forward(self,
Expand Down Expand Up @@ -193,25 +193,25 @@ def train(
training_loop.set_postfix({"loss": batch_loss/batches_processed})
training_loop.set_description(f"Fine-Tuning: epoch {j+1}/{epochs}")

if lr_scheduler is not None:
lr_scheduler.step()
if lr_scheduler is not None:
lr_scheduler.step()

if validation_input_data is not None:
testing_loop = tqdm(validation_data_loader, desc="Fine-Tuning Validation")
val_loss = 0.0
count = 0.0
validation_batch_count = 0
for validation_data_dict in testing_loop:
input_gene_ids = validation_data_dict["gene"].to(device)
src_key_padding_mask = input_gene_ids.eq(
self.vocab[self.config["pad_token"]]
)
output = self._forward(input_gene_ids, validation_data_dict, src_key_padding_mask, use_batch_labels, device)
val_labels = torch.tensor(validation_labels[validation_batch_count: validation_batch_count + self.config["batch_size"]], device=device)
val_loss += loss_function(output, val_labels).item()
validation_batch_count += self.config["batch_size"]
count += 1.0
testing_loop.set_postfix({"val_loss": val_loss/count})
if validation_input_data is not None:
testing_loop = tqdm(validation_data_loader, desc="Fine-Tuning Validation")
val_loss = 0.0
count = 0.0
validation_batch_count = 0
for validation_data_dict in testing_loop:
input_gene_ids = validation_data_dict["gene"].to(device)
src_key_padding_mask = input_gene_ids.eq(
self.vocab[self.config["pad_token"]]
)
output = self._forward(input_gene_ids, validation_data_dict, src_key_padding_mask, use_batch_labels, device)
val_labels = torch.tensor(validation_labels[validation_batch_count: validation_batch_count + self.config["batch_size"]], device=device)
val_loss += loss_function(output, val_labels).item()
validation_batch_count += self.config["batch_size"]
count += 1.0
testing_loop.set_postfix({"val_loss": val_loss/count})
logger.info(f"Fine-Tuning Complete. Epochs: {epochs}")

def get_outputs(
Expand Down

0 comments on commit 13b3f70

Please sign in to comment.