Skip to content

Commit

Permalink
Update examples for UCE
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwoodx committed Oct 7, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent ab53dbe commit 1532b0d
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions examples/fine_tune_models/fine_tune_UCE.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from helical import UCEConfig, UCE, UCEFineTuningModel
from helical import UCEConfig, UCEFineTuningModel
from helical.utils import get_anndata_from_hf_dataset
from datasets import load_dataset
from omegaconf import DictConfig
import hydra

@hydra.main(version_base=None, config_path="../run_models/configs", config_name="uce_config")
def run_fine_tuning(cfg: DictConfig):
uce_config=UCEConfig(**cfg)
uce = UCE(configurer=uce_config)

hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
ann_data = get_anndata_from_hf_dataset(hf_dataset)

dataset = uce.process_data(ann_data[:10], name="train")

cell_types = ann_data.obs["LVL1"][:10].tolist()

label_set = set(cell_types)

uce_config=UCEConfig(**cfg)
uce_fine_tune = UCEFineTuningModel(uce_config=uce_config, fine_tuning_head="classification", output_size=len(label_set))

dataset = uce_fine_tune.process_data(ann_data[:10], name="train")

class_id_dict = {label: i for i, label in enumerate(label_set)}
cell_types = [class_id_dict[cell] for cell in cell_types]

uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types)

if __name__ == "__main__":

0 comments on commit 1532b0d

Please sign in to comment.