Skip to content

Commit

Permalink
[DBT] Mimic dbt behavior for config() jinja within sql models (#634)
Browse files Browse the repository at this point in the history
  • Loading branch information
crericha authored Apr 7, 2023
1 parent 33f8cc0 commit 5fecad2
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 54 deletions.
7 changes: 6 additions & 1 deletion sqlmesh/core/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,9 @@ def update_with(self: T, other: t.Union[t.Dict[str, t.Any], T]) -> T:
else:
updated_fields[field] = getattr(other, field)

return self.copy(update=updated_fields)
# Assign each field to trigger assignment validators
updated = self.copy()
for field, value in updated_fields.items():
setattr(updated, field, value)

return updated
70 changes: 55 additions & 15 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from dbt.adapters.base import BaseRelation
from dbt.contracts.relation import RelationType
from jinja2 import nodes
from jinja2.exceptions import UndefinedError
from pydantic import Field, validator
from sqlglot.helper import ensure_list

Expand Down Expand Up @@ -140,6 +142,14 @@ def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:

@property
def all_sql(self) -> SqlStr:
return SqlStr("\n".join(self.pre_hook + [self.sql_no_config] + self.post_hook))

@property
def sql_no_config(self) -> SqlStr:
return SqlStr("")

@property
def sql_embedded_config(self) -> SqlStr:
return SqlStr("")

@property
Expand Down Expand Up @@ -190,13 +200,17 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
}
)

def attribute_dict(self) -> AttributeDict[str, t.Any]:
return AttributeDict(self.dict())

def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
"""Get common sqlmesh model parameters"""
jinja_macros = model_context.jinja_macros.trim(self._dependencies.macros)
jinja_macros.global_objs.update(
{
"this": self.relation_info,
"schema": self.table_schema,
"config": self.attribute_dict(),
**model_context.jinja_globals, # type: ignore
}
)
Expand All @@ -220,7 +234,6 @@ def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:

def render_config(self: BMC, context: DbtContext) -> BMC:
rendered = super().render_config(context)
rendered._dependencies = Dependencies(macros=extract_macro_references(rendered.all_sql))
rendered = ModelSqlRenderer(context, rendered).enriched_config

rendered_dependencies = rendered._dependencies
Expand Down Expand Up @@ -275,7 +288,7 @@ def __init__(self, context: DbtContext, config: BMC):
jinja_globals={
**context.jinja_globals,
**date_dict(c.EPOCH, c.EPOCH, c.EPOCH),
"config": self._config,
"config": lambda *args, **kwargs: "",
"ref": self._ref,
"var": self._var,
"source": self._source,
Expand All @@ -293,9 +306,15 @@ def __init__(self, context: DbtContext, config: BMC):
dialect=context.engine_adapter.dialect if context.engine_adapter else "",
)

self.jinja_env = self.context.jinja_macros.build_environment(**self._jinja_globals)

