From d8509cc2ecc745412c576c3382950ebe0ca21de8 Mon Sep 17 00:00:00 2001 From: Arun Suresh Kumar Date: Sun, 10 Dec 2023 21:12:33 +0530 Subject: [PATCH] Bug Fix: build_schema --- graphene_federation/main.py | 20 ++++++++++---------- graphene_federation/service.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/graphene_federation/main.py b/graphene_federation/main.py index e1a1fa0..0144209 100644 --- a/graphene_federation/main.py +++ b/graphene_federation/main.py @@ -21,18 +21,18 @@ def _get_query(schema: Schema, query_cls: Optional[ObjectType] = None) -> Object def build_schema( - query: Optional[ObjectType] = None, - mutation: Optional[ObjectType] = None, - enable_federation_2=False, - schema: Optional[Schema] = None, - **kwargs + query: Optional[ObjectType] = None, + mutation: Optional[ObjectType] = None, + enable_federation_2=False, + schema: Optional[Schema] = None, + **kwargs ) -> Schema: schema = schema or Schema(query=query, mutation=mutation, **kwargs) schema.auto_camelcase = kwargs.get("auto_camelcase", True) schema.federation_version = 2 if enable_federation_2 else 1 federation_query = _get_query(schema, schema.query if schema else query) - return Schema( - query=federation_query, - mutation=schema.mutation if schema else mutation, - **kwargs - ) + kwargs = schema.__dict__ + kwargs.pop("query") + kwargs.pop("graphql_schema") + kwargs.pop("federation_version") + return type(schema)(query=federation_query, **kwargs) diff --git a/graphene_federation/service.py b/graphene_federation/service.py index 786ea19..568827e 100644 --- a/graphene_federation/service.py +++ b/graphene_federation/service.py @@ -195,7 +195,7 @@ def get_sdl(schema: Schema) -> str: ( " ".join( [ - f'@key(fields: "{get_field_name(key)}"' + f' @key(fields: "{get_field_name(key)}"' for key in entity._keys ] ) @@ -206,7 +206,7 @@ def get_sdl(schema: Schema) -> str: else: type_annotation = ( " ".join( - [f'@key(fields: "{get_field_name(key)}")' for key in entity._keys] + [f' @key(fields: "{get_field_name(key)}")' for key in entity._keys] ) ) + " " repl_str = rf"\1{type_annotation}" @@ -238,7 +238,7 @@ def get_sdl(schema: Schema) -> str: pattern = re.compile(type_def_re) string_schema = pattern.sub(repl_str, string_schema) - return _schema + string_schema + return re.sub(r"[ ]+", " ", re.sub(r"\n+", "\n", _schema + string_schema)) # noqa def get_service_query(schema: Schema):