From f1131a97b628ca0ba131e5694a5db9c301476223 Mon Sep 17 00:00:00 2001 From: Apostol Fet Date: Tue, 3 Dec 2024 22:31:27 +0300 Subject: [PATCH 1/6] Bug with kwargs reproduction in tests --- tests/integrations/taskiq/test_taskiq.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/integrations/taskiq/test_taskiq.py b/tests/integrations/taskiq/test_taskiq.py index 8c70794e..d8e6e67c 100644 --- a/tests/integrations/taskiq/test_taskiq.py +++ b/tests/integrations/taskiq/test_taskiq.py @@ -17,6 +17,14 @@ async def return_int_task(data: FromDishka[int]) -> int: return data +@inject +async def task_with_kwargs( + _: FromDishka[int], + **kwargs: str, +) -> dict[str, str]: + return kwargs + + @asynccontextmanager async def create_broker() -> AsyncIterator[AsyncBroker]: in_memory_broker = InMemoryBroker().with_result_backend( @@ -41,3 +49,15 @@ async def test_return_int_task() -> None: kiq = await task.kiq() result = await kiq.wait_result() assert result.return_value == hash("dishka") + + +@pytest.mark.asyncio +async def test_task_with_kwargs() -> None: + async with create_broker() as broker: + task = broker.task(task_with_kwargs) + kwargs = {"key": "value"} + + kiq = await task.kiq(**kwargs) + result = await kiq.wait_result() + + assert result.return_value == kwargs From 23d803d4f8f55797fed7c984d1ad9d8b5b280d5a Mon Sep 17 00:00:00 2001 From: Apostol Fet Date: Tue, 3 Dec 2024 22:46:09 +0300 Subject: [PATCH 2/6] Add parameters in the right order --- src/dishka/integrations/base.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 21ec49b1..92728bac 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -2,6 +2,7 @@ from inspect import ( Parameter, Signature, + _ParameterKind, isasyncgenfunction, isgeneratorfunction, signature, @@ -119,7 +120,7 @@ def wrap_injection( auto_injected_func: Callable[P, T | Awaitable[T]] if additional_params: - new_params.extend(additional_params) + new_params = _add_params(new_params, additional_params) for param in additional_params: new_annotations[param.name] = param.annotation @@ -222,3 +223,27 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: return func(*args, **kwargs, **solved) return auto_injected_func + + +def _add_params( + params: list[Parameter], + additional_params: Sequence[Parameter], +): + params_kind_dict: dict[_ParameterKind, list[Parameter]] = {} + + for param in params: + params_kind_dict.setdefault(param.kind, []).append(param) + + for param in additional_params: + params_kind_dict.setdefault(param.kind, []).append(param) + + result_params = [] + result_params.extend(params_kind_dict.get(Parameter.POSITIONAL_ONLY, [])) + result_params.extend( + params_kind_dict.get(Parameter.POSITIONAL_OR_KEYWORD, []), + ) + result_params.extend(params_kind_dict.get(Parameter.VAR_POSITIONAL, [])) + result_params.extend(params_kind_dict.get(Parameter.KEYWORD_ONLY, [])) + result_params.extend(params_kind_dict.get(Parameter.VAR_KEYWORD, [])) + + return result_params From 3ff963df8edbffc770870a42bf554ee37debe36b Mon Sep 17 00:00:00 2001 From: Apostol Fet Date: Wed, 4 Dec 2024 23:43:49 +0300 Subject: [PATCH 3/6] Added tests for add params with expected errors --- tests/integrations/base/__init__.py | 0 tests/integrations/base/test_add_params.py | 78 ++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 tests/integrations/base/__init__.py create mode 100644 tests/integrations/base/test_add_params.py diff --git a/tests/integrations/base/__init__.py b/tests/integrations/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integrations/base/test_add_params.py b/tests/integrations/base/test_add_params.py new file mode 100644 index 00000000..87bcfb43 --- /dev/null +++ b/tests/integrations/base/test_add_params.py @@ -0,0 +1,78 @@ +from inspect import Parameter, Signature, signature + +import pytest + +from dishka.integrations.base import _add_params + + +def func( + pos_only, + /, + pos_keyword, + *, + keyword_only, +) -> None: ... + + +def func_expected( + pos_only, + add_pos_only, + /, + pos_keyword, + add_pos_keyword, + *add_var_pos, + keyword_only, + add_keyword_only, + **add_var_keyword, +) -> None: ... + + +def func_with_args_kwargs(*args, **kwargs): ... + + +def test_add_all_params(): + additional_params = [ + Parameter("add_pos_only", Parameter.POSITIONAL_ONLY), + Parameter("add_pos_keyword", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("add_var_pos", Parameter.VAR_POSITIONAL), + Parameter("add_keyword_only", Parameter.KEYWORD_ONLY), + Parameter("add_var_keyword", Parameter.VAR_KEYWORD), + ] + func_signature = signature(func) + func_params = list(func_signature.parameters.values()) + + result_params = _add_params(func_params, additional_params) + new_signature = Signature( + parameters=result_params, + return_annotation=func_signature.return_annotation, + ) + + assert new_signature == signature(func_expected) + + +def test_fail_add_second_args(): + additional_params = [ + Parameter("add_var_pos", Parameter.VAR_POSITIONAL), + ] + + func_signature = signature(func_with_args_kwargs) + func_params = list(func_signature.parameters.values()) + + with pytest.raises( + ValueError, match="more than one var positional parameter", + ): + _add_params(func_params, additional_params) + + +def test_fail_add_second_kwargs(): + additional_params = [ + Parameter("add_var_keyword", Parameter.VAR_KEYWORD), + ] + + func_signature = signature(func_with_args_kwargs) + func_params = list(func_signature.parameters.values()) + + with pytest.raises( + ValueError, match="more than one var keyword parameter", + ): + _add_params(func_params, additional_params) From e4e0c1557d86559398743cca30aab33920133d27 Mon Sep 17 00:00:00 2001 From: Apostol Fet Date: Wed, 4 Dec 2024 23:46:46 +0300 Subject: [PATCH 4/6] Added handling of cases with double *args and **kwargs addition --- src/dishka/integrations/base.py | 36 +++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 92728bac..751a40d0 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -226,7 +226,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _add_params( - params: list[Parameter], + params: Sequence[Parameter], additional_params: Sequence[Parameter], ): params_kind_dict: dict[_ParameterKind, list[Parameter]] = {} @@ -237,13 +237,33 @@ def _add_params( for param in additional_params: params_kind_dict.setdefault(param.kind, []).append(param) - result_params = [] - result_params.extend(params_kind_dict.get(Parameter.POSITIONAL_ONLY, [])) - result_params.extend( - params_kind_dict.get(Parameter.POSITIONAL_OR_KEYWORD, []), + + var_positional = params_kind_dict.get(Parameter.VAR_POSITIONAL, []) + if len(var_positional) > 1: + param_names = (param.name for param in var_positional) + var_positional_names = ", *".join(param_names) + base_msg = "more than one var positional parameter: *" + msg = base_msg + var_positional_names + raise ValueError(msg) + + var_keyword = params_kind_dict.get(Parameter.VAR_KEYWORD, []) + if len(var_keyword) > 1: + var_keyword_names = ", **".join(param.name for param in var_keyword) + msg = "more than one var keyword parameter: " + var_keyword_names + raise ValueError(msg) + + positional_only = params_kind_dict.get(Parameter.POSITIONAL_ONLY, []) + positional_or_keyword = params_kind_dict.get( + Parameter.POSITIONAL_OR_KEYWORD, + [], ) - result_params.extend(params_kind_dict.get(Parameter.VAR_POSITIONAL, [])) - result_params.extend(params_kind_dict.get(Parameter.KEYWORD_ONLY, [])) - result_params.extend(params_kind_dict.get(Parameter.VAR_KEYWORD, [])) + keyword_only = params_kind_dict.get(Parameter.KEYWORD_ONLY, []) + + result_params = [] + result_params.extend(positional_only) + result_params.extend(positional_or_keyword) + result_params.extend(var_positional) + result_params.extend(keyword_only) + result_params.extend(var_keyword) return result_params From 7b05799eb585d8f04c1e249541d0446ea68ab3e1 Mon Sep 17 00:00:00 2001 From: Apostol Fet Date: Sat, 7 Dec 2024 20:22:17 +0300 Subject: [PATCH 5/6] Add integration/base tests in tox --- tox.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tox.ini b/tox.ini index 1d0b69c7..cfaa0bbd 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,7 @@ requires = env_list = unit, real_world_example, + integrations-base, fastapi-{0096,0109}, aiohttp-393, flask-302, @@ -25,6 +26,7 @@ use_develop = true deps = pytest pytest-cov + integrations-base: -r requirements/test.txt aiohttp-393: -r requirements/aiohttp-393.txt aiohttp-latest: -r requirements/aiohttp-latest.txt fastapi-latest: -r requirements/fastapi-latest.txt @@ -57,6 +59,7 @@ deps = click-latest: -r requirements/click-latest.txt commands = + integrations-base: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/base aiohttp: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/aiohttp fastapi: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/fastapi aiogram: pytest --cov=dishka --cov-append --cov-report=term-missing -v tests/integrations/aiogram From 7b600b7bdb0edd34545208505679c4ee8d131d39 Mon Sep 17 00:00:00 2001 From: Apostol Fet Date: Sat, 7 Dec 2024 20:29:56 +0300 Subject: [PATCH 6/6] Change the error message according to kind.description --- src/dishka/integrations/base.py | 4 ++-- tests/integrations/base/test_add_params.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 751a40d0..44304776 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -242,14 +242,14 @@ def _add_params( if len(var_positional) > 1: param_names = (param.name for param in var_positional) var_positional_names = ", *".join(param_names) - base_msg = "more than one var positional parameter: *" + base_msg = "more than one variadic positional parameter: *" msg = base_msg + var_positional_names raise ValueError(msg) var_keyword = params_kind_dict.get(Parameter.VAR_KEYWORD, []) if len(var_keyword) > 1: var_keyword_names = ", **".join(param.name for param in var_keyword) - msg = "more than one var keyword parameter: " + var_keyword_names + msg = "more than one variadic keyword parameter: " + var_keyword_names raise ValueError(msg) positional_only = params_kind_dict.get(Parameter.POSITIONAL_ONLY, []) diff --git a/tests/integrations/base/test_add_params.py b/tests/integrations/base/test_add_params.py index 87bcfb43..d07ee3e0 100644 --- a/tests/integrations/base/test_add_params.py +++ b/tests/integrations/base/test_add_params.py @@ -59,7 +59,7 @@ def test_fail_add_second_args(): func_params = list(func_signature.parameters.values()) with pytest.raises( - ValueError, match="more than one var positional parameter", + ValueError, match="more than one variadic positional parameter", ): _add_params(func_params, additional_params) @@ -73,6 +73,6 @@ def test_fail_add_second_kwargs(): func_params = list(func_signature.parameters.values()) with pytest.raises( - ValueError, match="more than one var keyword parameter", + ValueError, match="more than one variadic keyword parameter", ): _add_params(func_params, additional_params)