Skip to content

Commit 011c0cf

Browse files
authored
fix(udf): don't allow silent wrong schema merges (#1460)
1 parent 4162955 commit 011c0cf

File tree

5 files changed

+100
-8
lines changed

5 files changed

+100
-8
lines changed

src/datachain/lib/dc/datachain.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -850,14 +850,13 @@ def map(
850850
if (prefetch := self._settings.prefetch) is not None:
851851
udf_obj.prefetch = prefetch
852852

853+
sys_schema = SignalSchema({"sys": Sys})
853854
return self._evolve(
854855
query=self._query.add_signals(
855856
udf_obj.to_udf_wrapper(self._settings.batch_size),
856857
**self._settings.to_dict(),
857858
),
858-
signal_schema=SignalSchema({"sys": Sys})
859-
| self.signals_schema
860-
| udf_obj.output,
859+
signal_schema=sys_schema | self.signals_schema | udf_obj.output,
861860
)
862861

863862
def gen(

src/datachain/lib/signal_schema.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from datachain.lib.data_model import DataModel, DataType, DataValue
3535
from datachain.lib.file import File
3636
from datachain.lib.model_store import ModelStore
37-
from datachain.lib.utils import DataChainParamsError
37+
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
3838
from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
3939
from datachain.sql.types import SQLType
4040

@@ -1038,7 +1038,28 @@ def get_headers_with_length(self, include_hidden: bool = True):
10381038
], max_length
10391039

10401040
def __or__(self, other):
1041-
return self.__class__(self.values | other.values)
1041+
new_values = dict(self.values)
1042+
1043+
for name, new_type in other.values.items():
1044+
if name in new_values:
1045+
current_type = new_values[name]
1046+
if current_type != new_type:
1047+
raise DataChainColumnError(
1048+
name,
1049+
"signal already exists with a different type",
1050+
)
1051+
continue
1052+
1053+
root = self._extract_root(name)
1054+
if any(self._extract_root(existing) == root for existing in new_values):
1055+
raise DataChainColumnError(
1056+
name,
1057+
"signal root already exists in schema",
1058+
)
1059+
1060+
new_values[name] = new_type
1061+
1062+
return self.__class__(new_values)
10421063

10431064
def __contains__(self, name: str):
10441065
return name in self.values

src/datachain/query/dataset.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,26 @@ def create_result_query(
683683
signal_name_cols = {c.name: c for c in signal_cols}
684684
cols = signal_cols
685685

686-
overlap = {c.name for c in original_cols} & {c.name for c in cols}
686+
original_names = {c.name for c in original_cols}
687+
new_names = {c.name for c in cols}
688+
689+
overlap = original_names & new_names
687690
if overlap:
688691
raise ValueError(
689692
"Column already exists or added in the previous steps: "
690-
+ ", ".join(overlap)
693+
+ ", ".join(sorted(overlap))
694+
)
695+
696+
def _root(name: str) -> str:
697+
return name.split(DEFAULT_DELIMITER, 1)[0]
698+
699+
existing_roots = {_root(name) for name in original_names}
700+
new_roots = {_root(name) for name in new_names}
701+
root_conflicts = existing_roots & new_roots
702+
if root_conflicts:
703+
raise ValueError(
704+
"Signals already exist in the previous steps: "
705+
+ ", ".join(sorted(root_conflicts))
691706
)
692707

693708
def q(*columns):

tests/func/test_udf.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from datachain.func import path as pathfunc
1313
from datachain.lib.file import AudioFile, AudioFragment, File
1414
from datachain.lib.udf import Mapper
15-
from datachain.lib.utils import DataChainError
15+
from datachain.lib.utils import DataChainColumnError, DataChainError
1616
from tests.utils import LARGE_TREE, NUM_TREE
1717

1818

@@ -390,6 +390,23 @@ def test_types():
390390
]
391391

392392

393+
def test_udf_rejects_root_override(test_session):
394+
class X(dc.DataModel):
395+
x: int
396+
397+
chain = dc.read_values(x=[X(x=0), X(x=1)], session=test_session)
398+
399+
with pytest.raises(
400+
DataChainColumnError,
401+
match="Error for column x: signal already exists with a different type",
402+
):
403+
chain.map(
404+
lambda x: x.model_dump(),
405+
params=["x"],
406+
output={"x": dict},
407+
)
408+
409+
393410
@pytest.mark.parametrize("use_cache", [False, True])
394411
@pytest.mark.parametrize("prefetch", [0, 2])
395412
def test_map_file(cloud_test_catalog, use_cache, prefetch, monkeypatch):

tests/unit/lib/test_signal_schema.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SignalSchemaError,
1818
SignalSchemaWarning,
1919
)
20+
from datachain.lib.utils import DataChainColumnError
2021
from datachain.sql.types import (
2122
JSON,
2223
Array,
@@ -185,6 +186,45 @@ def test_feature_schema_serialize_list():
185186
assert deserialized_schema.values == schema
186187

187188

189+
def test_schema_or_rejects_type_change():
190+
base = SignalSchema({"foo": int})
191+
new = SignalSchema({"foo": str})
192+
193+
with pytest.raises(DataChainColumnError, match="different type"):
194+
_ = base | new
195+
196+
197+
def test_schema_or_rejects_root_conflict():
198+
base = SignalSchema({"feature": MyType1})
199+
new = SignalSchema({"feature.extra": int})
200+
201+
with pytest.raises(DataChainColumnError, match="root"):
202+
_ = base | new
203+
204+
205+
def test_schema_or_allows_sys_root():
206+
base = SignalSchema({"foo": int})
207+
new = SignalSchema({"sys": Sys})
208+
209+
combined = base | new
210+
assert combined.values["sys"] is Sys
211+
212+
213+
def test_schema_or_rejects_sys_override():
214+
base = SignalSchema({"sys": Sys})
215+
new = SignalSchema({"sys": dict})
216+
217+
with pytest.raises(DataChainColumnError, match="different type"):
218+
_ = base | new
219+
220+
221+
def test_schema_or_allows_identical_signal():
222+
base = SignalSchema({"foo": int})
223+
combined = base | SignalSchema({"foo": int})
224+
225+
assert combined.values["foo"] is int
226+
227+
188228
def test_feature_schema_serialize_list_old():
189229
schema = {
190230
"name": str | None,

0 commit comments

Comments
 (0)