diff --git a/teeplot/__init__.py b/teeplot/__init__.py index 53fa7a2..bba9a13 100644 --- a/teeplot/__init__.py +++ b/teeplot/__init__.py @@ -1,5 +1,5 @@ """Top-level package for teeplot.""" __author__ = """Matthew Andres Moreno""" -__email__ = 'm.more500@gmail.com' -__version__ = '1.2.0' +__email__ = "m.more500@gmail.com" +__version__ = "1.2.0" diff --git a/teeplot/teeplot.py b/teeplot/teeplot.py index a43ec01..82e30f2 100644 --- a/teeplot/teeplot.py +++ b/teeplot/teeplot.py @@ -1,6 +1,7 @@ from collections import abc, Counter from contextlib import contextmanager import copy +import functools import os import pathlib import typing @@ -359,7 +360,7 @@ def save_callback(): @contextmanager -def teed(*args: list, **kwargs: dict): +def teed(*args, **kwargs): """Context manager interface to `teeplot.tee`. Plot save is dispatched upon exiting the context. Return value is the @@ -377,3 +378,34 @@ def teed(*args: list, **kwargs: dict): yield handle finally: saveit() + + +def teewrap( + **teeplot_kwargs: object, +): + """Decorator interface to `teeplot.tee` + + Works by returning a decorator that wraps `f` by calling `teeplot.tee` using + `f` and any passed in arguments and keyword arguments. However, using + `teeplot_outattrs` like in `teeplot.tee` will cause printed attributes to be + the same across function calls. For printing attributes on a per-call basis, + see `teeplot_outinclude` in `teeplot.tee`. + """ + if not all(k.startswith("teeplot_") for k in teeplot_kwargs): + raise ValueError( + "The `teewrap` decorator only accepts teeplot_* keyword arguments" + ) + + def decorator(f: typing.Callable): + @functools.wraps(f) + def inner(*args, **kwargs): + return tee( + f, + *args, + **teeplot_kwargs, + **kwargs, + ) + + return inner + + return decorator diff --git a/tests/test_teewrap.py b/tests/test_teewrap.py new file mode 100644 index 0000000..35094bb --- /dev/null +++ b/tests/test_teewrap.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +''' +`tee` tests for `teeplot` package. +''' + +import functools +import os + +import numpy as np +import pytest +import seaborn as sns + +from teeplot import teeplot as tp + + +@tp.teewrap( + teeplot_outattrs={ + 'additional' : 'teedmetadata', + 'for' : 'output-filename', + '_one-for' : 'exclusion', + }, +) +@functools.wraps(sns.lineplot) +def teed_snslineplot_outattrs(*args, **kwargs): + return sns.lineplot(*args, **kwargs) + +def test(): + + teed_snslineplot_outattrs( + x='timepoint', + y='signal', + hue='region', + style='event', + data=sns.load_dataset('fmri'), + ) + + for ext in '.pdf', '.png': + assert os.path.exists( + os.path.join('teeplots', f'additional=teedmetadata+for=output-filename+hue=region+style=event+viz=lineplot+x=timepoint+y=signal+ext={ext}'), + ) + + +@pytest.mark.parametrize("format", [".png", ".pdf", ".ps", ".eps", ".svg"]) +def test_outformat(format): + + # adapted from https://seaborn.pydata.org/generated/seaborn.lineplot.html + np.random.seed(1) + x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1) + + @tp.teewrap( + teeplot_outattrs={ + 'outformat' : 'teedmetadata', + }, + teeplot_subdir='mydirectory', + teeplot_save={format}, + ) + @functools.wraps(sns.lineplot) + def teed_lineplot_outformat(*args, **kwargs): + return sns.lineplot(*args, **kwargs) + + teed_lineplot_outformat( + x=x, + y=y, + sort=False, + lw=1, + ) + + assert os.path.exists( + os.path.join('teeplots', 'mydirectory', f'outformat=teedmetadata+viz=lineplot+ext={format}'), + ) + + +@tp.teewrap(teeplot_outinclude=['a', 'b']) +@functools.wraps(sns.lineplot) +def teed_snslineplot_extra_args(*args, a, b, **kwargs): + return sns.lineplot(*args, **kwargs) + + +@pytest.mark.parametrize('a', [False, 1, 1]) +@pytest.mark.parametrize('b', ['asdf', '']) +def test_included_outattrs(a, b): + + teed_snslineplot_extra_args( + a=a, + b=b, + x='timepoint', + y='signal', + hue='region', + data=sns.load_dataset('fmri'), + ) + + for ext in '.pdf', '.png': + assert os.path.exists( + os.path.join('teeplots', f'a={a}+b={b}+hue=region+viz=lineplot+x=timepoint+y=signal+ext={ext}'), + )