2
2
3
3
#include < gtest/gtest.h>
4
4
5
+ #include < functional>
5
6
#include < memory>
7
+ #include < stdexcept>
6
8
#include < string>
7
9
#include < vector>
8
10
9
11
#include " absl/status/status.h"
10
12
#include " torch_xla/csrc/runtime/computation_client.h"
11
13
#include " torch_xla/csrc/runtime/pjrt_computation_client.h"
12
14
#include " torch_xla/csrc/runtime/tensor_source.h"
13
- #include " tsl/platform/env.h"
14
- #include " tsl/platform/logging.h"
15
- #include " tsl/platform/statusor.h"
16
- #include " tsl/platform/test.h"
17
15
#include " xla/hlo/builder/xla_builder.h"
18
16
#include " xla/hlo/builder/xla_computation.h"
19
17
#include " xla/literal.h"
20
18
#include " xla/literal_util.h"
21
19
#include " xla/tests/literal_test_util.h"
22
- #include " xla/tsl/lib/core/status_test_util.h"
23
20
24
21
namespace torch_xla {
25
22
namespace runtime {
26
23
27
- absl::StatusOr<xla::XlaComputation> MakeComputation () {
28
- xla::Shape input_shape =
24
+ class PjRtComputationClientTest : public ::testing::Test {
25
+ protected:
26
+ PjRtComputationClientTest () {
27
+ // Get a CPU client.
28
+ tsl::setenv (" PJRT_DEVICE" , " CPU" , true );
29
+ client_ = std::make_unique<PjRtComputationClient>();
30
+ device_ = client_->GetDefaultDevice ();
31
+ }
32
+
33
+ static void FakeXlaCompileForTesting (
34
+ PjRtComputationClient* client,
35
+ std::function<absl::Status()> fake_compile) {
36
+ client->FakeXlaCompileForTesting (std::move (fake_compile));
37
+ }
38
+
39
+ std::unique_ptr<PjRtComputationClient> client_;
40
+ std::string device_;
41
+ };
42
+
43
+ // Returns a computation to compute x + y where x and y are both F32[2,2]
44
+ // arrays.
45
+ absl::StatusOr<xla::XlaComputation> MakeAddComputation () {
46
+ const xla::Shape input_shape =
29
47
xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {2 , 2 });
30
48
xla::XlaBuilder builder (" AddComputation" );
31
49
xla::XlaOp x = xla::Parameter (&builder, 0 , input_shape, " x" );
@@ -34,19 +52,51 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
34
52
return builder.Build ();
35
53
}
36
54
37
- TEST (PjRtComputationClientTest, Init) {
38
- // Get a CPU client.
39
- tsl::setenv (" PJRT_DEVICE" , " CPU" , true );
40
- auto client = std::make_unique<PjRtComputationClient>();
41
- std::string device = client->GetDefaultDevice ();
55
+ TEST_F (PjRtComputationClientTest, ThrowsExpectedExceptionWhenCompileFails) {
56
+ // Compose a computation to add two matrices.
57
+ xla::Shape out_shape (xla::F32, {2 , 2 },
58
+ /* dynamic_dimensions=*/ {});
59
+ std::vector<ComputationClient::CompileInstance> instances;
60
+ instances.push_back (ComputationClient::CompileInstance (
61
+ std::move (MakeAddComputation ().value ()), device_,
62
+ client_->GetCompilationDevices (device_, client_->GetLocalDevices ()),
63
+ &out_shape));
64
+
65
+ // Force XLA to fail with the given error when invoked by Compile() below.
66
+ FakeXlaCompileForTesting (
67
+ client_.get (), [] { return absl::InvalidArgumentError (" invalid arg" ); });
68
+
69
+ // Compiling the graph should fail, which should throw instead of crashing.
70
+ EXPECT_THROW (client_->Compile (std::move (instances)), std::invalid_argument);
71
+ }
72
+
73
+ TEST_F (PjRtComputationClientTest, ThrowsExpectedExceptionWhenCompileThrows) {
74
+ // Compose a computation to add two matrices.
75
+ xla::Shape out_shape (xla::F32, {2 , 2 },
76
+ /* dynamic_dimensions=*/ {});
77
+ std::vector<ComputationClient::CompileInstance> instances;
78
+ instances.push_back (ComputationClient::CompileInstance (
79
+ std::move (MakeAddComputation ().value ()), device_,
80
+ client_->GetCompilationDevices (device_, client_->GetLocalDevices ()),
81
+ &out_shape));
82
+
83
+ // Force XLA to throw with the given error when invoked by Compile() below.
84
+ FakeXlaCompileForTesting (client_.get (), []() -> absl::Status {
85
+ throw absl::BadStatusOrAccess (absl::InvalidArgumentError (" invalid arg" ));
86
+ });
87
+
88
+ // Compiling the graph should fail, which should throw instead of crashing.
89
+ EXPECT_THROW (client_->Compile (std::move (instances)), std::invalid_argument);
90
+ }
42
91
43
- // Compose a computation.
44
- auto shape = xla::ShapeUtil::MakeShape (xla::F32, {2 , 2 });
92
+ TEST_F (PjRtComputationClientTest, Init) {
93
+ // Compose a computation to add two 2x2 matrices.
94
+ auto out_shape = xla::ShapeUtil::MakeShape (xla::F32, {2 , 2 });
45
95
std::vector<ComputationClient::CompileInstance> instances;
46
96
instances.push_back (ComputationClient::CompileInstance (
47
- std::move (MakeComputation ().value ()), device ,
48
- client ->GetCompilationDevices (device, client ->GetLocalDevices ()),
49
- &shape ));
97
+ std::move (MakeAddComputation ().value ()), device_ ,
98
+ client_ ->GetCompilationDevices (device_, client_ ->GetLocalDevices ()),
99
+ &out_shape ));
50
100
51
101
// Prepare inputs.
52
102
xla::Literal literal_x =
@@ -56,22 +106,22 @@ TEST(PjRtComputationClientTest, Init) {
56
106
57
107
// Compile the graph.
58
108
std::vector<ComputationClient::ComputationPtr> computations =
59
- client ->Compile (std::move (instances));
109
+ client_ ->Compile (std::move (instances));
60
110
61
111
// Copy inputs to device.
62
112
ComputationClient::ExecuteComputationOptions options{};
63
113
std::vector<std::shared_ptr<const TensorSource>> args = {
64
- std::make_shared<LiteralSource>(std::move (literal_x), device ),
65
- std::make_shared<LiteralSource>(std::move (literal_y), device )};
114
+ std::make_shared<LiteralSource>(std::move (literal_x), device_ ),
115
+ std::make_shared<LiteralSource>(std::move (literal_y), device_ )};
66
116
67
117
// Execute the graph.
68
- std::vector<ComputationClient::DataPtr> results = client ->ExecuteComputation (
69
- *computations[0 ], client ->TransferToDevice (absl::MakeConstSpan (args)),
70
- device , options);
118
+ std::vector<ComputationClient::DataPtr> results = client_ ->ExecuteComputation (
119
+ *computations[0 ], client_ ->TransferToDevice (absl::MakeConstSpan (args)),
120
+ device_ , options);
71
121
72
- // Copy the output from device back to host and assert correctness..
122
+ // Copy the output from device back to host and assert correctness.
73
123
ASSERT_EQ (results.size (), 1 );
74
- auto result_literals = client ->TransferFromDevice (results);
124
+ auto result_literals = client_ ->TransferFromDevice (results);
75
125
ASSERT_THAT (result_literals, ::testing::SizeIs (1 ));
76
126
EXPECT_TRUE (xla::LiteralTestUtil::Equal (
77
127
xla::LiteralUtil::CreateR2<float >({{6 .0f , 8 .0f }, {10 .0f , 12 .0f }}),
0 commit comments