@property
def enriched_config(self) -> BMC:
if self._rendered_sql is None:
self._enriched_config = self._update_with_sql_config(self._enriched_config)
self._enriched_config._dependencies = Dependencies(
macros=extract_macro_references(self._enriched_config.all_sql)
)
self.render()
self._enriched_config._dependencies = self._enriched_config._dependencies.union(
self._captured_dependencies
Expand All @@ -304,14 +323,42 @@ def enriched_config(self) -> BMC:

def render(self) -> str:
if self._rendered_sql is None:
registry = self.context.jinja_macros
self._rendered_sql = (
registry.build_environment(**self._jinja_globals)
.from_string(self.config.all_sql)
.render()
)
try:
self._rendered_sql = self.jinja_env.from_string(
self._enriched_config.all_sql
).render()
except UndefinedError as e:
raise ConfigError(e.message)
return self._rendered_sql

def _update_with_sql_config(self, config: BMC) -> BMC:
def _extract_value(node: t.Any) -> t.Any:
if not isinstance(node, nodes.Node):
return node
if isinstance(node, nodes.Const):
return _extract_value(node.value)
if isinstance(node, nodes.TemplateData):
return _extract_value(node.data)
if isinstance(node, nodes.List):
return [_extract_value(val) for val in node.items]
if isinstance(node, nodes.Dict):
return {_extract_value(pair.key): _extract_value(pair.value) for pair in node.items}
if isinstance(node, nodes.Tuple):
return tuple(_extract_value(val) for val in node.items)

return self.jinja_env.from_string(nodes.Template([nodes.Output([node])])).render()

for call in self.jinja_env.parse(self._enriched_config.sql_embedded_config).find_all(
nodes.Call
):
if not isinstance(call.node, nodes.Name) or call.node.name != "config":
continue
config = config.update_with(
{kwarg.key: _extract_value(kwarg.value) for kwarg in call.kwargs}
)

return config

def _ref(self, package_name: str, model_name: t.Optional[str] = None) -> BaseRelation:
if package_name in self.context.models:
relation = BaseRelation.create(**self.context.models[package_name].relation_info)
Expand Down Expand Up @@ -341,13 +388,6 @@ def _source(self, source_name: str, table_name: str) -> BaseRelation:
self._captured_dependencies.sources.add(full_name)
return BaseRelation.create(**self.context.sources[full_name].relation_info)

def _config(self, *args: t.Any, **kwargs: t.Any) -> str:
if args and isinstance(args[0], dict):
self._enriched_config = self._enriched_config.update_with(args[0])
if kwargs:
self._enriched_config = self._enriched_config.update_with(kwargs)
return ""

class TrackingAdapter(ParsetimeAdapter):
def __init__(self, outer_self: ModelSqlRenderer, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
Expand Down
5 changes: 0 additions & 5 deletions sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ def no_log(msg: str, info: bool = False) -> str:
return ""


def config(*args: t.Any, **kwargs: t.Any) -> str:
return ""


def generate_var(variables: t.Dict[str, t.Any]) -> t.Callable:
def var(name: str, default: t.Optional[str] = None) -> str:
return variables.get(name, default)
Expand Down Expand Up @@ -252,7 +248,6 @@ def _try_literal_eval(value: str) -> t.Any:

BUILTIN_GLOBALS = {
"api": Api(),
"config": config,
"env_var": env_var,
"exceptions": Exceptions(),
"flags": Flags(),
Expand Down
5 changes: 4 additions & 1 deletion sqlmesh/dbt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class DbtConfig(PydanticModel):
class Config:
extra = "allow"
allow_mutation = True
validate_assignment = True


class GeneralConfig(DbtConfig, BaseConfig):
Expand Down Expand Up @@ -285,7 +286,9 @@ def render_value(val: t.Any) -> t.Any:

rendered = self.copy(deep=True)
for name in rendered.__fields__:
setattr(rendered, name, render_value(getattr(rendered, name)))
value = getattr(rendered, name)
if value is not None:
setattr(rendered, name, render_value(value))

return rendered

Expand Down
57 changes: 41 additions & 16 deletions sqlmesh/dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class ModelConfig(BaseModelConfig):
# redshift
bind: t.Optional[bool] = None

# Private fields
_sql_embedded_config: t.Optional[SqlStr] = None
_sql_no_config: t.Optional[SqlStr] = None

@validator(
"unique_key",
"cluster_by",
Expand Down Expand Up @@ -157,26 +161,47 @@ def model_kind(self, target: TargetConfig) -> ModelKind:
raise ConfigError(f"{materialization.value} materialization not supported.")

@property
def sql_no_config(self) -> str:
matches = re.findall(r"{{\s*config\(", self.sql)
if matches:
config_macro_start = self.sql.index(matches[0])
cursor = config_macro_start
def sql_no_config(self) -> SqlStr:
if self._sql_no_config is None:
self._sql_no_config = SqlStr("")
self._extract_sql_config()
return self._sql_no_config

@property
def sql_embedded_config(self) -> SqlStr:
if self._sql_embedded_config is None:
self._sql_embedded_config = SqlStr("")
self._extract_sql_config()
return self._sql_embedded_config

def _extract_sql_config(self) -> None:
def jinja_end(sql: str, start: int) -> int:
cursor = start
quote = None
while cursor < len(self.sql):
if self.sql[cursor] in ('"', "'"):
while cursor < len(sql):
if sql[cursor] in ('"', "'"):
if quote is None:
quote = self.sql[cursor]
elif quote == self.sql[cursor]:
quote = sql[cursor]
elif quote == sql[cursor]:
quote = None
if self.sql[cursor : cursor + 2] == "}}" and quote is None:
return "".join([self.sql[:config_macro_start], self.sql[cursor + 2 :]])
if sql[cursor : cursor + 2] == "}}" and quote is None:
return cursor + 2
cursor += 1
return self.sql

@property
def all_sql(self) -> SqlStr:
return SqlStr(";\n".join(self.pre_hook + [self.sql] + self.post_hook))
return cursor

self._sql_no_config = self.sql
matches = re.findall(r"{{\s*config\s*\(", self._sql_no_config)
for match in matches:
start = self._sql_no_config.find(match)
if start == -1:
continue
extracted = self._sql_no_config[start : jinja_end(self._sql_no_config, start)]
self._sql_embedded_config = SqlStr(
"\n".join([self._sql_embedded_config, extracted])
if self._sql_embedded_config
else extracted
)
self._sql_no_config = SqlStr(self._sql_no_config.replace(extracted, "").strip())

def to_sqlmesh(self, context: DbtContext) -> Model:
"""Converts the dbt model into a SQLMesh model."""
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/dbt/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _validate_quoting(cls, v: t.Dict[str, t.Any]) -> t.Dict[str, bool]:

@validator("columns", pre=True)
def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
if not isinstance(v, dict) or all(isinstance(col, ColumnConfig) for col in v.values()):
return v

return yaml_to_columns(v)

_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/utils/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def call_name(node: nodes.Expr) -> t.Tuple[str, ...]:


def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str:
return ENVIRONMENT.from_string(query).render(methods)
return ENVIRONMENT.from_string(query).render(methods or {})


def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]:
Expand Down
38 changes: 23 additions & 15 deletions tests/dbt/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest

from sqlmesh.core.model import SqlModel
from sqlmesh.dbt.basemodel import Dependencies
from sqlmesh.dbt.common import DbtContext
from sqlmesh.dbt.model import Materialization, ModelConfig
from sqlmesh.dbt.project import Project
Expand Down Expand Up @@ -130,6 +129,7 @@ def test_to_sqlmesh_fields(sushi_test_project: Project):


def test_model_config_sql_no_config():
context = DbtContext()
assert (
ModelConfig(
sql="""{{
Expand All @@ -139,64 +139,72 @@ def test_model_config_sql_no_config():
)
}}
query"""
).sql_no_config.strip()
)
.render_config(context)
.sql_no_config.strip()
== "query"
)

context.variables = {"new": "old"}
assert (
ModelConfig(
sql="""{{
config(
materialized='"table"',
materialized='table',
incremental_strategy='delete+insert',
post_hook=" '{{ macro_call(this) }}' "
post_hook=" '{{ var('new') }}' "
)
}}
query"""
).sql_no_config.strip()
)
.render_config(context)
.sql_no_config.strip()
== "query"
)

assert (
ModelConfig(
sql="""before {{config(materialized='table', post_hook=" {{ macro_call(this) }} ")}} after"""
).sql_no_config
sql="""before {{config(materialized='table', post_hook=" {{ var('new') }} ")}} after"""
)
.render_config(context)
.sql_no_config.strip()
== "before after"
)


