Skip to content

Commit 468df6d

Browse files
mrdropletcopybara-github
authored andcommitted
Add kde.lists.concat_lists
PiperOrigin-RevId: 716271121 Change-Id: I6e94d0e4d0381e52f6dc17cc6ee8c07c23d42819
1 parent 33b0c85 commit 468df6d

File tree

8 files changed

+237
-2
lines changed

8 files changed

+237
-2
lines changed

koladata/casting.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//
1515
#include "koladata/casting.h"
1616

17+
#include <algorithm>
1718
#include <cstdint>
1819
#include <utility>
1920
#include <vector>
@@ -24,6 +25,7 @@
2425
#include "absl/status/statusor.h"
2526
#include "absl/strings/str_cat.h"
2627
#include "absl/strings/str_format.h"
28+
#include "koladata/data_bag.h"
2729
#include "koladata/data_slice.h"
2830
#include "koladata/internal/casting.h"
2931
#include "koladata/internal/data_bag.h"
@@ -366,7 +368,16 @@ absl::StatusOr<SchemaAlignedSlices> AlignSchemas(
366368
return std::move(schema_agg).Get();
367369
};
368370

369-
ASSIGN_OR_RETURN(auto common_schema, get_common_schema());
371+
auto get_fallback_db = [&slices] {
372+
std::vector<DataBagPtr> ret(slices.size());
373+
std::transform(slices.begin(), slices.end(), ret.begin(),
374+
[](const DataSlice& ds) { return ds.GetBag(); });
375+
return ret;
376+
};
377+
ASSIGN_OR_RETURN(
378+
auto common_schema, get_common_schema(),
379+
AssembleErrorMessage(
380+
_, {.db = DataBag::ImmutableEmptyWithFallbacks(get_fallback_db())}));
370381
for (auto& slice : slices) {
371382
// Since we cast to a common schema, we don't need to validate implicit
372383
// compatibility or validate schema (during casting to OBJECT) as no

koladata/operators/lists.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <optional>
1919
#include <type_traits>
2020
#include <utility>
21+
#include <vector>
2122

2223
#include "absl/log/check.h"
2324
#include "absl/log/log.h"
@@ -35,7 +36,6 @@
3536
#include "koladata/internal/dtype.h"
3637
#include "koladata/internal/non_deterministic_token.h"
3738
#include "koladata/object_factories.h"
38-
#include "koladata/operators/utils.h"
3939
#include "koladata/uuid_utils.h"
4040
#include "arolla/dense_array/qtype/types.h"
4141
#include "arolla/jagged_shape/dense_array/util/concat.h"
@@ -165,4 +165,11 @@ absl::StatusOr<DataSlice> ListShaped(
165165
return result;
166166
}
167167

168+
absl::StatusOr<DataSlice> ConcatLists(std::vector<DataSlice> lists) {
169+
const DataBagPtr db = DataBag::Empty();
170+
ASSIGN_OR_RETURN(auto result, ConcatLists(db, std::move(lists)));
171+
db->UnsafeMakeImmutable();
172+
return result;
173+
}
174+
168175
} // namespace koladata::ops

koladata/operators/lists.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
// List operator implementations.
1919

2020
#include <cstdint>
21+
#include <vector>
2122

2223
#include "absl/status/statusor.h"
2324
#include "koladata/data_slice.h"
@@ -63,6 +64,9 @@ absl::StatusOr<DataSlice> ListShaped(
6364
const DataSlice& itemid,
6465
internal::NonDeterministicToken);
6566

67+
// kde.lists.concat_lists operator.
68+
absl::StatusOr<DataSlice> ConcatLists(std::vector<DataSlice> lists);
69+
6670
} // namespace koladata::ops
6771

6872
#endif // KOLADATA_OPERATORS_LISTS_H_

