diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ce285d..9dfeeaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ## Unreleased ### Changed * Python >=3.9 required +* `pytest.param` supported ## [1.5](https://pypi.org/project/pytest-parametrized/1.5/) - 2023-11-03 ### Changed diff --git a/README.md b/README.md index b3b6709..41499e7 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,8 @@ def test_foo(x=[0, 1], y=[2, 3]): pass ``` +`pytest.param` is supported for single values or `.product`. + ## fixtures [Parametrized fixtures](https://docs.pytest.org/en/latest/how-to/fixtures.html#fixture-parametrize) which simply return their param. diff --git a/parametrized.py b/parametrized.py index 989854f..af29e97 100644 --- a/parametrized.py +++ b/parametrized.py @@ -11,11 +11,13 @@ def parametrized(func, combine=None, **kwargs): argspec = inspect.getfullargspec(func) params = dict(zip(reversed(argspec.args), reversed(argspec.defaults))) func.__defaults__ = () # pytest ignores params with defaults - if combine is None: - (args,) = params.items() # multiple keywords require combine function, e.g., zip - else: - args = ','.join(params), combine(*params.values()) - return pytest.mark.parametrize(*args, **kwargs)(func) + if combine is None and len(params) > 1: + raise ValueError("multiple keywords require combine function, e.g., zip") + if combine not in (None, itertools.product): + params = {','.join(params): combine(*params.values())} + for param in params.items(): + func = pytest.mark.parametrize(*param, **kwargs)(func) + return func def fixture(*params, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index db6f3f0..c4bae37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ Issues = "https://github.com/coady/pytest-parametrized/issues" line-length = 100 [tool.ruff.format] -quote-style = "single" +quote-style = "preserve" [tool.coverage.run] source = ["parametrized"] diff --git a/tests/test_all.py b/tests/test_all.py index 4753127..c57ef0a 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -35,3 +35,8 @@ def test_error(name='abc', value=range(3)): @parametrized def test(name=(), value=()): pass + + +@parametrized.product +def test_param(key=[0], value=[0, pytest.param(1, marks=pytest.mark.xfail())]): + assert key == value