Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions xla/backends/gpu/runtime/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ bool IsCollectiveCommand(CommandType type);
class Command {
public:
using BufferUses = Thunk::BufferUses;
using ResourceUseVector = absl::InlinedVector<ResourceUse, 1>;
using ResourceUses = absl::InlinedVector<ResourceUse, 1>;

public:
explicit Command(CommandType cmd_type,
se::StreamPriority priority = se::StreamPriority::Default)
: cmd_type_(cmd_type), priority_(priority) {
token_ = Resource::Create(Resource::kToken);
resources_.push_back(ResourceUse::Write(token_));
resource_uses_.push_back(ResourceUse::Write(token_));
}

virtual ~Command() = default;
Expand Down Expand Up @@ -213,29 +213,34 @@ class Command {
// they got lucky and got the same buffer allocations), it will lead to
// deadlocks. By forcing the command update at thunk initialization time, we
// ensure that all ranks execute NCCL command update.
virtual bool requires_initialization() { return false; }
virtual bool requires_initialization() const { return false; }

// Returns true if command supports loop unroll, the while loop can be
// unrolled only if it has pre-known trip count and also all commands from the
// body commands are unrollable.
virtual bool support_loop_unroll() { return true; }
virtual bool support_loop_unroll() const { return true; }

// This is only true for DynamicSliceCopyFusionCmd when offset is dependents
// on loop iteration. As the command of slice operation is access the sliced
// memory region that varies across loop iterations, so even the original
// buffer allocation is the same, it still requires to do update.
virtual bool force_update() { return false; }
virtual bool force_update() const { return false; }

// Returns all buffers used by the cmd. These will be used to track cmd
// updates, thus they need to be consistent across calls to the function.
// Returns buffers used by this command. Buffer uses do not include buffers
// that might be used by nested commands, they must be collected separately
// by walking the nested commands using `Walk` API.
virtual BufferUses buffer_uses() const { return {}; }

std::shared_ptr<Resource> token() const { return token_; }

void add_resource_use(ResourceUse resource_use) {
resources_.push_back(resource_use);
resource_uses_.push_back(resource_use);
}
ResourceUseVector resources() const { return resources_; }

// Returns resource used by this command. Resource uses do not include
// resources that might be used by nested commands, they must be collected
// separately by walking the nested commands using `Walk` API.
ResourceUses resource_uses() const { return resource_uses_; }

// Returns true if command implemented as a nested command buffer.
virtual bool IsNestedCommandBuffer() const { return false; }
Expand Down Expand Up @@ -275,7 +280,7 @@ class Command {
std::string profile_annotation_;
CommandType cmd_type_;

ResourceUseVector resources_;
ResourceUses resource_uses_;

// The token resource is used to specify additional dependency across
// commands, like control dependency across HLO operators, and LHS scheduling
Expand Down Expand Up @@ -378,6 +383,18 @@ class CommandSequence : public std::vector<std::unique_ptr<Command>> {
}
return absl::OkStatus();
}

void Walk(absl::FunctionRef<void(const Command*)> callback) const {
for (const std::unique_ptr<Command>& cmd : *this) {
cmd->Walk(callback);
}
}

void Walk(absl::FunctionRef<void(Command*)> callback) {
for (std::unique_ptr<Command>& cmd : *this) {
cmd->Walk(callback);
}
}
};

} // namespace xla::gpu
Expand Down
90 changes: 17 additions & 73 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ static absl::string_view ReductionKindString(ReductionKind kind) {

// Create a callback to create a command buffer from a command sequence.
static se::CommandBuffer::CreateCommands CreateCommands(
const CommandBufferCmdExecutor* commands,
const Thunk::ExecuteParams* execute_params,
const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params,
const Command::RecordParams* record_params) {
return [=](se::CommandBuffer* command_buffer,
absl::Span<const se::CommandBuffer::Command* const> dependencies) {
Expand All @@ -164,11 +163,11 @@ static se::CommandBuffer::CreateCommands CreateCommands(

// Create callbacks to create a command buffer from command sequences.
static std::vector<se::CommandBuffer::CreateCommands> CreateCommands(
absl::Span<const CommandBufferCmdExecutor> commands,
absl::Span<const CommandExecutor> commands,
const Thunk::ExecuteParams* execute_params,
const Command::RecordParams* record_params) {
std::vector<se::CommandBuffer::CreateCommands> create_commands;
for (const CommandBufferCmdExecutor& cmd : commands) {
for (const CommandExecutor& cmd : commands) {
create_commands.push_back(
CreateCommands(&cmd, execute_params, record_params));
}
Expand All @@ -177,8 +176,7 @@ static std::vector<se::CommandBuffer::CreateCommands> CreateCommands(

// Create a callback to update a command buffer with command sequence.
static se::CommandBuffer::UpdateCommands UpdateCommands(
const CommandBufferCmdExecutor* commands,
const Thunk::ExecuteParams* execute_params,
const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params,
const Command::RecordParams* record_params) {
return [=](se::CommandBuffer* command_buffer) {
return commands->RecordUpdate(*execute_params, *record_params,
Expand All @@ -188,11 +186,11 @@ static se::CommandBuffer::UpdateCommands UpdateCommands(

// Create callbacks to update a command buffer with command sequence.
static std::vector<se::CommandBuffer::UpdateCommands> UpdateCommands(
absl::Span<const CommandBufferCmdExecutor> commands,
absl::Span<const CommandExecutor> commands,
const Thunk::ExecuteParams* execute_params,
const Command::RecordParams* record_params) {
std::vector<se::CommandBuffer::UpdateCommands> update_commands;
for (const CommandBufferCmdExecutor& cmd : commands) {
for (const CommandExecutor& cmd : commands) {
update_commands.push_back(
UpdateCommands(&cmd, execute_params, record_params));
}
Expand Down Expand Up @@ -740,21 +738,10 @@ Command::BufferUses Memset32Cmd::buffer_uses() const {
// ChildCmd
//===----------------------------------------------------------------------===//

ChildCmd::ChildCmd(CommandBufferCmdExecutor child_commands)
ChildCmd::ChildCmd(CommandExecutor child_commands)
: Command(CommandType::kChildCmd),
child_commands_(std::move(child_commands)) {}

bool ChildCmd::requires_initialization() {
return child_commands_.requires_initialization();
}

bool ChildCmd::force_update() { return child_commands_.force_update(); }

Command::BufferUses ChildCmd::buffer_uses() const {
return {child_commands_.buffer_uses().begin(),
child_commands_.buffer_uses().end()};
}

absl::Status ChildCmd::Initialize(const Thunk::InitializeParams& params) {
TF_RETURN_IF_ERROR(child_commands_.Initialize(params));
return absl::OkStatus();
Expand Down Expand Up @@ -797,8 +784,7 @@ absl::Status ChildCmd::WalkNested(
// CaseCmd
//===----------------------------------------------------------------------===//

CaseCmd::CaseCmd(ShapedSlice index,
std::vector<CommandBufferCmdExecutor> branches)
CaseCmd::CaseCmd(ShapedSlice index, std::vector<CommandExecutor> branches)
: Command(CommandType::kCaseCmd),
index_(index),
index_is_bool_(index.shape.element_type() == PRED),
Expand Down Expand Up @@ -847,23 +833,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CaseCmd::Record(
});
}

bool CaseCmd::requires_initialization() {
return absl::c_any_of(
branches_, [](const auto& seq) { return seq.requires_initialization(); });
}

bool CaseCmd::force_update() {
return absl::c_any_of(branches_,
[](const auto& seq) { return seq.force_update(); });
}

Command::BufferUses CaseCmd::buffer_uses() const {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(BufferUse::Read(index_.slice, index_.shape));
for (auto& branch : branches_) {
buffers.insert(branch.buffer_uses().begin(), branch.buffer_uses().end());
}
return {buffers.begin(), buffers.end()};
return {BufferUse::Read(index_.slice, index_.shape)};
}

absl::Status CaseCmd::WalkNested(
Expand All @@ -878,9 +849,8 @@ absl::Status CaseCmd::WalkNested(
// WhileCmd
//===----------------------------------------------------------------------===//

WhileCmd::WhileCmd(BufferAllocation::Slice pred,
CommandBufferCmdExecutor cond_commands,
CommandBufferCmdExecutor body_commands,
WhileCmd::WhileCmd(BufferAllocation::Slice pred, CommandExecutor cond_commands,
CommandExecutor body_commands,
std::optional<int64_t> trip_count, bool enable_loop_unroll)
: Command(CommandType::kWhileCmd),
pred_(pred),
Expand Down Expand Up @@ -1005,23 +975,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> WhileCmd::Record(
});
}

bool WhileCmd::requires_initialization() {
return (cond_commands_.requires_initialization() ||
body_commands_.requires_initialization());
}

bool WhileCmd::force_update() {
return cond_commands_.force_update() || body_commands_.force_update();
}

Command::BufferUses WhileCmd::buffer_uses() const {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(BufferUse::Read(pred_, ShapeUtil::MakeShape(PRED, {})));
buffers.insert(cond_commands_.buffer_uses().begin(),
cond_commands_.buffer_uses().end());
buffers.insert(body_commands_.buffer_uses().begin(),
body_commands_.buffer_uses().end());
return {buffers.begin(), buffers.end()};
return {BufferUse::Read(pred_, ShapeUtil::MakeShape(PRED, {}))};
}

absl::Status WhileCmd::WalkNested(
Expand Down Expand Up @@ -1086,16 +1041,16 @@ absl::StatusOr<const se::CommandBuffer::Command*> GemmCmd::Record(
}

Command::BufferUses GemmCmd::buffer_uses() const {
Command::BufferUses res{
Command::BufferUses buffer_uses = {
BufferUse::Read(lhs_buffer_, config_.lhs_layout.ToShape()),
BufferUse::Read(rhs_buffer_, config_.rhs_layout.ToShape()),
BufferUse::Write(output_buffer_, config_.output_layout.ToShape()),
};
if (workspace_.has_value()) {
res.push_back(BufferUse::Write(
buffer_uses.push_back(BufferUse::Write(
*workspace_, ShapeUtil::MakeShape(S8, {workspace_->size()})));
}
return res;
return buffer_uses;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2156,7 +2111,7 @@ Command::BufferUses CollectivePermuteCmd::buffer_uses() const {
//===----------------------------------------------------------------------===//

DynamicSliceFusionCmd::DynamicSliceFusionCmd(
CommandBufferCmdExecutor embedded_commands,
CommandExecutor embedded_commands,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<BufferAllocation> fake_allocations,
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets,
Expand Down Expand Up @@ -2203,7 +2158,7 @@ DynamicSliceFusionCmd::DynamicSliceFusionCmd(
// because the memory address might changed if the offset is loop
// iterator or operator outputs even if the parent command's memory pointers
// do not change.
bool DynamicSliceFusionCmd::requires_initialization() {
bool DynamicSliceFusionCmd::requires_initialization() const {
return !absl::c_all_of(slices_, [](const DynamicSliceThunk::SliceDef& slice) {
if (!slice.offsets.has_value()) {
return true;
Expand Down Expand Up @@ -2455,17 +2410,6 @@ absl::StatusOr<const se::CommandBuffer::Command*> DynamicSliceFusionCmd::Record(
});
}

Command::BufferUses DynamicSliceFusionCmd::buffer_uses() const {
Command::BufferUses buffers;
auto embed_buffers = embedded_commands_.buffer_uses();
for (const BufferUse& buffer_usage : embed_buffers) {
buffers.emplace_back(
*embedded_to_origin_slice_map_.at(buffer_usage.slice().index()),
buffer_usage.access(), buffer_usage.shape());
}
return buffers;
}

absl::Status DynamicSliceFusionCmd::WalkNested(
absl::FunctionRef<absl::Status(Command*)> callback) {
return embedded_commands_.Walk(callback);
Expand Down
40 changes: 4 additions & 36 deletions xla/backends/gpu/runtime/command_buffer_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ class EmptyCmd : public Command {
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

BufferUses buffer_uses() const override { return {}; }
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -303,14 +301,6 @@ class ChildCmd : public Command {
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

bool requires_initialization() override;

bool force_update() override;

bool support_loop_unroll() override { return false; }

BufferUses buffer_uses() const override;

absl::Status WalkNested(
absl::FunctionRef<absl::Status(Command*)> callback) override;

Expand All @@ -333,12 +323,6 @@ class CaseCmd : public Command {
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

bool requires_initialization() override;

bool force_update() override;

bool support_loop_unroll() override { return false; }

BufferUses buffer_uses() const override;

absl::Status WalkNested(
Expand Down Expand Up @@ -370,14 +354,6 @@ class WhileCmd : public Command {
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

bool requires_initialization() override;

bool force_update() override;

// We have not tried unrolling the loop inside another loop, so marking it
// unsupported for now.
bool support_loop_unroll() override { return false; }

BufferUses buffer_uses() const override;

absl::Status WalkNested(
Expand Down Expand Up @@ -566,7 +542,7 @@ class CollectiveCmd : public AsyncStartCommand {

absl::Status Prepare(const Thunk::PrepareParams& params) final;

bool requires_initialization() override { return true; }
bool requires_initialization() const final { return true; }

bool IsNestedCommandBuffer() const final { return true; }

Expand Down Expand Up @@ -604,8 +580,6 @@ class CollectiveDoneCmd : public AsyncDoneCommand {
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

BufferUses buffer_uses() const override { return {}; }

std::shared_ptr<CollectiveThunk::AsyncEvents> async_events() const {
return async_events_;
}
Expand Down Expand Up @@ -818,13 +792,9 @@ class DynamicSliceFusionCmd : public Command {
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

BufferUses buffer_uses() const override;

bool force_update() override { return true; }
bool force_update() const final { return true; }

bool requires_initialization() override;

bool support_loop_unroll() override { return true; }
bool requires_initialization() const final;

bool IsNestedCommandBuffer() const final { return true; }

Expand Down Expand Up @@ -879,9 +849,7 @@ class DynamicSliceCopyFusionCmd : public Command {
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) override;

bool force_update() override { return offsets_.depends_on_loop; }

bool support_loop_unroll() override { return true; }
bool force_update() const final { return offsets_.depends_on_loop; }

BufferUses buffer_uses() const override;

Expand Down
Loading
Loading