Skip to content

Commit 75960a9

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
PR #37903: [xla:gpu] Use Command::Walk APIs to collect buffer uses and command properties
Imported from GitHub PR #37903 In preparation for `Command` and `Thunk` unification make sure that `buffer_uses` has the same semantics, and use command walking API to collect buffer and resource uses for all nested commands. Copybara import of the project: -- de2b073 by Eugene Zhulenev <ezhulenev@openxla.org>: [xla:gpu] Use Command::Walk APIs to collect buffer uses and command properties Merging this change closes #37903 FUTURE_COPYBARA_INTEGRATE_REVIEW=#37903 from ezhulenev:command-walking-0 de2b073 PiperOrigin-RevId: 874228029
1 parent 06f2977 commit 75960a9

File tree

5 files changed

+108
-147
lines changed

5 files changed

+108
-147
lines changed

xla/backends/gpu/runtime/command.h

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,14 @@ bool IsCollectiveCommand(CommandType type);
127127
class Command {
128128
public:
129129
using BufferUses = Thunk::BufferUses;
130-
using ResourceUseVector = absl::InlinedVector<ResourceUse, 1>;
130+
using ResourceUses = absl::InlinedVector<ResourceUse, 1>;
131131

132132
public:
133133
explicit Command(CommandType cmd_type,
134134
se::StreamPriority priority = se::StreamPriority::Default)
135135
: cmd_type_(cmd_type), priority_(priority) {
136136
token_ = Resource::Create(Resource::kToken);
137-
resources_.push_back(ResourceUse::Write(token_));
137+
resource_uses_.push_back(ResourceUse::Write(token_));
138138
}
139139

140140
virtual ~Command() = default;
@@ -213,29 +213,34 @@ class Command {
213213
// they got lucky and got the same buffer allocations), it will lead to
214214
// deadlocks. By forcing the command update at thunk initialization time, we
215215
// ensure that all ranks execute NCCL command update.
216-
virtual bool requires_initialization() { return false; }
216+
virtual bool requires_initialization() const { return false; }
217217

218218
// Returns true if command supports loop unroll, the while loop can be
219219
// unrolled only if it has pre-known trip count and also all commands from the
220220
// body commands are unrollable.
221-
virtual bool support_loop_unroll() { return true; }
221+
virtual bool support_loop_unroll() const { return true; }
222222

223223
// This is only true for DynamicSliceCopyFusionCmd when offset is dependents
224224
// on loop iteration. As the command of slice operation is access the sliced
225225
// memory region that varies across loop iterations, so even the original
226226
// buffer allocation is the same, it still requires to do update.
227-
virtual bool force_update() { return false; }
227+
virtual bool force_update() const { return false; }
228228

229-
// Returns all buffers used by the cmd. These will be used to track cmd
230-
// updates, thus they need to be consistent across calls to the function.
229+
// Returns buffers used by this command. Buffer uses do not include buffers
230+
// that might be used by nested commands, they must be collected separately
231+
// by walking the nested commands using `Walk` API.
231232
virtual BufferUses buffer_uses() const { return {}; }
232233

233234
std::shared_ptr<Resource> token() const { return token_; }
234235

235236
void add_resource_use(ResourceUse resource_use) {
236-
resources_.push_back(resource_use);
237+
resource_uses_.push_back(resource_use);
237238
}
238-
ResourceUseVector resources() const { return resources_; }
239+
240+
// Returns resource used by this command. Resource uses do not include
241+
// resources that might be used by nested commands, they must be collected
242+
// separately by walking the nested commands using `Walk` API.
243+
ResourceUses resource_uses() const { return resource_uses_; }
239244

240245
// Returns true if command implemented as a nested command buffer.
241246
virtual bool IsNestedCommandBuffer() const { return false; }
@@ -275,7 +280,7 @@ class Command {
275280
std::string profile_annotation_;
276281
CommandType cmd_type_;
277282

278-
ResourceUseVector resources_;
283+
ResourceUses resource_uses_;
279284

280285
// The token resource is used to specify additional dependency across
281286
// commands, like control dependency across HLO operators, and LHS scheduling
@@ -378,6 +383,18 @@ class CommandSequence : public std::vector<std::unique_ptr<Command>> {
378383
}
379384
return absl::OkStatus();
380385
}
386+
387+
void Walk(absl::FunctionRef<void(const Command*)> callback) const {
388+
for (const std::unique_ptr<Command>& cmd : *this) {
389+
cmd->Walk(callback);
390+
}
391+
}
392+
393+
void Walk(absl::FunctionRef<void(Command*)> callback) {
394+
for (std::unique_ptr<Command>& cmd : *this) {
395+
cmd->Walk(callback);
396+
}
397+
}
381398
};
382399

383400
} // namespace xla::gpu

