Skip to content

Commit

Permalink
Update postprocessing package to work with plugins
Browse files Browse the repository at this point in the history
- diagnostic plugins created with the cookiecutter are now correctly
recognized and implemented
  • Loading branch information
Felix Erdmann committed Jan 29, 2025
1 parent f9bf9cf commit 13922fc
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 38 deletions.
6 changes: 3 additions & 3 deletions pysteps/postprocessing/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
The methods in this module implement the following interface::
diagnostics_xxx(optional arguments)
diagnostic_xxx(optional arguments)
where **xxx** is the name of the diagnostic to be applied.
Expand All @@ -19,9 +19,9 @@
"""


def diagnostics_example1(filename, **kwargs):
def diagnostic_example1(filename, **kwargs):
return "Hello, I am an example diagnostics postprocessor."


def diagnostics_example2(filename, **kwargs):
def diagnostic_example2(filename, **kwargs):
return [[42, 42], [42, 42]]
4 changes: 2 additions & 2 deletions pysteps/postprocessing/ensemblestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def banddepth(X, thr=None, norm=False):
return depth


def ensemblestats_example1(filename, **kwargs):
def ensemblestat_example1(filename, **kwargs):
return "Hello, I am an example of postprocessing ensemble statistics."


def ensemblestats_example2(filename, **kwargs):
def ensemblestat_example2(filename, **kwargs):
return [[42, 42], [42, 42]]
68 changes: 35 additions & 33 deletions pysteps/postprocessing/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ def add_postprocessor(
@param attributes: the existing functions in the selected module
"""

# module_name ends with an "s", the function prefix is without "s"
function_prefix = module[:-1]
# get funtion name without mo
short_name = postprocessors_function_name.replace(f"{function_prefix}_", "")
short_name = postprocessors_function_name.replace(f"{module}_", "")
if short_name not in methods_dict:
methods_dict[short_name] = _postprocessors
else:
Expand Down Expand Up @@ -86,34 +84,36 @@ def discover_postprocessors():

importlib.reload(pkg_resources)

for entry_point in pkg_resources.iter_entry_points(
group="pysteps.plugins.postprocessors", name=None
):
_postprocessors = entry_point.load()

postprocessors_function_name = _postprocessors.__name__


if "diagnostic" in entry_point.module_name:
add_postprocessor(
postprocessors_function_name,
_postprocessors,
_diagnostics_methods,
"diagnostics",
entry_point.attrs,
)
elif "ensemblestat" in entry_point.module_name:
add_postprocessor(
postprocessors_function_name,
_postprocessors,
_ensemblestats_methods,
"ensemblestats",
entry_point.attrs,
)
else:
raise ValueError(
f"Unknown module {entry_point.module_name} in the entrypoint {entry_point.name}"
)
# Discover the postprocessors available in the plugins
for plugintype in ["diagnostic", "ensemblestat"]:
for entry_point in pkg_resources.iter_entry_points(
group=f"pysteps.plugins.{plugintype}", name=None
):
_postprocessors = entry_point.load()

postprocessors_function_name = _postprocessors.__name__


if "diagnostic" in entry_point.module_name:
add_postprocessor(
postprocessors_function_name,
_postprocessors,
_diagnostics_methods,
"diagnostics",
entry_point.attrs,
)
elif "ensemblestat" in entry_point.module_name:
add_postprocessor(
postprocessors_function_name,
_postprocessors,
_ensemblestats_methods,
"ensemblestats",
entry_point.attrs,
)
else:
raise ValueError(
f"Unknown module {entry_point.module_name} in the entrypoint {entry_point.name}"
)


def print_postprocessors_info(module_name, interface_methods, module_methods):
Expand Down Expand Up @@ -164,7 +164,7 @@ def postprocessors_info():

available_postprocessors = set()
postprocessors_in_the_interface = set()
# Discover the postprocessors available in the plugins
# List the plugins that have been added to the postprocessing.[plugintype] module
for plugintype in ["diagnostics", "ensemblestats"]:
interface_methods = (
_diagnostics_methods
Expand All @@ -175,8 +175,10 @@ def postprocessors_info():
available_module_methods = [
attr
for attr in dir(importlib.import_module(module_name))
if attr.startswith("postprocessors")
if attr.startswith(plugintype[:-1])
]
# add the pre-existing ensemblestats functions (see _ensemblestats_methods above)
if "ensemblestats" in plugintype: available_module_methods += ["mean","excprob","banddepth"]
print_postprocessors_info(
module_name, interface_methods, available_module_methods
)
Expand Down

0 comments on commit 13922fc

Please sign in to comment.