Skip to content

Commit 902f944

Browse files
felbrocopybara-github
authored andcommitted
Fix scalar vs non-scalar issue in kd.attr
``` x = kd.obj() kd.attr(x, kd.slice(['x']), kd.slice([1])) ``` Crashes due to an incorrect assumption that `x.GetObjSchema()` is never an item PiperOrigin-RevId: 863658648 Change-Id: Iae5d3423e9ee726df4064ff084e1550505224d3d
1 parent 13f9e6f commit 902f944

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

koladata/operators/core.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ absl::Status UpdateSchemaForConflictDetection(
113113
absl::Status UpdateSchemaForConflictDetection(const DataSlice& obj,
114114
const DataSlice& attr_names,
115115
DataBagPtr& result_db) {
116-
bool object_mode = obj.GetSchemaImpl() == schema::kObject;
117-
ASSIGN_OR_RETURN(DataSlice src_schema,
118-
object_mode ? obj.GetObjSchema() : obj.GetSchema());
116+
ASSIGN_OR_RETURN(DataSlice src_schema, obj.GetSchemaImpl() == schema::kObject
117+
? obj.GetObjSchema()
118+
: obj.GetSchema());
119119
DataSlice dst_schema = src_schema.WithBag(result_db);
120120

121121
ASSIGN_OR_RETURN(auto aligned_slices,
@@ -125,9 +125,7 @@ absl::Status UpdateSchemaForConflictDetection(const DataSlice& obj,
125125
aligned_slices[1].impl<internal::DataSliceImpl>();
126126

127127
absl::Status status = absl::OkStatus();
128-
if (object_mode) {
129-
DCHECK(!src_schema.is_item());
130-
DCHECK(!dst_schema.is_item());
128+
if (!src_schema.is_item()) {
131129
// In this case, src_schema and dst_schema are not DataItems, so we need to
132130
// iterate over all sub-items.
133131
DCHECK(aligned_obj.dtype() == arolla::GetQType<internal::ObjectId>());

py/koladata/operators/tests/core_attr_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ def test_attr_update_with_ds_attr(self):
9696
testing.assert_equal(ds1.updated(db).x.no_bag(), ds([42, 3]))
9797
testing.assert_equal(ds1.updated(db).y.no_bag(), ds([2, 42]))
9898

99+
def test_attr_update_with_ds_attr_scalar_obj(self):
100+
with self.subTest('object'):
101+
ds1 = kd.obj(x=1)
102+
db = kd.core.attr(ds1, ds(['x', 'y']), ds([42, 43]))
103+
testing.assert_equal(ds1.updated(db).x.no_bag(), ds(42))
104+
testing.assert_equal(ds1.updated(db).y.no_bag(), ds(43))
105+
106+
with self.subTest('entity'):
107+
ds1 = kd.new(x=1)
108+
db = kd.core.attr(ds1, ds(['x', 'y']), ds([42, 43]))
109+
testing.assert_equal(ds1.updated(db).x.no_bag(), ds(42))
110+
testing.assert_equal(ds1.updated(db).y.no_bag(), ds(43))
111+
99112
def test_invalid_attr_name(self):
100113
o = bag().new(x=1)
101114
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)