Skip to content

Commit 6521e08

Browse files
Koladata Teamcopybara-github
authored andcommitted
support named schemas in traverser and deep_clone
PiperOrigin-RevId: 718827570 Change-Id: I6b085612bbd79212f700cd9dd8c69d172f7b4c19
1 parent bfc0776 commit 6521e08

File tree

4 files changed

+74
-7
lines changed

4 files changed

+74
-7
lines changed

koladata/internal/op_utils/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ cc_test(
207207
"@com_google_absl//absl/container:flat_hash_set",
208208
"@com_google_absl//absl/log:check",
209209
"@com_google_absl//absl/status",
210+
"@com_google_absl//absl/strings",
210211
"@com_google_absl//absl/strings:str_format",
211212
"@com_google_absl//absl/types:span",
212213
"@com_google_arolla//arolla/dense_array",

koladata/internal/op_utils/traverser.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ class Traverser {
285285
if (!status.ok()) {
286286
return;
287287
}
288+
if (attr_name == schema::kSchemaNameAttr) {
289+
return;
290+
}
288291
status = PrevisitAttribute(item, attr_name);
289292
});
290293
return status;
@@ -339,8 +342,11 @@ class Traverser {
339342
status = attr_schema_or.status();
340343
return;
341344
}
345+
auto schema_dtype = (attr_name == schema::kSchemaNameAttr)
346+
? schema::kString
347+
: schema::kSchema;
342348
status = Previsit(
343-
{.item = *attr_schema_or, .schema = DataItem(schema::kSchema)});
349+
{.item = *attr_schema_or, .schema = DataItem(schema_dtype)});
344350
});
345351
return status;
346352
}
@@ -459,12 +465,15 @@ class Traverser {
459465
return;
460466
}
461467
auto attr_schema = *attr_schema_or;
462-
if (!attr_schema.is_schema()) {
468+
DataItem attr_schema_schema = DataItem(schema::kSchema);
469+
if (attr_name == schema::kSchemaNameAttr) {
470+
attr_schema_schema = DataItem(schema::kString);
471+
} else if (!attr_schema.is_schema()) {
463472
status = absl::InvalidArgumentError(absl::StrFormat(
464473
"schema %v has unexpected attribute %s", schema, attr_name));
465474
return;
466475
}
467-
auto attr_value_or = GetValue(attr_schema, DataItem(schema::kSchema));
476+
auto attr_value_or = GetValue(attr_schema, attr_schema_schema);
468477
if (!attr_value_or.ok()) {
469478
status = attr_value_or.status();
470479
return;
@@ -538,13 +547,19 @@ class Traverser {
538547
return VisitDict(item, is_object);
539548
}
540549
arolla::DenseArrayBuilder<DataItem> attr_values(attr_names.size());
550+
arolla::DenseArrayBuilder<arolla::Text> actual_attr_names(
551+
attr_names.size());
541552
absl::Status status = absl::OkStatus();
553+
size_t attr_count = 0;
542554
attr_names.ForEach([&](int64_t id, bool presence,
543555
std::string_view attr_name) {
544556
DCHECK(presence);
545557
if (!status.ok()) {
546558
return;
547559
}
560+
if (attr_name == schema::kSchemaNameAttr) {
561+
return;
562+
}
548563
auto attr_item_or = databag_.GetAttr(item.item, attr_name, fallbacks_);
549564
if (!attr_item_or.ok()) {
550565
status = attr_item_or.status();
@@ -561,14 +576,17 @@ class Traverser {
561576
status = value_or.status();
562577
return;
563578
}
564-
attr_values.Set(id, *value_or);
579+
attr_values.Set(attr_count, *value_or);
580+
actual_attr_names.Set(attr_count, attr_name);
581+
++attr_count;
565582
});
566583
if (!status.ok()) {
567584
return status;
568585
}
569-
return visitor_->VisitorT::VisitObject(item.item, item.schema, is_object,
570-
attr_names,
571-
std::move(attr_values).Build());
586+
return visitor_->VisitorT::VisitObject(
587+
item.item, item.schema, is_object,
588+
std::move(actual_attr_names).Build(attr_count),
589+
std::move(attr_values).Build(attr_count));
572590
}
573591

574592
absl::Status VisitObject(const ItemWithSchema& item) {

koladata/internal/op_utils/traverser_test.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/container/flat_hash_set.h"
2828
#include "absl/log/check.h"
2929
#include "absl/status/status.h"
30+
#include "absl/strings/str_cat.h"
3031
#include "absl/strings/str_format.h"
3132
#include "absl/types/span.h"
3233
#include "koladata/internal/data_bag.h"
@@ -140,6 +141,10 @@ class NoOpVisitor : AbstractVisitor {
140141
: previsited_(), value_item_(DataItem("get value result")) {}
141142

142143
absl::Status Previsit(const DataItem& item, const DataItem& schema) override {
144+
if (!schema.is_schema()) {
145+
return absl::InvalidArgumentError(
146+
absl::StrFormat("%v is not a schema", schema));
147+
}
143148
previsited_.push_back({item, schema});
144149
return absl::OkStatus();
145150
}
@@ -540,6 +545,35 @@ TEST_P(NoOpTraverserTest, SchemaSlice) {
540545
{GetFallbackDb(db).get()}));
541546
}
542547

548+
TEST_P(NoOpTraverserTest, SliceWithNamedSchema) {
549+
auto db = DataBagImpl::CreateEmptyDatabag();
550+
auto obj_ids = DataSliceImpl::AllocateEmptyObjects(3);
551+
auto a0 = obj_ids[0];
552+
auto a1 = obj_ids[1];
553+
auto a2 = obj_ids[2];
554+
auto int_dtype = DataItem(schema::kInt32);
555+
auto schema = AllocateSchema();
556+
ASSERT_OK_AND_ASSIGN(auto named_schema,
557+
db->CreateUuSchemaFromFields(
558+
absl::StrCat("__named_schema__", "foo"), {}, {}));
559+
TriplesT schema_triples = {
560+
{schema, {{"x", int_dtype}, {"y", int_dtype}}},
561+
{named_schema,
562+
{{schema::kSchemaNameAttr, DataItem(arolla::Text("foo"))},
563+
{"x", DataItem(schema)},
564+
{"y", int_dtype}}}};
565+
TriplesT data_triples = {{a0, {{"x", a1}, {"y", DataItem(1)}}},
566+
{a1, {{"x", DataItem(2)}}}};
567+
SetDataTriples(*db, data_triples);
568+
SetSchemaTriples(*db, schema_triples);
569+
SetDataTriples(*db, GenNoiseDataTriples());
570+
SetSchemaTriples(*db, GenNoiseSchemaTriples());
571+
572+
auto ds = DataSliceImpl::Create(CreateDenseArray<DataItem>({a0}));
573+
EXPECT_OK(TraverseSlice(ds, DataItem(named_schema), *GetMainDb(db),
574+
{GetFallbackDb(db).get()}));
575+
}
576+
543577
TEST_P(ObjectsTraverserTest, ObjectsSlice) {
544578
auto db = DataBagImpl::CreateEmptyDatabag();
545579
auto obj_ids = DataSliceImpl::AllocateEmptyObjects(10);

py/koladata/operators/tests/core_deep_clone_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,20 @@ def test_entity(self, pass_schema):
154154
with self.assertRaisesRegex(AssertionError, 'not equal by fingerprint'):
155155
testing.assert_equal(result.ref(), o.ref())
156156

157+
@parameterized.product(
158+
pass_schema=[True, False],
159+
)
160+
def test_named_schema(self, pass_schema):
161+
db = data_bag.DataBag.empty()
162+
schema = db.named_schema('foo', x=schema_constants.INT32)
163+
o = db.new(x=ds([1, 2, 3]), schema=schema)
164+
if pass_schema:
165+
result = expr_eval.eval(kde.deep_clone(o, schema))
166+
else:
167+
result = expr_eval.eval(kde.deep_clone(o))
168+
testing.assert_equal(result.get_schema().no_bag(), schema.no_bag())
169+
testing.assert_equal(result.x.no_bag(), ds([1, 2, 3]))
170+
157171
@parameterized.product(
158172
pass_schema=[True, False],
159173
)

0 commit comments

Comments
 (0)