Skip to content

Commit 08e2bd7

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
PR #40920: [xla] Add a separate pass to propagate metadata across kCall instruction
Imported from GitHub PR #40920 - New `PropagateCallMetadata` HLO pass that propagates metadata (op_name prefix and stack_frame_id) from `kCall` instructions into their called computations, recursing through nested control-flow (while, conditional) but not into embedded computations (reduce's to_apply, etc.) - Removes metadata propagation from `CallInliner` — the `RecursivelyUpdateMetadata` helper and per-instruction metadata update during inlining are no longer needed since the new pass handles this as a separate concern - Wired into GPU compiler pre-SPMD pipeline before `CallInliner`, so metadata is propagated while `kCall` ops still exist (including non-inlinable calls) - Test coverage for op_name propagation, stack frame concatenation, overflow protection, redundant prefix detection, nested calls, and idempotency #### Motivation The `CallInliner` only updates metadata for calls it actually inlines. Non-inlinable calls (e.g. calls with inlineable="false") were skipped entirely, leaving their callee instructions with incomplete metadata context. Extracting this into a standalone pass ensures all calls get metadata propagation regardless of inlining decisions. Copybara import of the project: -- 3315b79 by Eugene Zhulenev <ezhulenev@openxla.org>: [xla] Add a separate pass to propagate metadata across kCall instructions Merging this change closes #40920 FUTURE_COPYBARA_INTEGRATE_REVIEW=#40920 from ezhulenev:propagate-metadata-pass 3315b79 PiperOrigin-RevId: 900707821
1 parent bf6c157 commit 08e2bd7

File tree

8 files changed

+607
-279
lines changed

8 files changed

+607
-279
lines changed

xla/hlo/transforms/BUILD

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,38 @@ xla_cc_test(
522522
],
523523
)
524524

