Skip to content

Commit

Permalink
Add more unittests to hyena configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 20, 2024
1 parent 5b17f54 commit c1ee014
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
39 changes: 39 additions & 0 deletions ci/tests/test_hyena_dna/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from helical.models.hyena_dna.model import HyenaDNA,HyenaDNAConfig
import pytest
import pytest

@pytest.mark.parametrize("model_name, d_model, d_inner", [
("hyenadna-tiny-1k-seqlen", 128, 512),
("hyenadna-tiny-1k-seqlen-d256", 256, 1024)
])
def test_hyena_dna__ok(model_name, d_model, d_inner):
"""
Test case for the HyenaDNA class initialization.
Args:
model_name (str): The name of the model.
d_model (int): The dimensionality of the model.
d_inner (int): The dimensionality of the inner layers.
"""
configurer = HyenaDNAConfig(model_name=model_name)
model = HyenaDNA(configurer=configurer)
assert model.config["model_name"] == model_name
assert model.config["d_model"] == d_model
assert model.config["d_inner"] == d_inner

@pytest.mark.parametrize("model_name", [
("wrong_name")
])
def test_hyena_dna__nok(model_name):
"""
Test case when an invalid model name is provided.
Verifies that a ValueError is raised when an invalid model name is passed to the HyenaDNAConfig constructor.
Parameters:
- model_name (str): The invalid model name.
Raises:
- ValueError: If the model name is invalid.
"""
with pytest.raises(ValueError):
HyenaDNAConfig(model_name=model_name)
10 changes: 5 additions & 5 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@

class HyenaDNA(HelicalBaseModel):
"""HyenaDNA model."""
default_config = HyenaDNAConfig()
default_configurer = HyenaDNAConfig()

def __init__(self, model_dir: Optional[str] = None, model_config: HyenaDNAConfig = default_config) -> None:
def __init__(self, model_dir: Optional[str] = None, configurer: HyenaDNAConfig = default_configurer) -> None:
super().__init__()
self.model_config = model_config.config
self.config = configurer.config
self.log = logging.getLogger("Hyena-DNA-Model")

if model_dir is None:
self.downloader = Downloader()
model_path = f"hyena_dna/{self.model_config['model_name']}.ckpt"
model_path = f"hyena_dna/{self.config['model_name']}.ckpt"
self.downloader.download_via_name(model_path)
self.model_path = Path(os.path.join(self.downloader.CACHE_DIR_HELICAL, model_path))
else:
self.model_path = Path(os.path.join(model_dir, f"{self.model_config['model_name']}.ckpt"))
self.model_path = Path(os.path.join(model_dir, f"{self.config['model_name']}.ckpt"))



Expand Down

0 comments on commit c1ee014

Please sign in to comment.