Skip to content

Commit 593271f

Browse files
apronchenkovcopybara-github
authored andcommitted
Deprecate py_boxing.LIST_TO_SLICE_BOXING_POLICY
PiperOrigin-RevId: 858110336 Change-Id: Idd3dbd6919682b75e930334d9d1a4c1d66967769
1 parent 0308362 commit 593271f

File tree

9 files changed

+67
-293
lines changed

9 files changed

+67
-293
lines changed

koladata/operators/shapes.cc

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "arolla/qexpr/operators/dense_array/edge_ops.h"
3636
#include "arolla/qtype/qtype.h"
3737
#include "arolla/qtype/qtype_traits.h"
38+
#include "arolla/qtype/tuple_qtype.h"
3839
#include "arolla/qtype/typed_slot.h"
3940
#include "arolla/util/repr.h"
4041
#include "koladata/arolla_utils.h"
@@ -102,25 +103,22 @@ absl::StatusOr<DataSlice::JaggedShape::Edge> GetEdgeFromSizes(
102103
// `i`. Only rank-0 or rank-1 int DataSlices are supported.
103104
class JaggedShapeCreateOperator : public arolla::QExprOperator {
104105
public:
105-
explicit JaggedShapeCreateOperator(absl::Span<const arolla::QTypePtr> types)
106-
: QExprOperator(types, koladata::GetJaggedShapeQType()) {
107-
for (const auto& input_type : types) {
108-
DCHECK(input_type == arolla::GetQType<DataSlice>() ||
109-
input_type == arolla::GetQType<DataSlice::JaggedShape::Edge>());
110-
}
111-
}
106+
using QExprOperator::QExprOperator;
112107

113108
private:
114109
absl::StatusOr<std::unique_ptr<arolla::BoundOperator>> DoBind(
115110
absl::Span<const arolla::TypedSlot> input_slots,
116111
arolla::TypedSlot output_slot) const override {
117-
Slot<DataSlice::JaggedShape> shape_slot =
118-
output_slot.UnsafeToSlot<DataSlice::JaggedShape>();
112+
int64_t edge_or_slice_count = input_slots[0].SubSlotCount();
113+
std::vector<arolla::TypedSlot> edge_or_slice_slots;
114+
edge_or_slice_slots.reserve(edge_or_slice_count);
115+
for (int64_t i = 0; i < edge_or_slice_count; ++i) {
116+
edge_or_slice_slots.emplace_back(input_slots[0].SubSlot(i));
117+
}
119118
return MakeBoundOperator(
120119
"kd.shapes.new",
121-
[edge_or_slice_slots =
122-
std::vector(input_slots.begin(), input_slots.end()),
123-
shape_slot = std::move(shape_slot)](
120+
[edge_or_slice_slots = std::move(edge_or_slice_slots),
121+
shape_slot = output_slot.UnsafeToSlot<DataSlice::JaggedShape>()](
124122
arolla::EvaluationContext* ctx,
125123
arolla::FramePtr frame) -> absl::Status {
126124
DataSlice::JaggedShape::EdgeVec edges;
@@ -284,16 +282,27 @@ absl::StatusOr<arolla::OperatorPtr>
284282
JaggedShapeCreateOperatorFamily::DoGetOperator(
285283
absl::Span<const arolla::QTypePtr> input_types,
286284
arolla::QTypePtr output_type) const {
287-
for (const auto& input_type : input_types) {
288-
if (input_type != arolla::GetQType<DataSlice>() &&
289-
input_type != arolla::GetQType<DataSlice::JaggedShape::Edge>()) {
285+
if (input_types.size() != 1) {
286+
return absl::InvalidArgumentError(
287+
absl::StrCat("expected exactly one input, got: ", input_types.size()));
288+
}
289+
if (!arolla::IsTupleQType(input_types[0])) {
290+
return absl::InvalidArgumentError(
291+
absl::StrCat("unsupported input type: ", input_types[0]->name()));
292+
}
293+
for (const auto& type_field : input_types[0]->type_fields()) {
294+
if (type_field.GetType() != arolla::GetQType<DataSlice>() &&
295+
type_field.GetType() !=
296+
arolla::GetQType<DataSlice::JaggedShape::Edge>()) {
290297
return absl::InvalidArgumentError(
291-
absl::StrCat("unsupported input type: ", input_type->name()));
298+
absl::StrCat("unsupported input type: ", input_types[0]->name()));
292299
}
293300
}
294-
return arolla::EnsureOutputQTypeMatches(
295-
std::make_shared<JaggedShapeCreateOperator>(input_types), input_types,
296-
output_type);
301+
if (output_type != koladata::GetJaggedShapeQType()) {
302+
return absl::InvalidArgumentError(absl::StrCat(
303+
"expected output type JaggedShape, got: ", output_type->name()));
304+
}
305+
return std::make_shared<JaggedShapeCreateOperator>(input_types, output_type);
297306
}
298307

299308
absl::StatusOr<arolla::OperatorPtr>

py/koladata/expr/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ py_library(
6565
"//py:python_path",
6666
"//py/koladata/fstring",
6767
"//py/koladata/types:data_slice",
68+
"//py/koladata/types:py_boxing",
6869
"//py/koladata/types:qtypes",
6970
"//py/koladata/types:schema_constants",
7071
"//py/koladata/util:kd_functools",

py/koladata/operators/jagged_shape.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,31 @@
3434
_reshape = arolla_bridge._reshape # pylint: disable=protected-access
3535

3636

37-
def _expect_slices_or_edges(value):
38-
"""Constrains `value` to be a tuple of DataSlices or Edges."""
39-
is_slice_or_edge = arolla.LambdaOperator(
37+
def _expect_data_slice_or_edge_args(param):
38+
"""Constrains `param` to be a tuple of DataSlices or Edges."""
39+
is_data_slice_or_edge = arolla.LambdaOperator(
4040
'x', (P.x == qtypes.DATA_SLICE) | (P.x == arolla.DENSE_ARRAY_EDGE)
4141
)
4242
return (
43-
M.seq.all(M.seq.map(is_slice_or_edge, M.qtype.get_field_qtypes(value))),
43+
M.qtype.is_tuple_qtype(param)
44+
& M.seq.all(
45+
M.seq.map(is_data_slice_or_edge, M.qtype.get_field_qtypes(param))
46+
),
4447
(
45-
'all arguments must be DataSlices or Edges, got:'
46-
f' {constraints.variadic_name_type_msg(value)}'
48+
'expected all arguments to be data slices or edges, got'
49+
f' {constraints.variadic_name_type_msg(param)}'
4750
),
4851
)
4952

5053

5154
@optools.add_to_registry(via_cc_operator_package=True)
52-
@arolla.optools.as_backend_operator(
55+
@optools.as_backend_operator(
5356
'kd.shapes.new',
54-
qtype_constraints=[_expect_slices_or_edges(P.dimensions)],
57+
qtype_constraints=[_expect_data_slice_or_edge_args(P.dimensions)],
5558
qtype_inference_expr=qtypes.JAGGED_SHAPE,
56-
experimental_aux_policy=py_boxing.LIST_TO_SLICE_BOXING_POLICY,
59+
custom_boxing_fn_name_per_parameter={
60+
'dimensions': py_boxing.WITH_LIST_TO_SLICE_SUPPORT,
61+
},
5762
)
5863
def new(*dimensions): # pylint: disable=unused-argument
5964
"""Returns a JaggedShape from the provided dimensions.
@@ -88,34 +93,13 @@ def new(*dimensions): # pylint: disable=unused-argument
8893
raise NotImplementedError('implemented in the backend')
8994

9095

91-
@optools.add_to_registry(via_cc_operator_package=True)
9296
@arolla.optools.as_backend_operator(
9397
'kd.shapes._new_with_size',
94-
qtype_constraints=[
95-
qtype_utils.expect_data_slice(P.result_size),
96-
_expect_slices_or_edges(P.dimensions),
97-
],
98+
qtype_constraints=[_expect_data_slice_or_edge_args(P.dimensions)],
9899
qtype_inference_expr=qtypes.JAGGED_SHAPE,
99-
experimental_aux_policy=py_boxing.LIST_TO_SLICE_BOXING_POLICY,
100100
)
101101
def _new_with_size(result_size, *dimensions): # pylint: disable=unused-argument
102-
"""Returns a JaggedShape from the provided dimensions and size.
103-
104-
It supports a single placeholder dimension argument denoted as `-1`, for which
105-
its true value is inferred from the provided `size` argument and remaining
106-
`dimensions`. The resulting dimension must be a uniform dimension, i.e. all
107-
parent elements must have the same child size.
108-
109-
Args:
110-
result_size: The size of the resulting JaggedShape.
111-
*dimensions: A combination of Edges and DataSlices representing the
112-
dimensions of the JaggedShape. Edges are used as is, while DataSlices are
113-
treated as sizes. DataItems (of ints) are interpreted as uniform
114-
dimensions which have the same child size for all parent elements.
115-
DataSlices (of ints) are interpreted as a list of sizes, where `ds[i]` is
116-
the child size of parent `i`. Only rank-0 or rank-1 int DataSlices are
117-
supported.
118-
"""
102+
"""(internal) Returns a JaggedShape from the provided dimensions and size."""
119103
raise NotImplementedError('implemented in the backend')
120104

121105

@@ -143,7 +127,19 @@ def get_shape(x): # pylint: disable=unused-argument
143127

144128

145129
@optools.add_to_registry(aliases=['kd.reshape'], via_cc_operator_package=True)
146-
@optools.as_lambda_operator('kd.shapes.reshape')
130+
@optools.as_lambda_operator(
131+
'kd.shapes.reshape',
132+
qtype_constraints=[
133+
qtype_utils.expect_data_slice(P.x),
134+
(
135+
M.qtype.is_tuple_qtype(P.shape) | (P.shape == qtypes.JAGGED_SHAPE),
136+
(
137+
'expected a tuple or a shape, got:'
138+
f' {constraints.name_type_msg(P.shape)}'
139+
),
140+
),
141+
],
142+
)
147143
def reshape(x, shape):
148144
"""Returns a DataSlice with the provided shape.
149145

py/koladata/operators/tests/BUILD

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6640,25 +6640,6 @@ py_test(
66406640
],
66416641
)
66426642

6643-
py_test(
6644-
name = "shapes_new_with_size_test",
6645-
srcs = ["shapes_new_with_size_test.py"],
6646-
deps = [
6647-
"//py:python_path",
6648-
"//py/koladata/expr:expr_eval",
6649-
"//py/koladata/expr:input_container",
6650-
"//py/koladata/expr:view",
6651-
"//py/koladata/operators:kde_operators",
6652-
"//py/koladata/testing",
6653-
"//py/koladata/types:data_slice",
6654-
"//py/koladata/types:jagged_shape",
6655-
"//py/koladata/types:literal_operator",
6656-
"@com_google_absl_py//absl/testing:absltest",
6657-
"@com_google_absl_py//absl/testing:parameterized",
6658-
"@com_google_arolla//py/arolla",
6659-
],
6660-
)
6661-
66626643
py_test(
66636644
name = "schema_is_dict_schema_test",
66646645
srcs = ["schema_is_dict_schema_test.py"],

py/koladata/operators/tests/shapes_new_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ def test_boxing(self):
6868
kde.shapes.new([1]),
6969
arolla.abc.bind_op(
7070
kde.shapes.new,
71-
literal_operator.literal(data_slice.DataSlice.from_vals([1])),
71+
literal_operator.literal(
72+
arolla.tuple(data_slice.DataSlice.from_vals([1]))
73+
),
7274
),
7375
)
7476

7577
def test_unsupported_qtype_error(self):
7678
with self.assertRaisesRegex(
7779
ValueError,
7880
re.escape(
79-
'all arguments must be DataSlices or Edges, got: *dimensions:'
80-
' (FLOAT32)'
81+
'expected all arguments to be data slices or edges, got'
82+
' *dimensions: (FLOAT32)'
8183
),
8284
):
8385
kde.shapes.new(arolla.float32(1.0))

0 commit comments

Comments
 (0)