Skip to content

Commit

Permalink
Merge pull request #429 from cisco/vidamoda/ignore_params_torch_crf
Browse files Browse the repository at this point in the history
Fixing torch-crf ignoring params
  • Loading branch information
vrdn-23 authored Aug 2, 2022
2 parents 834e23f + 8fdc245 commit cc107cf
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions mindmeld/components/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,8 @@ def get_classifier_config(
copy.deepcopy(getattr(module_conf, attr_name)))
except AttributeError:
try:
result = copy.deepcopy(
getattr(module_conf, CONFIG_DEPRECATION_MAPPING[attr_name])
result = merge_param_configs(_get_default_classifier_config(clf_type), copy.deepcopy(
getattr(module_conf, CONFIG_DEPRECATION_MAPPING[attr_name]))
)
msg = (
"%s config is deprecated. Please use the equivalent %s config "
Expand Down
4 changes: 2 additions & 2 deletions mindmeld/components/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def _get_model_config(loaded_config=None, **kwargs):
Returns:
ModelConfig: The model configuration corresponding to the provided config name
"""
if 'params' in loaded_config and 'params' in kwargs:
kwargs['params'] = {**loaded_config['params'], **kwargs['params']}
try:
# If all params required for model config were passed in, use kwargs
return ModelConfig(**kwargs)
Expand All @@ -411,8 +413,6 @@ def _get_model_config(loaded_config=None, **kwargs):
if not loaded_config:
logger.warning("loaded_config is not passed in")
model_config = loaded_config or {}
if 'params' in model_config and 'params' in kwargs:
kwargs['params'].update(model_config['params'])

model_config.update(kwargs)

Expand Down
13 changes: 10 additions & 3 deletions mindmeld/models/taggers/pytorch_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,15 @@ def save_best_weights_path(self, path):
else:
raise MindMeldError("CRF weights not saved. Please re-train model from scratch.")

def validate_params(self):
def validate_params(self, kwargs):
"""Validate the argument values saved into the CRF model. """
for key in kwargs:
msg = (
"Unexpected param `{param}`, dropping it from model config.".format(
param=key
)
)
logger.warning(msg)
if self.optimizer not in ["sgd", "adam"]:
raise MindMeldError(
f"Optimizer type {self.optimizer_type} not supported. Supported options are ['sgd', 'adam']")
Expand Down Expand Up @@ -431,7 +438,7 @@ def compute_marginal_probabilities(self, inputs, mask):
# pylint: disable=too-many-arguments
def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split=True, drop_input=0.2, batch_size=8,
number_of_epochs=100, patience=3, dev_split_ratio=0.2, optimizer="sgd",
random_state=None):
random_state=None, **kwargs):
"""Set the parameters for the PyTorch CRF model and also validates the parameters.
Args:
Expand Down Expand Up @@ -459,7 +466,7 @@ def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split=
self.optimizer = optimizer # ["sgd", "adam"]
self.random_state = random_state or randint(1, 10000001)

self.validate_params()
self.validate_params(kwargs)

logger.debug("Random state for torch-crf is %s", self.random_state)
if self.feat_type == "dict":
Expand Down
3 changes: 1 addition & 2 deletions tests/models/test_tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_get_boundary_counts_sequential(

@pytest.mark.parametrize(
"model_type,params",
[("memm", {"penalty": "l2", "C": 10000}), ("crf", {"c1": 0.01, "c2": 0.01}), ("torch-crf", {"feat_type": "dict"}),
[("memm", {"penalty": "l2", "C": 10000}), ("torch-crf", {"feat_type": "dict"}),
("torch-crf", {"feat_type": "hash"})],
)
def test_view_extracted_features(kwik_e_mart_nlp, model_type, params):
Expand Down Expand Up @@ -311,7 +311,6 @@ def test_view_extracted_features(kwik_e_mart_nlp, model_type, params):
"query,model_type,params",
[
("Main st store hours", "memm", {"penalty": "l2", "C": 10000}),
("Main st store hours", "crf", {"c1": 0.01, "c2": 0.01}),
("Main st store hours", "torch-crf", {"feat_type": "dict"}),
("Main st store hours", "torch-crf", {"feat_type": "hash"})
],
Expand Down

0 comments on commit cc107cf

Please sign in to comment.