Skip to content

Commit d1ca3c4

Browse files
In pjrt runtime client, raise a Python exception if XLA compilation fails. (#9138)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent 2196d0c commit d1ca3c4

File tree

6 files changed

+156
-48
lines changed

6 files changed

+156
-48
lines changed

torch_xla/csrc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ ptxla_cc_library(
126126
":shape_builder",
127127
":shape_helper",
128128
":version",
129+
"//torch_xla/csrc:thread_pool",
129130
"//torch_xla/csrc/runtime",
130131
"//torch_xla/csrc/runtime:stablehlo_helper",
131132
"//torch_xla/csrc/runtime:xla_util",

torch_xla/csrc/runtime/BUILD

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,10 @@ cc_library(
114114
":env_vars",
115115
":operation_manager",
116116
":pjrt_registry",
117-
":profiler",
118117
":stablehlo_helper",
119118
":tensor_source",
120119
":tf_logging",
121120
":xla_coordinator",
122-
"//torch_xla/csrc:thread_pool",
123121
"@com_google_absl//absl/strings",
124122
"@com_google_absl//absl/synchronization",
125123
"@com_google_absl//absl/types:span",
@@ -485,15 +483,10 @@ ptxla_cc_test(
485483
deps = [
486484
":computation_client",
487485
":pjrt_computation_client",
486+
":operation_manager",
488487
":tensor_source",
489488
"@com_google_absl//absl/status",
490-
"@xla//xla/tsl/lib/core:status_test_util",
491-
"@tsl//tsl/platform:env",
492-
"@tsl//tsl/platform:errors",
493-
"@tsl//tsl/platform:logging",
494-
"@tsl//tsl/platform:test",
495489
"@tsl//tsl/platform:test_main",
496-
"@tsl//tsl/platform:statusor",
497490
"@xla//xla:literal",
498491
"@xla//xla:literal_util",
499492
"@xla//xla:shape_util",
@@ -502,6 +495,7 @@ ptxla_cc_test(
502495
"@xla//xla/tests:literal_test_util",
503496
"@xla//xla/tools:hlo_module_loader",
504497
],
498+
timeout = "short",
505499
)
506500

507501
# ptxla_cc_test(

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
22

33
#include <algorithm>
4-
#include <future>
5-
#include <unordered_set>
4+
#include <stdexcept>
65
#include <vector>
76

87
#include "absl/status/status.h"
@@ -13,34 +12,30 @@
1312
#include "torch_xla/csrc/runtime/debug_macros.h"
1413
#include "torch_xla/csrc/runtime/env_hash.h"
1514
#include "torch_xla/csrc/runtime/env_vars.h"
16-
#include "torch_xla/csrc/runtime/operation_manager.h"
1715
#include "torch_xla/csrc/runtime/pjrt_registry.h"
18-
#include "torch_xla/csrc/runtime/profiler.h"
1916
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
2017
#include "torch_xla/csrc/runtime/tensor_source.h"
2118
#include "torch_xla/csrc/runtime/tf_logging.h"
19+
#include "torch_xla/csrc/runtime/util.h"
2220
#include "torch_xla/csrc/runtime/xla_coordinator.h"
23-
#include "torch_xla/csrc/thread_pool.h"
2421
#include "tsl/profiler/lib/traceme.h"
2522
#include "xla/hlo/builder/xla_builder.h"
2623
#include "xla/hlo/builder/xla_computation.h"
27-
#include "xla/layout_util.h"
2824
#include "xla/literal.h"
2925
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
3026
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
3127
#include "xla/pjrt/pjrt_api.h"
3228
#include "xla/pjrt/pjrt_c_api_client.h"
3329
#include "xla/pjrt/pjrt_client.h"
3430
#include "xla/pjrt/pjrt_executable.h"
35-
#include "xla/protobuf_util.h"
3631
#include "xla/service/custom_call_target_registry.h"
3732
#include "xla/shape.h"
3833

39-
using xla::internal::XlaBuilderFriend;
40-
4134
namespace torch_xla {
4235
namespace runtime {
4336

37+
using xla::internal::XlaBuilderFriend;
38+
4439
namespace {
4540

4641
// Builds a map from the device's global ordinal to its index in the `devices`
@@ -625,20 +620,27 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
625620
device_assignment);
626621
}
627622

623+
// Compile the computation to an executible. For better user experience, if
624+
// the XLA compiler fails for any reason, we raise a Python exception.
628625
std::unique_ptr<xla::PjRtLoadedExecutable> executable;
629626
if (runtime::sys_util::GetEnvBool("XLA_STABLEHLO_COMPILE", false)) {
630627
// Convert HLO to StableHLO for PjRt client compilation.
631628
mlir::MLIRContext context;
632629
mlir::ModuleOp mlir_module =
633630
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
634631
ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module);
635-
executable =
636-
client_->CompileAndLoad(mlir_module, compile_options).value();
632+
executable = util::RaisePythonValueErrorOnFailure([&] {
633+
return fake_xla_compile_
634+
? fake_xla_compile_()
635+
: client_->CompileAndLoad(mlir_module, compile_options);
636+
});
637637
StableHloCompileCounter()->AddValue(1);
638638
} else {
639-
executable =
640-
client_->CompileAndLoad(instance.computation, compile_options)
641-
.value();
639+
executable = util::RaisePythonValueErrorOnFailure([&] {
640+
return fake_xla_compile_ ? fake_xla_compile_()
641+
: client_->CompileAndLoad(instance.computation,
642+
compile_options);
643+
});
642644
}
643645

644646
auto memory_stats_status_or = executable->GetCompiledMemoryStats();

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace runtime {
2727
class PjRtComputationClient : public ComputationClient {
2828
public:
2929
PjRtComputationClient();
30-
~PjRtComputationClient();
30+
~PjRtComputationClient() override;
3131

3232
DataPtr CreateDataPlaceholder(
3333
std::string device, xla::Shape shape,
@@ -162,6 +162,14 @@ class PjRtComputationClient : public ComputationClient {
162162
const std::function<void()>& callback) override;
163163

164164
private:
165+
friend class PjRtComputationClientTest;
166+
167+
// If `function` is not nullptr, makes the client call it instead of the real
168+
// XLA compiler when compiling. Used for injecting fault for testing.
169+
void FakeXlaCompileForTesting(std::function<absl::Status()> function) {
170+
fake_xla_compile_ = std::move(function);
171+
}
172+
165173
std::unique_ptr<xla::PjRtClient> client_;
166174
std::unique_ptr<XlaCoordinator> coordinator_;
167175
// global_ordinals_ tracks a map from PjRtDeviceId to the device's
@@ -174,6 +182,10 @@ class PjRtComputationClient : public ComputationClient {
174182
tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());
175183
torch::lazy::hash_t comp_env_hash_;
176184

185+
// If not nullptr, invoke this instead of the actual XLA compilation. Used
186+
// only for testing.
187+
std::function<absl::Status()> fake_xla_compile_ = nullptr;
188+
177189
xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
178190

179191
struct PjRtData : public Data {

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,48 @@
22

33
#include <gtest/gtest.h>
44

5+
#include <functional>
56
#include <memory>
7+
#include <stdexcept>
68
#include <string>
79
#include <vector>
810

911
#include "absl/status/status.h"
1012
#include "torch_xla/csrc/runtime/computation_client.h"
1113
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
1214
#include "torch_xla/csrc/runtime/tensor_source.h"
13-
#include "tsl/platform/env.h"
14-
#include "tsl/platform/logging.h"
15-
#include "tsl/platform/statusor.h"
16-
#include "tsl/platform/test.h"
1715
#include "xla/hlo/builder/xla_builder.h"
1816
#include "xla/hlo/builder/xla_computation.h"
1917
#include "xla/literal.h"
2018
#include "xla/literal_util.h"
2119
#include "xla/tests/literal_test_util.h"
22-
#include "xla/tsl/lib/core/status_test_util.h"
2320

2421
namespace torch_xla {
2522
namespace runtime {
2623

27-
absl::StatusOr<xla::XlaComputation> MakeComputation() {
28-
xla::Shape input_shape =
24+
class PjRtComputationClientTest : public ::testing::Test {
25+
protected:
26+
PjRtComputationClientTest() {
27+
// Get a CPU client.
28+
tsl::setenv("PJRT_DEVICE", "CPU", true);
29+
client_ = std::make_unique<PjRtComputationClient>();
30+
device_ = client_->GetDefaultDevice();
31+
}
32+
33+
static void FakeXlaCompileForTesting(
34+
PjRtComputationClient* client,
35+
std::function<absl::Status()> fake_compile) {
36+
client->FakeXlaCompileForTesting(std::move(fake_compile));
37+
}
38+
39+
std::unique_ptr<PjRtComputationClient> client_;
40+
std::string device_;
41+
};
42+
43+
// Returns a computation to compute x + y where x and y are both F32[2,2]
44+
// arrays.
45+
absl::StatusOr<xla::XlaComputation> MakeAddComputation() {
46+
const xla::Shape input_shape =
2947
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
3048
xla::XlaBuilder builder("AddComputation");
3149
xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x");
@@ -34,19 +52,51 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
3452
return builder.Build();
3553
}
3654

37-
TEST(PjRtComputationClientTest, Init) {
38-
// Get a CPU client.
39-
tsl::setenv("PJRT_DEVICE", "CPU", true);
40-
auto client = std::make_unique<PjRtComputationClient>();
41-
std::string device = client->GetDefaultDevice();
55+
TEST_F(PjRtComputationClientTest, ThrowsExpectedExceptionWhenCompileFails) {
56+
// Compose a computation to add two matrices.
57+
xla::Shape out_shape(xla::F32, {2, 2},
58+
/*dynamic_dimensions=*/{});
59+
std::vector<ComputationClient::CompileInstance> instances;
60+
instances.push_back(ComputationClient::CompileInstance(
61+
std::move(MakeAddComputation().value()), device_,
62+
client_->GetCompilationDevices(device_, client_->GetLocalDevices()),
63+
&out_shape));
64+
65+
// Force XLA to fail with the given error when invoked by Compile() below.
66+
FakeXlaCompileForTesting(
67+
client_.get(), [] { return absl::InvalidArgumentError("invalid arg"); });
68+
69+
// Compiling the graph should fail, which should throw instead of crashing.
70+
EXPECT_THROW(client_->Compile(std::move(instances)), std::invalid_argument);
71+
}
72+
73+
TEST_F(PjRtComputationClientTest, ThrowsExpectedExceptionWhenCompileThrows) {
74+
// Compose a computation to add two matrices.
75+
xla::Shape out_shape(xla::F32, {2, 2},
76+
/*dynamic_dimensions=*/{});
77+
std::vector<ComputationClient::CompileInstance> instances;
78+
instances.push_back(ComputationClient::CompileInstance(
79+
std::move(MakeAddComputation().value()), device_,
80+
client_->GetCompilationDevices(device_, client_->GetLocalDevices()),
81+
&out_shape));
82+
83+
// Force XLA to throw with the given error when invoked by Compile() below.
84+
FakeXlaCompileForTesting(client_.get(), []() -> absl::Status {
85+
throw absl::BadStatusOrAccess(absl::InvalidArgumentError("invalid arg"));
86+
});
87+
88+
// Compiling the graph should fail, which should throw instead of crashing.
89+
EXPECT_THROW(client_->Compile(std::move(instances)), std::invalid_argument);
90+
}
4291

43-
// Compose a computation.
44-
auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2});
92+
TEST_F(PjRtComputationClientTest, Init) {
93+
// Compose a computation to add two 2x2 matrices.
94+
auto out_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2});
4595
std::vector<ComputationClient::CompileInstance> instances;
4696
instances.push_back(ComputationClient::CompileInstance(
47-
std::move(MakeComputation().value()), device,
48-
client->GetCompilationDevices(device, client->GetLocalDevices()),
49-
&shape));
97+
std::move(MakeAddComputation().value()), device_,
98+
client_->GetCompilationDevices(device_, client_->GetLocalDevices()),
99+
&out_shape));
50100

51101
// Prepare inputs.
52102
xla::Literal literal_x =
@@ -56,22 +106,22 @@ TEST(PjRtComputationClientTest, Init) {
56106

57107
// Compile the graph.
58108
std::vector<ComputationClient::ComputationPtr> computations =
59-
client->Compile(std::move(instances));
109+
client_->Compile(std::move(instances));
60110

61111
// Copy inputs to device.
62112
ComputationClient::ExecuteComputationOptions options{};
63113
std::vector<std::shared_ptr<const TensorSource>> args = {
64-
std::make_shared<LiteralSource>(std::move(literal_x), device),
65-
std::make_shared<LiteralSource>(std::move(literal_y), device)};
114+
std::make_shared<LiteralSource>(std::move(literal_x), device_),
115+
std::make_shared<LiteralSource>(std::move(literal_y), device_)};
66116

67117
// Execute the graph.
68-
std::vector<ComputationClient::DataPtr> results = client->ExecuteComputation(
69-
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
70-
device, options);
118+
std::vector<ComputationClient::DataPtr> results = client_->ExecuteComputation(
119+
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
120+
device_, options);
71121

72-
// Copy the output from device back to host and assert correctness..
122+
// Copy the output from device back to host and assert correctness.
73123
ASSERT_EQ(results.size(), 1);
74-
auto result_literals = client->TransferFromDevice(results);
124+
auto result_literals = client_->TransferFromDevice(results);
75125
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
76126
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
77127
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),

torch_xla/csrc/runtime/util.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
#include <memory>
1010
#include <numeric>
1111
#include <set>
12+
#include <stdexcept>
1213
#include <string>
1314
#include <type_traits>
1415
#include <vector>
1516

17+
#include "absl/status/statusor.h"
1618
#include "absl/types/optional.h"
1719
#include "absl/types/span.h"
1820
#include "torch_xla/csrc/runtime/types.h"
@@ -128,6 +130,53 @@ T Multiply(const S& input) {
128130
std::multiplies<T>());
129131
}
130132

133+
namespace internal {
134+
135+
// ExtractStatusOrValue<U>::type is T if U is absl::StatusOr<T>, and is
136+
// undefined otherwise.
137+
template <typename U>
138+
struct ExtractStatusOrValue;
139+
template <typename T>
140+
struct ExtractStatusOrValue<absl::StatusOr<T>> {
141+
using type = T;
142+
};
143+
144+
} // namespace internal
145+
146+
// RaisePythonValueErrorOnFailure(func) requires `func` to be a functor that
147+
// takes no argument and returns an absl::StatusOr<T>. It's a wrapper of
148+
// `func()` that translates any failure in `func()` to a Python ValueError
149+
// exception. In particular:
150+
//
151+
// - if `func()` returns an error, throws an std::invalid_argument,
152+
// which is translated to a Python ValueError exception;
153+
// (https://pybind11.readthedocs.io/en/stable/advanced/exceptions.html).
154+
// - if `func()` throws any exception, rethrows it as an
155+
// std::invalid_argument so that we get a Python ValueError;
156+
// - if `func()` successfully returns a value of type T, returns the value;
157+
// - however, if `func()` crashes (e.g. due to a CHECK), we cannot
158+
// catch it; therefore we should ensure that `func()` never
159+
// crashes (and fix any crash as a bug).
160+
template <typename Func>
161+
typename internal::ExtractStatusOrValue<decltype(std::declval<Func>()())>::type
162+
RaisePythonValueErrorOnFailure(const Func& func) {
163+
decltype(std::declval<Func>()()) result;
164+
try {
165+
result = func();
166+
} catch (const std::exception& e) {
167+
throw std::invalid_argument(e.what());
168+
} catch (...) {
169+
throw std::invalid_argument(
170+
"Function threw an unknown exception. Please file a bug at "
171+
"https://github.com/pytorch/xla/issues with details on how to "
172+
"reproduce the error.");
173+
}
174+
if (result.ok()) {
175+
return *std::move(result);
176+
}
177+
throw std::invalid_argument(std::string(result.status().message()));
178+
}
179+
131180
} // namespace util
132181
} // namespace runtime
133182
} // namespace torch_xla

0 commit comments

Comments
 (0)