@@ -26,6 +26,7 @@ limitations under the License.
2626#include " riegeli/bytes/writer.h"
2727#include " riegeli/records/record_writer.h"
2828#include " xla/service/hlo.pb.h"
29+ #include " xla/service/hlo_proto_util.h"
2930#include " xla/sort_json.h"
3031#include " xla/tsl/platform/errors.h"
3132#include " xla/util/split_proto/split_proto.pb.h"
@@ -69,21 +70,46 @@ SplitProtoManifest BuildManifest(int32_t num_of_contants) {
6970// deterministic because the order of keys in json is not guaranteed.
7071//
7172// If the backend config is not a json string, it is not modified.
72- void NormalizeBackendConfig (gpu::GpuExecutableProto& executable) {
73+ absl::Status NormalizeBackendConfig (gpu::GpuExecutableProto& executable) {
7374 for (HloComputationProto& computation :
7475 *executable.mutable_hlo_module_with_config ()
7576 ->mutable_hlo_module ()
7677 ->mutable_computations ()) {
7778 for (HloInstructionProto& instruction :
7879 *computation.mutable_instructions ()) {
79- absl::StatusOr<std::string> normalized_backend_config =
80- SortJson (instruction.backend_config ());
81- if (normalized_backend_config.ok ()) {
82- instruction.set_backend_config (*normalized_backend_config);
80+ TF_ASSIGN_OR_RETURN (
81+ std::string backend_config_str,
82+ GetBackendConfigString (
83+ instruction, &executable.hlo_module_with_config ().hlo_module ()));
84+ auto normalized_or = SortJson (backend_config_str);
85+ if (!normalized_or.ok ()) {
86+ continue ;
87+ }
88+ std::string normalized = std::move (normalized_or).value ();
89+ if (normalized == backend_config_str) {
90+ continue ;
91+ }
92+ if (instruction.has_backend_config_payload ()) {
93+ Payload* payload = instruction.mutable_backend_config_payload ();
94+ switch (payload->payload_source_case ()) {
95+ case Payload::kId : {
96+ int id = static_cast <int >(payload->id ());
97+ auto * module = executable.mutable_hlo_module_with_config ()
98+ ->mutable_hlo_module ();
99+ *module ->mutable_payloads (id) = normalized;
100+ break ;
101+ }
102+ case Payload::kValue :
103+ case Payload::PAYLOAD_SOURCE_NOT_SET:
104+ payload->set_value (normalized);
105+ break ;
106+ }
107+ } else {
108+ instruction.set_backend_config (normalized);
83109 }
84- // If the backend config is not a json string, then do nothing.
85110 }
86111 }
112+ return absl::OkStatus ();
87113}
88114
89115} // namespace
@@ -115,7 +141,7 @@ absl::Status WriteSplitGpuExecutable(gpu::GpuExecutableProto executable,
115141 executable.clear_constants ();
116142
117143 // The rest of the fields (i.e. the non-offloaded fields)
118- NormalizeBackendConfig (executable);
144+ TF_RETURN_IF_ERROR ( NormalizeBackendConfig (executable) );
119145 // Module IDs are created via a static counter when deserializing, and they
120146 // can cause non-determinism, so we don't preserve them.
121147 executable.mutable_hlo_module_with_config ()->mutable_hlo_module ()->clear_id ();
0 commit comments