Skip to content

Commit e896c77

Browse files
timofey-stepanovcopybara-github
authored andcommitted
ReturnsOperatorEvalError wrapper.
Wraps the given function, so all its errors are converted into `OperatorEvalError`. I also split `OperatorEvalError` into two versions, because the logic whether to use cause was a bit tangled. I intend to remove `ToOperatorEvalError` wrapper and all its usages in subsequent CLs. PiperOrigin-RevId: 716169561 Change-Id: I68919074cf21ad01be480d6a44218e1e3cb23d25
1 parent 97a6999 commit e896c77

File tree

17 files changed

+152
-55
lines changed

17 files changed

+152
-55
lines changed

koladata/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ cc_library(
503503
"//koladata/internal:object_id",
504504
"//koladata/internal:schema_utils",
505505
"//koladata/internal:uuid_object",
506+
"//koladata/internal/op_utils:error",
506507
"//koladata/internal/op_utils:has",
507-
"//koladata/internal/op_utils:utils",
508508
"@com_google_absl//absl/base:nullability",
509509
"@com_google_absl//absl/log:check",
510510
"@com_google_absl//absl/status",

koladata/internal/op_utils/BUILD

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ cc_library(
386386
srcs = ["select.cc"],
387387
hdrs = ["select.h"],
388388
deps = [
389-
":utils",
389+
":error",
390390
"//koladata/internal:data_item",
391391
"//koladata/internal:data_slice",
392392
"@com_google_absl//absl/log:check",
@@ -635,7 +635,7 @@ cc_library(
635635
name = "inverse_select",
636636
hdrs = ["inverse_select.h"],
637637
deps = [
638-
":utils",
638+
":error",
639639
"//koladata/internal:data_item",
640640
"//koladata/internal:data_slice",
641641
"@com_google_absl//absl/status",
@@ -891,27 +891,29 @@ cc_test(
891891
)
892892

893893
cc_library(
894-
name = "utils",
895-
srcs = ["utils.cc"],
896-
hdrs = ["utils.h"],
894+
name = "error",
895+
srcs = ["error.cc"],
896+
hdrs = ["error.h"],
897897
deps = [
898898
"//koladata/internal:error_cc_proto",
899899
"//koladata/internal:error_utils",
900900
"@com_google_absl//absl/status",
901901
"@com_google_absl//absl/strings",
902902
"@com_google_absl//absl/strings:str_format",
903+
"@com_google_arolla//arolla/util",
903904
],
904905
)
905906

906907
cc_test(
907-
name = "utils_test",
908-
srcs = ["utils_test.cc"],
908+
name = "error_test",
909+
srcs = ["error_test.cc"],
909910
deps = [
910-
":utils",
911+
":error",
911912
"//koladata/internal:error_cc_proto",
912913
"//koladata/internal:error_utils",
913914
"@com_google_absl//absl/status",
914915
"@com_google_absl//absl/status:status_matchers",
916+
"@com_google_absl//absl/status:statusor",
915917
"@com_google_arolla//arolla/util:status_backport",
916918
"@com_google_googletest//:gtest_main",
917919
],
Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
#include "koladata/internal/op_utils/utils.h"
15+
#include "koladata/internal/op_utils/error.h"
1616

1717
#include <optional>
1818
#include <utility>
1919

2020
#include "absl/status/status.h"
21+
#include "absl/strings/match.h"
2122
#include "absl/strings/str_format.h"
2223
#include "absl/strings/string_view.h"
2324
#include "koladata/internal/error.pb.h"
@@ -26,31 +27,30 @@
2627
namespace koladata::internal {
2728

2829
absl::Status OperatorEvalError(absl::Status status,
29-
absl::string_view operator_name,
30-
absl::string_view error_message) {
31-
internal::Error error;
32-
std::optional<internal::Error> cause = internal::GetErrorPayload(status);
33-
if (cause) {
34-
*error.mutable_cause() = *std::move(cause);
35-
} else if (!error_message.empty()) {
36-
error.mutable_cause()->set_error_message(status.message());
37-
} else {
38-
error_message = status.message();
30+
absl::string_view operator_name) {
31+
internal::Error error =
32+
internal::GetErrorPayload(status).value_or(internal::Error());
33+
if (error.error_message().empty()) {
34+
error.set_error_message(status.message());
35+
}
36+
if (!absl::StartsWith(error.error_message(), operator_name)) {
37+
error.set_error_message(
38+
absl::StrFormat("%s: %s", operator_name, error.error_message()));
3939
}
40-
error.set_error_message(
41-
absl::StrFormat("%s: %s", operator_name, error_message));
4240
return internal::WithErrorPayload(std::move(status), error);
4341
}
4442

43+
absl::Status OperatorEvalError(absl::Status status,
44+
absl::string_view operator_name,
45+
absl::string_view error_message) {
46+
return OperatorEvalError(KodaErrorFromCause(error_message, std::move(status)),
47+
operator_name);
48+
}
49+
4550
absl::Status OperatorEvalError(absl::string_view operator_name,
4651
absl::string_view error_message) {
47-
internal::Error error;
48-
error.set_error_message(
49-
absl::StrFormat("%s: %s", operator_name, error_message));
50-
// Note error_message inside the status is not used by the error handling
51-
// logic in the CPython layer.
52-
return internal::WithErrorPayload(absl::InvalidArgumentError(error_message),
53-
error);
52+
return OperatorEvalError(absl::InvalidArgumentError(error_message),
53+
operator_name);
5454
}
5555

5656
} // namespace koladata::internal
Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
#ifndef KOLADATA_INTERNAL_OP_UTILS_UTILS_H_
16-
#define KOLADATA_INTERNAL_OP_UTILS_UTILS_H_
15+
#ifndef KOLADATA_INTERNAL_OP_UTILS_ERROR_H_
16+
#define KOLADATA_INTERNAL_OP_UTILS_ERROR_H_
1717

18+
#include <string>
1819
#include <utility>
1920

2021
#include "absl/status/status.h"
2122
#include "absl/strings/string_view.h"
23+
#include "arolla/util/status.h"
2224

2325
namespace koladata::internal {
2426

@@ -27,12 +29,16 @@ namespace koladata::internal {
2729
//
2830
// By default, the error message is taken from the status, but can be overridden
2931
// by passing a custom error_message.
32+
absl::Status OperatorEvalError(absl::Status status,
33+
absl::string_view operator_name);
3034
absl::Status OperatorEvalError(absl::Status status,
3135
absl::string_view operator_name,
32-
absl::string_view error_message = "");
36+
absl::string_view error_message);
3337

3438
// absl::Status adaptor that attaches operator name and wraps the given status
3539
// into an OperatorEvalError.
40+
// TODO: b/389032294 - Remove this adaptor once ReturnsOperatorEvalError is used
41+
// automatically.
3642
inline auto ToOperatorEvalError(absl::string_view operator_name) {
3743
return [operator_name](absl::Status status) {
3844
return OperatorEvalError(std::move(status), operator_name);
@@ -44,6 +50,35 @@ inline auto ToOperatorEvalError(absl::string_view operator_name) {
4450
absl::Status OperatorEvalError(absl::string_view operator_name,
4551
absl::string_view error_message);
4652

53+
// Wraps the given function, so all its errors are converted into
54+
// OperatorEvalError.
55+
template <typename Ret, typename... Args>
56+
class ReturnsOperatorEvalError {
57+
public:
58+
ReturnsOperatorEvalError(std::string name, Ret (*func)(Args...))
59+
: name_(std::move(name)), func_(func) {}
60+
61+
Ret operator()(const Args&... args) const {
62+
if constexpr (arolla::IsStatusOrT<Ret>::value) {
63+
auto result = func_(args...);
64+
if (!result.ok()) {
65+
return OperatorEvalError(result.status(), name_);
66+
}
67+
return result;
68+
} else {
69+
return func_(args...);
70+
}
71+
}
72+
73+
private:
74+
std::string name_;
75+
Ret (*func_)(Args...);
76+
};
77+
78+
template <typename Ret, typename... Args>
79+
ReturnsOperatorEvalError(std::string name, Ret (*func)(Args...))
80+
-> ReturnsOperatorEvalError<Ret, Args...>;
81+
4782
} // namespace koladata::internal
4883

49-
#endif // KOLADATA_INTERNAL_OP_UTILS_UTILS_H_
84+
#endif // KOLADATA_INTERNAL_OP_UTILS_ERROR_H_
Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,27 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
#include "koladata/internal/op_utils/utils.h"
15+
#include "koladata/internal/op_utils/error.h"
1616

1717
#include <optional>
18+
#include <utility>
1819

1920
#include "gmock/gmock.h"
2021
#include "gtest/gtest.h"
2122
#include "absl/status/status.h"
2223
#include "absl/status/status_matchers.h"
24+
#include "absl/status/statusor.h"
2325
#include "koladata/internal/error.pb.h"
2426
#include "koladata/internal/error_utils.h"
2527
#include "arolla/util/status_macros_backport.h"
2628

2729
namespace koladata::internal {
2830
namespace {
2931

32+
using ::absl_testing::IsOkAndHolds;
3033
using ::absl_testing::StatusIs;
31-
using ::testing::StrEq;
34+
using ::testing::Eq;
35+
using ::testing::Field;
3236

3337
TEST(OperatorEvalError, NoCause) {
3438
absl::Status status = OperatorEvalError("op_name", "error_message");
@@ -37,7 +41,7 @@ TEST(OperatorEvalError, NoCause) {
3741
std::optional<internal::Error> payload =
3842
internal::GetErrorPayload(status);
3943
EXPECT_TRUE(payload.has_value());
40-
EXPECT_THAT(payload->error_message(), StrEq("op_name: error_message"));
44+
EXPECT_THAT(payload->error_message(), Eq("op_name: error_message"));
4145
EXPECT_FALSE(payload->has_cause());
4246
}
4347

@@ -49,8 +53,8 @@ TEST(OperatorEvalError, WithStatus) {
4953
std::optional<internal::Error> payload =
5054
internal::GetErrorPayload(new_status);
5155
EXPECT_TRUE(payload.has_value());
52-
EXPECT_THAT(payload->error_message(), StrEq("op_name: Test error"));
53-
EXPECT_THAT(payload->cause().error_message(), StrEq(""));
56+
EXPECT_THAT(payload->error_message(), Eq("op_name: Test error"));
57+
EXPECT_THAT(payload->cause().error_message(), Eq(""));
5458
}
5559

5660
TEST(OperatorEvalError, WithStatusAndErrorMessage) {
@@ -62,8 +66,8 @@ TEST(OperatorEvalError, WithStatusAndErrorMessage) {
6266
std::optional<internal::Error> payload =
6367
internal::GetErrorPayload(new_status);
6468
EXPECT_TRUE(payload.has_value());
65-
EXPECT_THAT(payload->error_message(), StrEq("op_name: error_message"));
66-
EXPECT_THAT(payload->cause().error_message(), StrEq("Test error"));
69+
EXPECT_THAT(payload->error_message(), Eq("op_name: error_message"));
70+
EXPECT_THAT(payload->cause().error_message(), Eq("Test error"));
6771
}
6872

6973
TEST(OperatorEvalError, WithStatusContainingCause) {
@@ -79,8 +83,8 @@ TEST(OperatorEvalError, WithStatusContainingCause) {
7983
std::optional<internal::Error> payload =
8084
internal::GetErrorPayload(new_status);
8185
EXPECT_TRUE(payload.has_value());
82-
EXPECT_THAT(payload->error_message(), StrEq("op_name: error_message"));
83-
EXPECT_THAT(payload->cause().error_message(), StrEq("cause"));
86+
EXPECT_THAT(payload->error_message(), Eq("op_name: error_message"));
87+
EXPECT_THAT(payload->cause().error_message(), Eq("cause"));
8488
}
8589

8690
TEST(OperatorEvalError, ToOperatorEvalError) {
@@ -93,9 +97,65 @@ TEST(OperatorEvalError, ToOperatorEvalError) {
9397
StatusIs(absl::StatusCode::kInvalidArgument, "Test error"));
9498
std::optional<internal::Error> payload = internal::GetErrorPayload(status);
9599
EXPECT_TRUE(payload.has_value());
96-
EXPECT_THAT(payload->error_message(), StrEq("op_name: Test error"));
100+
EXPECT_THAT(payload->error_message(), Eq("op_name: Test error"));
97101
EXPECT_FALSE(payload->has_cause());
98102
}
99103

104+
TEST(OperatorEvalError, SubsequentCalls) {
105+
absl::Status status = OperatorEvalError(
106+
OperatorEvalError("op_name", "error_message"), "op_name");
107+
EXPECT_THAT(status,
108+
StatusIs(absl::StatusCode::kInvalidArgument, "error_message"));
109+
std::optional<internal::Error> payload = internal::GetErrorPayload(status);
110+
EXPECT_TRUE(payload.has_value());
111+
EXPECT_THAT(payload->error_message(), Eq("op_name: error_message"));
112+
EXPECT_FALSE(payload->has_cause());
113+
}
114+
115+
absl::StatusOr<int> ReturnsError() {
116+
return absl::InvalidArgumentError("test error");
117+
};
118+
119+
TEST(ReturnsOperatorEvalError, WrapsErrors) {
120+
auto wrapped_fn = ReturnsOperatorEvalError("op_name", ReturnsError);
121+
auto status = wrapped_fn().status();
122+
EXPECT_THAT(status,
123+
StatusIs(absl::StatusCode::kInvalidArgument, "test error"));
124+
std::optional<internal::Error> payload = internal::GetErrorPayload(status);
125+
EXPECT_TRUE(payload.has_value());
126+
EXPECT_THAT(payload->error_message(), Eq("op_name: test error"));
127+
EXPECT_FALSE(payload->has_cause());
128+
}
129+
130+
// Counts the number of times the object is copied.
131+
struct CopyCounter {
132+
public:
133+
CopyCounter() = default;
134+
CopyCounter(CopyCounter&& other) = default;
135+
CopyCounter& operator=(CopyCounter&& other) = default;
136+
CopyCounter(const CopyCounter& other) : copy_count(other.copy_count + 1) {}
137+
CopyCounter& operator=(const CopyCounter& other) {
138+
copy_count = other.copy_count + 1;
139+
return *this;
140+
}
141+
int copy_count = 0;
142+
};
143+
144+
absl::StatusOr<CopyCounter> ForwardsCopyCounter(CopyCounter counter) {
145+
return counter;
146+
};
147+
148+
TEST(ReturnsOperatorEvalError, NoExtraCopies) {
149+
CopyCounter counter;
150+
151+
// Test that CopyCounter actually counts the number of copies.
152+
EXPECT_THAT(counter, Field(&CopyCounter::copy_count, Eq(0)));
153+
EXPECT_THAT(CopyCounter(counter), Field(&CopyCounter::copy_count, Eq(1)));
154+
155+
auto wrapped_fn = ReturnsOperatorEvalError("op_name", ForwardsCopyCounter);
156+
EXPECT_THAT(wrapped_fn(std::move(counter)),
157+
IsOkAndHolds(Field(&CopyCounter::copy_count, Eq(1))));
158+
}
159+
100160
} // namespace
101161
} // namespace koladata::internal

koladata/internal/op_utils/inverse_select.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include "absl/strings/string_view.h"
2727
#include "koladata/internal/data_item.h"
2828
#include "koladata/internal/data_slice.h"
29-
#include "koladata/internal/op_utils/utils.h"
29+
#include "koladata/internal/op_utils/error.h"
3030
#include "koladata/internal/slice_builder.h"
3131
#include "arolla/dense_array/dense_array.h"
3232
#include "arolla/dense_array/edge.h"

koladata/internal/op_utils/select.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#include "absl/types/span.h"
2828
#include "koladata/internal/data_item.h"
2929
#include "koladata/internal/data_slice.h"
30-
#include "koladata/internal/op_utils/utils.h"
30+
#include "koladata/internal/op_utils/error.h"
3131
#include "koladata/internal/slice_builder.h"
3232
#include "arolla/dense_array/dense_array.h"
3333
#include "arolla/dense_array/edge.h"

koladata/object_factories.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
#include "koladata/internal/error.pb.h"
4444
#include "koladata/internal/missing_value.h"
4545
#include "koladata/internal/object_id.h"
46+
#include "koladata/internal/op_utils/error.h"
4647
#include "koladata/internal/op_utils/has.h"
47-
#include "koladata/internal/op_utils/utils.h"
4848
#include "koladata/internal/schema_utils.h"
4949
#include "koladata/internal/uuid_object.h"
5050
#include "koladata/repr_utils.h"

koladata/operators/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ cc_library(
134134
"//koladata/internal/op_utils:deep_clone",
135135
"//koladata/internal/op_utils:deep_uuid",
136136
"//koladata/internal/op_utils:equal",
137+
"//koladata/internal/op_utils:error",
137138
"//koladata/internal/op_utils:extract",
138139
"//koladata/internal/op_utils:has",
139140
"//koladata/internal/op_utils:inverse_select",
@@ -143,7 +144,6 @@ cc_library(
143144
"//koladata/internal/op_utils:presence_or",
144145
"//koladata/internal/op_utils:reverse",
145146
"//koladata/internal/op_utils:select",
146-
"//koladata/internal/op_utils:utils",
147147
"@com_google_absl//absl/algorithm:container",
148148
"@com_google_absl//absl/base:core_headers",
149149
"@com_google_absl//absl/base:no_destructor",

koladata/operators/arolla_bridge.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
#include "koladata/internal/data_item.h"
4242
#include "koladata/internal/dtype.h"
4343
#include "koladata/internal/error_utils.h"
44-
#include "koladata/internal/op_utils/utils.h"
44+
#include "koladata/internal/op_utils/error.h"
4545
#include "koladata/internal/schema_utils.h"
4646
#include "koladata/schema_utils.h"
4747
#include "koladata/shape_utils.h"

0 commit comments

Comments
 (0)