@@ -30,6 +30,7 @@ limitations under the License.
3030#include " absl/strings/str_cat.h"
3131#include " absl/strings/string_view.h"
3232#include " absl/types/span.h"
33+ #include " xla/tsl/platform/status_macros.h"
3334#include " llvm/ADT/STLExtras.h"
3435#include " mlir/AsmParser/AsmParser.h"
3536#include " mlir/IR/Attributes.h"
@@ -95,7 +96,7 @@ constexpr unsigned kGEMMWorkspaceBufferIndex = 1;
9596absl::StatusOr<std::unique_ptr<Thunk>> BuildCustomKernelThunkForFusion (
9697 IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
9798 CustomKernel custom_kernel) {
98- TF_ASSIGN_OR_RETURN (
99+ ASSIGN_OR_RETURN (
99100 auto kernel_arguments,
100101 emitters::KernelArguments::Create (ir_emitter_context.buffer_assignment (),
101102 GetDefaultBufferAlignment (), &fusion));
@@ -154,7 +155,7 @@ absl::StatusOr<BufferAllocation::Slice> GetOperandSlice(
154155 const auto * param = Cast<HloParameterInstruction>(slice_instr->operand (0 ));
155156 // At this point we've walked through all `shape_idx`, `index` should be
156157 // empty.
157- TF_ASSIGN_OR_RETURN (
158+ ASSIGN_OR_RETURN (
158159 BufferAllocation::Slice orig_slice,
159160 GetAllocationSlice (buffer_assignment,
160161 fusion_instr.operand (param->parameter_number ()),
@@ -731,9 +732,7 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
731732 deterministic_ops);
732733 }
733734
734- FusionEmissionResult result;
735- result.thunks .push_back (std::move (thunk));
736- return result;
735+ return FusionEmissionResult{ThunkSequence::Of (std::move (thunk))};
737736}
738737
739738absl::StatusOr<FusionEmissionResult> EmitCustomCall (
@@ -1049,9 +1048,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
10491048 : legacy_thunk (std::move (operands), std::move (results)));
10501049 }
10511050
1052- FusionEmissionResult result;
1053- result.thunks .push_back (std::move (thunk));
1054- return result;
1051+ return FusionEmissionResult{ThunkSequence::Of (std::move (thunk))};
10551052}
10561053
10571054using Slice = std::optional<BufferAllocation::Slice>;
@@ -1256,8 +1253,6 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
12561253 Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation (
12571254 instr, ir_emitter_context.GetNextThunkId ());
12581255
1259- FusionEmissionResult result;
1260-
12611256 // First we get the thunk sequence. This decides whether to generate a d2d
12621257 // copy thunk or collective thunk.
12631258 ThunkSequence seq;
@@ -1311,6 +1306,7 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
13111306 return implementable_status;
13121307 }
13131308
1309+ FusionEmissionResult result;
13141310 // Depending on whether this is a dynamic fusion or not, we wrap the
13151311 // thunk(s) within a dynamic-slice thunk.
13161312 if (slice_data.isDynamic ) {
@@ -1332,11 +1328,9 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
13321328 std::move (slice_data.orig_shapes ), std::move (slice_data.sliced_shapes ),
13331329 std::move (slice_data.offset_primitive_types ),
13341330 std::move (offset_modules_metadata));
1335- result.thunks . push_back (std::move (thunk));
1331+ result.thunks = ThunkSequence::Of (std::move (thunk));
13361332 } else {
1337- for (auto & thunk : seq) {
1338- result.thunks .push_back (std::move (thunk));
1339- }
1333+ result.thunks = std::move (seq);
13401334 }
13411335 return result;
13421336}
0 commit comments