Skip to content

Commit

Permalink
Fixing torch-crf ignoring params
Browse files Browse the repository at this point in the history
  • Loading branch information
vrdn-23 committed Aug 2, 2022
1 parent 834e23f commit 6b45122
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
4 changes: 2 additions & 2 deletions mindmeld/components/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,8 @@ def get_classifier_config(
copy.deepcopy(getattr(module_conf, attr_name)))
except AttributeError:
try:
result = copy.deepcopy(
getattr(module_conf, CONFIG_DEPRECATION_MAPPING[attr_name])
result = merge_param_configs(_get_default_classifier_config(clf_type), copy.deepcopy(
getattr(module_conf, CONFIG_DEPRECATION_MAPPING[attr_name]))
)
msg = (
"%s config is deprecated. Please use the equivalent %s config "
Expand Down
4 changes: 2 additions & 2 deletions mindmeld/components/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def _get_model_config(loaded_config=None, **kwargs):
Returns:
ModelConfig: The model configuration corresponding to the provided config name
"""
if 'params' in loaded_config and 'params' in kwargs:
kwargs['params'] = {**loaded_config['params'], **kwargs['params']}
try:
# If all params required for model config were passed in, use kwargs
return ModelConfig(**kwargs)
Expand All @@ -411,8 +413,6 @@ def _get_model_config(loaded_config=None, **kwargs):
if not loaded_config:
logger.warning("loaded_config is not passed in")
model_config = loaded_config or {}
if 'params' in model_config and 'params' in kwargs:
kwargs['params'].update(model_config['params'])

model_config.update(kwargs)

Expand Down
13 changes: 10 additions & 3 deletions mindmeld/models/taggers/pytorch_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,15 @@ def save_best_weights_path(self, path):
else:
raise MindMeldError("CRF weights not saved. Please re-train model from scratch.")

def validate_params(self):
def validate_params(self, kwargs):
"""Validate the argument values saved into the CRF model. """
for key in kwargs:
msg = (
"Unexpected param `{param}`, dropping it from model config.".format(
param=key
)
)
logger.warning(msg)
if self.optimizer not in ["sgd", "adam"]:
raise MindMeldError(
f"Optimizer type {self.optimizer_type} not supported. Supported options are ['sgd', 'adam']")
Expand Down Expand Up @@ -431,7 +438,7 @@ def compute_marginal_probabilities(self, inputs, mask):
# pylint: disable=too-many-arguments
def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split=True, drop_input=0.2, batch_size=8,
number_of_epochs=100, patience=3, dev_split_ratio=0.2, optimizer="sgd",
random_state=None):
random_state=None, **kwargs):
"""Set the parameters for the PyTorch CRF model and also validates the parameters.
Args:
Expand Down Expand Up @@ -459,7 +466,7 @@ def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split=
self.optimizer = optimizer # ["sgd", "adam"]
self.random_state = random_state or randint(1, 10000001)

self.validate_params()
self.validate_params(kwargs)

logger.debug("Random state for torch-crf is %s", self.random_state)
if self.feat_type == "dict":
Expand Down

0 comments on commit 6b45122

Please sign in to comment.