From 1532b0dce92a5c614bd2812d3d4be152bc0a18c0 Mon Sep 17 00:00:00 2001 From: Matthew Wood Date: Mon, 7 Oct 2024 15:08:25 +0200 Subject: [PATCH] Update examples for UCE --- examples/fine_tune_models/fine_tune_UCE.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/fine_tune_models/fine_tune_UCE.py b/examples/fine_tune_models/fine_tune_UCE.py index e54810aa..342f82ee 100644 --- a/examples/fine_tune_models/fine_tune_UCE.py +++ b/examples/fine_tune_models/fine_tune_UCE.py @@ -1,4 +1,4 @@ -from helical import UCEConfig, UCE, UCEFineTuningModel +from helical import UCEConfig, UCEFineTuningModel from helical.utils import get_anndata_from_hf_dataset from datasets import load_dataset from omegaconf import DictConfig @@ -6,21 +6,21 @@ @hydra.main(version_base=None, config_path="../run_models/configs", config_name="uce_config") def run_fine_tuning(cfg: DictConfig): - uce_config=UCEConfig(**cfg) - uce = UCE(configurer=uce_config) - hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists") ann_data = get_anndata_from_hf_dataset(hf_dataset) - dataset = uce.process_data(ann_data[:10], name="train") - cell_types = ann_data.obs["LVL1"][:10].tolist() label_set = set(cell_types) + + uce_config=UCEConfig(**cfg) + uce_fine_tune = UCEFineTuningModel(uce_config=uce_config, fine_tuning_head="classification", output_size=len(label_set)) + + dataset = uce_fine_tune.process_data(ann_data[:10], name="train") + class_id_dict = {label: i for i, label in enumerate(label_set)} cell_types = [class_id_dict[cell] for cell in cell_types] - uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set)) uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types) if __name__ == "__main__":