From 04ffd97ab870a6bf0a1fff5da189162de2d93d5e Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Fri, 22 Dec 2023 15:56:54 -0800 Subject: [PATCH] Fix type conditional check and `UnboundLocalError` for `params_results` (#770) --- e3sm_diags/run.py | 18 ++++++++++++------ tests/integration/test_all_sets_image_diffs.py | 8 ++++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/e3sm_diags/run.py b/e3sm_diags/run.py index c92f5ee83..370086df8 100644 --- a/e3sm_diags/run.py +++ b/e3sm_diags/run.py @@ -1,6 +1,6 @@ import copy from itertools import chain -from typing import List +from typing import List, Union import e3sm_diags # noqa: F401 from e3sm_diags.e3sm_diags_driver import get_default_diags_path, main @@ -47,7 +47,7 @@ def is_cfg_file_arg_set(self): def run_diags( self, parameters: List[CoreParameter], use_cfg: bool = True - ) -> List[CoreParameter]: + ) -> Union[List[CoreParameter], None]: """Run a set of diagnostics with a list of parameters. Parameters @@ -68,8 +68,8 @@ def run_diags( Returns ------- - List[CoreParameter] - A list of parameter objects with their results. + Union[List[CoreParameter], None] + A list of parameter objects with their results (if successful). Raises ------ @@ -77,6 +77,7 @@ def run_diags( If a diagnostic run using a parameter fails for any reason. """ params = self.get_run_parameters(parameters, use_cfg) + params_results = None if params is None or len(params) == 0: raise RuntimeError( @@ -89,7 +90,9 @@ def run_diags( except Exception: logger.exception("Error traceback:", exc_info=True) - move_log_to_prov_dir(params_results[0].results_dir) + # param_results might be None because the run(s) failed, so move + # the log using the `params[0].results_dir` instead. + move_log_to_prov_dir(params[0].results_dir) return params_results @@ -449,7 +452,10 @@ def _get_instance_of_param_class(self, cls, parameters): for cls_type in class_types: for p in parameters: - if isinstance(p, cls_type): + # NOTE: This conditional is used instead of + # `isinstance(p, cls_type)` because we want to check for exact + # type matching and exclude sub-class matching. + if type(p) is cls_type: return p msg = "There's weren't any class of types {} in your parameters." diff --git a/tests/integration/test_all_sets_image_diffs.py b/tests/integration/test_all_sets_image_diffs.py index 2373d2ad4..0e15be4e8 100644 --- a/tests/integration/test_all_sets_image_diffs.py +++ b/tests/integration/test_all_sets_image_diffs.py @@ -34,10 +34,14 @@ def run_diags_and_get_results_dir() -> str: params = _get_test_params() results = runner.run_diags(params) - results_dir = results[0].results_dir + # If results is None then that means some/all diagnostic set(s) failed. + # We use params[0].results_dir to check if any diagnostic sets passed. + if results is not None: + results_dir = results[0].results_dir + else: + results_dir = params[0].results_dir logger.info(f"results_dir={results_dir}") - return results_dir