Skip to content

Commit 51c4041

Browse files
[XLA:GPU]: Add some debugging information for triton collective kernels
PiperOrigin-RevId: 900676743
1 parent 147ce73 commit 51c4041

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

xla/service/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,7 @@ xla_cc_test(
23182318
"//xla/tests:xla_internal_test_main",
23192319
"//xla/tsl/platform:statusor",
23202320
"@com_google_absl//absl/log:check",
2321+
"@com_google_absl//absl/strings",
23212322
"@com_google_absl//absl/strings:string_view",
23222323
"@com_google_absl//absl/types:span",
23232324
"@com_google_googletest//:gtest",

xla/service/hlo_creation_utils.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -998,9 +998,11 @@ std::unique_ptr<HloModule> NewModuleWithFusion(
998998
HloComputation::Builder entry_builder("entry");
999999
std::vector<HloInstruction*> entry_parameters =
10001000
build_parameter_instructions(entry_builder);
1001-
HloInstruction* fusion_instruction = entry_builder.AddInstruction(
1002-
HloInstruction::CreateFusion(instruction->shape(), fusion_kind,
1003-
entry_parameters, fused_computation));
1001+
HloInstruction* fusion_instruction =
1002+
entry_builder.AddInstruction(HloInstruction::CreateFusion(
1003+
instruction->shape(), fusion_kind, entry_parameters,
1004+
fused_computation,
1005+
/*prefix=*/absl::StrCat(instruction->name(), "-")));
10041006

10051007
hlo_module->AddEntryComputation(entry_builder.Build(fusion_instruction));
10061008

xla/service/hlo_creation_utils_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include <gtest/gtest.h>
2222
#include "absl/log/check.h"
23+
#include "absl/strings/str_cat.h"
2324
#include "absl/strings/string_view.h"
2425
#include "absl/types/span.h"
2526
#include "xla/array2d.h"
@@ -602,6 +603,8 @@ TEST_F(HloCreationUtilsTest, NewModuleWithFusion) {
602603
EXPECT_EQ(to_apply->name(), "apply_op");
603604
EXPECT_EQ(to_apply->num_parameters(), 2);
604605
EXPECT_EQ(to_apply->root_instruction()->opcode(), HloOpcode::kAdd);
606+
EXPECT_EQ(fusion_module->entry_computation()->root_instruction()->name(),
607+
absl::StrCat(all_reduce_start->name(), "-", "fusion"));
605608
}
606609
} // namespace
607610
} // namespace xla

xla/stream_executor/cuda/cuda_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ absl::StatusOr<CUfunction> GetModuleFunction(Context* context, CUmodule module,
252252
TF_RETURN_IF_ERROR(cuda::ToStatus(
253253
cuModuleGetFunction(&function, module, kernel_name),
254254
absl::StrCat(xla::XlaFormatDevice(context->device_ordinal()),
255-
"Failed to get module function")));
255+
"Failed to get module function ", kernel_name)));
256256
return function;
257257
}
258258

0 commit comments

Comments
 (0)