def test_variables(assert_exp_eq, sushi_test_project):
# Case 1: using an undefined variable without a default value
defined_variables = {}
model_variables = {"foo"}

model_config = ModelConfig(alias="test", sql="SELECT {{ var('foo') }}")
model_config._dependencies = Dependencies(variables=model_variables)

context = sushi_test_project.context
context.variables = defined_variables

model_config = ModelConfig(alias="test", sql="SELECT {{ var('foo') }}")

kwargs = {"context": context}

with pytest.raises(ConfigError, match=r".*Variable 'foo' was not found.*"):
model_config = model_config.render_config(context)
rendered = model_config.render_config(context)
model_config.to_sqlmesh(**kwargs)

# Case 2: using a defined variable without a default value
defined_variables["foo"] = 6
context.variables = defined_variables
assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')
rendered = model_config.render_config(context)
assert_exp_eq(rendered.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')

# Case 3: using a defined variable with a default value
model_config.sql = "SELECT {{ var('foo', 5) }}"

assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')
rendered = model_config.render_config(context)
assert_exp_eq(rendered.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')

# Case 4: using an undefined variable with a default value
del defined_variables["foo"]
context.variables = defined_variables

assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 5 AS "5"')
rendered = model_config.render_config(context)
assert_exp_eq(rendered.to_sqlmesh(**kwargs).render_query(), 'SELECT 5 AS "5"')

# Finally, check that variable scoping & overwriting (some_var) works as expected
expected_sushi_variables = {
Expand Down
Loading

0 comments on commit 5fecad2

Please sign in to comment.