diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d95c498507..2c3f71f06a 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -61,6 +61,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) nullable = kwargs.pop("nullable", Undefined) foreign_key = kwargs.pop("foreign_key", Undefined) + foreign_key_kwargs = kwargs.pop("foreign_key_kwargs", Undefined) unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_column = kwargs.pop("sa_column", Undefined) @@ -81,6 +82,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.primary_key = primary_key self.nullable = nullable self.foreign_key = foreign_key + self.foreign_key_kwargs = foreign_key_kwargs self.unique = unique self.index = index self.sa_column = sa_column @@ -143,6 +145,7 @@ def Field( regex: Optional[str] = None, primary_key: bool = False, foreign_key: Optional[Any] = None, + foreign_key_kwargs: Optional[Mapping[str, Any]] = None, unique: bool = False, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, @@ -174,6 +177,7 @@ def Field( regex=regex, primary_key=primary_key, foreign_key=foreign_key, + foreign_key_kwargs=foreign_key_kwargs, unique=unique, nullable=nullable, index=index, @@ -432,9 +436,10 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore nullable = field_nullable args = [] foreign_key = getattr(field.field_info, "foreign_key", None) + foreign_key_kwargs = getattr(field.field_info, "foreign_key_kwargs", None) unique = getattr(field.field_info, "unique", False) if foreign_key: - args.append(ForeignKey(foreign_key)) + args.append(ForeignKey(foreign_key, **(foreign_key_kwargs or dict()))) kwargs = { "primary_key": primary_key, "nullable": nullable,