diff --git a/graphene_federation/__init__.py b/graphene_federation/__init__.py index 4d6bcf6..73786d0 100644 --- a/graphene_federation/__init__.py +++ b/graphene_federation/__init__.py @@ -7,3 +7,4 @@ from .inaccessible import inaccessible from .provides import provides from .override import override +from .compose_directive import mark_composable, is_composable diff --git a/graphene_federation/compose_directive.py b/graphene_federation/compose_directive.py new file mode 100644 index 0000000..c6aa505 --- /dev/null +++ b/graphene_federation/compose_directive.py @@ -0,0 +1,62 @@ +from typing import Optional + +from graphql import GraphQLDirective + + +def is_composable(directive: GraphQLDirective) -> bool: + """ + Checks if the directive will be composed to supergraph. + Validates the presence of _compose_import_url attribute + """ + return hasattr(directive, "_compose_import_url") + + +def mark_composable( + directive: GraphQLDirective, import_url: str, import_as: Optional[str] = None +) -> GraphQLDirective: + """ + Marks directive with _compose_import_url and _compose_import_as + Enables Identification of directives which are to be composed to supergraph + """ + setattr(directive, "_compose_import_url", import_url) + if import_as: + setattr(directive, "_compose_import_as", import_as) + return directive + + +def compose_directive_schema_extensions(directives: list[GraphQLDirective]): + """ + Generates schema extends string for ComposeDirective + """ + link_schema = "" + compose_directive_schema = "" + # Using dictionary to generate cleaner schema when multiple directives imports from same URL. + links: dict = {} + + for directive in directives: + # TODO: Replace with walrus operator when dropping Python 3.8 support + if hasattr(directive, "_compose_import_url"): + compose_import_url = getattr(directive, "_compose_import_url") + if hasattr(directive, "_compose_import_as"): + compose_import_as = getattr(directive, "_compose_import_as") + import_value = f'{{ name: "@{directive.name}, as: "@{compose_import_as}" }}' + imported_name = compose_import_as + else: + import_value = f'"@{directive.name}"' + imported_name = directive.name + + import_url = compose_import_url + + if links.get(import_url): + links[import_url] = links[import_url].append(import_value) + else: + links[import_url] = [import_value] + + compose_directive_schema += ( + f' @composeDirective(name: "@{imported_name}")\n' + ) + + for import_url in links: + link_schema += f' @link(url: "{import_url}", import: [{",".join(value for value in links[import_url])}])\n' + + return link_schema + compose_directive_schema diff --git a/graphene_federation/service.py b/graphene_federation/service.py index f8dd488..27655c9 100644 --- a/graphene_federation/service.py +++ b/graphene_federation/service.py @@ -5,6 +5,7 @@ from graphene.types.union import UnionOptions from graphql import GraphQLInterfaceType, GraphQLObjectType +from .compose_directive import is_composable, compose_directive_schema_extensions from .external import get_external_fields from .inaccessible import get_inaccessible_types, get_inaccessible_fields from .override import get_override_fields @@ -120,7 +121,7 @@ def get_sdl(schema: Schema) -> str: external_fields = get_external_fields(schema) override_fields = get_override_fields(schema) - _schema = "" + schema_extensions = [] if schema.federation_version >= 2: shareable_types = get_shareable_types(schema) @@ -129,28 +130,41 @@ def get_sdl(schema: Schema) -> str: tagged_fields = get_tagged_fields(schema) inaccessible_fields = get_inaccessible_fields(schema) - _schema_import = [] + federation_spec_import = [] if extended_types: - _schema_import.append('"@extends"') + federation_spec_import.append('"@extends"') if external_fields: - _schema_import.append('"@external"') + federation_spec_import.append('"@external"') if entities: - _schema_import.append('"@key"') + federation_spec_import.append('"@key"') if override_fields: - _schema_import.append('"@override"') + federation_spec_import.append('"@override"') if provides_parent_types or provides_fields: - _schema_import.append('"@provides"') + federation_spec_import.append('"@provides"') if required_fields: - _schema_import.append('"@requires"') + federation_spec_import.append('"@requires"') if inaccessible_types or inaccessible_fields: - _schema_import.append('"@inaccessible"') + federation_spec_import.append('"@inaccessible"') if shareable_types or shareable_fields: - _schema_import.append('"@shareable"') + federation_spec_import.append('"@shareable"') if tagged_fields: - _schema_import.append('"@tag"') - schema_import = ", ".join(_schema_import) - _schema = f'extend schema @link(url: "https://specs.apollo.dev/federation/v{schema.federation_version}", import: [{schema_import}])\n' + federation_spec_import.append('"@tag"') + + if schema.federation_version >= 2.1: + preserved_directives = [ + directive for directive in schema.directives if is_composable(directive) + ] + if preserved_directives: + federation_spec_import.append('"@composeDirective"') + schema_extensions.append( + compose_directive_schema_extensions(preserved_directives) + ) + + schema_import = ", ".join(federation_spec_import) + schema_extensions = [ + f'@link(url: "https://specs.apollo.dev/federation/v{schema.federation_version}", import: [{schema_import}])' + ] + schema_extensions # Add fields directives (@external, @provides, @requires, @shareable, @inaccessible) entities_ = ( @@ -229,7 +243,11 @@ def get_sdl(schema: Schema) -> str: pattern = re.compile(type_def_re) string_schema = pattern.sub(repl_str, string_schema) - return re.sub(r"[ ]+", " ", re.sub(r"\n+", "\n", _schema + string_schema)) # noqa + if schema_extensions: + string_schema = ( + "extend schema\n " + "\n ".join(schema_extensions) + "\n" + string_schema + ) + return re.sub(r"[ ]+", " ", re.sub(r"\n+", "\n", string_schema)) # noqa def get_service_query(schema: Schema):