@@ -152,8 +152,7 @@ static absl::string_view ReductionKindString(ReductionKind kind) {
152152
153153// Create a callback to create a command buffer from a command sequence.
154154static se::CommandBuffer::CreateCommands CreateCommands (
155- const CommandBufferCmdExecutor* commands,
156- const Thunk::ExecuteParams* execute_params,
155+ const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params,
157156 const Command::RecordParams* record_params) {
158157 return [=](se::CommandBuffer* command_buffer,
159158 absl::Span<const se::CommandBuffer::Command* const > dependencies) {
@@ -164,11 +163,11 @@ static se::CommandBuffer::CreateCommands CreateCommands(
164163
165164// Create callbacks to create a command buffer from command sequences.
166165static std::vector<se::CommandBuffer::CreateCommands> CreateCommands (
167- absl::Span<const CommandBufferCmdExecutor > commands,
166+ absl::Span<const CommandExecutor > commands,
168167 const Thunk::ExecuteParams* execute_params,
169168 const Command::RecordParams* record_params) {
170169 std::vector<se::CommandBuffer::CreateCommands> create_commands;
171- for (const CommandBufferCmdExecutor & cmd : commands) {
170+ for (const CommandExecutor & cmd : commands) {
172171 create_commands.push_back (
173172 CreateCommands (&cmd, execute_params, record_params));
174173 }
@@ -177,8 +176,7 @@ static std::vector<se::CommandBuffer::CreateCommands> CreateCommands(
177176
178177// Create a callback to update a command buffer with command sequence.
179178static se::CommandBuffer::UpdateCommands UpdateCommands (
180- const CommandBufferCmdExecutor* commands,
181- const Thunk::ExecuteParams* execute_params,
179+ const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params,
182180 const Command::RecordParams* record_params) {
183181 return [=](se::CommandBuffer* command_buffer) {
184182 return commands->RecordUpdate (*execute_params, *record_params,
@@ -188,11 +186,11 @@ static se::CommandBuffer::UpdateCommands UpdateCommands(
188186
189187// Create callbacks to update a command buffer with command sequence.
190188static std::vector<se::CommandBuffer::UpdateCommands> UpdateCommands (
191- absl::Span<const CommandBufferCmdExecutor > commands,
189+ absl::Span<const CommandExecutor > commands,
192190 const Thunk::ExecuteParams* execute_params,
193191 const Command::RecordParams* record_params) {
194192 std::vector<se::CommandBuffer::UpdateCommands> update_commands;
195- for (const CommandBufferCmdExecutor & cmd : commands) {
193+ for (const CommandExecutor & cmd : commands) {
196194 update_commands.push_back (
197195 UpdateCommands (&cmd, execute_params, record_params));
198196 }
@@ -740,21 +738,10 @@ Command::BufferUses Memset32Cmd::buffer_uses() const {
740738// ChildCmd
741739// ===----------------------------------------------------------------------===//
742740
743- ChildCmd::ChildCmd (CommandBufferCmdExecutor child_commands)
741+ ChildCmd::ChildCmd (CommandExecutor child_commands)
744742 : Command(CommandType::kChildCmd ),
745743 child_commands_(std::move(child_commands)) {}
746744
747- bool ChildCmd::requires_initialization () {
748- return child_commands_.requires_initialization ();
749- }
750-
751- bool ChildCmd::force_update () { return child_commands_.force_update (); }
752-
753- Command::BufferUses ChildCmd::buffer_uses () const {
754- return {child_commands_.buffer_uses ().begin (),
755- child_commands_.buffer_uses ().end ()};
756- }
757-
758745absl::Status ChildCmd::Initialize (const Thunk::InitializeParams& params) {
759746 TF_RETURN_IF_ERROR (child_commands_.Initialize (params));
760747 return absl::OkStatus ();
@@ -797,8 +784,7 @@ absl::Status ChildCmd::WalkNested(
797784// CaseCmd
798785// ===----------------------------------------------------------------------===//
799786
800- CaseCmd::CaseCmd (ShapedSlice index,
801- std::vector<CommandBufferCmdExecutor> branches)
787+ CaseCmd::CaseCmd (ShapedSlice index, std::vector<CommandExecutor> branches)
802788 : Command(CommandType::kCaseCmd ),
803789 index_(index),
804790 index_is_bool_(index.shape.element_type() == PRED),
@@ -847,23 +833,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CaseCmd::Record(
847833 });
848834}
849835
850- bool CaseCmd::requires_initialization () {
851- return absl::c_any_of (
852- branches_, [](const auto & seq) { return seq.requires_initialization (); });
853- }
854-
855- bool CaseCmd::force_update () {
856- return absl::c_any_of (branches_,
857- [](const auto & seq) { return seq.force_update (); });
858- }
859-
860836Command::BufferUses CaseCmd::buffer_uses () const {
861- absl::flat_hash_set<BufferUse> buffers;
862- buffers.emplace (BufferUse::Read (index_.slice , index_.shape ));
863- for (auto & branch : branches_) {
864- buffers.insert (branch.buffer_uses ().begin (), branch.buffer_uses ().end ());
865- }
866- return {buffers.begin (), buffers.end ()};
837+ return {BufferUse::Read (index_.slice , index_.shape )};
867838}
868839
869840absl::Status CaseCmd::WalkNested (
@@ -878,9 +849,8 @@ absl::Status CaseCmd::WalkNested(
878849// WhileCmd
879850// ===----------------------------------------------------------------------===//
880851
881- WhileCmd::WhileCmd (BufferAllocation::Slice pred,
882- CommandBufferCmdExecutor cond_commands,
883- CommandBufferCmdExecutor body_commands,
852+ WhileCmd::WhileCmd (BufferAllocation::Slice pred, CommandExecutor cond_commands,
853+ CommandExecutor body_commands,
884854 std::optional<int64_t > trip_count, bool enable_loop_unroll)
885855 : Command(CommandType::kWhileCmd ),
886856 pred_(pred),
@@ -1005,23 +975,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> WhileCmd::Record(
1005975 });
1006976}
1007977
1008- bool WhileCmd::requires_initialization () {
1009- return (cond_commands_.requires_initialization () ||
1010- body_commands_.requires_initialization ());
1011- }
1012-
1013- bool WhileCmd::force_update () {
1014- return cond_commands_.force_update () || body_commands_.force_update ();
1015- }
1016-
1017978Command::BufferUses WhileCmd::buffer_uses () const {
1018- absl::flat_hash_set<BufferUse> buffers;
1019- buffers.emplace (BufferUse::Read (pred_, ShapeUtil::MakeShape (PRED, {})));
1020- buffers.insert (cond_commands_.buffer_uses ().begin (),
1021- cond_commands_.buffer_uses ().end ());
1022- buffers.insert (body_commands_.buffer_uses ().begin (),
1023- body_commands_.buffer_uses ().end ());
1024- return {buffers.begin (), buffers.end ()};
979+ return {BufferUse::Read (pred_, ShapeUtil::MakeShape (PRED, {}))};
1025980}
1026981
1027982absl::Status WhileCmd::WalkNested (
@@ -1086,16 +1041,16 @@ absl::StatusOr<const se::CommandBuffer::Command*> GemmCmd::Record(
10861041}
10871042
10881043Command::BufferUses GemmCmd::buffer_uses () const {
1089- Command::BufferUses res {
1044+ Command::BufferUses buffer_uses = {
10901045 BufferUse::Read (lhs_buffer_, config_.lhs_layout .ToShape ()),
10911046 BufferUse::Read (rhs_buffer_, config_.rhs_layout .ToShape ()),
10921047 BufferUse::Write (output_buffer_, config_.output_layout .ToShape ()),
10931048 };
10941049 if (workspace_.has_value ()) {
1095- res .push_back (BufferUse::Write (
1050+ buffer_uses .push_back (BufferUse::Write (
10961051 *workspace_, ShapeUtil::MakeShape (S8, {workspace_->size ()})));
10971052 }
1098- return res ;
1053+ return buffer_uses ;
10991054}
11001055
11011056// ===----------------------------------------------------------------------===//
@@ -2156,7 +2111,7 @@ Command::BufferUses CollectivePermuteCmd::buffer_uses() const {
21562111// ===----------------------------------------------------------------------===//
21572112
21582113DynamicSliceFusionCmd::DynamicSliceFusionCmd (
2159- CommandBufferCmdExecutor embedded_commands,
2114+ CommandExecutor embedded_commands,
21602115 std::vector<std::optional<BufferAllocation::Slice>> arguments,
21612116 std::vector<BufferAllocation> fake_allocations,
21622117 std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets,
@@ -2203,7 +2158,7 @@ DynamicSliceFusionCmd::DynamicSliceFusionCmd(
22032158// because the memory address might changed if the offset is loop
22042159// iterator or operator outputs even if the parent command's memory pointers
22052160// do not change.
2206- bool DynamicSliceFusionCmd::requires_initialization () {
2161+ bool DynamicSliceFusionCmd::requires_initialization () const {
22072162 return !absl::c_all_of (slices_, [](const DynamicSliceThunk::SliceDef& slice) {
22082163 if (!slice.offsets .has_value ()) {
22092164 return true ;
@@ -2455,17 +2410,6 @@ absl::StatusOr<const se::CommandBuffer::Command*> DynamicSliceFusionCmd::Record(
24552410 });
24562411}
24572412
2458- Command::BufferUses DynamicSliceFusionCmd::buffer_uses () const {
2459- Command::BufferUses buffers;
2460- auto embed_buffers = embedded_commands_.buffer_uses ();
2461- for (const BufferUse& buffer_usage : embed_buffers) {
2462- buffers.emplace_back (
2463- *embedded_to_origin_slice_map_.at (buffer_usage.slice ().index ()),
2464- buffer_usage.access (), buffer_usage.shape ());
2465- }
2466- return buffers;
2467- }
2468-
24692413absl::Status DynamicSliceFusionCmd::WalkNested (
24702414 absl::FunctionRef<absl::Status(Command*)> callback) {
24712415 return embedded_commands_.Walk (callback);
0 commit comments