Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a decorator for teeplot called teewrap #9

Merged
merged 15 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions teeplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Top-level package for teeplot."""

__author__ = """Matthew Andres Moreno"""
__email__ = 'm.more500@gmail.com'
__version__ = '1.2.0'
__email__ = "m.more500@gmail.com"
__version__ = "1.2.0"
116 changes: 86 additions & 30 deletions teeplot/teeplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import abc, Counter
from contextlib import contextmanager
import copy
import functools
import inspect
import os
import pathlib
import typing
Expand All @@ -16,16 +18,15 @@


def _is_running_on_ci() -> bool:
ci_envs = ['CI', 'TRAVIS', 'GITHUB_ACTIONS', 'GITLAB_CI', 'JENKINS_URL']
ci_envs = ["CI", "TRAVIS", "GITHUB_ACTIONS", "GITLAB_CI", "JENKINS_URL"]
return any(env in os.environ for env in ci_envs)


draftmode: bool = False

oncollision: typext.Literal[
"error", "fix", "ignore", "warn"
] = os.environ.get(
oncollision: typext.Literal["error", "fix", "ignore", "warn"] = os.environ.get(
"TEEPLOT_ONCOLLISION",
"warn" if (_is_running_on_ci() or not hasattr(sys, 'ps1')) else "ignore",
"warn" if (_is_running_on_ci() or not hasattr(sys, "ps1")) else "ignore",
).lower()
if not oncollision in ("error", "fix", "ignore", "warn"):
raise RuntimeError(
Expand All @@ -51,8 +52,8 @@ def _is_running_on_ci() -> bool:
# see https://gecco-2021.sigevo.org/Paper-Submission-Instructions
@matplotlib.rc_context(
{
'pdf.fonttype': 42,
'ps.fonttype': 42,
"pdf.fonttype": 42,
"ps.fonttype": 42,
},
)
def tee(
Expand All @@ -61,18 +62,19 @@ def tee(
teeplot_callback: bool = False,
teeplot_dpi: int = 300,
teeplot_oncollision: typing.Optional[
typext.Literal["error", "fix", "ignore", "warn"]] = None,
typext.Literal["error", "fix", "ignore", "warn"]
] = None,
teeplot_outattrs: typing.Dict[str, str] = {},
teeplot_outdir: str = "teeplots",
teeplot_outinclude: typing.Iterable[str] = tuple(),
teeplot_outexclude: typing.Iterable[str] = tuple(),
teeplot_postprocess: typing.Union[str, typing.Callable] = "",
teeplot_save: typing.Union[typing.Iterable[str], bool] = True,
teeplot_show: typing.Optional[bool] = None,
teeplot_subdir: str = '',
teeplot_subdir: str = "",
teeplot_transparent: bool = True,
teeplot_verbose: bool = True,
**kwargs: typing.Any
**kwargs: typing.Any,
) -> typing.Any:
"""Executes a plotting function and saves the resulting plot to specified
formats using a descriptive filename automatically generated from plotting
Expand Down Expand Up @@ -183,12 +185,11 @@ def tee(
elif isinstance(teeplot_save, str):
if not teeplot_save in formats:
raise ValueError(
f"only {[*formats]} save formats are supported, "
f"not {teeplot_save}",
f"only {[*formats]} save formats are supported, " f"not {teeplot_save}",
)
# remove explicitly disabled outputs
blacklist = set(k for k, v in formats.items() if v is False)
exclusions = {teeplot_save} & blacklist
exclusions = {teeplot_save} & blacklist
if teeplot_verbose and exclusions:
print(f"skipping {exclusions}")
teeplot_save = {teeplot_save} - exclusions
Expand All @@ -201,7 +202,7 @@ def tee(
)
# remove explicitly disabled outputs
blacklist = set(k for k, v in formats.items() if v is False)
exclusions = set(teeplot_save) & blacklist
exclusions = set(teeplot_save) & blacklist
if teeplot_verbose and exclusions:
print(f"skipping {exclusions}")
teeplot_save = set(teeplot_save) - exclusions
Expand Down Expand Up @@ -266,29 +267,33 @@ def tee(
incl = [*teeplot_outinclude]
attr_maker = lambda ext: {
**{
slugify(k) : slugify(str(v))
slugify(k): slugify(str(v))
for k, v in kwargs.items()
if isinstance(v, str) or k in incl
if isinstance(v, (str, int, float)) or k in incl
},
**{
'viz' : slugify(plotter.__name__),
'ext' : ext,
"viz": slugify(plotter.__name__),
"ext": ext,
},
**(
{"post": teeplot_postprocess.__name__}
if teeplot_postprocess and isinstance(teeplot_postprocess, abc.Callable)
else {"post": slugify(teeplot_postprocess)}
if teeplot_postprocess and not teeplot_postprocess.endswith(";")
else {}
else (
{"post": slugify(teeplot_postprocess)}
if teeplot_postprocess and not teeplot_postprocess.endswith(";")
else {}
)
),
**teeplot_outattrs,
}
excl = [*teeplot_outexclude]
out_filenamer = lambda ext: kn.pack({
k : v
for k, v in attr_maker(ext).items()
if not k.startswith('_') and not k in excl
})
out_filenamer = lambda ext: kn.pack(
{
k: v
for k, v in attr_maker(ext).items()
if not k.startswith("_") and not k in excl
}
)

out_folder = pathlib.Path(teeplot_outdir, teeplot_subdir)
out_folder.mkdir(parents=True, exist_ok=True)
Expand All @@ -315,7 +320,7 @@ def save_callback():
count = _history[out_path]
suffix = f"ext={ext}"
assert str(out_path).endswith(suffix)
out_path = str(out_path)[:-len(suffix)] + f"#={count}+" + suffix
out_path = str(out_path)[: -len(suffix)] + f"#={count}+" + suffix
elif teeplot_oncollision == "ignore":
pass
elif teeplot_oncollision == "warn":
Expand All @@ -333,7 +338,7 @@ def save_callback():
print(out_path)
plt.savefig(
str(out_path),
bbox_inches='tight',
bbox_inches="tight",
transparent=teeplot_transparent,
dpi=teeplot_dpi,
# see https://matplotlib.org/2.1.1/users/whats_new.html#reproducible-ps-pdf-and-svg-output
Expand All @@ -347,7 +352,7 @@ def save_callback():
},
)

if teeplot_show or (teeplot_show is None and hasattr(sys, 'ps1')):
if teeplot_show or (teeplot_show is None and hasattr(sys, "ps1")):
plt.show()

return teed
Expand All @@ -359,7 +364,7 @@ def save_callback():


@contextmanager
def teed(*args: list, **kwargs: dict):
def teed(*args, **kwargs):
"""Context manager interface to `teeplot.tee`.

Plot save is dispatched upon exiting the context. Return value is the
Expand All @@ -377,3 +382,54 @@ def teed(*args: list, **kwargs: dict):
yield handle
finally:
saveit()


def validate_teewrap_kwargs(teeplot_kwargs):
params = {k for k in inspect.signature(tee).parameters if k.startswith("teeplot")}
if not all(k in params for k in teeplot_kwargs):
raise ValueError(
"The only keyword arguments passed into the `teewrap` decorator can be teeplot arguments"
)
if "teeplot_outattrs" in params:
raise ValueError(
"`teeplot_outattrs` cannot be used with `teewrap`. Use `teeplot_outattr_names` instead."
)


def teewrap(
**teeplot_kwargs: object,
):
"""Decorator interface to `teeplot.tee`

Works by returning a decorator that wraps `f` by calling `teeplot.tee` using
`f` and any passed in arguments and keyword arguments. However `teeplot_outattrs`
is not allowed with this function, as it would not make sense to have hardcoded
attributes as a decorator. Instead, see `teeplot_outinclude` in `teeplot.tee`.
`teeplot.teewrap` defaults to including all keyword arguments.
"""
validate_teewrap_kwargs(teeplot_kwargs)

def decorator(f: typing.Callable):
@functools.wraps(f)
def inner(*args, **kwargs):

teeplot_outattr_names = teeplot_kwargs.get("teeplot_outinclude")
if teeplot_outattr_names is None:
return tee(
f,
*args,
**teeplot_kwargs,
teeplot_outattrs=kwargs,
**kwargs,
)

return tee(
f,
*args,
**teeplot_kwargs,
**kwargs,
)

return inner

return decorator
Loading