From 6b45122ebe8c6af9321bc2dec7adc2181f6acf5e Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Tue, 2 Aug 2022 14:01:55 -0700 Subject: [PATCH] Fixing torch-crf ignoring params --- mindmeld/components/_config.py | 4 ++-- mindmeld/components/classifier.py | 4 ++-- mindmeld/models/taggers/pytorch_crf.py | 13 ++++++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mindmeld/components/_config.py b/mindmeld/components/_config.py index fe91c40eb..913681022 100644 --- a/mindmeld/components/_config.py +++ b/mindmeld/components/_config.py @@ -783,8 +783,8 @@ def get_classifier_config( copy.deepcopy(getattr(module_conf, attr_name))) except AttributeError: try: - result = copy.deepcopy( - getattr(module_conf, CONFIG_DEPRECATION_MAPPING[attr_name]) + result = merge_param_configs(_get_default_classifier_config(clf_type), copy.deepcopy( + getattr(module_conf, CONFIG_DEPRECATION_MAPPING[attr_name])) ) msg = ( "%s config is deprecated. Please use the equivalent %s config " diff --git a/mindmeld/components/classifier.py b/mindmeld/components/classifier.py index 590141c02..26550c449 100644 --- a/mindmeld/components/classifier.py +++ b/mindmeld/components/classifier.py @@ -403,6 +403,8 @@ def _get_model_config(loaded_config=None, **kwargs): Returns: ModelConfig: The model configuration corresponding to the provided config name """ + if 'params' in loaded_config and 'params' in kwargs: + kwargs['params'] = {**loaded_config['params'], **kwargs['params']} try: # If all params required for model config were passed in, use kwargs return ModelConfig(**kwargs) @@ -411,8 +413,6 @@ def _get_model_config(loaded_config=None, **kwargs): if not loaded_config: logger.warning("loaded_config is not passed in") model_config = loaded_config or {} - if 'params' in model_config and 'params' in kwargs: - kwargs['params'].update(model_config['params']) model_config.update(kwargs) diff --git a/mindmeld/models/taggers/pytorch_crf.py b/mindmeld/models/taggers/pytorch_crf.py index 11f493431..428b0a544 100644 --- a/mindmeld/models/taggers/pytorch_crf.py +++ b/mindmeld/models/taggers/pytorch_crf.py @@ -294,8 +294,15 @@ def save_best_weights_path(self, path): else: raise MindMeldError("CRF weights not saved. Please re-train model from scratch.") - def validate_params(self): + def validate_params(self, kwargs): """Validate the argument values saved into the CRF model. """ + for key in kwargs: + msg = ( + "Unexpected param `{param}`, dropping it from model config.".format( + param=key + ) + ) + logger.warning(msg) if self.optimizer not in ["sgd", "adam"]: raise MindMeldError( f"Optimizer type {self.optimizer_type} not supported. Supported options are ['sgd', 'adam']") @@ -431,7 +438,7 @@ def compute_marginal_probabilities(self, inputs, mask): # pylint: disable=too-many-arguments def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split=True, drop_input=0.2, batch_size=8, number_of_epochs=100, patience=3, dev_split_ratio=0.2, optimizer="sgd", - random_state=None): + random_state=None, **kwargs): """Set the parameters for the PyTorch CRF model and also validates the parameters. Args: @@ -459,7 +466,7 @@ def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split= self.optimizer = optimizer # ["sgd", "adam"] self.random_state = random_state or randint(1, 10000001) - self.validate_params() + self.validate_params(kwargs) logger.debug("Random state for torch-crf is %s", self.random_state) if self.feat_type == "dict":