From 3f73a6a1a58c97ce8f1789b49560ca57fefa7774 Mon Sep 17 00:00:00 2001 From: Denis Artyushin Date: Wed, 31 Jul 2024 13:35:49 +0300 Subject: [PATCH] Fix list of unions for graphql_fields (#41) --- graphql_query/base_model.py | 15 ++++++++++++++- tests/tests_base_model/test_inline_fragment.py | 18 +++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/graphql_query/base_model.py b/graphql_query/base_model.py index cc2d66f..dab1202 100644 --- a/graphql_query/base_model.py +++ b/graphql_query/base_model.py @@ -38,7 +38,20 @@ def _get_fields(model: Type['GraphQLQueryBaseModel']) -> List[Union[str, Field, list_args = get_args(f.annotation)[0] _field_template.name = f_name - _field_template.fields = _get_fields(list_args) + + if get_origin(list_args) is Union: + union_args = [union_arg for union_arg in get_args(list_args) if union_arg is not type(None)] + + if len(union_args) == 1: + _field_template.fields = _get_fields(union_args[0]) + + else: + _field_template.fields = [ + InlineFragment(type=union_arg.__name__, fields=_get_fields(union_arg)) + for union_arg in union_args + ] + else: + _field_template.fields = _get_fields(list_args) # # union type diff --git a/tests/tests_base_model/test_inline_fragment.py b/tests/tests_base_model/test_inline_fragment.py index 0e9e8bb..7bd62d5 100644 --- a/tests/tests_base_model/test_inline_fragment.py +++ b/tests/tests_base_model/test_inline_fragment.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List from graphql_query import Field, GraphQLQueryBaseModel, InlineFragment @@ -13,6 +13,7 @@ class Human(GraphQLQueryBaseModel): class Hero(GraphQLQueryBaseModel): name: str type: Union[Human, Droid] + types: List[Union[Human, Droid]] correct = [ Field(name="name", fields=[]), @@ -23,6 +24,13 @@ class Hero(GraphQLQueryBaseModel): InlineFragment(type="Droid", fields=[Field(name="primaryFunction", fields=[])]), ], ), + Field( + name="types", + fields=[ + InlineFragment(type="Human", fields=[Field(name="height", fields=[])]), + InlineFragment(type="Droid", fields=[Field(name="primaryFunction", fields=[])]), + ], + ), ] generated = Hero.graphql_fields() @@ -39,5 +47,13 @@ class Hero(GraphQLQueryBaseModel): primaryFunction } } + types { + ... on Human { + height + } + ... on Droid { + primaryFunction + } + } }""" )