525+
cc_library(
526+
name = "propagate_call_metadata",
527+
srcs = ["propagate_call_metadata.cc"],
528+
hdrs = ["propagate_call_metadata.h"],
529+
deps = [
530+
"//xla:xla_data_proto_cc",
531+
"//xla/hlo/ir:hlo",
532+
"//xla/hlo/pass:hlo_pass",
533+
"@com_google_absl//absl/container:flat_hash_set",
534+
"@com_google_absl//absl/status:statusor",
535+
"@com_google_absl//absl/strings",
536+
"@com_google_absl//absl/strings:string_view",
537+
],
538+
)
539+
540+
xla_cc_test(
541+
name = "propagate_call_metadata_test",
542+
size = "small",
543+
srcs = ["propagate_call_metadata_test.cc"],
544+
deps = [
545+
":propagate_call_metadata",
546+
"//xla:xla_data_proto_cc",
547+
"//xla/hlo/ir:hlo",
548+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
549+
"//xla/tsl/lib/core:status_test_util",
550+
"//xla/tsl/platform:statusor",
551+
"@com_google_absl//absl/strings:string_view",
552+
"@com_google_googletest//:gtest",
553+
"@com_google_googletest//:gtest_main",
554+
],
555+
)
556+
525557
cc_library(
526558
name = "add_original_value",
527559
srcs = ["add_original_value.cc"],
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/transforms/propagate_call_metadata.h"
17+
18+
#include <algorithm>
19+
#include <string>
20+
#include <utility>
21+
22+
#include "absl/container/flat_hash_set.h"
23+
#include "absl/status/statusor.h"
24+
#include "absl/strings/match.h"
25+
#include "absl/strings/str_cat.h"
26+
#include "absl/strings/string_view.h"
27+
#include "absl/strings/strip.h"
28+
#include "xla/hlo/ir/hlo_computation.h"
29+
#include "xla/hlo/ir/hlo_instruction.h"
30+
#include "xla/hlo/ir/hlo_module.h"
31+
#include "xla/hlo/ir/hlo_module_metadata.h"
32+
#include "xla/hlo/ir/stack_frames.h"
33+
#include "xla/xla_data.pb.h"
34+
35+
namespace xla {
36+
namespace {
37+
38+
// Limit on op_name length to prevent unbounded growth from deeply nested calls.
39+
constexpr int kMaxOpNameSize = 1024;
40+
41+
// Sanitize and prepend the prefix to the instruction's op_name.
42+
bool UpdateOpName(OpMetadata& metadata, absl::string_view prefix) {
43+
if (prefix.empty()) {
44+
return false;
45+
}
46+
// Strip trailing '/' from prefix.
47+
absl::string_view clean_prefix = absl::StripSuffix(prefix, "/");
48+
if (clean_prefix.empty()) {
49+
return false;
50+
}
51+
52+
std::string op_name = metadata.op_name();
53+
// Strip leading/trailing '/' from existing op_name.
54+
absl::string_view clean_name = absl::StripPrefix(op_name, "/");
55+
clean_name = absl::StripSuffix(clean_name, "/");
56+
57+
// Already has the prefix.
58+
if (absl::StartsWith(clean_name, clean_prefix)) {
59+
return false;
60+
}
61+
// op_name is a substring of prefix — already captured.
62+
if (!clean_name.empty() && absl::StrContains(clean_prefix, clean_name)) {
63+
return false;
64+
}
65+
std::string result;
66+
if (clean_name.empty()) {
67+
result = std::string(clean_prefix);
68+
} else {
69+
result = absl::StrCat(clean_prefix, "/", clean_name);
70+
}
71+
// Cap at kMaxOpNameSize to avoid unbounded growth from deeply nested calls.
72+
if (result.size() > kMaxOpNameSize) {
73+
result.resize(kMaxOpNameSize);
74+
}
75+
metadata.set_op_name(std::move(result));
76+
return true;
77+
}
78+
79+
// Update stack frame: concatenate parent_frame_id as ancestor.
80+
bool UpdateStackFrame(HloInstruction* hlo, StackFrameId parent_frame_id) {
81+
if (!parent_frame_id.valid()) {
82+
return false;
83+
}
84+
HloModule* module = hlo->GetModule();
85+
OpMetadata metadata = hlo->metadata();
86+
if (module->stack_frames().IsPrefix(
87+
parent_frame_id, StackFrameId{metadata.stack_frame_id()})) {
88+
return false;
89+
}
90+
metadata.set_stack_frame_id(
91+
module->mutable_stack_frames()
92+
.Concatenate(parent_frame_id, StackFrameId{metadata.stack_frame_id()})
93+
.value);
94+
hlo->set_metadata(metadata);
95+
return true;
96+
}
97+
98+
// Propagate metadata into all instructions in a computation.
99+
// Recurses into control-flow sub-computations (while, conditional) with the
100+
// same prefix. Does NOT recurse into kCall — nested calls are handled by
101+
// the top-level loop which processes computations in reverse post-order.
102+
bool PropagateIntoComputation(HloComputation* computation,
103+
absl::string_view prefix,
104+
StackFrameId parent_frame_id) {
105+
bool changed = false;
106+
for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
107+
OpMetadata metadata = instr->metadata();
108+
if (UpdateOpName(metadata, prefix)) {
109+
instr->set_metadata(metadata);
110+
changed = true;
111+
}
112+
if (UpdateStackFrame(instr, parent_frame_id)) {
113+
changed = true;
114+
}
115+
116+
// Recurse into while/conditional sub-computations with same prefix.
117+
if (GetInstructionCallContext(instr->opcode()) ==
118+
CallContext::kControlFlow &&
119+
instr->opcode() != HloOpcode::kCall) {
120+
for (HloComputation* sub : instr->called_computations()) {
121+
changed |= PropagateIntoComputation(sub, prefix, parent_frame_id);
122+
}
123+
}
124+
}
125+
return changed;
126+
}
127+
128+
} // namespace
129+
130+
absl::StatusOr<bool> PropagateCallMetadata::RunImpl(
131+
HloModule* module,
132+
const absl::flat_hash_set<absl::string_view>& execution_threads) {
133+
bool changed = false;
134+
135+
// Process in reverse post-order (callers before callees) so that nested
136+
// call instructions have their metadata updated before we propagate into
137+
// their callees.
138+
auto computations = module->MakeNonfusionComputations(execution_threads);
139+
std::reverse(computations.begin(), computations.end());
140+
141+
for (HloComputation* computation : computations) {
142+
for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
143+
if (instr->opcode() != HloOpcode::kCall) {
144+
continue;
145+
}
146+
const OpMetadata& call_metadata = instr->metadata();
147+
absl::string_view prefix = call_metadata.op_name();
148+
StackFrameId parent_frame_id{call_metadata.stack_frame_id()};
149+
if (prefix.empty() && !parent_frame_id.valid()) {
150+
continue;
151+
}
152+
for (HloComputation* callee : instr->called_computations()) {
153+
changed |= PropagateIntoComputation(callee, prefix, parent_frame_id);
154+
}
155+
}
156+
}
157+
158+
return changed;
159+
}
160+
161+
} // namespace xla
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_HLO_TRANSFORMS_PROPAGATE_CALL_METADATA_H_
17+
#define XLA_HLO_TRANSFORMS_PROPAGATE_CALL_METADATA_H_
18+
19+
#include "absl/container/flat_hash_set.h"
20+
#include "absl/status/statusor.h"
21+
#include "absl/strings/string_view.h"
22+
#include "xla/hlo/ir/hlo_module.h"
23+
#include "xla/hlo/pass/hlo_pass_interface.h"
24+
25+
namespace xla {
26+
27+
// Propagates metadata (op_name prefix and stack_frame_id) from kCall
28+
// instructions into their called computations, recursing through nested
29+
// calls and control-flow structures.
30+
//
31+
// This pass should run before call inlining, while kCall ops still exist.
32+
class PropagateCallMetadata : public HloModulePass {
33+
public:
34+
PropagateCallMetadata() = default;
35+
~PropagateCallMetadata() override = default;
36+
37+
absl::string_view name() const override { return "propagate-call-metadata"; }
38+
39+
protected:
40+
absl::StatusOr<bool> RunImpl(
41+
HloModule* module,
42+
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
43+
};
44+
45+
} // namespace xla
46+
47+
#endif // XLA_HLO_TRANSFORMS_PROPAGATE_CALL_METADATA_H_

0 commit comments

Comments
 (0)