Skip to content

Commit

Permalink
feat: add working middleware to dbt proxy to sync alter table modify …
Browse files Browse the repository at this point in the history
…column comment stmts to manifest
  • Loading branch information
z3z1ma committed Jan 5, 2025
1 parent f6051bf commit 1dba753
Showing 1 changed file with 47 additions and 65 deletions.
112 changes: 47 additions & 65 deletions src/dbt_osmosis/sql/proxy.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false
"""Proxy server experiment that any MySQL client (including BI tools) can connect to."""

# pyright: reportMissingTypeStubs=false, reportAny=false, reportImplicitOverride=false, reportUnknownMemberType=false, reportUnusedImport=false, reportUnknownParameterType=false
import asyncio
import functools
import re
import typing as t
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain

from dbt.adapters.contracts.connection import AdapterResponse
from mysql_mimic import MysqlServer, Session
from mysql_mimic.errors import MysqlError
from mysql_mimic.results import AllowedResult
from mysql_mimic.schema import (
Column,
InfoSchema,
dict_depth, # pyright: ignore[reportUnknownVariableType, reportPrivateLocalImportUsage]
info_schema_tables,
)
from mysql_mimic.session import Query
from sqlglot import exp

import dbt_osmosis.core.logger as logger
Expand All @@ -29,38 +32,8 @@
execute_sql_code,
)

# TODO: this doesn't capture comment body consistently
ALTER_MODIFY_COL_PATTERN = re.compile(
r"""
^\s* # start, allow leading whitespace
ALTER\s+TABLE # "ALTER TABLE" (case-insensitive via flags=)
\s+
(?: # optional schema part:
(?:"(?P<schema>[^"]+)" # double-quoted schema
|(?P<schema_unquoted>\w+) # or unquoted schema
)
\s*\.
)?
(?:"(?P<table>[^"]+)" # table in double-quotes
|(?P<table_unquoted>\w+) # or unquoted table
)
\s+
MODIFY\s+COLUMN\s+ # "MODIFY COLUMN" (case-insensitive via flags=)
(?:"(?P<column>[^"]+)" # column in double-quotes
|(?P<column_unquoted>\w+) # or unquoted column
)
.*? # lazily consume anything until we might see COMMENT
(?: # optional comment group
COMMENT\s+ # must have "COMMENT" then space(s)
(["']) # capture the quote symbol in group 1
(?P<comment>
(?:.|[^"'])* # any escaped char or anything that isn't ' or "
)
\1 # match the same quote symbol (group 1)
)?
\s*;?\s*$ # optional whitespace, optional semicolon
""",
flags=re.IGNORECASE | re.DOTALL | re.VERBOSE,
r"(?i)(?:/\*.*?\*/\s*)?ALTER TABLE\s+(?:(?P<schema>[^\s\.]+)\.)?(?P<table>[^\s\.]+)\s+MODIFY COLUMN\s+(?P<column>[^\s]+)\s+.*?COMMENT\s+'(?P<comment>[^']*)';?"
)


Expand All @@ -70,70 +43,79 @@ def parse_alter_modify_column(sql: str) -> dict[str, str] | None:
ALTER TABLE schema.table MODIFY COLUMN col TYPE ... COMMENT 'some text';
Returns None if the pattern does not match, otherwise a dict with:
{
"schema": ... or None,
"table": ...,
"column": ...,
"comment": ... or None
}
{"schema": ..., "table": ..., "column": ..., "comment": ...}
"""
match = ALTER_MODIFY_COL_PATTERN.match(sql)
if not match:
return None

# Because we have both quoted and unquoted named groups, pick whichever matched:
schema = match.group("schema") or match.group("schema_unquoted")
table = match.group("table") or match.group("table_unquoted")
column = match.group("column") or match.group("column_unquoted")
comment = match.group("comment") # can be None if COMMENT was not present

schema = match.group("schema")
table = match.group("table")
column = match.group("column")
comment = match.group("comment")
if not all((schema, table, column, comment)):
return None
return {"schema": schema, "table": table, "column": column, "comment": comment}


class QueryException(MysqlError):
def __init__(self, response: AdapterResponse) -> None:
self.response: AdapterResponse = response
super().__init__(response._message) # pyright: ignore[reportPrivateUsage]
self.response: AdapterResponse = response


class DbtSession(Session):
def __init__(self, project: DbtProjectContext, *args: t.Any, **kwargs: t.Any) -> None:
self.project: DbtProjectContext = project
super().__init__(*args, **kwargs)
self.project: DbtProjectContext = project
self.middlewares.append(self._alter_column_middleware)

def _parse(self, sql: str) -> list[exp.Expression]:
if _has_jinja(sql):
node = compile_sql_code(self.project, sql)
sql = node.compiled_code or node.raw_code
return [e for e in self.dialect().parse(sql) if e]

async def query(self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any]):
async def _alter_column_middleware(self, q: Query) -> AllowedResult:
"""Intercept ALTER TABLE ... MODIFY COLUMN ... COMMENT statements
This middleware will update the column description in the dbt project manifest. Eventually
it could use our Yaml context class to actually write the changes to disk.
"""
if isinstance(q.expression, exp.Command):
lower_sql = q.sql.lower()
likely_alter_column = all(
k in lower_sql for k in ("alter", "modify", "column", "comment")
)
if doc_update_req := (likely_alter_column and parse_alter_modify_column(q.sql)):
rel = (doc_update_req["schema"], doc_update_req["table"])
for node in chain(
self.project.manifest.sources.values(), self.project.manifest.nodes.values()
):
if rel == (node.schema, node.name):
for column in node.columns.values():
if column.name == doc_update_req["column"]:
column.description = doc_update_req["comment"]
break
return [], []
return await q.next()

async def query(
self, expression: exp.Expression, sql: str, attrs: dict[str, t.Any]
) -> AllowedResult:
logger.info("Query: %s", sql)
if isinstance(expression, exp.Command):
cmd = f"{expression.this} {expression.expression}"
doc_update = "alter" in sql.lower() and parse_alter_modify_column(cmd)
if doc_update:
logger.info("Will update doc: %s", doc_update)
else:
logger.info("Ignoring command %s", sql)
return (), [] # pyright: ignore[reportUnknownVariableType]
resp, table = await asyncio.to_thread(
execute_sql_code, self.project, expression.sql(dialect=self.project.adapter.type())
)
if resp.code:
raise QueryException(resp)
logger.info(table)
return [
t.cast(tuple[t.Any], row.values()) for row in t.cast(tuple[t.Any], table.rows.values())
], t.cast(tuple[str], table.column_names)
rows = t.cast(tuple[t.Any], table.rows.values())
return [row.values() for row in rows], t.cast(tuple[str], table.column_names)

async def schema(self):
schema: defaultdict[str, dict[str, dict[str, tuple[str, str | None]]]] = defaultdict(dict)
for source in self.project.manifest.sources.values():
schema[source.schema][source.name] = {
c.name: (c.data_type or "UNKOWN", c.description) for c in source.columns.values()
}
for node in self.project.manifest.nodes.values():
for node in chain(
self.project.manifest.sources.values(), self.project.manifest.nodes.values()
):
schema[node.schema][node.name] = {
c.name: (c.data_type or "UNKOWN", c.description) for c in node.columns.values()
}
Expand Down

0 comments on commit 1dba753

Please sign in to comment.