Skip to content

Commit 430a9c1

Browse files
ermilovmaximGoogle-ML-Automation
authored andcommitted
triton 1 2 3
PiperOrigin-RevId: 686225992
1 parent ea91cef commit 430a9c1

File tree

9 files changed

+39
-47
lines changed

9 files changed

+39
-47
lines changed

xla/backends/gpu/codegen/copy.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ absl::StatusOr<FusionEmissionResult> MemcpyFusion::Emit(
119119
return absl::OkStatus();
120120
}));
121121

122-
FusionEmissionResult result;
122+
ThunkSequence thunks;
123123
for (int i = 0; i < src_buffers.size(); ++i) {
124124
if (src_buffers[i] != dst_buffers[i]) {
125-
result.thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
125+
thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
126126
Thunk::ThunkInfo::WithProfileAnnotation(
127127
&fusion, ir_emitter_context.GetNextThunkId()),
128128
/*source_buffer=*/ShapedSlice{src_buffers[i], src_shapes[i]},
129129
/*destination_buffer=*/ShapedSlice{dst_buffers[i], src_shapes[i]},
130130
/*mem_size=*/src_buffers[i].size()));
131131
}
132132
}
133-
return result;
133+
return FusionEmissionResult{std::move(thunks)};
134134
}
135135

136136
absl::StatusOr<FusionEmissionResult> DynamicMemcpyFusion::Emit(
@@ -185,8 +185,6 @@ absl::StatusOr<FusionEmissionResult> DynamicMemcpyFusion::Emit(
185185
ir_emitter_context.buffer_assignment().GetShapeForUniqueSlice(&fusion,
186186
{}));
187187

188-
FusionEmissionResult result;
189-
190188
ASSIGN_OR_RETURN(auto config, fusion.backend_config<GpuBackendConfig>());
191189
const auto& memcpy_config =
192190
config.fusion_backend_config().dynamic_memcpy_config();
@@ -197,7 +195,8 @@ absl::StatusOr<FusionEmissionResult> DynamicMemcpyFusion::Emit(
197195
absl::c_copy(memcpy_config.dst_offset_bytes(),
198196
std::back_inserter(offsets.dst_offsets));
199197

200-
result.thunks.emplace_back(std::make_unique<DynamicMemcpyThunk>(
198+
FusionEmissionResult result;
199+
result.thunks = ThunkSequence::Of(std::make_unique<DynamicMemcpyThunk>(
201200
Thunk::ThunkInfo::WithProfileAnnotation(
202201
&fusion, ir_emitter_context.GetNextThunkId()),
203202
/*source_buffer=*/ShapedSlice{src_buffer, src_shape},

xla/backends/gpu/codegen/cudnn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ absl::StatusOr<FusionEmissionResult> CuDnnFusion::Emit(
4242
emitters::KernelArguments::Create(ir_emitter_context.buffer_assignment(),
4343
GetDefaultBufferAlignment(), &fusion));
4444
FusionEmissionResult result;
45-
result.thunks.emplace_back(std::make_unique<CuDnnThunk>(
45+
result.thunks = ThunkSequence::Of(std::make_unique<CuDnnThunk>(
4646
emitters::GetComputationFingerprint(
4747
fusion.fused_instructions_computation(), {}),
4848
Thunk::ThunkInfo::WithProfileAnnotation(

xla/backends/gpu/codegen/custom.cc

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
733733
}
734734

735735
FusionEmissionResult result;
736-
result.thunks.push_back(std::move(thunk));
736+
result.thunks = ThunkSequence::Of(std::move(thunk));
737737
return result;
738738
}
739739

@@ -1051,7 +1051,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
10511051
}
10521052

10531053
FusionEmissionResult result;
1054-
result.thunks.push_back(std::move(thunk));
1054+
result.thunks = ThunkSequence::Of(std::move(thunk));
10551055
return result;
10561056
}
10571057

@@ -1257,8 +1257,6 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
12571257
Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(
12581258
instr, ir_emitter_context.GetNextThunkId());
12591259

1260-
FusionEmissionResult result;
1261-
12621260
// First we get the thunk sequence. This decides whether to generate a d2d
12631261
// copy thunk or collective thunk.
12641262
ThunkSequence seq;
@@ -1312,6 +1310,7 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
13121310
return implementable_status;
13131311
}
13141312

1313+
FusionEmissionResult result;
13151314
// Depending on whether this is a dynamic fusion or not, we wrap the
13161315
// thunk(s) within a dynamic-slice thunk.
13171316
if (slice_data.isDynamic) {
@@ -1333,11 +1332,9 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
13331332
std::move(slice_data.orig_shapes), std::move(slice_data.sliced_shapes),
13341333
std::move(slice_data.offset_primitive_types),
13351334
std::move(offset_modules_metadata));
1336-
result.thunks.push_back(std::move(thunk));
1335+
result.thunks = ThunkSequence::Of(std::move(thunk));
13371336
} else {
1338-
for (auto& thunk : seq) {
1339-
result.thunks.push_back(std::move(thunk));
1340-
}
1337+
result.thunks = std::move(seq);
13411338
}
13421339
return result;
13431340
}
@@ -1380,14 +1377,11 @@ absl::StatusOr<FusionEmissionResult> CustomFusion::Emit(
13801377
" returned empty custom kernels for a fused computation"));
13811378
}
13821379

1383-
TF_ASSIGN_OR_RETURN(auto thunk,
1384-
BuildCustomKernelThunkForFusion(
1385-
ir_emitter_context, fusion,
1386-
std::move(kernels[config.kernel_index()])));
1380+
ASSIGN_OR_RETURN(auto thunk, BuildCustomKernelThunkForFusion(
1381+
ir_emitter_context, fusion,
1382+
std::move(kernels[config.kernel_index()])));
13871383

1388-
FusionEmissionResult result;
1389-
result.thunks.push_back(std::move(thunk));
1390-
return result;
1384+
return FusionEmissionResult{ThunkSequence::Of(std::move(thunk))};
13911385
}
13921386

13931387
absl::StatusOr<FusionEmissionResult> DynamicSliceFusion::Emit(

xla/backends/gpu/codegen/emitters/mlir_kernel_emitter.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,15 +354,15 @@ absl::StatusOr<FusionEmissionResult> MlirKernelFusion::Emit(
354354
ir_emitter_context.gpu_device_info());
355355
return entry;
356356
});
357-
TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
357+
ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
358358

