Skip to content

Commit b8453fa

Browse files
committed
Fix: Support signals for Python models
1 parent 1e1ace1 commit b8453fa

File tree

4 files changed

+46
-16
lines changed

4 files changed

+46
-16
lines changed

examples/sushi/models/orders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"end_ts": "int",
3737
"event_date": "date",
3838
},
39+
signals=[("test_signal", {"arg": 1})],
3940
)
4041
def execute(
4142
context: ExecutionContext,

sqlmesh/core/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ def _load_python_models(
672672
default_catalog=self.context.default_catalog,
673673
infer_names=self.config.model_naming.infer_names,
674674
audit_definitions=audits,
675+
signal_definitions=signals,
675676
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
676677
):
677678
if model.enabled:

sqlmesh/core/model/decorator.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlglot.dialects.dialect import DialectType
1010

1111
from sqlmesh.core.macros import MacroRegistry
12+
from sqlmesh.core.signal import SignalRegistry
1213
from sqlmesh.utils.jinja import JinjaMacroRegistry
1314
from sqlmesh.core import constants as c
1415
from sqlmesh.core.dialect import MacroFunc, parse_one
@@ -48,23 +49,24 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:
4849
self.kwargs = kwargs
4950

5051
# Make sure that argument values are expressions in order to pass validation in ModelMeta.
51-
calls = self.kwargs.pop("audits", [])
52-
self.kwargs["audits"] = [
53-
(
54-
(call, {})
55-
if isinstance(call, str)
56-
else (
57-
call[0],
58-
{
59-
arg_key: exp.convert(
60-
tuple(arg_value) if isinstance(arg_value, list) else arg_value
61-
)
62-
for arg_key, arg_value in call[1].items()
63-
},
52+
for function_call_attribute in ("audits", "signals"):
53+
calls = self.kwargs.pop(function_call_attribute, [])
54+
self.kwargs[function_call_attribute] = [
55+
(
56+
(call, {})
57+
if isinstance(call, str)
58+
else (
59+
call[0],
60+
{
61+
arg_key: exp.convert(
62+
tuple(arg_value) if isinstance(arg_value, list) else arg_value
63+
)
64+
for arg_key, arg_value in call[1].items()
65+
},
66+
)
6467
)
65-
)
66-
for call in calls
67-
]
68+
for call in calls
69+
]
6870

6971
if "default_catalog" in kwargs:
7072
raise ConfigError("`default_catalog` cannot be set on a per-model basis.")
@@ -142,6 +144,7 @@ def model(
142144
defaults: t.Optional[t.Dict[str, t.Any]] = None,
143145
macros: t.Optional[MacroRegistry] = None,
144146
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
147+
signal_definitions: t.Optional[SignalRegistry] = None,
145148
audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None,
146149
dialect: t.Optional[str] = None,
147150
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT,
@@ -223,6 +226,7 @@ def model(
223226
"macros": macros,
224227
"jinja_macros": jinja_macros,
225228
"audit_definitions": audit_definitions,
229+
"signal_definitions": signal_definitions,
226230
"blueprint_variables": blueprint_variables,
227231
**rendered_fields,
228232
}

tests/core/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5303,6 +5303,30 @@ def my_signal(batch):
53035303
)
53045304

53055305

5306+
def test_load_python_model_with_signals():
5307+
@signal()
5308+
def always_true(batch):
5309+
return True
5310+
5311+
@model(
5312+
name="model_with_signal",
5313+
kind="full",
5314+
columns={'"COL"': "int"},
5315+
signals=[("always_true", {})],
5316+
)
5317+
def model_with_signal(context, **kwargs):
5318+
return pd.DataFrame([{"COL": 1}])
5319+
5320+
models = model.get_registry()["model_with_signal"].models(
5321+
get_variables=lambda _: {},
5322+
path=Path("."),
5323+
module_path=Path("."),
5324+
signal_definitions=signal.get_registry(),
5325+
)
5326+
assert len(models) == 1
5327+
assert models[0].signals == [("always_true", {})]
5328+
5329+
53065330
def test_null_column_type():
53075331
expressions = d.parse(
53085332
"""

0 commit comments

Comments
 (0)