Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/concepts/models/python_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,33 @@ def entrypoint(
)
```

Blueprint variables can also be used as **column names and column types** in the `columns` dictionary. For example, if each blueprint produces a model with a different set of column names and types, both can be parameterized using the same `@{variable}` syntax:

```python linenums="1"
import pandas as pd
from sqlmesh import ExecutionContext, model

@model(
"@{customer}.metrics",
kind="FULL",
blueprints=[
{"customer": "customer1", "primary_metric": "revenue", "primary_type": "int", "secondary_metric": "cost", "secondary_type": "double"},
{"customer": "customer2", "primary_metric": "sales", "primary_type": "text", "secondary_metric": "profit", "secondary_type": "double"},
],
columns={
"@{primary_metric}": "@{primary_type}",
"@{secondary_metric}": "@{secondary_type}",
},
)
def entrypoint(context: ExecutionContext, **kwargs) -> pd.DataFrame:
return pd.DataFrame({
context.blueprint_var("primary_metric"): [1],
context.blueprint_var("secondary_metric"): [1.5],
})
```

Global variables (defined in the project config) can also be used as column names and types in the same way.

Note the use of curly brace syntax `@{customer}` in the model name above. It is used to ensure SQLMesh can combine the macro variable into the model name identifier correctly - learn more [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings).

Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints="@gen_blueprints()"`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files.
Expand Down
5 changes: 3 additions & 2 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,8 +1096,9 @@ def extend_sqlglot() -> None:
DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
Jinja: lambda self, e: e.name,
JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
JinjaStatement: lambda self,
e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
JinjaStatement: lambda self, e: (
f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};"
),
VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
MacroFunc: _macro_func_sql,
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def _get_source_queries(
)
for c in target_columns_to_types
]
query_factory = (
lambda: exp.Select()
query_factory = lambda: (
exp.Select()
.select(*select_columns)
.from_(query_or_df.subquery("select_source_columns"))
)
Expand Down
10 changes: 6 additions & 4 deletions sqlmesh/core/model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:

self.columns = {
column_name: (
column_type
if isinstance(column_type, exp.DataType)
column_type # Column types with macros (containing @) will be validated later after rendering
if isinstance(column_type, exp.DataType) or "@" in column_type
else exp.DataType.build(
str(column_type), dialect=self.kwargs.get("dialect", self._dialect)
)
)
for column_name, column_type in self.kwargs.pop("columns", {}).items()
for column_name, column_type in self.kwargs.get("columns", {}).items()
}

def __call__(
Expand Down Expand Up @@ -196,6 +196,8 @@ def model(
if isinstance(rendered_name, exp.Expr):
rendered_fields["name"] = rendered_name.sql(dialect=dialect)

rendered_columns = rendered_fields.get("columns")

rendered_defaults = (
render_model_defaults(
defaults=defaults,
Expand Down Expand Up @@ -223,7 +225,7 @@ def model(
"default_catalog": default_catalog,
"variables": variables,
"dialect": dialect,
"columns": self.columns if self.columns else None,
"columns": rendered_columns if rendered_columns else None,
"module_path": module_path,
"macros": macros,
"jinja_macros": jinja_macros,
Expand Down
10 changes: 9 additions & 1 deletion sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,15 @@ def render_field_value(value: t.Any) -> t.Any:
if isinstance(field_value, dict):
rendered_dict = {}
for key, value in field_value.items():
if key in RUNTIME_RENDERED_MODEL_FIELDS:
if field == "columns":
column_name = render_field_value(key)
column_type = render_field_value(value)
# If column_type is an Expr (from rendering macros), convert to string.
# Otherwise, leave it as-is (string) for the validator to parse with the correct dialect.
if isinstance(column_type, exp.Expr):
column_type = column_type.sql(dialect=dialect)
rendered_dict[column_name] = column_type
elif key in RUNTIME_RENDERED_MODEL_FIELDS:
rendered_dict[key] = parse_strings_with_macro_refs(value, dialect)
elif (
# don't parse kind auto_restatement_cron="@..." kwargs (e.g. @daily) into MacroVar
Expand Down
18 changes: 11 additions & 7 deletions sqlmesh/lsp/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,9 @@ def get_model_find_all_references(
# Find the model reference at the cursor position
model_at_position = next(
filter(
lambda ref: isinstance(ref, ModelReference)
and _position_within_range(position, ref.range),
lambda ref: (
isinstance(ref, ModelReference) and _position_within_range(position, ref.range)
),
get_model_definitions_for_a_path(lint_context, document_uri),
),
None,
Expand Down Expand Up @@ -486,8 +487,9 @@ def get_macro_find_all_references(
# Find the macro reference at the cursor position
macro_at_position = next(
filter(
lambda ref: isinstance(ref, MacroReference)
and _position_within_range(position, ref.range),
lambda ref: (
isinstance(ref, MacroReference) and _position_within_range(position, ref.range)
),
get_macro_definitions_for_a_path(lsp_context, document_uri),
),
None,
Expand Down Expand Up @@ -517,9 +519,11 @@ def get_macro_find_all_references(

# Get macro references that point to the same macro definition
matching_refs = filter(
lambda ref: isinstance(ref, MacroReference)
and ref.path == target_macro_path
and ref.target_range == target_macro_target_range,
lambda ref: (
isinstance(ref, MacroReference)
and ref.path == target_macro_path
and ref.target_range == target_macro_target_range
),
get_macro_definitions_for_a_path(lsp_context, file_uri),
)

Expand Down
96 changes: 94 additions & 2 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ def test_model_union_query(sushi_context, assert_exp_eq):
"@get_date() == '1996-02-10'",
"'all'",
3,
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n",
lambda expected_select: (
f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n"
),
),
# Test case 4: DISTINCT type
(
Expand Down Expand Up @@ -374,7 +376,9 @@ def test_model_union_query(sushi_context, assert_exp_eq):
"",
"",
3,
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n",
lambda expected_select: (
f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n"
),
),
# Test case 9: Missing union type AND condition one table
(
Expand Down Expand Up @@ -10353,6 +10357,94 @@ def entrypoint(context, *args, **kwargs):
assert ctx.fetchdf("SELECT * FROM test_schema2.foo").to_dict() == {"id": {0: 1}}


def test_python_model_blueprint_column_names(tmp_path: Path) -> None:
"""Blueprint variables can be used as column names and types in Python model definitions."""
py_model = tmp_path / "models" / "blueprint_col_names.py"
py_model.parent.mkdir(parents=True, exist_ok=True)
py_model.write_text(
"""
import pandas as pd # noqa: TID253
from sqlmesh import model

@model(
"test_schema.@model_name",
blueprints=[
{"model_name": "hotel_revenue", "col_a": "revenue", "type_a": "int", "col_b": "cost", "type_b": "double"},
{"model_name": "coffee_sales", "col_a": "sales", "type_a": "bigint", "col_b": "profit", "type_b": "text"},
],
kind="FULL",
columns={
"@{col_a}": "@{type_a}",
"@{col_b}": "@{type_b}",
},
)
def entrypoint(context, *args, **kwargs):
return pd.DataFrame({
context.blueprint_var("col_a"): [1],
context.blueprint_var("col_b"): [1.5],
})
"""
)

ctx = Context(
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
paths=tmp_path,
)
assert len(ctx.models) == 2

model1 = ctx.get_model("test_schema.hotel_revenue", raise_if_missing=True)
model2 = ctx.get_model("test_schema.coffee_sales", raise_if_missing=True)

assert model1.columns_to_types_ is not None
assert set(model1.columns_to_types_.keys()) == {"revenue", "cost"}
assert model1.columns_to_types_["revenue"] == exp.DataType.build("int")
assert model1.columns_to_types_["cost"] == exp.DataType.build("double")

assert model2.columns_to_types_ is not None
assert set(model2.columns_to_types_.keys()) == {"sales", "profit"}
assert model2.columns_to_types_["sales"] == exp.DataType.build("bigint")
assert model2.columns_to_types_["profit"] == exp.DataType.build("text")


def test_python_model_variable_column_names(tmp_path: Path) -> None:
"""Global variables can be used as column names in Python model definitions."""
py_model = tmp_path / "models" / "var_col_names.py"
py_model.parent.mkdir(parents=True, exist_ok=True)
py_model.write_text(
"""
import pandas as pd # noqa: TID253
from sqlmesh import model

@model(
"test_schema.model",
kind="FULL",
columns={
"@{metric_col}": "int",
"static_col": "text",
},
)
def entrypoint(context, *args, **kwargs):
return pd.DataFrame({"revenue": [1], "static_col": ["x"]})
"""
)

ctx = Context(
config=Config(
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
variables={"metric_col": "revenue"},
),
paths=tmp_path,
)
assert len(ctx.models) == 1

model = ctx.get_model("test_schema.model", raise_if_missing=True)

assert model.columns_to_types_ is not None
assert set(model.columns_to_types_.keys()) == {"revenue", "static_col"}
assert model.columns_to_types_["revenue"] == exp.DataType.build("int")
assert model.columns_to_types_["static_col"] == exp.DataType.build("text")


@time_machine.travel("2020-01-01 00:00:00 UTC")
def test_dynamic_date_spine_model(assert_exp_eq):
@macro()
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_plan_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,8 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
finalized_ts=to_timestamp("2023-01-02"),
)

state_reader.get_environment.side_effect = (
lambda name: existing_dev_environment if name == "dev" else existing_prod_environment
state_reader.get_environment.side_effect = lambda name: (
existing_dev_environment if name == "dev" else existing_prod_environment
)
state_reader.get_environments_summary.return_value = [
existing_prod_environment.summary,
Expand Down Expand Up @@ -857,8 +857,8 @@ def test_build_plan_stages_restatement_dev_does_not_clear_intervals(
finalized_ts=to_timestamp("2023-01-02"),
)

state_reader.get_environment.side_effect = (
lambda name: existing_dev_environment if name == "dev" else existing_prod_environment
state_reader.get_environment.side_effect = lambda name: (
existing_dev_environment if name == "dev" else existing_prod_environment
)
state_reader.get_environments_summary.return_value = [
existing_prod_environment.summary,
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_selector_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot)
)

state_reader_mock = mocker.Mock()
state_reader_mock.get_environment.side_effect = (
lambda name: prod_env if name == "prod" else dev_env
state_reader_mock.get_environment.side_effect = lambda name: (
prod_env if name == "prod" else dev_env
)

all_snapshots = {
Expand Down Expand Up @@ -875,8 +875,8 @@ def test_select_models_selected_fqns_fallback(mocker: MockerFixture, make_snapsh
)

state_reader_mock = mocker.Mock()
state_reader_mock.get_environment.side_effect = (
lambda name: fallback_env if name == "prod" else None
state_reader_mock.get_environment.side_effect = lambda name: (
fallback_env if name == "prod" else None
)
state_reader_mock.get_snapshots.return_value = {
deleted_model_snapshot.snapshot_id: deleted_model_snapshot,
Expand Down
Loading