Skip to content

Commit

Permalink
fix: update modules to use updated constant name
Browse files Browse the repository at this point in the history
  • Loading branch information
emapco committed Feb 9, 2025
1 parent 30e0d57 commit 0d9b12b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions chem_mrl/benchmark/DatabaseBenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from sqlalchemy import create_engine, text

from chem_mrl.constants import (
BASE_MODEL_HIDDEN_DIM,
BASE_MODEL_NAME,
CHEM_MRL_DIMENSIONS,
EMBEDDING_MODEL_HIDDEN_DIM,
)
from chem_mrl.molecular_embedder import ChemMRL
from chem_mrl.molecular_fingerprinter import MorganFingerprinter
Expand Down Expand Up @@ -144,7 +144,7 @@ def run_benchmark(
model_name: str,
model_mrl_dimensions: list[int] = CHEM_MRL_DIMENSIONS,
base_model_name: str = BASE_MODEL_NAME,
base_model_hidden_dim: int = EMBEDDING_MODEL_HIDDEN_DIM,
base_model_hidden_dim: int = BASE_MODEL_HIDDEN_DIM,
smiles_column_name: str = "smiles",
):
print("Starting benchmark...")
Expand Down
12 changes: 6 additions & 6 deletions chem_mrl/molecular_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from sentence_transformers import SentenceTransformer, models

from chem_mrl.constants import BASE_MODEL_NAME, EMBEDDING_MODEL_HIDDEN_DIM
from chem_mrl.constants import BASE_MODEL_HIDDEN_DIM, BASE_MODEL_NAME


class ChemMRL:
Expand Down Expand Up @@ -33,19 +33,19 @@ def __init__(
"""
if fp_size is not None and fp_size < 32:
raise ValueError("fp_size must be greater than 32")
if fp_size is not None and fp_size > EMBEDDING_MODEL_HIDDEN_DIM:
raise ValueError(f"fp_size must be less than {EMBEDDING_MODEL_HIDDEN_DIM}")
if fp_size is not None and fp_size > BASE_MODEL_HIDDEN_DIM:
raise ValueError(f"fp_size must be less than {BASE_MODEL_HIDDEN_DIM}")
self._model_name = model_name
self._fp_size = fp_size
self._use_half_precision = use_half_precision
self._device = device
self._batch_size = batch_size
if normalize_embeddings is None:
normalize_embeddings = fp_size is not None and fp_size < EMBEDDING_MODEL_HIDDEN_DIM
normalize_embeddings = fp_size is not None and fp_size < BASE_MODEL_HIDDEN_DIM
self._normalize_embeddings = normalize_embeddings

if model_name == BASE_MODEL_NAME:
if fp_size is not None and fp_size != EMBEDDING_MODEL_HIDDEN_DIM:
if fp_size is not None and fp_size != BASE_MODEL_HIDDEN_DIM:
raise ValueError(f"{BASE_MODEL_NAME} only supports embeddings of size 768")
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(
Expand All @@ -56,7 +56,7 @@ def __init__(
device=device,
)
else:
enable_truncate_dim = fp_size is not None and fp_size < EMBEDDING_MODEL_HIDDEN_DIM
enable_truncate_dim = fp_size is not None and fp_size < BASE_MODEL_HIDDEN_DIM
self._model = SentenceTransformer(
model_name,
device=device,
Expand Down

0 comments on commit 0d9b12b

Please sign in to comment.