Skip to content

Commit

Permalink
Merge pull request #427 from cisco/vidamoda/change_param_behaviour
Browse files Browse the repository at this point in the history
Change behavior of param and param selection merging
  • Loading branch information
vrdn-23 authored Jul 22, 2022
2 parents 238412f + b85faf9 commit 562dc8d
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 16 deletions.
31 changes: 26 additions & 5 deletions mindmeld/components/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
DEFAULT_DOMAIN_CLASSIFIER_CONFIG = {
"model_type": "text",
"model_settings": {"classifier_type": "logreg"},
"params": {
"solver": "liblinear",
},
"param_selection": {
"type": "k-fold",
"k": 10,
Expand All @@ -55,6 +58,9 @@
DEFAULT_INTENT_CLASSIFIER_CONFIG = {
"model_type": "text",
"model_settings": {"classifier_type": "logreg"},
"params": {
"solver": "liblinear",
},
"param_selection": {
"type": "k-fold",
"k": 10,
Expand All @@ -80,6 +86,9 @@
"tag_scheme": "IOB",
"feature_scaler": "max-abs",
},
"params": {
"solver": "liblinear",
},
"param_selection": {
"type": "k-fold",
"k": 5,
Expand Down Expand Up @@ -109,7 +118,7 @@
DEFAULT_ROLE_CLASSIFIER_CONFIG = {
"model_type": "text",
"model_settings": {"classifier_type": "logreg"},
"params": {"C": 100, "penalty": "l1"},
"params": {"C": 100, "penalty": "l1", "solver": "liblinear"},
"features": {
"bag-of-words-before": {
"ngram_lengths_to_start_positions": {1: [-2, -1], 2: [-2, -1]}
Expand Down Expand Up @@ -549,6 +558,17 @@ class NlpConfigError(Exception):
pass


def merge_param_configs(default_dict, user_defined_dict):
new_dict = dict(user_defined_dict)
if "params" not in default_dict:
return new_dict
if "params" in user_defined_dict:
new_dict["params"] = {**default_dict["params"], **user_defined_dict["params"]}
else:
new_dict["params"] = default_dict["params"]
return new_dict


def get_custom_action_config(app_path):
if not app_path:
return None
Expand Down Expand Up @@ -669,8 +689,8 @@ def get_system_entity_url_config(app_path):

return (
get_nlp_config(app_path)
.get("system_entity_recognizer", {})
.get("url", DEFAULT_DUCKLING_URL)
.get("system_entity_recognizer", {})
.get("url", DEFAULT_DUCKLING_URL)
)


Expand Down Expand Up @@ -743,7 +763,7 @@ def get_classifier_config(
try:
raw_args = {"domain": domain, "intent": intent, "entity": entity}
args = {k: raw_args[k] for k in func_args}
return copy.deepcopy(func(**args))
return merge_param_configs(_get_default_classifier_config(clf_type), copy.deepcopy(func(**args)))
except Exception as exc: # pylint: disable=broad-except
# Note: this is intentionally broad -- provider could raise any exception
logger.warning(
Expand All @@ -759,7 +779,8 @@ def get_classifier_config(
"question_answering": "QUESTION_ANSWERER_CONFIG",
}[clf_type]
try:
return copy.deepcopy(getattr(module_conf, attr_name))
return merge_param_configs(_get_default_classifier_config(clf_type),
copy.deepcopy(getattr(module_conf, attr_name)))
except AttributeError:
try:
result = copy.deepcopy(
Expand Down
7 changes: 3 additions & 4 deletions mindmeld/components/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,11 @@ def _get_model_config(loaded_config=None, **kwargs):
if not loaded_config:
logger.warning("loaded_config is not passed in")
model_config = loaded_config or {}
if 'params' in model_config and 'params' in kwargs:
kwargs['params'].update(model_config['params'])

model_config.update(kwargs)

# If a parameter selection grid was passed in at runtime, override params set in the
# application specified or default config
if kwargs.get("param_selection") and not kwargs.get("params"):
model_config.pop("params", None)
return ModelConfig(**model_config)

def dump(self, model_path, incremental_model_path=None):
Expand Down
12 changes: 11 additions & 1 deletion mindmeld/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def _fit(self, examples, labels, params=None):
def _get_model_constructor(self):
raise NotImplementedError

def _fit_cv(self, examples, labels, groups=None, selection_settings=None):
def _fit_cv(self, examples, labels, groups=None, selection_settings=None, fixed_params=None):
"""Called by the fit method when cross validation parameters are passed in. Runs cross
validation and returns the best estimator and parameters.
Expand Down Expand Up @@ -463,6 +463,16 @@ def _fit_cv(self, examples, labels, groups=None, selection_settings=None):

param_grid = self._convert_params(selection_settings["grid"], labels)
model_class = self._get_model_constructor()
if fixed_params:
for key, val in fixed_params.items():
if key not in param_grid:
param_grid[key] = [val]
else:
logger.info(
"Found parameter %s both in params and param_selection. Proceeding with param_selection.. \
(If you did not set this, it could be a Mindmeld default.)",
key
)
estimator, param_grid = self._get_cv_estimator_and_params(
model_class, param_grid
)
Expand Down
4 changes: 2 additions & 2 deletions mindmeld/models/tagger_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def fit(self, examples, labels, params=None):
labels (ProcessedQueryList.EntitiesIterator): A list of expected labels.
params (dict): Parameters of the classifier.
"""
skip_param_selection = params is not None or self.config.param_selection is None
skip_param_selection = self.config.param_selection is None
params = params or self.config.params

# Shuffle to prevent order effects
Expand Down Expand Up @@ -253,7 +253,7 @@ def fit(self, examples, labels, params=None):
if isinstance(self._clf, non_supported_classes):
raise MindMeldError(f"The {type(self._clf).__name__} model does not support cross-validation")

_, best_params = self._fit_cv(X, y, groups)
_, best_params = self._fit_cv(X, y, groups, fixed_params=params)
self._clf = self._fit(X, y, best_params)
self._current_params = best_params

Expand Down
4 changes: 2 additions & 2 deletions mindmeld/models/text_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def fit(self, examples, labels, params=None):
interfaces.
"""
params = params or self.config.params
skip_param_selection = params is not None or self.config.param_selection is None
skip_param_selection = self.config.param_selection is None

# Shuffle to prevent order effects
indices = list(range(len(labels)))
Expand All @@ -454,7 +454,7 @@ def fit(self, examples, labels, params=None):
self._current_params = params
else:
# run cross validation to select params
best_clf, best_params = self._fit_cv(X, y, groups)
best_clf, best_params = self._fit_cv(X, y, groups, fixed_params=params)
self._clf = best_clf
self._current_params = best_params

Expand Down
4 changes: 2 additions & 2 deletions tests/components/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_get_classifier_config_func():
"entity", APP_PATH, domain="domain", intent="intent"
)["params"]

expected = {"penalty": "l2", "C": 100}
expected = {"penalty": "l2", "C": 100, "solver": "liblinear"}

assert actual == expected

Expand All @@ -139,7 +139,7 @@ def test_get_classifier_config_func_error():
"params"
]

expected = {"error": "intent", "penalty": "l2", "C": 100}
expected = {"error": "intent", "penalty": "l2", "C": 100, "solver": "liblinear"}

assert actual == expected

Expand Down
1 change: 1 addition & 0 deletions tests/components/test_intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@ def test_intent_classifier_random_forest(kwik_e_mart_app_path, caplog):
mock.assert_any_call(
"Unexpected param `fit_intercept`, dropping it from model config."
)
mock.assert_any_call("Unexpected param `solver`, dropping it from model config.")

0 comments on commit 562dc8d

Please sign in to comment.