Skip to content

Commit

Permalink
Merge pull request #9 from mmore500/teewrap
Browse files Browse the repository at this point in the history
Adds a decorator for `teeplot` called `teewrap`
  • Loading branch information
vivaansinghvi07 authored Feb 15, 2025
2 parents 0dc830d + 5eec37c commit 4a2963b
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 3 deletions.
4 changes: 2 additions & 2 deletions teeplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Top-level package for teeplot."""

__author__ = """Matthew Andres Moreno"""
__email__ = 'm.more500@gmail.com'
__version__ = '1.2.0'
__email__ = "m.more500@gmail.com"
__version__ = "1.2.0"
34 changes: 33 additions & 1 deletion teeplot/teeplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import abc, Counter
from contextlib import contextmanager
import copy
import functools
import os
import pathlib
import typing
Expand Down Expand Up @@ -359,7 +360,7 @@ def save_callback():


@contextmanager
def teed(*args: list, **kwargs: dict):
def teed(*args, **kwargs):
"""Context manager interface to `teeplot.tee`.
Plot save is dispatched upon exiting the context. Return value is the
Expand All @@ -377,3 +378,34 @@ def teed(*args: list, **kwargs: dict):
yield handle
finally:
saveit()


def teewrap(
**teeplot_kwargs: object,
):
"""Decorator interface to `teeplot.tee`
Works by returning a decorator that wraps `f` by calling `teeplot.tee` using
`f` and any passed in arguments and keyword arguments. However, using
`teeplot_outattrs` like in `teeplot.tee` will cause printed attributes to be
the same across function calls. For printing attributes on a per-call basis,
see `teeplot_outinclude` in `teeplot.tee`.
"""
if not all(k.startswith("teeplot_") for k in teeplot_kwargs):
raise ValueError(
"The `teewrap` decorator only accepts teeplot_* keyword arguments"
)

def decorator(f: typing.Callable):
@functools.wraps(f)
def inner(*args, **kwargs):
return tee(
f,
*args,
**teeplot_kwargs,
**kwargs,
)

return inner

return decorator
96 changes: 96 additions & 0 deletions tests/test_teewrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python

'''
`tee` tests for `teeplot` package.
'''

import functools
import os

import numpy as np
import pytest
import seaborn as sns

from teeplot import teeplot as tp


@tp.teewrap(
teeplot_outattrs={
'additional' : 'teedmetadata',
'for' : 'output-filename',
'_one-for' : 'exclusion',
},
)
@functools.wraps(sns.lineplot)
def teed_snslineplot_outattrs(*args, **kwargs):
return sns.lineplot(*args, **kwargs)

def test():

teed_snslineplot_outattrs(
x='timepoint',
y='signal',
hue='region',
style='event',
data=sns.load_dataset('fmri'),
)

for ext in '.pdf', '.png':
assert os.path.exists(
os.path.join('teeplots', f'additional=teedmetadata+for=output-filename+hue=region+style=event+viz=lineplot+x=timepoint+y=signal+ext={ext}'),
)


@pytest.mark.parametrize("format", [".png", ".pdf", ".ps", ".eps", ".svg"])
def test_outformat(format):

# adapted from https://seaborn.pydata.org/generated/seaborn.lineplot.html
np.random.seed(1)
x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1)

@tp.teewrap(
teeplot_outattrs={
'outformat' : 'teedmetadata',
},
teeplot_subdir='mydirectory',
teeplot_save={format},
)
@functools.wraps(sns.lineplot)
def teed_lineplot_outformat(*args, **kwargs):
return sns.lineplot(*args, **kwargs)

teed_lineplot_outformat(
x=x,
y=y,
sort=False,
lw=1,
)

assert os.path.exists(
os.path.join('teeplots', 'mydirectory', f'outformat=teedmetadata+viz=lineplot+ext={format}'),
)


@tp.teewrap(teeplot_outinclude=['a', 'b'])
@functools.wraps(sns.lineplot)
def teed_snslineplot_extra_args(*args, a, b, **kwargs):
return sns.lineplot(*args, **kwargs)


@pytest.mark.parametrize('a', [False, 1, 1])
@pytest.mark.parametrize('b', ['asdf', ''])
def test_included_outattrs(a, b):

teed_snslineplot_extra_args(
a=a,
b=b,
x='timepoint',
y='signal',
hue='region',
data=sns.load_dataset('fmri'),
)

for ext in '.pdf', '.png':
assert os.path.exists(
os.path.join('teeplots', f'a={a}+b={b}+hue=region+viz=lineplot+x=timepoint+y=signal+ext={ext}'),
)

0 comments on commit 4a2963b

Please sign in to comment.