Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ cc_library(
"//xla/stream_executor:stream",
"//xla/tools:hlo_extractor",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status_macros",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
Expand Down
11 changes: 5 additions & 6 deletions xla/backends/gpu/codegen/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ absl::StatusOr<FusionEmissionResult> MemcpyFusion::Emit(
return absl::OkStatus();
}));

FusionEmissionResult result;
ThunkSequence thunks;
for (int i = 0; i < src_buffers.size(); ++i) {
if (src_buffers[i] != dst_buffers[i]) {
result.thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(
&fusion, ir_emitter_context.GetNextThunkId()),
/*source_buffer=*/ShapedSlice{src_buffers[i], src_shapes[i]},
/*destination_buffer=*/ShapedSlice{dst_buffers[i], src_shapes[i]},
/*mem_size=*/src_buffers[i].size()));
}
}
return result;
return FusionEmissionResult{std::move(thunks)};
}

absl::StatusOr<FusionEmissionResult> DynamicMemcpyFusion::Emit(
Expand Down Expand Up @@ -185,8 +185,6 @@ absl::StatusOr<FusionEmissionResult> DynamicMemcpyFusion::Emit(
ir_emitter_context.buffer_assignment().GetShapeForUniqueSlice(&fusion,
{}));

FusionEmissionResult result;

ASSIGN_OR_RETURN(auto config, fusion.backend_config<GpuBackendConfig>());
const auto& memcpy_config =
config.fusion_backend_config().dynamic_memcpy_config();
Expand All @@ -197,7 +195,8 @@ absl::StatusOr<FusionEmissionResult> DynamicMemcpyFusion::Emit(
absl::c_copy(memcpy_config.dst_offset_bytes(),
std::back_inserter(offsets.dst_offsets));

result.thunks.emplace_back(std::make_unique<DynamicMemcpyThunk>(
FusionEmissionResult result;
result.thunks = ThunkSequence::Of(std::make_unique<DynamicMemcpyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(
&fusion, ir_emitter_context.GetNextThunkId()),
/*source_buffer=*/ShapedSlice{src_buffer, src_shape},
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/codegen/cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ absl::StatusOr<FusionEmissionResult> CuDnnFusion::Emit(
emitters::KernelArguments::Create(ir_emitter_context.buffer_assignment(),
GetDefaultBufferAlignment(), &fusion));
FusionEmissionResult result;
result.thunks.emplace_back(std::make_unique<CuDnnThunk>(
result.thunks = ThunkSequence::Of(std::make_unique<CuDnnThunk>(
emitters::GetComputationFingerprint(
fusion.fused_instructions_computation(), {}),
Thunk::ThunkInfo::WithProfileAnnotation(
Expand Down
29 changes: 10 additions & 19 deletions xla/backends/gpu/codegen/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/tsl/platform/status_macros.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -732,9 +733,7 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
deterministic_ops);
}

FusionEmissionResult result;
result.thunks.push_back(std::move(thunk));
return result;
return FusionEmissionResult{ThunkSequence::Of(std::move(thunk))};
}

absl::StatusOr<FusionEmissionResult> EmitCustomCall(
Expand Down Expand Up @@ -1050,9 +1049,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
: legacy_thunk(std::move(operands), std::move(results)));
}

FusionEmissionResult result;
result.thunks.push_back(std::move(thunk));
return result;
return FusionEmissionResult{ThunkSequence::Of(std::move(thunk))};
}

using Slice = std::optional<BufferAllocation::Slice>;
Expand Down Expand Up @@ -1257,8 +1254,6 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(
instr, ir_emitter_context.GetNextThunkId());

FusionEmissionResult result;

// First we get the thunk sequence. This decides whether to generate a d2d
// copy thunk or collective thunk.
ThunkSequence seq;
Expand Down Expand Up @@ -1312,6 +1307,7 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
return implementable_status;
}

FusionEmissionResult result;
// Depending on whether this is a dynamic fusion or not, we wrap the
// thunk(s) within a dynamic-slice thunk.
if (slice_data.isDynamic) {
Expand All @@ -1333,11 +1329,9 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
std::move(slice_data.orig_shapes), std::move(slice_data.sliced_shapes),
std::move(slice_data.offset_primitive_types),
std::move(offset_modules_metadata));
result.thunks.push_back(std::move(thunk));
result.thunks = ThunkSequence::Of(std::move(thunk));
} else {
for (auto& thunk : seq) {
result.thunks.push_back(std::move(thunk));
}
result.thunks = std::move(seq);
}
return result;
}
Expand Down Expand Up @@ -1380,14 +1374,11 @@ absl::StatusOr<FusionEmissionResult> CustomFusion::Emit(
" returned empty custom kernels for a fused computation"));
}

TF_ASSIGN_OR_RETURN(auto thunk,
BuildCustomKernelThunkForFusion(
ir_emitter_context, fusion,
std::move(kernels[config.kernel_index()])));
ASSIGN_OR_RETURN(auto thunk, BuildCustomKernelThunkForFusion(
ir_emitter_context, fusion,
std::move(kernels[config.kernel_index()])));

FusionEmissionResult result;
result.thunks.push_back(std::move(thunk));
return result;
return FusionEmissionResult{ThunkSequence::Of(std::move(thunk))};
}

absl::StatusOr<FusionEmissionResult> DynamicSliceFusion::Emit(
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/gpu/codegen/emitters/mlir_kernel_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,15 @@ absl::StatusOr<FusionEmissionResult> MlirKernelFusion::Emit(
ir_emitter_context.gpu_device_info());
return entry;
});
TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);

if (cached) {
VLOG(3) << "Reuse: " << fusion.name() << " -> " << entry->kernel_name;
}