359359
if (cached) {
360360
VLOG(3) << "Reuse: " << fusion.name() << " -> " << entry->kernel_name;
361361
}
362362

363363
FusionEmissionResult result;
364364
result.module = std::move(module);
365-
result.thunks.emplace_back(std::make_unique<KernelThunk>(
365+
result.thunks = ThunkSequence::Of(std::make_unique<KernelThunk>(
366366
Thunk::ThunkInfo::WithProfileAnnotation(
367367
&fusion, ir_emitter_context.GetNextThunkId()),
368368
entry->kernel_name, args, launch_dims, entry->cluster_dim,

xla/backends/gpu/codegen/fusion_emitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ namespace xla {
4545
namespace gpu {
4646

4747
struct FusionEmissionResult {
48+
AsyncThunkSequence thunks;
4849
std::unique_ptr<llvm::Module> module;
49-
ThunkSequence thunks;
5050
};
5151

5252
class FusionInterface {

xla/backends/gpu/codegen/sort.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,25 @@ absl::StatusOr<FusionEmissionResult> SortFusion::Emit(
8888
}
8989
}
9090

91-
FusionEmissionResult result;
91+
ThunkSequence thunks;
9292
for (int i = 0; i < src_buffers.size(); ++i) {
9393
if (src_buffers[i] != dst_buffers[i]) {
94-
result.thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
94+
thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
9595
Thunk::ThunkInfo::WithProfileAnnotation(
9696
&fusion, ir_emitter_context.GetNextThunkId()),
9797
/*source_buffer=*/ShapedSlice{src_buffers[i], src_shapes[i]},
9898
/*destination_buffer=*/ShapedSlice{dst_buffers[i], src_shapes[i]},
9999
/*mem_size=*/src_buffers[i].size()));
100100
}
101101
}
102-
std::string op_name(sort->name());
103-
result.module = ir_emitter_context.CreateLLVMModule(op_name);
104-
ASSIGN_OR_RETURN(ThunkSequence sort_thunks,
105-
EmitBitonicSortLLVMIR(sort, &ir_emitter_context).Await());
106-
result.thunks.insert(result.thunks.end(),
107-
std::make_move_iterator(sort_thunks.begin()),
108-
std::make_move_iterator(sort_thunks.end()));
109-
return result;
102+
return FusionEmissionResult{
103+
EmitBitonicSortLLVMIR(sort, &ir_emitter_context)
104+
.Map([thunks = std::move(thunks)](ThunkSequence sort_thunks) mutable {
105+
thunks.insert(thunks.end(),
106+
std::make_move_iterator(sort_thunks.begin()),
107+
std::make_move_iterator(sort_thunks.end()));
108+
return std::move(thunks);
109+
})};
110110
}
111111

112112
} // namespace gpu

xla/backends/gpu/codegen/triton/fusion.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ TritonFusion::GenerateTritonKernelAndWrapper(
118118
absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
119119
IrEmitterContext& ir_emitter_context,
120120
const HloFusionInstruction& fusion) const {
121-
TF_ASSIGN_OR_RETURN(EmitResult kernel_and_module,
122-
Emit(ir_emitter_context, fusion, nullptr, {}));
121+
ASSIGN_OR_RETURN(EmitResult kernel_and_module,
122+
Emit(ir_emitter_context, fusion, nullptr, {}));
123123
FusionEmissionResult result;
124-
result.thunks.push_back(std::move(kernel_and_module.kernel_thunk));
124+
result.thunks = ThunkSequence::Of(std::move(kernel_and_module.kernel_thunk));
125125
result.module = std::move(kernel_and_module.llvm_module);
126126
return result;
127127
}

xla/service/gpu/thunk_emitter.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,8 +1346,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitAsyncComputation(
13461346
return GetThunkSequence(std::move(start_thunk));
13471347
}
13481348

1349-
absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusion(
1350-
const HloFusionInstruction* instr) {
1349+
AsyncThunkSequence ThunkEmitter::EmitFusion(const HloFusionInstruction* instr) {
13511350
const se::DeviceDescription& device_info =
13521351
ir_emitter_context_->gpu_device_info();
13531352
const HloFusionAnalysis fusion_analysis =
@@ -1360,7 +1359,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusion(
13601359
&ir_emitter_context_->buffer_assignment(),
13611360
/*call_graph=*/*call_graph_),
13621361
ir_emitter_context_->mlir_context());
1363-
TF_ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));
1362+
ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));
13641363

