Skip to content

Commit

Permalink
loading vectorizer and label_enc in _load_model
Browse files Browse the repository at this point in the history
  • Loading branch information
saanikat committed Sep 16, 2024
1 parent b71112f commit 50bc889
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 141 deletions.
3 changes: 3 additions & 0 deletions bedms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
"""
This module initializes 'bedms' package.
"""
from .attr_standardizer import AttrStandardizer
98 changes: 74 additions & 24 deletions bedms/attr_standardizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""
This module has the class AttrStandardizer for 'bedms'.
"""
import logging
from typing import Dict, Tuple, Union

import pickle
import peppy
import torch
import torch.nn as nn
from torch import nn
import torch.nn.functional as torch_functional

from .const import (
Expand All @@ -20,6 +23,13 @@
OUTPUT_SIZE_FAIRTRACKS,
PROJECT_NAME,
SENTENCE_TRANSFORMER_MODEL,
REPO_ID,
ENCODE_VECTORIZER_FILENAME,
ENCODE_LABEL_ENCODER_FILENAME,
FAIRTRACKS_VECTORIZER_FILENAME,
FAIRTRACKS_LABEL_ENCODER_FILENAME,
BEDBASE_VECTORIZER_FILENAME,
BEDBASE_LABEL_ENCODER_FILENAME,
)
from .model import BoWSTModel
from .utils import (
Expand All @@ -28,13 +38,17 @@
fetch_from_pephub,
get_any_pep,
load_from_huggingface,
hf_hub_download,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(PROJECT_NAME)


class AttrStandardizer:
"""
This is the AttrStandardizer class which holds the models for Attribute Standardization.
"""
def __init__(self, schema: str, confidence: int = CONFIDENCE_THRESHOLD) -> None:
"""
Initializes the attribute standardizer with user provided schema, loads the model.
Expand All @@ -43,7 +57,7 @@ def __init__(self, schema: str, confidence: int = CONFIDENCE_THRESHOLD) -> None:
:param int confidence: Confidence threshold for the predictions.
"""
self.schema = schema
self.model = self._load_model()
self.model, self.vectorizer, self.label_encoder = self._load_model()
self.conf_threshold = confidence

def _get_parameters(self) -> Tuple[int, int, int, int, int, float]:
Expand All @@ -61,7 +75,7 @@ def _get_parameters(self) -> Tuple[int, int, int, int, int, float]:
OUTPUT_SIZE_ENCODE,
DROPOUT_PROB,
)
elif self.schema == "FAIRTRACKS":
if self.schema == "FAIRTRACKS":
return (
INPUT_SIZE_BOW_FAIRTRACKS,
EMBEDDING_SIZE,
Expand All @@ -70,7 +84,7 @@ def _get_parameters(self) -> Tuple[int, int, int, int, int, float]:
OUTPUT_SIZE_FAIRTRACKS,
DROPOUT_PROB,
)
elif self.schema == "BEDBASE":
if self.schema == "BEDBASE":
return (
INPUT_SIZE_BOW_BEDBASE,
EMBEDDING_SIZE,
Expand All @@ -79,17 +93,50 @@ def _get_parameters(self) -> Tuple[int, int, int, int, int, float]:
OUTPUT_SIZE_BEDBASE,
DROPOUT_PROB,
)
else:
raise ValueError(
f"Schema not available: {self.schema}. Presently, three schemas are available: ENCODE , FAIRTRACKS, BEDBASE"

raise ValueError(
f"Schema not available: {self.schema}."
"Presently, three schemas are available: ENCODE , FAIRTRACKS, BEDBASE"
)

def _load_model(self) -> nn.Module:
def _load_model(self) -> tuple[nn.Module, object, object]:
"""
Calls function to load the model from HuggingFace repository and sets to eval().
Calls function to load the model from HuggingFace repository
load vectorizer and label encoder and sets to eval().
:return nn.Module: Loaded Neural Network Model.
:return object: The scikit learn vectorizer for bag of words encoding.
:return object: Label encoder object for the labels (y).
"""
try:
if self.schema == "ENCODE":
filename_vc = ENCODE_VECTORIZER_FILENAME
filename_lb = ENCODE_LABEL_ENCODER_FILENAME
elif self.schema == "FAIRTRACKS":
filename_vc = FAIRTRACKS_VECTORIZER_FILENAME
filename_lb = FAIRTRACKS_LABEL_ENCODER_FILENAME
elif self.schema == "BEDBASE":
filename_vc = BEDBASE_VECTORIZER_FILENAME
filename_lb = BEDBASE_LABEL_ENCODER_FILENAME

vectorizer = None
label_encoder = None

vc_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename_vc,
)

with open(vc_path, "rb") as f:
vectorizer = pickle.load(f)

lb_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename_lb,
)

with open(lb_path, "rb") as f:
label_encoder = pickle.load(f)

model = load_from_huggingface(self.schema)
state_dict = torch.load(model)

Expand All @@ -112,7 +159,7 @@ def _load_model(self) -> nn.Module:
)
model.load_state_dict(state_dict)
model.eval()
return model
return model, vectorizer, label_encoder

except Exception as e:
logger.error(f"Error loading the model: {str(e)}")
Expand All @@ -122,7 +169,9 @@ def standardize(
self, pep: Union[str, peppy.Project]
) -> Dict[str, Dict[str, float]]:
"""
Fetches the user provided PEP from the PEPHub registry path, returns the predictions.
Fetches the user provided PEP
from the PEPHub registry path,
returns the predictions.
:param str pep: peppy.Project object or PEPHub registry path to PEP.
:return Dict[str, Dict[str, float]]: Suggestions to the user.
Expand All @@ -138,30 +187,31 @@ def standardize(
try:
csv_file = fetch_from_pephub(pep)

X_values_st, X_headers_st, X_values_bow, num_rows = data_preprocessing(
x_values_st, x_headers_st, x_values_bow, num_rows = data_preprocessing(
csv_file
)
(
X_headers_embeddings_tensor,
X_values_embeddings_tensor,
X_values_bow_tensor,
x_headers_embeddings_tensor,
x_values_embeddings_tensor,
x_values_bow_tensor,
label_encoder,
) = data_encoding(
self.vectorizer,
self.label_encoder,
num_rows,
X_values_st,
X_headers_st,
X_values_bow,
self.schema,
x_values_st,
x_headers_st,
x_values_bow,
model_name=SENTENCE_TRANSFORMER_MODEL,
)

logger.info("Data Preprocessing completed.")

with torch.no_grad():
outputs = self.model(
X_values_bow_tensor,
X_values_embeddings_tensor,
X_headers_embeddings_tensor,
x_values_bow_tensor,
x_values_embeddings_tensor,
x_headers_embeddings_tensor,
)
probabilities = torch_functional.softmax(outputs, dim=1)

Expand All @@ -174,7 +224,7 @@ def standardize(
]

suggestions = {}
for i, category in enumerate(X_headers_st):
for i, category in enumerate(x_headers_st):
category_suggestions = {}
if top_confidences[i][0] >= self.conf_threshold:
for j in range(3):
Expand Down
6 changes: 5 additions & 1 deletion bedms/const.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""
This module contains constant values used in the 'bedms' package.
"""

PROJECT_NAME = "bedmess"

AVAILABLE_SCHEMAS = ["ENCODE", "FAIRTRACKS", "BEDBASE"]

PEP_FILE_TYPES = ["yaml", "csv"]
REPO_ID = "databio/attribute-standardizer-model6"
MODEL_ENCODE = "model_encode.pth"
MODEL_FAIRTRACKS = "model_fairtracks.pth"
Expand Down
6 changes: 4 additions & 2 deletions bedms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def __init__(
Initializes the BoWSTModel.
:param int input_size_values: Size of the input for the values (BoW).
:param int inout_size_values_embeddings: Size of the input for the values sentence transformer embeddings.
:param int input_size_headers: Size of the input for the headers with sentence transformer embeddings.
:param int inout_size_values_embeddings: Size of the input
for the values sentence transformer embeddings.
:param int input_size_headers: Size of the input
for the headers with sentence transformer embeddings.
:param int hidden_size: Size of the hidden layer.
:param int output_size: Size of the output layer.
:param float dropout_prob: Dropout probability for regularization.
Expand Down
Loading

0 comments on commit 50bc889

Please sign in to comment.