Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic schema name #350

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,14 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
):
super().__init__(metadata, bind, options)
self.indentation: str = indentation
# TODO add check if there is a "." in the value if set?
self.dynamic_schema_import_path: str | None = dynamic_schema_import_path
self.dynamic_schema_value: str | None = dynamic_schema_value
self.imports: dict[str, set[str]] = defaultdict(set)
self.module_imports: set[str] = set()

Expand Down Expand Up @@ -197,6 +202,8 @@ def collect_imports(self, models: Iterable[Model]) -> None:

for model in models:
self.collect_imports_for_model(model)
if self.dynamic_schema_import_path:
self.add_literal_import(*self.dynamic_schema_import_path.rsplit(".", 1))

def collect_imports_for_model(self, model: Model) -> None:
if model.__class__ is Model:
Expand Down Expand Up @@ -374,7 +381,9 @@ def render_table(self, table: Table) -> str:
if len(index.columns) > 1 or not uses_default_name(index):
args.append(self.render_index(index))

if table.schema:
if self.dynamic_schema_value:
kwargs["schema"] = self.dynamic_schema_value
elif table.schema:
kwargs["schema"] = repr(table.schema)

table_comment = getattr(table, "comment", None)
Expand Down Expand Up @@ -722,9 +731,18 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
base_class_name: str = "Base",
):
super().__init__(metadata, bind, options, indentation=indentation)
super().__init__(
metadata,
bind,
options,
indentation=indentation,
dynamic_schema_import_path=dynamic_schema_import_path,
dynamic_schema_value=dynamic_schema_value,
)
self.base_class_name: str = base_class_name
self.inflect_engine = inflect.engine()

Expand Down Expand Up @@ -1159,14 +1177,23 @@ def render_table_args(self, table: Table) -> str:
if len(index.columns) > 1 or not uses_default_name(index):
args.append(self.render_index(index))

if table.schema:
if self.dynamic_schema_value:
kwargs["schema"] = self.dynamic_schema_value
elif table.schema:
kwargs["schema"] = table.schema

if table.comment:
kwargs["comment"] = table.comment

if kwargs:
formatted_kwargs = pformat(kwargs)
# NB: using pformat on the dict turns schema value (python code) to a string
formatted_kwargs = f",\n{self.indentation}".join(
f"'{k}': {pformat(v)}"
if v != self.dynamic_schema_value
else f"'{k}': {v}"
for k, v in kwargs.items()
)
formatted_kwargs = f"{{{formatted_kwargs}}}"
if not args:
return formatted_kwargs
else:
Expand Down Expand Up @@ -1309,6 +1336,8 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
base_class_name: str = "Base",
quote_annotations: bool = False,
metadata_key: str = "sa",
Expand All @@ -1318,6 +1347,8 @@ def __init__(
bind,
options,
indentation=indentation,
dynamic_schema_import_path=dynamic_schema_import_path,
dynamic_schema_value=dynamic_schema_value,
base_class_name=base_class_name,
)
self.metadata_key: str = metadata_key
Expand Down Expand Up @@ -1348,13 +1379,17 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
base_class_name: str = "SQLModel",
):
super().__init__(
metadata,
bind,
options,
indentation=indentation,
dynamic_schema_import_path=dynamic_schema_import_path,
dynamic_schema_value=dynamic_schema_value,
base_class_name=base_class_name,
)

Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from textwrap import dedent

import pytest
Expand Down Expand Up @@ -31,3 +32,12 @@ def validate_code(generated_code: str, expected_code: str) -> None:
configure_mappers()
finally:
clear_mappers()


@dataclass
class SchemaObject:
name: str


# NB: not a fixture on purpose
schema_obj = SchemaObject(name="best_schema")
50 changes: 50 additions & 0 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from textwrap import dedent

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import PrimaryKeyConstraint
Expand Down Expand Up @@ -30,6 +32,20 @@ def generator(
return DeclarativeGenerator(metadata, engine, options)


@pytest.fixture
def generator_dynamic_schema(
request: FixtureRequest, metadata: MetaData, engine: Engine
) -> CodeGenerator:
schema_import_path, schema_value = getattr(request, "param", (None, None))
return DeclarativeGenerator(
metadata,
engine,
[],
dynamic_schema_import_path=schema_import_path,
dynamic_schema_value=schema_value,
)


def test_indexes(generator: CodeGenerator) -> None:
simple_items = Table(
"simple_items",
Expand Down Expand Up @@ -1509,3 +1525,37 @@ class Simple(Base):
server_default=text("'test'"))
""",
)


@pytest.mark.parametrize(
"generator_dynamic_schema",
[[".conftest.schema_obj", "schema_obj.name"]],
indirect=True,
)
def test_use_dynamic_schema(generator_dynamic_schema: CodeGenerator) -> None:
Table(
"simple_items",
generator_dynamic_schema.metadata,
Column("id", INTEGER, primary_key=True),
)

expected_code = """\
from .conftest import schema_obj
from sqlalchemy import Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class SimpleItems(Base):
__tablename__ = 'simple_items'
__table_args__ = {'schema': schema_obj.name}

id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""
generated_code = generator_dynamic_schema.generate()
expected_code = dedent(expected_code)
assert generated_code == expected_code
# TODO: code execution fails with KeyError: "'__name__' not in globals", any idea?
# validate_code(generated_code, expected_code)
Loading