diff --git a/README.md b/README.md index c33d083..415bcc1 100644 --- a/README.md +++ b/README.md @@ -183,8 +183,9 @@ Defaults are set in both [`config/defaults.yaml`](config/defaults.yaml) and [`sr > ``` > > ```shell -> python main.py method=fedprox # fedprox.mu = 1 -> python main.py --config-name your_config method=fedprox # fedprox.mu = 0.01 +> python main.py method=fedprox # fedprox.mu = 1 +> python main.py --config-name your_config method=fedprox # fedprox.mu = 0.01 +> python main.py --config-name your_config method=fedprox fedprox.mu=0.001 # fedprox.mu = 0.001 > ``` ### Monitor 📈 diff --git a/src/utils/tools.py b/src/utils/tools.py index 7d99c0c..1b73d76 100644 --- a/src/utils/tools.py +++ b/src/utils/tools.py @@ -150,8 +150,17 @@ def _merge_configs(defaults: DictConfig, config: DictConfig) -> DictConfig: final_args = _merge_configs(final_args, config) + if hasattr(config, method_name): + final_args[method_name] = config[method_name] + if get_method_args_func is not None: - final_args[method_name] = DictConfig(get_method_args_func([]).__dict__) + default_method_args = DictConfig(get_method_args_func([]).__dict__) + if hasattr(final_args, method_name): + for key in default_method_args.keys(): + if key not in final_args[method_name].keys(): + final_args[method_name][key] = default_method_args[key] + else: + final_args[method_name] = default_method_args assert final_args.mode in [ "serial",