Skip to content

Commit 1f10433

Browse files
thcmbsGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Support proto payload in split gpu executable
PiperOrigin-RevId: 901065155
1 parent 8977857 commit 1f10433

File tree

3 files changed

+92
-8
lines changed

3 files changed

+92
-8
lines changed

xla/util/split_proto/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ cc_library(
110110
"//xla:sort_json",
111111
"//xla:xla_proto_cc",
112112
"//xla/service:hlo_proto_cc",
113+
"//xla/service:hlo_proto_util",
113114
"//xla/service/gpu:gpu_executable_proto_cc",
114115
"//xla/tsl/platform:errors",
115116
"@com_google_absl//absl/status",

xla/util/split_proto/split_gpu_executable_writer.cc

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "riegeli/bytes/writer.h"
2727
#include "riegeli/records/record_writer.h"
2828
#include "xla/service/hlo.pb.h"
29+
#include "xla/service/hlo_proto_util.h"
2930
#include "xla/sort_json.h"
3031
#include "xla/tsl/platform/errors.h"
3132
#include "xla/util/split_proto/split_proto.pb.h"
@@ -69,21 +70,46 @@ SplitProtoManifest BuildManifest(int32_t num_of_contants) {
6970
// deterministic because the order of keys in json is not guaranteed.
7071
//
7172
// If the backend config is not a json string, it is not modified.
72-
void NormalizeBackendConfig(gpu::GpuExecutableProto& executable) {
73+
absl::Status NormalizeBackendConfig(gpu::GpuExecutableProto& executable) {
7374
for (HloComputationProto& computation :
7475
*executable.mutable_hlo_module_with_config()
7576
->mutable_hlo_module()
7677
->mutable_computations()) {
7778
for (HloInstructionProto& instruction :
7879
*computation.mutable_instructions()) {
79-
absl::StatusOr<std::string> normalized_backend_config =
80-
SortJson(instruction.backend_config());
81-
if (normalized_backend_config.ok()) {
82-
instruction.set_backend_config(*normalized_backend_config);
80+
TF_ASSIGN_OR_RETURN(
81+
std::string backend_config_str,
82+
GetBackendConfigString(
83+
instruction, &executable.hlo_module_with_config().hlo_module()));
84+
auto normalized_or = SortJson(backend_config_str);
85+
if (!normalized_or.ok()) {
86+
continue;
87+
}
88+
std::string normalized = std::move(normalized_or).value();
89+
if (normalized == backend_config_str) {
90+
continue;
91+
}
92+
if (instruction.has_backend_config_payload()) {
93+
Payload* payload = instruction.mutable_backend_config_payload();
94+
switch (payload->payload_source_case()) {
95+
case Payload::kId: {
96+
int id = static_cast<int>(payload->id());
97+
auto* module = executable.mutable_hlo_module_with_config()
98+
->mutable_hlo_module();
99+
*module->mutable_payloads(id) = normalized;
100+
break;
101+
}
102+
case Payload::kValue:
103+
case Payload::PAYLOAD_SOURCE_NOT_SET:
104+
payload->set_value(normalized);
105+
break;
106+
}
107+
} else {
108+
instruction.set_backend_config(normalized);
83109
}
84-
// If the backend config is not a json string, then do nothing.
85110
}
86111
}
112+
return absl::OkStatus();
87113
}
88114

89115
} // namespace
@@ -115,7 +141,7 @@ absl::Status WriteSplitGpuExecutable(gpu::GpuExecutableProto executable,
115141
executable.clear_constants();
116142

117143
// The rest of the fields (i.e. the non-offloaded fields)
118-
NormalizeBackendConfig(executable);
144+
TF_RETURN_IF_ERROR(NormalizeBackendConfig(executable));
119145
// Module IDs are created via a static counter when deserializing, and they
120146
// can cause non-determinism, so we don't preserve them.
121147
executable.mutable_hlo_module_with_config()->mutable_hlo_module()->clear_id();

xla/util/split_proto/split_gpu_executable_writer_test.cc

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,69 @@ TEST(SplitGpuExecutableWriterTest, NonJsonBackendConfigIsAccepted) {
113113
->mutable_hlo_module()
114114
->add_computations()
115115
->add_instructions()
116-
->mutable_backend_config() = "not-json";
116+
->mutable_backend_config() = "x-json";
117117

118118
std::string serialized1;
119119
ASSERT_OK(WriteSplitGpuExecutable(
120120
proto1, std::make_unique<riegeli::StringWriter<>>(&serialized1)));
121121
}
122122

123+
TEST(SplitGpuExecutableWriterTest, JsonBackendConfigPayloadIsNormalized) {
124+
GpuExecutableProto proto1;
125+
proto1.mutable_hlo_module_with_config()
126+
->mutable_hlo_module()
127+
->add_computations()
128+
->add_instructions()
129+
->mutable_backend_config_payload()
130+
->set_value(R"json({"a": 1, "b": 2, "c": 3})json");
131+
132+
GpuExecutableProto proto2;
133+
proto2.mutable_hlo_module_with_config()
134+
->mutable_hlo_module()
135+
->add_computations()
136+
->add_instructions()
137+
->mutable_backend_config_payload()
138+
->set_value(R"json({"c": 3, "b": 2, "a": 1})json");
139+
140+
std::string serialized1;
141+
ASSERT_OK(WriteSplitGpuExecutable(
142+
proto1, std::make_unique<riegeli::StringWriter<>>(&serialized1)));
143+
144+
std::string serialized2;
145+
ASSERT_OK(WriteSplitGpuExecutable(
146+
proto2, std::make_unique<riegeli::StringWriter<>>(&serialized2)));
147+
148+
EXPECT_EQ(serialized1, serialized2);
149+
}
150+
151+
TEST(SplitGpuExecutableWriterTest,
152+
JsonBackendConfigExternalPayloadIsNormalized) {
153+
GpuExecutableProto proto1;
154+
auto* module1 = proto1.mutable_hlo_module_with_config()->mutable_hlo_module();
155+
module1->add_payloads(R"json({"a": 1, "b": 2, "c": 3})json");
156+
module1->add_computations()
157+
->add_instructions()
158+
->mutable_backend_config_payload()
159+
->set_id(0);
160+
161+
GpuExecutableProto proto2;
162+
auto* module2 = proto2.mutable_hlo_module_with_config()->mutable_hlo_module();
163+
module2->add_payloads(R"json({"c": 3, "b": 2, "a": 1})json");
164+
module2->add_computations()
165+
->add_instructions()
166+
->mutable_backend_config_payload()
167+
->set_id(0);
168+
169+
std::string serialized1;
170+
ASSERT_OK(WriteSplitGpuExecutable(
171+
proto1, std::make_unique<riegeli::StringWriter<>>(&serialized1)));
172+
173+
std::string serialized2;
174+
ASSERT_OK(WriteSplitGpuExecutable(
175+
proto2, std::make_unique<riegeli::StringWriter<>>(&serialized2)));
176+
177+
EXPECT_EQ(serialized1, serialized2);
178+
}
179+
123180
} // namespace
124181
} // namespace xla

0 commit comments

Comments
 (0)