xla/backends/gpu/runtime/command_buffer_cmd.cc

Lines changed: 17 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ static absl::string_view ReductionKindString(ReductionKind kind) {
152152

153153
// Create a callback to create a command buffer from a command sequence.
154154
static se::CommandBuffer::CreateCommands CreateCommands(
155-
const CommandBufferCmdExecutor* commands,
156-
const Thunk::ExecuteParams* execute_params,
155+
const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params,
157156
const Command::RecordParams* record_params) {
158157
return [=](se::CommandBuffer* command_buffer,
159158
absl::Span<const se::CommandBuffer::Command* const> dependencies) {
@@ -164,11 +163,11 @@ static se::CommandBuffer::CreateCommands CreateCommands(
164163

165164
// Create callbacks to create a command buffer from command sequences.
166165
static std::vector<se::CommandBuffer::CreateCommands> CreateCommands(
167-
absl::Span<const CommandBufferCmdExecutor> commands,
166+
absl::Span<const CommandExecutor> commands,
168167
const Thunk::ExecuteParams* execute_params,
169168
const Command::RecordParams* record_params) {
170169
std::vector<se::CommandBuffer::CreateCommands> create_commands;
171-
for (const CommandBufferCmdExecutor& cmd : commands) {
170+
for (const CommandExecutor& cmd : commands) {
172171
create_commands.push_back(
173172
CreateCommands(&cmd, execute_params, record_params));
174173
}
@@ -177,8 +176,7 @@ static std::vector<se::CommandBuffer::CreateCommands> CreateCommands(
177176

178177
// Create a callback to update a command buffer with command sequence.
179178
static se::CommandBuffer::UpdateCommands UpdateCommands(
180-
const CommandBufferCmdExecutor* commands,
181-
const Thunk::ExecuteParams* execute_params,
179+
const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params,
182180
const Command::RecordParams* record_params) {
183181
return [=](se::CommandBuffer* command_buffer) {
184182
return commands->RecordUpdate(*execute_params, *record_params,
@@ -188,11 +186,11 @@ static se::CommandBuffer::UpdateCommands UpdateCommands(
188186

189187
// Create callbacks to update a command buffer with command sequence.
190188
static std::vector<se::CommandBuffer::UpdateCommands> UpdateCommands(
191-
absl::Span<const CommandBufferCmdExecutor> commands,
189+
absl::Span<const CommandExecutor> commands,
192190
const Thunk::ExecuteParams* execute_params,
193191
const Command::RecordParams* record_params) {
194192
std::vector<se::CommandBuffer::UpdateCommands> update_commands;
195-
for (const CommandBufferCmdExecutor& cmd : commands) {
193+
for (const CommandExecutor& cmd : commands) {
196194
update_commands.push_back(
197195
UpdateCommands(&cmd, execute_params, record_params));
198196
}
@@ -740,21 +738,10 @@ Command::BufferUses Memset32Cmd::buffer_uses() const {
740738
// ChildCmd
741739
//===----------------------------------------------------------------------===//
742740

743-
ChildCmd::ChildCmd(CommandBufferCmdExecutor child_commands)
741+
ChildCmd::ChildCmd(CommandExecutor child_commands)
744742
: Command(CommandType::kChildCmd),
745743
child_commands_(std::move(child_commands)) {}
746744

747-
bool ChildCmd::requires_initialization() {
748-
return child_commands_.requires_initialization();
749-
}
750-
751-
bool ChildCmd::force_update() { return child_commands_.force_update(); }
752-
753-
Command::BufferUses ChildCmd::buffer_uses() const {
754-
return {child_commands_.buffer_uses().begin(),
755-
child_commands_.buffer_uses().end()};
756-
}
757-
758745
absl::Status ChildCmd::Initialize(const Thunk::InitializeParams& params) {
759746
TF_RETURN_IF_ERROR(child_commands_.Initialize(params));
760747
return absl::OkStatus();
@@ -797,8 +784,7 @@ absl::Status ChildCmd::WalkNested(
797784
// CaseCmd
798785
//===----------------------------------------------------------------------===//
799786

800-
CaseCmd::CaseCmd(ShapedSlice index,
801-
std::vector<CommandBufferCmdExecutor> branches)
787+
CaseCmd::CaseCmd(ShapedSlice index, std::vector<CommandExecutor> branches)
802788
: Command(CommandType::kCaseCmd),
803789
index_(index),
804790
index_is_bool_(index.shape.element_type() == PRED),
@@ -847,23 +833,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CaseCmd::Record(
847833
});
848834
}
849835

850-
bool CaseCmd::requires_initialization() {
851-
return absl::c_any_of(
852-
branches_, [](const auto& seq) { return seq.requires_initialization(); });
853-
}
854-
855-
bool CaseCmd::force_update() {
856-
return absl::c_any_of(branches_,
857-
[](const auto& seq) { return seq.force_update(); });
858-
}
859-
860836
Command::BufferUses CaseCmd::buffer_uses() const {
861-
absl::flat_hash_set<BufferUse> buffers;
862-
buffers.emplace(BufferUse::Read(index_.slice, index_.shape));
863-
for (auto& branch : branches_) {
864-
buffers.insert(branch.buffer_uses().begin(), branch.buffer_uses().end());
865-
}
866-
return {buffers.begin(), buffers.end()};
837+
return {BufferUse::Read(index_.slice, index_.shape)};
867838
}
868839

869840
absl::Status CaseCmd::WalkNested(
@@ -878,9 +849,8 @@ absl::Status CaseCmd::WalkNested(
878849
// WhileCmd
879850
//===----------------------------------------------------------------------===//
880851

881-
WhileCmd::WhileCmd(BufferAllocation::Slice pred,
882-
CommandBufferCmdExecutor cond_commands,
883-
CommandBufferCmdExecutor body_commands,
852+
WhileCmd::WhileCmd(BufferAllocation::Slice pred, CommandExecutor cond_commands,
853+
CommandExecutor body_commands,
884854
std::optional<int64_t> trip_count, bool enable_loop_unroll)
885855
: Command(CommandType::kWhileCmd),
886856
pred_(pred),
@@ -1005,23 +975,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> WhileCmd::Record(
1005975
});
1006976
}
1007977

1008-
bool WhileCmd::requires_initialization() {
1009-
return (cond_commands_.requires_initialization() ||
1010-
body_commands_.requires_initialization());
1011-
}
1012-
1013-
bool WhileCmd::force_update() {
1014-
return cond_commands_.force_update() || body_commands_.force_update();
1015-
}
1016-
1017978
Command::BufferUses WhileCmd::buffer_uses() const {
1018-
absl::flat_hash_set<BufferUse> buffers;
1019-
buffers.emplace(BufferUse::Read(pred_, ShapeUtil::MakeShape(PRED, {})));
1020-
buffers.insert(cond_commands_.buffer_uses().begin(),
1021-
cond_commands_.buffer_uses().end());
1022-
buffers.insert(body_commands_.buffer_uses().begin(),
1023-
body_commands_.buffer_uses().end());
1024-
return {buffers.begin(), buffers.end()};
979+
return {BufferUse::Read(pred_, ShapeUtil::MakeShape(PRED, {}))};
1025980
}
1026981

1027982
absl::Status WhileCmd::WalkNested(
@@ -1086,16 +1041,16 @@ absl::StatusOr<const se::CommandBuffer::Command*> GemmCmd::Record(
10861041
}
10871042

10881043
Command::BufferUses GemmCmd::buffer_uses() const {
1089-
Command::BufferUses res{
1044+
Command::BufferUses buffer_uses = {
10901045
BufferUse::Read(lhs_buffer_, config_.lhs_layout.ToShape()),
10911046
BufferUse::Read(rhs_buffer_, config_.rhs_layout.ToShape()),
10921047
BufferUse::Write(output_buffer_, config_.output_layout.ToShape()),
10931048
};
10941049
if (workspace_.has_value()) {
1095-
res.push_back(BufferUse::Write(
1050+
buffer_uses.push_back(BufferUse::Write(
10961051
*workspace_, ShapeUtil::MakeShape(S8, {workspace_->size()})));
10971052
}
1098-
return res;
1053+
return buffer_uses;
10991054
}
11001055

11011056
//===----------------------------------------------------------------------===//
@@ -2156,7 +2111,7 @@ Command::BufferUses CollectivePermuteCmd::buffer_uses() const {
21562111
//===----------------------------------------------------------------------===//
21572112

21582113
DynamicSliceFusionCmd::DynamicSliceFusionCmd(
2159-
CommandBufferCmdExecutor embedded_commands,
2114+
CommandExecutor embedded_commands,
21602115
std::vector<std::optional<BufferAllocation::Slice>> arguments,
21612116
std::vector<BufferAllocation> fake_allocations,
21622117
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets,
@@ -2203,7 +2158,7 @@ DynamicSliceFusionCmd::DynamicSliceFusionCmd(
22032158
// because the memory address might changed if the offset is loop
22042159
// iterator or operator outputs even if the parent command's memory pointers
22052160
// do not change.
2206-
bool DynamicSliceFusionCmd::requires_initialization() {
2161+
bool DynamicSliceFusionCmd::requires_initialization() const {
22072162
return !absl::c_all_of(slices_, [](const DynamicSliceThunk::SliceDef& slice) {
22082163
if (!slice.offsets.has_value()) {
22092164
return true;
@@ -2455,17 +2410,6 @@ absl::StatusOr<const se::CommandBuffer::Command*> DynamicSliceFusionCmd::Record(
24552410
});
24562411
}
24572412

2458-
Command::BufferUses DynamicSliceFusionCmd::buffer_uses() const {
2459-
Command::BufferUses buffers;
2460-
auto embed_buffers = embedded_commands_.buffer_uses();
2461-
for (const BufferUse& buffer_usage : embed_buffers) {
2462-
buffers.emplace_back(
2463-
*embedded_to_origin_slice_map_.at(buffer_usage.slice().index()),
2464-
buffer_usage.access(), buffer_usage.shape());
2465-
}
2466-
return buffers;
2467-
}
2468-
24692413
absl::Status DynamicSliceFusionCmd::WalkNested(
24702414
absl::FunctionRef<absl::Status(Command*)> callback) {
24712415
return embedded_commands_.Walk(callback);

xla/backends/gpu/runtime/command_buffer_cmd.h

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ class EmptyCmd : public Command {
135135
const Thunk::ExecuteParams& execute_params,
136136
const RecordParams& record_params, RecordAction record_action,
137137
se::CommandBuffer* command_buffer) override;
138-
139-
BufferUses buffer_uses() const override { return {}; }
140138
};
141139

142140
//===----------------------------------------------------------------------===//
@@ -303,14 +301,6 @@ class ChildCmd : public Command {
303301
const RecordParams& record_params, RecordAction record_action,
304302
se::CommandBuffer* command_buffer) override;
305303

306-
bool requires_initialization() override;
307-
308-
bool force_update() override;
309-
310-
bool support_loop_unroll() override { return false; }
311-
312-
BufferUses buffer_uses() const override;
313-
314304
absl::Status WalkNested(
315305
absl::FunctionRef<absl::Status(Command*)> callback) override;
316306

@@ -333,12 +323,6 @@ class CaseCmd : public Command {
333323
const RecordParams& record_params, RecordAction record_action,
334324
se::CommandBuffer* command_buffer) override;
335325

336-
bool requires_initialization() override;
337-
338-
bool force_update() override;
339-
340-
bool support_loop_unroll() override { return false; }
341-
342326
BufferUses buffer_uses() const override;
343327

344328
absl::Status WalkNested(
@@ -370,14 +354,6 @@ class WhileCmd : public Command {
370354
const RecordParams& record_params, RecordAction record_action,
371355
se::CommandBuffer* command_buffer) override;
372356

373-
bool requires_initialization() override;
374-
375-
bool force_update() override;
376-
377-
// We have not tried unrolling the loop inside another loop, so marking it
378-
// unsupported for now.
379-
bool support_loop_unroll() override { return false; }
380-
381357
BufferUses buffer_uses() const override;
382358

383359
absl::Status WalkNested(
@@ -566,7 +542,7 @@ class CollectiveCmd : public AsyncStartCommand {
566542

567543
absl::Status Prepare(const Thunk::PrepareParams& params) final;
568544

569-
bool requires_initialization() override { return true; }
545+
bool requires_initialization() const final { return true; }
570546

571547
bool IsNestedCommandBuffer() const final { return true; }
572548

@@ -604,8 +580,6 @@ class CollectiveDoneCmd : public AsyncDoneCommand {
604580
const RecordParams& record_params, RecordAction record_action,
605581
se::CommandBuffer* command_buffer) override;
606582

607-
BufferUses buffer_uses() const override { return {}; }
608-
609583
std::shared_ptr<CollectiveThunk::AsyncEvents> async_events() const {
610584
return async_events_;
611585
}
@@ -818,13 +792,9 @@ class DynamicSliceFusionCmd : public Command {
818792
const RecordParams& record_params, RecordAction record_action,
819793
se::CommandBuffer* command_buffer) override;
820794

821-
BufferUses buffer_uses() const override;
822-
823-
bool force_update() override { return true; }
795+
bool force_update() const final { return true; }
824796

825-
bool requires_initialization() override;
826-
827-
bool support_loop_unroll() override { return true; }
797+
bool requires_initialization() const final;
828798

829799
bool IsNestedCommandBuffer() const final { return true; }
830800

@@ -879,9 +849,7 @@ class DynamicSliceCopyFusionCmd : public Command {
879849
const RecordParams& record_params, RecordAction record_action,
880850
se::CommandBuffer* command_buffer) override;
881851

882-
bool force_update() override { return offsets_.depends_on_loop; }
883-
884-
bool support_loop_unroll() override { return true; }
852+
bool force_update() const final { return offsets_.depends_on_loop; }
885853

886854
BufferUses buffer_uses() const override;
887855

0 commit comments

Comments
 (0)