From 13922fc47c1171a5908eff36f0cb2a4d286f9794 Mon Sep 17 00:00:00 2001 From: Felix Erdmann Date: Wed, 29 Jan 2025 18:30:57 +0100 Subject: [PATCH] Update postprocessing package to work with plugins - diagnostic plugins created with the cookiecutter are now correctly recognized and implemented --- pysteps/postprocessing/diagnostics.py | 6 +-- pysteps/postprocessing/ensemblestats.py | 4 +- pysteps/postprocessing/interface.py | 68 +++++++++++++------------ 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/pysteps/postprocessing/diagnostics.py b/pysteps/postprocessing/diagnostics.py index 2093f43d..ac7802e1 100644 --- a/pysteps/postprocessing/diagnostics.py +++ b/pysteps/postprocessing/diagnostics.py @@ -6,7 +6,7 @@ The methods in this module implement the following interface:: - diagnostics_xxx(optional arguments) + diagnostic_xxx(optional arguments) where **xxx** is the name of the diagnostic to be applied. @@ -19,9 +19,9 @@ """ -def diagnostics_example1(filename, **kwargs): +def diagnostic_example1(filename, **kwargs): return "Hello, I am an example diagnostics postprocessor." -def diagnostics_example2(filename, **kwargs): +def diagnostic_example2(filename, **kwargs): return [[42, 42], [42, 42]] diff --git a/pysteps/postprocessing/ensemblestats.py b/pysteps/postprocessing/ensemblestats.py index 2093bf22..8249ac66 100644 --- a/pysteps/postprocessing/ensemblestats.py +++ b/pysteps/postprocessing/ensemblestats.py @@ -179,9 +179,9 @@ def banddepth(X, thr=None, norm=False): return depth -def ensemblestats_example1(filename, **kwargs): +def ensemblestat_example1(filename, **kwargs): return "Hello, I am an example of postprocessing ensemble statistics." -def ensemblestats_example2(filename, **kwargs): +def ensemblestat_example2(filename, **kwargs): return [[42, 42], [42, 42]] diff --git a/pysteps/postprocessing/interface.py b/pysteps/postprocessing/interface.py index 0a51182f..6be8139d 100644 --- a/pysteps/postprocessing/interface.py +++ b/pysteps/postprocessing/interface.py @@ -48,10 +48,8 @@ def add_postprocessor( @param attributes: the existing functions in the selected module """ - # module_name ends with an "s", the function prefix is without "s" - function_prefix = module[:-1] # get funtion name without mo - short_name = postprocessors_function_name.replace(f"{function_prefix}_", "") + short_name = postprocessors_function_name.replace(f"{module}_", "") if short_name not in methods_dict: methods_dict[short_name] = _postprocessors else: @@ -86,34 +84,36 @@ def discover_postprocessors(): importlib.reload(pkg_resources) - for entry_point in pkg_resources.iter_entry_points( - group="pysteps.plugins.postprocessors", name=None - ): - _postprocessors = entry_point.load() - - postprocessors_function_name = _postprocessors.__name__ - - - if "diagnostic" in entry_point.module_name: - add_postprocessor( - postprocessors_function_name, - _postprocessors, - _diagnostics_methods, - "diagnostics", - entry_point.attrs, - ) - elif "ensemblestat" in entry_point.module_name: - add_postprocessor( - postprocessors_function_name, - _postprocessors, - _ensemblestats_methods, - "ensemblestats", - entry_point.attrs, - ) - else: - raise ValueError( - f"Unknown module {entry_point.module_name} in the entrypoint {entry_point.name}" - ) + # Discover the postprocessors available in the plugins + for plugintype in ["diagnostic", "ensemblestat"]: + for entry_point in pkg_resources.iter_entry_points( + group=f"pysteps.plugins.{plugintype}", name=None + ): + _postprocessors = entry_point.load() + + postprocessors_function_name = _postprocessors.__name__ + + + if "diagnostic" in entry_point.module_name: + add_postprocessor( + postprocessors_function_name, + _postprocessors, + _diagnostics_methods, + "diagnostics", + entry_point.attrs, + ) + elif "ensemblestat" in entry_point.module_name: + add_postprocessor( + postprocessors_function_name, + _postprocessors, + _ensemblestats_methods, + "ensemblestats", + entry_point.attrs, + ) + else: + raise ValueError( + f"Unknown module {entry_point.module_name} in the entrypoint {entry_point.name}" + ) def print_postprocessors_info(module_name, interface_methods, module_methods): @@ -164,7 +164,7 @@ def postprocessors_info(): available_postprocessors = set() postprocessors_in_the_interface = set() - # Discover the postprocessors available in the plugins + # List the plugins that have been added to the postprocessing.[plugintype] module for plugintype in ["diagnostics", "ensemblestats"]: interface_methods = ( _diagnostics_methods @@ -175,8 +175,10 @@ def postprocessors_info(): available_module_methods = [ attr for attr in dir(importlib.import_module(module_name)) - if attr.startswith("postprocessors") + if attr.startswith(plugintype[:-1]) ] + # add the pre-existing ensemblestats functions (see _ensemblestats_methods above) + if "ensemblestats" in plugintype: available_module_methods += ["mean","excprob","banddepth"] print_postprocessors_info( module_name, interface_methods, available_module_methods )