diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81a8e..8552cd7619 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -645,6 +645,12 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) +class SchemaEnum(sa_Enum): + def __init__(self, *args, **kwargs): + kwargs['inherit_schema'] = True + super().__init__(*args, **kwargs) + + def get_sqlalchemy_type(field: Any) -> Any: if IS_PYDANTIC_V2: field_info = field @@ -659,7 +665,7 @@ def get_sqlalchemy_type(field: Any) -> Any: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): - return sa_Enum(type_) + return SchemaEnum(type_) if issubclass( type_, (