diff --git a/ci/tests/test_hyena_dna/test_model.py b/ci/tests/test_hyena_dna/test_model.py new file mode 100644 index 00000000..248d44ed --- /dev/null +++ b/ci/tests/test_hyena_dna/test_model.py @@ -0,0 +1,39 @@ +from helical.models.hyena_dna.model import HyenaDNA,HyenaDNAConfig +import pytest +import pytest + +@pytest.mark.parametrize("model_name, d_model, d_inner", [ + ("hyenadna-tiny-1k-seqlen", 128, 512), + ("hyenadna-tiny-1k-seqlen-d256", 256, 1024) +]) +def test_hyena_dna__ok(model_name, d_model, d_inner): + """ + Test case for the HyenaDNA class initialization. + + Args: + model_name (str): The name of the model. + d_model (int): The dimensionality of the model. + d_inner (int): The dimensionality of the inner layers. + """ + configurer = HyenaDNAConfig(model_name=model_name) + model = HyenaDNA(configurer=configurer) + assert model.config["model_name"] == model_name + assert model.config["d_model"] == d_model + assert model.config["d_inner"] == d_inner + +@pytest.mark.parametrize("model_name", [ + ("wrong_name") +]) +def test_hyena_dna__nok(model_name): + """ + Test case when an invalid model name is provided. + Verifies that a ValueError is raised when an invalid model name is passed to the HyenaDNAConfig constructor. + + Parameters: + - model_name (str): The invalid model name. + + Raises: + - ValueError: If the model name is invalid. + """ + with pytest.raises(ValueError): + HyenaDNAConfig(model_name=model_name) diff --git a/helical/models/hyena_dna/model.py b/helical/models/hyena_dna/model.py index 1bb0de1a..7fbee700 100644 --- a/helical/models/hyena_dna/model.py +++ b/helical/models/hyena_dna/model.py @@ -13,20 +13,20 @@ class HyenaDNA(HelicalBaseModel): """HyenaDNA model.""" - default_config = HyenaDNAConfig() + default_configurer = HyenaDNAConfig() - def __init__(self, model_dir: Optional[str] = None, model_config: HyenaDNAConfig = default_config) -> None: + def __init__(self, model_dir: Optional[str] = None, configurer: HyenaDNAConfig = default_configurer) -> None: super().__init__() - self.model_config = model_config.config + self.config = configurer.config self.log = logging.getLogger("Hyena-DNA-Model") if model_dir is None: self.downloader = Downloader() - model_path = f"hyena_dna/{self.model_config['model_name']}.ckpt" + model_path = f"hyena_dna/{self.config['model_name']}.ckpt" self.downloader.download_via_name(model_path) self.model_path = Path(os.path.join(self.downloader.CACHE_DIR_HELICAL, model_path)) else: - self.model_path = Path(os.path.join(model_dir, f"{self.model_config['model_name']}.ckpt")) + self.model_path = Path(os.path.join(model_dir, f"{self.config['model_name']}.ckpt"))