FusionEmissionResult result;
result.module = std::move(module);
result.thunks.emplace_back(std::make_unique<KernelThunk>(
result.thunks = ThunkSequence::Of(std::make_unique<KernelThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(
&fusion, ir_emitter_context.GetNextThunkId()),
entry->kernel_name, args, launch_dims, entry->cluster_dim,
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/codegen/fusion_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ namespace xla {
namespace gpu {

struct FusionEmissionResult {
AsyncThunkSequence thunks;
std::unique_ptr<llvm::Module> module;
ThunkSequence thunks;
};

class FusionInterface {
Expand Down
21 changes: 11 additions & 10 deletions xla/backends/gpu/codegen/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/check.h"
Expand Down Expand Up @@ -88,25 +89,25 @@ absl::StatusOr<FusionEmissionResult> SortFusion::Emit(
}
}

FusionEmissionResult result;
ThunkSequence thunks;
for (int i = 0; i < src_buffers.size(); ++i) {
if (src_buffers[i] != dst_buffers[i]) {
result.thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(
&fusion, ir_emitter_context.GetNextThunkId()),
/*source_buffer=*/ShapedSlice{src_buffers[i], src_shapes[i]},
/*destination_buffer=*/ShapedSlice{dst_buffers[i], src_shapes[i]},
/*mem_size=*/src_buffers[i].size()));
}
}
std::string op_name(sort->name());
result.module = ir_emitter_context.CreateLLVMModule(op_name);
ASSIGN_OR_RETURN(ThunkSequence sort_thunks,
EmitBitonicSortLLVMIR(sort, &ir_emitter_context).Await());
result.thunks.insert(result.thunks.end(),
std::make_move_iterator(sort_thunks.begin()),
std::make_move_iterator(sort_thunks.end()));
return result;
return FusionEmissionResult{
EmitBitonicSortLLVMIR(sort, &ir_emitter_context)
.Map([thunks = std::move(thunks)](ThunkSequence sort_thunks) mutable {
thunks.insert(thunks.end(),
std::make_move_iterator(sort_thunks.begin()),
std::make_move_iterator(sort_thunks.end()));
return std::move(thunks);
})};
}

} // namespace gpu
Expand Down
11 changes: 5 additions & 6 deletions xla/backends/gpu/codegen/triton/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,11 @@ TritonFusion::GenerateTritonKernelAndWrapper(
absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
TF_ASSIGN_OR_RETURN(EmitResult kernel_and_module,
Emit(ir_emitter_context, fusion, nullptr, {}));
FusionEmissionResult result;
result.thunks.push_back(std::move(kernel_and_module.kernel_thunk));
result.module = std::move(kernel_and_module.llvm_module);
return result;
ASSIGN_OR_RETURN(EmitResult kernel_and_module,
Emit(ir_emitter_context, fusion, nullptr, {}));
return FusionEmissionResult{
ThunkSequence::Of(std::move(kernel_and_module.kernel_thunk)),
std::move(kernel_and_module.llvm_module)};
}

absl::StatusOr<TritonFusion::EmitResult> TritonFusion::Emit(
Expand Down
15 changes: 7 additions & 8 deletions xla/service/gpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1346,8 +1346,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitAsyncComputation(
return GetThunkSequence(std::move(start_thunk));
}

absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusion(
const HloFusionInstruction* instr) {
AsyncThunkSequence ThunkEmitter::EmitFusion(const HloFusionInstruction* instr) {
const se::DeviceDescription& device_info =
ir_emitter_context_->gpu_device_info();
const HloFusionAnalysis fusion_analysis =
Expand All @@ -1360,7 +1359,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusion(
&ir_emitter_context_->buffer_assignment(),
/*call_graph=*/*call_graph_),
ir_emitter_context_->mlir_context());
TF_ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));
ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));

// Use override flag because libdevice functions can be present in both.
if (result.module) {
Expand Down Expand Up @@ -2516,14 +2515,14 @@ AsyncThunkSequence ThunkEmitter::EmitAsyncStart(const HloInstruction* instr) {
std::nullopt);
}
case HloOpcode::kFusion: {
TF_ASSIGN_OR_RETURN(ThunkSequence fusion_thunks,
EmitFusion(Cast<HloFusionInstruction>(wrapped)));
ASSIGN_OR_RETURN(ThunkSequence fusion_thunks,
EmitFusion(Cast<HloFusionInstruction>(wrapped)).Await());

auto* async_start = Cast<HloAsyncInstruction>(instr);
const ExecutionStreamAssignment& stream_assignment =
ir_emitter_context_->execution_stream_assignment();
TF_ASSIGN_OR_RETURN(ExecutionStreamId execution_stream_id,
stream_assignment.GetExecutionStreamId(async_start));
ASSIGN_OR_RETURN(ExecutionStreamId execution_stream_id,
stream_assignment.GetExecutionStreamId(async_start));

auto start_thunk = std::make_unique<AsyncStartThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(
Expand All @@ -2537,7 +2536,7 @@ AsyncThunkSequence ThunkEmitter::EmitAsyncStart(const HloInstruction* instr) {
wrapped->ToString());
}

return GetThunkSequence(std::move(start_thunk));
return ThunkSequence::Of(std::move(start_thunk));
}
case HloOpcode::kCall: {
return EmitAsyncComputation(instr);
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/thunk_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class ThunkEmitter {
std::vector<CollectiveThunk::Buffer>& buffers,
const HloInstruction* async_start, const HloInstType* inst);

absl::StatusOr<ThunkSequence> EmitFusion(const HloFusionInstruction* hlo);
AsyncThunkSequence EmitFusion(const HloFusionInstruction* instr);

absl::StatusOr<ThunkSequence> EmitFftThunk(const HloFftInstruction* hlo);

Expand Down
Loading