From f6051bf0b909974b26248b44501aa4c1d5b1419b Mon Sep 17 00:00:00 2001 From: z3z1ma Date: Sat, 4 Jan 2025 20:26:03 -0700 Subject: [PATCH] feat: add working dbt proxy experiment --- pyproject.toml | 1 + src/dbt_osmosis/sql/proxy.py | 179 +++++++++++++++++++++++++++++++++++ uv.lock | 33 ++++++- 3 files changed, 208 insertions(+), 5 deletions(-) create mode 100644 src/dbt_osmosis/sql/proxy.py diff --git a/pyproject.toml b/pyproject.toml index 72861c1..aed1edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "rich>=10", "pluggy>=1.5.0,<2", "typing-extensions~=4.12.2 ; python_version < '3.10'", + "mysql-mimic>=2.5.7", ] [project.optional-dependencies] diff --git a/src/dbt_osmosis/sql/proxy.py b/src/dbt_osmosis/sql/proxy.py new file mode 100644 index 0000000..2937732 --- /dev/null +++ b/src/dbt_osmosis/sql/proxy.py @@ -0,0 +1,179 @@ +"""Proxy server experiment that any MySQL client (including BI tools) can connect to.""" + +# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false +import asyncio +import functools +import re +import typing as t +from collections import defaultdict +from collections.abc import Iterator + +from dbt.adapters.contracts.connection import AdapterResponse +from mysql_mimic import MysqlServer, Session +from mysql_mimic.errors import MysqlError +from mysql_mimic.schema import ( + Column, + InfoSchema, + dict_depth, # pyright: ignore[reportUnknownVariableType, reportPrivateLocalImportUsage] + info_schema_tables, +) +from sqlglot import exp + +import dbt_osmosis.core.logger as logger +from dbt_osmosis.core.osmosis import ( + DbtConfiguration, + DbtProjectContext, + _has_jinja, # pyright: ignore[reportPrivateUsage] + compile_sql_code, + create_dbt_project_context, + execute_sql_code, +) + +# TODO: this doesn't capture comment body consistently +ALTER_MODIFY_COL_PATTERN = re.compile( + r""" + ^\s* # start, allow leading whitespace + ALTER\s+TABLE # "ALTER TABLE" (case-insensitive via flags=) + \s+ + (?: # optional schema part: + (?:"(?P[^"]+)" # double-quoted schema + |(?P\w+) # or unquoted schema + ) + \s*\. + )? + (?:"(?P[^"]+)" # table in double-quotes + |(?P\w+) # or unquoted table + ) + \s+ + MODIFY\s+COLUMN\s+ # "MODIFY COLUMN" (case-insensitive via flags=) + (?:"(?P[^"]+)" # column in double-quotes + |(?P\w+) # or unquoted column + ) + .*? # lazily consume anything until we might see COMMENT + (?: # optional comment group + COMMENT\s+ # must have "COMMENT" then space(s) + (["']) # capture the quote symbol in group 1 + (?P + (?:.|[^"'])* # any escaped char or anything that isn't ' or " + ) + \1 # match the same quote symbol (group 1) + )? + \s*;?\s*$ # optional whitespace, optional semicolon + """, + flags=re.IGNORECASE | re.DOTALL | re.VERBOSE, +) + + +def parse_alter_modify_column(sql: str) -> dict[str, str] | None: + """ + Attempt to parse a statement like: + ALTER TABLE schema.table MODIFY COLUMN col TYPE ... COMMENT 'some text'; + + Returns None if the pattern does not match, otherwise a dict with: + { + "schema": ... or None, + "table": ..., + "column": ..., + "comment": ... or None + } + """ + match = ALTER_MODIFY_COL_PATTERN.match(sql) + if not match: + return None + + # Because we have both quoted and unquoted named groups, pick whichever matched: + schema = match.group("schema") or match.group("schema_unquoted") + table = match.group("table") or match.group("table_unquoted") + column = match.group("column") or match.group("column_unquoted") + comment = match.group("comment") # can be None if COMMENT was not present + + return {"schema": schema, "table": table, "column": column, "comment": comment} + + +class QueryException(MysqlError): + def __init__(self, response: AdapterResponse) -> None: + self.response: AdapterResponse = response + super().__init__(response._message) # pyright: ignore[reportPrivateUsage] + + +class DbtSession(Session): + def __init__(self, project: DbtProjectContext, *args: t.Any, **kwargs: t.Any) -> None: + self.project: DbtProjectContext = project + super().__init__(*args, **kwargs) + + def _parse(self, sql: str) -> list[exp.Expression]: + if _has_jinja(sql): + node = compile_sql_code(self.project, sql) + sql = node.compiled_code or node.raw_code + return [e for e in self.dialect().parse(sql) if e] + + async def query(self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any]): + logger.info("Query: %s", sql) + if isinstance(expression, exp.Command): + cmd = f"{expression.this} {expression.expression}" + doc_update = "alter" in sql.lower() and parse_alter_modify_column(cmd) + if doc_update: + logger.info("Will update doc: %s", doc_update) + else: + logger.info("Ignoring command %s", sql) + return (), [] # pyright: ignore[reportUnknownVariableType] + resp, table = await asyncio.to_thread( + execute_sql_code, self.project, expression.sql(dialect=self.project.adapter.type()) + ) + if resp.code: + raise QueryException(resp) + logger.info(table) + return [ + t.cast(tuple[t.Any], row.values()) for row in t.cast(tuple[t.Any], table.rows.values()) + ], t.cast(tuple[str], table.column_names) + + async def schema(self): + schema: defaultdict[str, dict[str, dict[str, tuple[str, str | None]]]] = defaultdict(dict) + for source in self.project.manifest.sources.values(): + schema[source.schema][source.name] = { + c.name: (c.data_type or "UNKOWN", c.description) for c in source.columns.values() + } + for node in self.project.manifest.nodes.values(): + schema[node.schema][node.name] = { + c.name: (c.data_type or "UNKOWN", c.description) for c in node.columns.values() + } + iter_columns = mapping_to_columns(schema) + return InfoSchema(info_schema_tables(iter_columns)) + + +def mapping_to_columns(schema: dict[str, t.Any]) -> Iterator[Column]: + """Convert a schema mapping into a list of Column instances""" + depth = dict_depth(schema) + if depth < 2: + return + if depth == 2: + # {table: {col: type}} + schema = {"": schema} + depth += 1 + if depth == 3: + # {db: {table: {col: type}}} + schema = {"def": schema} # def is the default MySQL catalog + depth += 1 + if depth != 4: + raise MysqlError("Invalid schema mapping") + + for catalog, dbs in schema.items(): + for db, tables in dbs.items(): + for table, cols in tables.items(): + for column, (coltype, comment) in cols.items(): + yield Column( + name=column, + type=coltype, + table=table, + schema=db, + catalog=catalog, + comment=comment, + ) + + +if __name__ == "__main__": + c = DbtConfiguration() + server = MysqlServer( + session_factory=functools.partial(DbtSession, create_dbt_project_context(c)) + ) + asyncio.run(server.serve_forever()) diff --git a/uv.lock b/uv.lock index 6932b6a..133a827 100644 --- a/uv.lock +++ b/uv.lock @@ -17,7 +17,7 @@ dependencies = [ { name = "parsedatetime" }, { name = "python-slugify" }, { name = "pytimeparse" }, - { name = "tzdata", marker = "platform_system == 'Windows'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/29/77/6f5df1c68bf056f5fdefc60ccc616303c6211e71cd6033c830c12735f605/agate-1.9.1.tar.gz", hash = "sha256:bc60880c2ee59636a2a80cd8603d63f995be64526abf3cbba12f00767bcd5b3d", size = 202303 } wheels = [ @@ -184,7 +184,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -396,6 +396,7 @@ source = { editable = "." } dependencies = [ { name = "click" }, { name = "dbt-core" }, + { name = "mysql-mimic" }, { name = "pluggy" }, { name = "rich" }, { name = "ruamel-yaml" }, @@ -429,6 +430,7 @@ requires-dist = [ { name = "dbt-duckdb", marker = "extra == 'dev'", specifier = ">=1.0.0,<2" }, { name = "dbt-duckdb", marker = "extra == 'workbench'", specifier = ">=1.0.0,<2" }, { name = "feedparser", marker = "extra == 'workbench'", specifier = "~=6.0.11" }, + { name = "mysql-mimic", specifier = ">=2.5.7" }, { name = "openai", marker = "extra == 'openai'", specifier = "~=1.58.1" }, { name = "pluggy", specifier = ">=1.5.0,<2" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">3.0.0,<5" }, @@ -1195,6 +1197,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/98/cff14d53a2f2f67d7fe8a4e235a383ee71aba6a1da12aeea24b325d0c72a/multimethod-1.12-py3-none-any.whl", hash = "sha256:fd0c473c43558908d97cc06e4d68e8f69202f167db46f7b4e4058893e7dbdf60", size = 10646 }, ] +[[package]] +name = "mysql-mimic" +version = "2.5.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sqlglot" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/1d/f216617d6f042698a51e9377ca94ffdb9c7485c12cbee317257f0124c343/mysql-mimic-2.5.7.tar.gz", hash = "sha256:11c1cc387dce6c6ee72759ed048bf249bf3d51af9b20fbfa1e79208807583e0f", size = 52374 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/40/dc111963785614d1fcdcbe8ed1bd25345a0ea12100aeaa731504756ffd56/mysql_mimic-2.5.7-py3-none-any.whl", hash = "sha256:f4b88cd48109428a2cc221f935725a5ff25fc23651c7ebb670ba4c14b35590a1", size = 46159 }, +] + [[package]] name = "narwhals" version = "1.20.1" @@ -2161,6 +2175,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, ] +[[package]] +name = "sqlglot" +version = "26.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/99/39/35cee255a3de5a4bfbe8780d200761423bb1949249ff541ba81420eebbf5/sqlglot-26.0.1.tar.gz", hash = "sha256:588cde7739029fda310fb7dd49afdc0a20b79e760e4cd6d5e1cd083e7e458b90", size = 19785413 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/ac/7cf4f8c133cd2cec68937c87322a5052987f3995f21b87e3d545b4d4aa02/sqlglot-26.0.1-py3-none-any.whl", hash = "sha256:ced4967ce3a4a713d35e2037492fbe1a5187936fdfbd72d7b9ace7815c2d2225", size = 437917 }, +] + [[package]] name = "sqlparse" version = "0.5.3" @@ -2235,7 +2258,7 @@ dependencies = [ { name = "typing-extensions" }, { name = "tzlocal" }, { name = "validators" }, - { name = "watchdog", marker = "platform_system != 'Darwin'" }, + { name = "watchdog", marker = "sys_platform != 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9d/2d/5273692001363f9534422707a8a8382b7b0b250832fdff473ed992680aa9/streamlit-1.29.0.tar.gz", hash = "sha256:b6dfff9c5e132e5518c92150efcd452980db492a45fafeac3d4688d2334efa07", size = 8033351 } wheels = [ @@ -2345,7 +2368,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ @@ -2388,7 +2411,7 @@ name = "tzlocal" version = "5.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "tzdata", marker = "platform_system == 'Windows'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/04/d3/c19d65ae67636fe63953b20c2e4a8ced4497ea232c43ff8d01db16de8dc0/tzlocal-5.2.tar.gz", hash = "sha256:8d399205578f1a9342816409cc1e46a93ebd5755e39ea2d85334bea911bf0e6e", size = 30201 } wheels = [