koladata/operators/operators.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ OPERATOR_FAMILY("kd.ids.uuid_for_list",
158158
OPERATOR("kd.ids.uuids_with_allocation_size", UuidsWithAllocationSize);
159159
//
160160
OPERATOR("kd.json.to_json", ToJson);
161+
OPERATOR_FAMILY("kd.lists._concat_lists",
162+
arolla::MakeVariadicInputOperatorFamily(ConcatLists));
161163
//
162164
OPERATOR("kd.lists._explode", Explode, "kd.lists.explode");
163165
OPERATOR("kd.lists._implode", Implode, "kd.lists.implode");

py/koladata/operators/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,15 @@ py_library(
366366
deps = [
367367
":arolla_bridge",
368368
":jagged_shape",
369+
":koda_internal",
369370
":optools",
370371
":qtype_utils",
371372
":slices",
372373
":view_overloads",
373374
"//koladata/operators",
374375
"//py/koladata/types:data_slice",
375376
"//py/koladata/types:py_boxing",
377+
"//py/koladata/types:qtypes",
376378
"//py/koladata/types:schema_constants",
377379
"@com_google_arolla//py/arolla",
378380
"@com_google_arolla//py/arolla/jagged_shape",

py/koladata/operators/lists.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
from arolla.jagged_shape import jagged_shape
2121
from koladata.operators import arolla_bridge
2222
from koladata.operators import jagged_shape as jagged_shape_ops
23+
from koladata.operators import koda_internal as _
2324
from koladata.operators import optools
2425
from koladata.operators import qtype_utils
2526
from koladata.operators import slices as slice_ops
2627
from koladata.operators import view_overloads as _
2728
from koladata.types import data_slice
2829
from koladata.types import py_boxing
30+
from koladata.types import qtypes
2931
from koladata.types import schema_constants
3032

3133

@@ -266,6 +268,34 @@ def implode(
266268
return _implode(x, arolla_bridge.to_arolla_int64(ndim), itemid)
267269

268270

271+
@arolla.optools.as_backend_operator(
272+
'kd.lists._concat_lists',
273+
qtype_inference_expr=qtypes.DATA_SLICE,
274+
)
275+
def _concat_lists(*args): # pylint: disable=unused-argument
276+
raise NotImplementedError('implemented in the backend')
277+
278+
279+
@optools.add_to_registry(aliases=['kd.concat_lists'])
280+
@optools.as_lambda_operator(
281+
'kd.lists.concat_lists',
282+
qtype_constraints=[
283+
qtype_utils.expect_data_slice(P.arg0),
284+
qtype_utils.expect_data_slice_args(P.args),
285+
],
286+
deterministic=False,
287+
)
288+
def concat_lists(arg0, *args):
289+
"""Implementation of kde.lists.concat_lists."""
290+
# TODO: Support 0 args.
291+
args = arolla.optools.fix_trace_args(args)
292+
return arolla.M.core.apply_varargs(
293+
_concat_lists,
294+
arolla.abc.aux_bind_op('koda_internal.non_deterministic_identity', arg0),
295+
args,
296+
)
297+
298+
269299
@optools.add_to_registry(aliases=['kd.list_size'])
270300
@optools.as_backend_operator(
271301
'kd.lists.size',

py/koladata/operators/tests/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,26 @@ py_test(
32463246
],
32473247
)
32483248

3249+
py_test(
3250+
name = "lists_concat_lists_test",
3251+
srcs = ["lists_concat_lists_test.py"],
3252+
deps = [
3253+
"//py/koladata/exceptions",
3254+
"//py/koladata/expr:expr_eval",
3255+
"//py/koladata/expr:input_container",
3256+
"//py/koladata/operators:kde_operators",
3257+
"//py/koladata/operators:optools",
3258+
"//py/koladata/operators/tests/util:qtypes",
3259+
"//py/koladata/testing",
3260+
"//py/koladata/types:data_bag",
3261+
"//py/koladata/types:data_slice",
3262+
"//py/koladata/types:qtypes",
3263+
"@com_google_absl_py//absl/testing:absltest",
3264+
"@com_google_absl_py//absl/testing:parameterized",
3265+
"@com_google_arolla//py/arolla",
3266+
],
3267+
)
3268+
32493269
py_test(
32503270
name = "slices_subslice_test",
32513271
srcs = ["slices_subslice_test.py"],
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
from arolla import arolla
18+
from koladata.exceptions import exceptions
19+
from koladata.expr import expr_eval
20+
from koladata.expr import input_container
21+
from koladata.operators import kde_operators
22+
from koladata.operators import optools
23+
from koladata.operators.tests.util import qtypes as test_qtypes
24+
from koladata.testing import testing
25+
from koladata.types import data_bag
26+
from koladata.types import data_slice
27+
from koladata.types import qtypes
28+
29+
kde = kde_operators.kde
30+
DATA_SLICE = qtypes.DATA_SLICE
31+
NON_DETERMINISTIC_TOKEN = qtypes.NON_DETERMINISTIC_TOKEN
32+
bag = data_bag.DataBag.empty
33+
I = input_container.InputContainer('I')
34+
35+
36+
QTYPES = frozenset([
37+
(
38+
DATA_SLICE,
39+
arolla.make_tuple_qtype(),
40+
NON_DETERMINISTIC_TOKEN,
41+
DATA_SLICE,
42+
),
43+
(
44+
DATA_SLICE,
45+
arolla.make_tuple_qtype(DATA_SLICE),
46+
NON_DETERMINISTIC_TOKEN,
47+
DATA_SLICE,
48+
),
49+
(
50+
DATA_SLICE,
51+
arolla.make_tuple_qtype(DATA_SLICE, DATA_SLICE),
52+
NON_DETERMINISTIC_TOKEN,
53+
DATA_SLICE,
54+
),
55+
])
56+
57+
db = data_bag.DataBag.empty()
58+
ds = lambda vals: data_slice.DataSlice.from_vals(vals).with_bag(db)
59+
60+
OBJ1 = db.obj()
61+
OBJ2 = db.obj()
62+
63+
64+
class ListsConcatListsTest(parameterized.TestCase):
65+
66+
def test_mutability(self):
67+
self.assertFalse(
68+
expr_eval.eval(
69+
kde.lists.concat_lists(db.list([1, 2]), db.list([3]))
70+
).is_mutable()
71+
)
72+
73+
@parameterized.parameters(
74+
((db.list([1, 2, 3]),), db.list([1, 2, 3])),
75+
(
76+
(db.list([[1], [2, 3]]),),
77+
db.list([[1], [2, 3]]),
78+
),
79+
(
80+
(db.list([1]), db.list([2, 3]), db.list([4, 5, 6])),
81+
db.list([1, 2, 3, 4, 5, 6]),
82+
),
83+
(
84+
(
85+
db.implode(ds([[0], [1]])),
86+
db.implode(ds([[2], [3]])),
87+
db.implode(ds([[4, 5], [6]])),
88+
),
89+
db.implode(ds([[0, 2, 4, 5], [1, 3, 6]])),
90+
),
91+
(
92+
# Compatible primitive types follow type promotion.
93+
(db.list([1, 2, 3]), db.list([4.5, 5.5, 6.5])),
94+
db.list([1.0, 2.0, 3.0, 4.5, 5.5, 6.5]),
95+
),
96+
(
97+
(db.list([1]), db.list([None, 2.5]), db.list(['a', OBJ1, b'b'])),
98+
db.list([1, None, 2.5, 'a', OBJ1, b'b']),
99+
),
100+
(
101+
(
102+
db.implode(ds([[0], [None]])),
103+
db.implode(ds([['a'], [b'b']])),
104+
db.implode(ds([[4.5, OBJ1], [OBJ2]])),
105+
),
106+
db.implode(ds([[0, 'a', 4.5, OBJ1], [None, b'b', OBJ2]])),
107+
),
108+
)
109+
def test_eval(self, lists, expected):
110+
testing.assert_nested_lists_equal(
111+
expr_eval.eval(kde.lists.concat_lists(*lists)), expected
112+
)
113+
114+
def test_non_deterministic_token(self):
115+
res_1 = expr_eval.eval(
116+
kde.lists.concat_lists(db.list([1, 2]), db.list([3]))
117+
)
118+
res_2 = expr_eval.eval(
119+
kde.lists.concat_lists(db.list([1, 2]), db.list([3]))
120+
)
121+
self.assertNotEqual(res_1.db.fingerprint, res_2.db.fingerprint)
122+
testing.assert_equal(res_1[:].no_bag(), res_2[:].no_bag())
123+
124+
expr = kde.lists.concat_lists(db.list([1, 2]), db.list([3]))
125+
res_1 = expr_eval.eval(expr)
126+
res_2 = expr_eval.eval(expr)
127+
self.assertNotEqual(res_1.db.fingerprint, res_2.db.fingerprint)
128+
testing.assert_equal(res_1[:].no_bag(), res_2[:].no_bag())
129+
130+
def test_qtype_signatures(self):
131+
self.assertCountEqual(
132+
arolla.testing.detect_qtype_signatures(
133+
kde.lists.concat_lists,
134+
possible_qtypes=test_qtypes.DETECT_SIGNATURES_QTYPES,
135+
max_arity=3,
136+
),
137+
QTYPES,
138+
)
139+
140+
def test_alias(self):
141+
self.assertTrue(
142+
optools.equiv_to_op(kde.lists.concat_lists, kde.concat_lists)
143+
)
144+
145+
def test_concat_failure(self):
146+
a = db.list([1, 2, 3])
147+
b = db.list([[1, 2, 3], [4, 5, 6]])
148+
with self.assertRaisesRegex(
149+
exceptions.KodaError,
150+
r"""cannot find a common schema for provided schemas
151+
152+
the common schema\(s\) INT32: INT32
153+
the first conflicting schema [0-9a-f]{32}:0: LIST\[INT32\]""",
154+
):
155+
expr_eval.eval(kde.lists.concat_lists(a, b))
156+
157+
158+
if __name__ == '__main__':
159+
absltest.main()

0 commit comments

Comments
 (0)