1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414//
15- #include " koladata/internal/op_utils/utils .h"
15+ #include " koladata/internal/op_utils/error .h"
1616
1717#include < optional>
18+ #include < utility>
1819
1920#include " gmock/gmock.h"
2021#include " gtest/gtest.h"
2122#include " absl/status/status.h"
2223#include " absl/status/status_matchers.h"
24+ #include " absl/status/statusor.h"
2325#include " koladata/internal/error.pb.h"
2426#include " koladata/internal/error_utils.h"
2527#include " arolla/util/status_macros_backport.h"
2628
2729namespace koladata ::internal {
2830namespace {
2931
32+ using ::absl_testing::IsOkAndHolds;
3033using ::absl_testing::StatusIs;
31- using ::testing::StrEq;
34+ using ::testing::Eq;
35+ using ::testing::Field;
3236
3337TEST (OperatorEvalError, NoCause) {
3438 absl::Status status = OperatorEvalError (" op_name" , " error_message" );
@@ -37,7 +41,7 @@ TEST(OperatorEvalError, NoCause) {
3741 std::optional<internal::Error> payload =
3842 internal::GetErrorPayload (status);
3943 EXPECT_TRUE (payload.has_value ());
40- EXPECT_THAT (payload->error_message (), StrEq (" op_name: error_message" ));
44+ EXPECT_THAT (payload->error_message (), Eq (" op_name: error_message" ));
4145 EXPECT_FALSE (payload->has_cause ());
4246}
4347
@@ -49,8 +53,8 @@ TEST(OperatorEvalError, WithStatus) {
4953 std::optional<internal::Error> payload =
5054 internal::GetErrorPayload (new_status);
5155 EXPECT_TRUE (payload.has_value ());
52- EXPECT_THAT (payload->error_message (), StrEq (" op_name: Test error" ));
53- EXPECT_THAT (payload->cause ().error_message (), StrEq (" " ));
56+ EXPECT_THAT (payload->error_message (), Eq (" op_name: Test error" ));
57+ EXPECT_THAT (payload->cause ().error_message (), Eq (" " ));
5458}
5559
5660TEST (OperatorEvalError, WithStatusAndErrorMessage) {
@@ -62,8 +66,8 @@ TEST(OperatorEvalError, WithStatusAndErrorMessage) {
6266 std::optional<internal::Error> payload =
6367 internal::GetErrorPayload (new_status);
6468 EXPECT_TRUE (payload.has_value ());
65- EXPECT_THAT (payload->error_message (), StrEq (" op_name: error_message" ));
66- EXPECT_THAT (payload->cause ().error_message (), StrEq (" Test error" ));
69+ EXPECT_THAT (payload->error_message (), Eq (" op_name: error_message" ));
70+ EXPECT_THAT (payload->cause ().error_message (), Eq (" Test error" ));
6771}
6872
6973TEST (OperatorEvalError, WithStatusContainingCause) {
@@ -79,8 +83,8 @@ TEST(OperatorEvalError, WithStatusContainingCause) {
7983 std::optional<internal::Error> payload =
8084 internal::GetErrorPayload (new_status);
8185 EXPECT_TRUE (payload.has_value ());
82- EXPECT_THAT (payload->error_message (), StrEq (" op_name: error_message" ));
83- EXPECT_THAT (payload->cause ().error_message (), StrEq (" cause" ));
86+ EXPECT_THAT (payload->error_message (), Eq (" op_name: error_message" ));
87+ EXPECT_THAT (payload->cause ().error_message (), Eq (" cause" ));
8488}
8589
8690TEST (OperatorEvalError, ToOperatorEvalError) {
@@ -93,9 +97,65 @@ TEST(OperatorEvalError, ToOperatorEvalError) {
9397 StatusIs (absl::StatusCode::kInvalidArgument , " Test error" ));
9498 std::optional<internal::Error> payload = internal::GetErrorPayload (status);
9599 EXPECT_TRUE (payload.has_value ());
96- EXPECT_THAT (payload->error_message (), StrEq (" op_name: Test error" ));
100+ EXPECT_THAT (payload->error_message (), Eq (" op_name: Test error" ));
97101 EXPECT_FALSE (payload->has_cause ());
98102}
99103
104+ TEST (OperatorEvalError, SubsequentCalls) {
105+ absl::Status status = OperatorEvalError (
106+ OperatorEvalError (" op_name" , " error_message" ), " op_name" );
107+ EXPECT_THAT (status,
108+ StatusIs (absl::StatusCode::kInvalidArgument , " error_message" ));
109+ std::optional<internal::Error> payload = internal::GetErrorPayload (status);
110+ EXPECT_TRUE (payload.has_value ());
111+ EXPECT_THAT (payload->error_message (), Eq (" op_name: error_message" ));
112+ EXPECT_FALSE (payload->has_cause ());
113+ }
114+
115+ absl::StatusOr<int > ReturnsError () {
116+ return absl::InvalidArgumentError (" test error" );
117+ };
118+
119+ TEST (ReturnsOperatorEvalError, WrapsErrors) {
120+ auto wrapped_fn = ReturnsOperatorEvalError (" op_name" , ReturnsError);
121+ auto status = wrapped_fn ().status ();
122+ EXPECT_THAT (status,
123+ StatusIs (absl::StatusCode::kInvalidArgument , " test error" ));
124+ std::optional<internal::Error> payload = internal::GetErrorPayload (status);
125+ EXPECT_TRUE (payload.has_value ());
126+ EXPECT_THAT (payload->error_message (), Eq (" op_name: test error" ));
127+ EXPECT_FALSE (payload->has_cause ());
128+ }
129+
130+ // Counts the number of times the object is copied.
131+ struct CopyCounter {
132+ public:
133+ CopyCounter () = default ;
134+ CopyCounter (CopyCounter&& other) = default ;
135+ CopyCounter& operator =(CopyCounter&& other) = default ;
136+ CopyCounter (const CopyCounter& other) : copy_count(other.copy_count + 1 ) {}
137+ CopyCounter& operator =(const CopyCounter& other) {
138+ copy_count = other.copy_count + 1 ;
139+ return *this ;
140+ }
141+ int copy_count = 0 ;
142+ };
143+
144+ absl::StatusOr<CopyCounter> ForwardsCopyCounter (CopyCounter counter) {
145+ return counter;
146+ };
147+
148+ TEST (ReturnsOperatorEvalError, NoExtraCopies) {
149+ CopyCounter counter;
150+
151+ // Test that CopyCounter actually counts the number of copies.
152+ EXPECT_THAT (counter, Field (&CopyCounter::copy_count, Eq (0 )));
153+ EXPECT_THAT (CopyCounter (counter), Field (&CopyCounter::copy_count, Eq (1 )));
154+
155+ auto wrapped_fn = ReturnsOperatorEvalError (" op_name" , ForwardsCopyCounter);
156+ EXPECT_THAT (wrapped_fn (std::move (counter)),
157+ IsOkAndHolds (Field (&CopyCounter::copy_count, Eq (1 ))));
158+ }
159+
100160} // namespace
101161} // namespace koladata::internal
0 commit comments