Skip to content

Commit d905941

Browse files
authored
Fix!: Allow python models to emit DataFrame's with a different column order (#4348)
1 parent c7571fe commit d905941

File tree

11 files changed

+245
-41
lines changed

11 files changed

+245
-41
lines changed

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ athena = ["PyAthena[Pandas]"]
4040
azuresql = ["pymssql"]
4141
bigquery = [
4242
"google-cloud-bigquery[pandas]",
43-
"google-cloud-bigquery-storage"
43+
"google-cloud-bigquery-storage",
44+
"bigframes>=1.32.0"
4445
]
45-
bigframes = ["bigframes>=1.32.0"]
4646
clickhouse = ["clickhouse-connect"]
4747
databricks = ["databricks-sql-connector[pyarrow]"]
4848
dev = [
@@ -107,8 +107,7 @@ slack = ["slack_sdk"]
107107
snowflake = [
108108
"cryptography",
109109
"snowflake-connector-python[pandas,secure-local-storage]",
110-
# as at 2024-08-05, snowflake-snowpark-python is only available up to Python 3.11
111-
"snowflake-snowpark-python; python_version<'3.12'",
110+
"snowflake-snowpark-python",
112111
]
113112
trino = ["trino"]
114113
web = [

sqlmesh/core/engine_adapter/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,12 @@ def _df_to_source_queries(
246246
assert isinstance(df, pd.DataFrame)
247247
num_rows = len(df.index)
248248
batch_size = sys.maxsize if batch_size == 0 else batch_size
249+
250+
# we need to ensure that the order of the columns in columns_to_types columns matches the order of the values
251+
# they can differ if a user specifies columns() on a python model in a different order than what's in the DataFrame's emitted by that model
252+
df = df[list(columns_to_types)]
249253
values = list(df.itertuples(index=False, name=None))
254+
250255
return [
251256
SourceQuery(
252257
query_factory=partial(

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,13 @@ def query_factory() -> Query:
218218
# as later calls.
219219
if not self.table_exists(temp_table):
220220
columns_to_types_create = columns_to_types.copy()
221-
self._convert_df_datetime(df, columns_to_types_create)
221+
ordered_df = df[
222+
list(columns_to_types_create)
223+
] # reorder DataFrame so it matches columns_to_types
224+
self._convert_df_datetime(ordered_df, columns_to_types_create)
222225
self.create_table(temp_table, columns_to_types_create)
223226
rows: t.List[t.Tuple[t.Any, ...]] = list(
224-
df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore
227+
ordered_df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore
225228
)
226229
conn = self._connection_pool.get()
227230
conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows)

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,25 @@ def _df_to_source_queries(
288288
is_snowpark_dataframe = snowpark and isinstance(df, snowpark.dataframe.DataFrame)
289289

290290
def query_factory() -> Query:
291+
# The catalog needs to be normalized before being passed to Snowflake's library functions because they
292+
# just wrap whatever they are given in quotes without checking if its already quoted
293+
database = (
294+
normalize_identifiers(temp_table.catalog, dialect=self.dialect)
295+
if temp_table.catalog
296+
else None
297+
)
298+
291299
if is_snowpark_dataframe:
292-
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect, identify=True)) # type: ignore
300+
temp_table.set("catalog", database)
301+
df_renamed = df.rename(
302+
{
303+
col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True)
304+
for col in columns_to_types
305+
}
306+
) # type: ignore
307+
df_renamed.createOrReplaceTempView(
308+
temp_table.sql(dialect=self.dialect, identify=True)
309+
) # type: ignore
293310
elif isinstance(df, pd.DataFrame):
294311
from snowflake.connector.pandas_tools import write_pandas
295312

@@ -325,11 +342,7 @@ def query_factory() -> Query:
325342
df,
326343
temp_table.name,
327344
schema=temp_table.db or None,
328-
database=normalize_identifiers(temp_table.catalog, dialect=self.dialect).sql(
329-
dialect=self.dialect
330-
)
331-
if temp_table.catalog
332-
else None,
345+
database=database.sql(dialect=self.dialect) if database else None,
333346
chunk_size=self.DEFAULT_BATCH_SIZE,
334347
overwrite=True,
335348
table_type="temp",

sqlmesh/core/engine_adapter/spark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,16 @@ def _ensure_pyspark_df(
279279
) -> PySparkDataFrame:
280280
pyspark_df = self.try_get_pyspark_df(generic_df)
281281
if pyspark_df:
282+
if columns_to_types:
283+
# ensure Spark dataframe column order matches columns_to_types
284+
pyspark_df = pyspark_df.select(*columns_to_types)
282285
return pyspark_df
283286
df = self.try_get_pandas_df(generic_df)
284287
if df is None:
285288
raise SQLMeshError("Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame")
289+
if columns_to_types:
290+
# ensure Pandas dataframe column order matches columns_to_types
291+
df = df[list(columns_to_types)]
286292
kwargs = (
287293
dict(schema=self.sqlglot_to_spark_types(columns_to_types)) if columns_to_types else {}
288294
)

tests/core/engine_adapter/integration/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def get_table_comment(
359359
FROM pg_class c
360360
INNER JOIN pg_description d ON c.oid = d.objoid AND d.objsubid = 0
361361
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
362-
WHERE
362+
WHERE
363363
c.relname = '{table_name}'
364364
AND n.nspname= '{schema_name}'
365365
AND c.relkind = '{"v" if table_kind == "VIEW" else "r"}'
@@ -465,12 +465,12 @@ def get_column_comments(
465465
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
466466
INNER JOIN pg_attribute a ON c.oid = a.attrelid
467467
INNER JOIN pg_description d
468-
ON
468+
ON
469469
a.attnum = d.objsubid
470470
AND d.objoid = c.oid
471471
WHERE
472472
n.nspname = '{schema_name}'
473-
AND c.relname = '{table_name}'
473+
AND c.relname = '{table_name}'
474474
AND c.relkind = '{"v" if table_kind == "VIEW" else "r"}'
475475
;
476476
"""
@@ -494,6 +494,7 @@ def create_context(
494494
self,
495495
config_mutator: t.Optional[t.Callable[[str, Config], None]] = None,
496496
path: t.Optional[pathlib.Path] = None,
497+
ephemeral_state_connection: bool = True,
497498
) -> Context:
498499
private_sqlmesh_dir = pathlib.Path(pathlib.Path().home(), ".sqlmesh")
499500
config = load_config_from_paths(
@@ -509,14 +510,12 @@ def create_context(
509510
config.gateways = {self.gateway: config.gateways[self.gateway]}
510511

511512
gateway_config = config.gateways[self.gateway]
512-
if (
513-
(sc := gateway_config.state_connection)
514-
and (conn := gateway_config.connection)
515-
and sc.type_ == "duckdb"
516-
):
517-
# if duckdb is being used as the state connection, set concurrent_tasks=1 on the main connection
518-
# to prevent duckdb from being accessed from multiple threads and getting deadlocked
519-
conn.concurrent_tasks = 1
513+
if ephemeral_state_connection:
514+
# Override whatever state connection has been configured on the integration test config to use in-memory DuckDB instead
515+
# This is so tests that initialize a SQLMesh context can run concurrently without clobbering each others state
516+
from sqlmesh.core.config.connection import DuckDBConnectionConfig
517+
518+
gateway_config.state_connection = DuckDBConnectionConfig()
520519

521520
if "athena" in self.gateway:
522521
conn = gateway_config.connection

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2721,7 +2721,9 @@ def _use_warehouse_as_state_connection(gateway_name: str, config: Config):
27212721

27222722
config.gateways[gateway_name].state_schema = test_schema
27232723

2724-
sqlmesh_context = ctx.create_context(config_mutator=_use_warehouse_as_state_connection)
2724+
sqlmesh_context = ctx.create_context(
2725+
config_mutator=_use_warehouse_as_state_connection, ephemeral_state_connection=False
2726+
)
27252727
assert sqlmesh_context.config.get_state_schema(ctx.gateway) == test_schema
27262728

27272729
state_sync = (
@@ -2732,3 +2734,83 @@ def _use_warehouse_as_state_connection(gateway_name: str, config: Config):
27322734

27332735
# will throw if one of the migrations produces an error, which can happen if we forget to take quoting or normalization into account
27342736
sqlmesh_context.migrate()
2737+
2738+
2739+
def test_python_model_column_order(ctx: TestContext, tmp_path: pathlib.Path):
2740+
if ctx.test_type == "pyspark" and ctx.dialect in ("spark", "databricks"):
2741+
# dont skip
2742+
pass
2743+
elif ctx.test_type != "df":
2744+
pytest.skip("python model column order test only needs to be run once per db")
2745+
2746+
schema = ctx.add_test_suffix(TEST_SCHEMA)
2747+
2748+
(tmp_path / "models").mkdir()
2749+
2750+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
2751+
# is returned by the DataFrame within the model
2752+
model_path = tmp_path / "models" / "python_model.py"
2753+
if ctx.test_type == "pyspark":
2754+
# python model that emits a PySpark dataframe
2755+
model_path.write_text(
2756+
"""
2757+
from pyspark.sql import DataFrame, Row
2758+
import typing as t
2759+
from sqlmesh import ExecutionContext, model
2760+
2761+
@model(
2762+
"TEST_SCHEMA.model",
2763+
columns={
2764+
"id": "int",
2765+
"name": "varchar"
2766+
}
2767+
)
2768+
def execute(
2769+
context: ExecutionContext,
2770+
**kwargs: t.Any,
2771+
) -> DataFrame:
2772+
return context.spark.createDataFrame([
2773+
Row(name="foo", id=1)
2774+
])
2775+
""".replace("TEST_SCHEMA", schema)
2776+
)
2777+
else:
2778+
# python model that emits a Pandas DataFrame
2779+
model_path.write_text(
2780+
"""
2781+
import pandas as pd
2782+
import typing as t
2783+
from sqlmesh import ExecutionContext, model
2784+
2785+
@model(
2786+
"TEST_SCHEMA.model",
2787+
columns={
2788+
"id": "int",
2789+
"name": "varchar"
2790+
}
2791+
)
2792+
def execute(
2793+
context: ExecutionContext,
2794+
**kwargs: t.Any,
2795+
) -> pd.DataFrame:
2796+
return pd.DataFrame([
2797+
{"name": "foo", "id": 1}
2798+
])
2799+
""".replace("TEST_SCHEMA", schema)
2800+
)
2801+
2802+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
2803+
2804+
assert len(sqlmesh_ctx.models) == 1
2805+
2806+
plan = sqlmesh_ctx.plan(auto_apply=True)
2807+
assert len(plan.new_snapshots) == 1
2808+
2809+
engine_adapter = sqlmesh_ctx.engine_adapter
2810+
2811+
query = exp.select("*").from_(
2812+
exp.to_table(f"{schema}.model", dialect=ctx.dialect), dialect=ctx.dialect
2813+
)
2814+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
2815+
assert len(df) == 1
2816+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/integration/test_integration_bigquery.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,50 @@ def test_table_diff_table_name_matches_column_name(ctx: TestContext):
433433

434434
assert row_diff.stats["join_count"] == 1
435435
assert row_diff.full_match_count == 1
436+
437+
438+
def test_bigframe_python_model_column_order(ctx: TestContext, tmp_path: Path):
439+
model_name = ctx.table("TEST")
440+
441+
(tmp_path / "models").mkdir()
442+
443+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
444+
# is returned by the DataFrame within the model
445+
model_path = tmp_path / "models" / "python_model.py"
446+
447+
# python model that emits a BigFrame dataframe
448+
model_path.write_text(
449+
"""
450+
from bigframes.pandas import DataFrame
451+
import typing as t
452+
from sqlmesh import ExecutionContext, model
453+
454+
@model(
455+
'MODEL_NAME',
456+
columns={
457+
"id": "int",
458+
"name": "varchar"
459+
},
460+
dialect="bigquery"
461+
)
462+
def execute(
463+
context: ExecutionContext,
464+
**kwargs: t.Any,
465+
) -> DataFrame:
466+
return DataFrame({'name': ['foo'], 'id': [1]}, session=context.bigframe)
467+
""".replace("MODEL_NAME", model_name.sql(dialect="bigquery"))
468+
)
469+
470+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
471+
472+
assert len(sqlmesh_ctx.models) == 1
473+
474+
plan = sqlmesh_ctx.plan(auto_apply=True)
475+
assert len(plan.new_snapshots) == 1
476+
477+
engine_adapter = sqlmesh_ctx.engine_adapter
478+
479+
query = exp.select("*").from_(model_name)
480+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
481+
assert len(df) == 1
482+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
22
import pytest
33
from sqlglot import exp
4+
from pathlib import Path
45
from sqlglot.optimizer.qualify_columns import quote_identifiers
56
from sqlglot.helper import seq_get
67
from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter
@@ -210,3 +211,49 @@ def test_create_iceberg_table(ctx: TestContext, engine_adapter: SnowflakeEngineA
210211
result = sqlmesh.plan(auto_apply=True)
211212

212213
assert len(result.new_snapshots) == 2
214+
215+
216+
def test_snowpark_python_model_column_order(ctx: TestContext, tmp_path: Path):
217+
model_name = ctx.table("TEST")
218+
219+
(tmp_path / "models").mkdir()
220+
221+
# note: this model deliberately defines the columns in the @model definition to be in a different order than what
222+
# is returned by the DataFrame within the model
223+
model_path = tmp_path / "models" / "python_model.py"
224+
225+
# python model that emits a Snowpark DataFrame
226+
model_path.write_text(
227+
"""
228+
from snowflake.snowpark.dataframe import DataFrame
229+
import typing as t
230+
from sqlmesh import ExecutionContext, model
231+
232+
@model(
233+
'MODEL_NAME',
234+
columns={
235+
"id": "int",
236+
"name": "varchar"
237+
}
238+
)
239+
def execute(
240+
context: ExecutionContext,
241+
**kwargs: t.Any,
242+
) -> DataFrame:
243+
return context.snowpark.create_dataframe([["foo", 1]], schema=["name", "id"])
244+
""".replace("MODEL_NAME", model_name.sql(dialect="snowflake"))
245+
)
246+
247+
sqlmesh_ctx = ctx.create_context(path=tmp_path)
248+
249+
assert len(sqlmesh_ctx.models) == 1
250+
251+
plan = sqlmesh_ctx.plan(auto_apply=True)
252+
assert len(plan.new_snapshots) == 1
253+
254+
engine_adapter = sqlmesh_ctx.engine_adapter
255+
256+
query = exp.select("*").from_(plan.environment.snapshots[0].fully_qualified_table)
257+
df = engine_adapter.fetchdf(query, quote_identifiers=True)
258+
assert len(df) == 1
259+
assert df.iloc[0].to_dict() == {"id": 1, "name": "foo"}

0 commit comments

Comments
 (0)