diff --git a/graphene_federation/main.py b/graphene_federation/main.py index 247381d..c723983 100644 --- a/graphene_federation/main.py +++ b/graphene_federation/main.py @@ -23,13 +23,16 @@ def _get_query(schema: Schema, query_cls: Optional[ObjectType] = None) -> Object def build_schema( query: Optional[ObjectType] = None, mutation: Optional[ObjectType] = None, - enable_federation_2=False, + federation_version: Optional[float] = None, + enable_federation_2: bool = False, schema: Optional[Schema] = None, **kwargs ) -> Schema: schema = schema or Schema(query=query, mutation=mutation, **kwargs) schema.auto_camelcase = kwargs.get("auto_camelcase", True) - schema.federation_version = 2 if enable_federation_2 else 1 + schema.federation_version = ( + (federation_version or 2) if (enable_federation_2 or federation_version) else 1 + ) federation_query = _get_query(schema, schema.query if schema else query) kwargs = schema.__dict__ kwargs.pop("query") diff --git a/graphene_federation/service.py b/graphene_federation/service.py index ca3aab7..edbc8cf 100644 --- a/graphene_federation/service.py +++ b/graphene_federation/service.py @@ -122,7 +122,7 @@ def get_sdl(schema: Schema) -> str: _schema = "" - if schema.federation_version == 2: + if schema.federation_version >= 2: shareable_types = get_shareable_types(schema) inaccessible_types = get_inaccessible_types(schema) shareable_fields = get_shareable_fields(schema) @@ -161,7 +161,7 @@ def get_sdl(schema: Schema) -> str: | set(provides_fields.values()) ) - if schema.federation_version == 2: + if schema.federation_version >= 2: entities_ = ( entities_ | set(shareable_types.values()) @@ -187,7 +187,7 @@ def get_sdl(schema: Schema) -> str: # resolvable argument of @key directive is true by default. If false, we add 'resolvable: false' to sdl. if ( - schema.federation_version == 2 + schema.federation_version >= 2 and hasattr(entity, "_resolvable") and not entity._resolvable ): @@ -204,7 +204,7 @@ def get_sdl(schema: Schema) -> str: pattern = re.compile(type_def_re) string_schema = pattern.sub(repl_str, string_schema) - if schema.federation_version == 2: + if schema.federation_version >= 2: for type_name, type in shareable_types.items(): # noinspection PyProtectedMember if isinstance(type._meta, UnionOptions):