13651364
// Use override flag because libdevice functions can be present in both.
13661365
if (result.module) {
@@ -2516,14 +2515,14 @@ AsyncThunkSequence ThunkEmitter::EmitAsyncStart(const HloInstruction* instr) {
25162515
std::nullopt);
25172516
}
25182517
case HloOpcode::kFusion: {
2519-
TF_ASSIGN_OR_RETURN(ThunkSequence fusion_thunks,
2520-
EmitFusion(Cast<HloFusionInstruction>(wrapped)));
2518+
ASSIGN_OR_RETURN(ThunkSequence fusion_thunks,
2519+
EmitFusion(Cast<HloFusionInstruction>(wrapped)).Await());
25212520

25222521
auto* async_start = Cast<HloAsyncInstruction>(instr);
25232522
const ExecutionStreamAssignment& stream_assignment =
25242523
ir_emitter_context_->execution_stream_assignment();
2525-
TF_ASSIGN_OR_RETURN(ExecutionStreamId execution_stream_id,
2526-
stream_assignment.GetExecutionStreamId(async_start));
2524+
ASSIGN_OR_RETURN(ExecutionStreamId execution_stream_id,
2525+
stream_assignment.GetExecutionStreamId(async_start));
25272526

25282527
auto start_thunk = std::make_unique<AsyncStartThunk>(
25292528
Thunk::ThunkInfo::WithProfileAnnotation(
@@ -2537,7 +2536,7 @@ AsyncThunkSequence ThunkEmitter::EmitAsyncStart(const HloInstruction* instr) {
25372536
wrapped->ToString());
25382537
}
25392538

2540-
return GetThunkSequence(std::move(start_thunk));
2539+
return ThunkSequence::Of(std::move(start_thunk));
25412540
}
25422541
case HloOpcode::kCall: {
25432542
return EmitAsyncComputation(instr);

xla/service/gpu/thunk_emitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class ThunkEmitter {
153153
std::vector<CollectiveThunk::Buffer>& buffers,
154154
const HloInstruction* async_start, const HloInstType* inst);
155155

156-
absl::StatusOr<ThunkSequence> EmitFusion(const HloFusionInstruction* hlo);
156+
AsyncThunkSequence EmitFusion(const HloFusionInstruction* instr);
157157

158158
absl::StatusOr<ThunkSequence> EmitFftThunk(const HloFftInstruction* hlo);
159159

0 commit comments

Comments
 (0)