Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions xla/backends/cpu/benchmarks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,32 @@ xla_cc_test(
],
)

xla_cc_test(
name = "softmax_benchmark_test",
srcs = ["softmax_benchmark_test.cc"],
deps = [
":aot_benchmark_helper",
":hlo_benchmark_runner",
"//xla:literal",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/service:hlo_proto_cc",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_benchmark//:benchmark",
"@tsl//tsl/platform:stacktrace_handler",
],
)

xla_cc_test(
name = "dynamic_update_slice_benchmark_test",
srcs = ["dynamic_update_slice_benchmark_test.cc"],
Expand Down
239 changes: 239 additions & 0 deletions xla/backends/cpu/benchmarks/softmax_benchmark_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/* Copyright 2026 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <cstdlib>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "benchmark/benchmark.h"
#include "xla/backends/cpu/benchmarks/aot_benchmark_helper.h"
#include "xla/backends/cpu/benchmarks/hlo_benchmark_runner.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/stacktrace_handler.h"

ABSL_FLAG(std::string, shapes,
"{(f32[1024,1024], s32[1]), (f32[1024,4096], s32[1]), "
"(f32[8,1024,1024], s32[2])}",
"List of shapes encoding Softmax (ab)using the shape parser. "
"The first shape is the input shape, the second shape's dimensions "
"are the dimensions to reduce over.");

ABSL_FLAG(int32_t, num_executions, 1,
"Number of times to execute the HLO within a single benchmark "
"iteration.");

ABSL_FLAG(bool, aot_compiled_execution, false,
"If true, when running the benchmark, the HLO will be compiled AOT.");

ABSL_FLAG(std::string, xla_flags, "", "Flags to append to XLA_FLAGS");

namespace xla::cpu {

namespace {

void Set_XLA_FLAGS() {
const char* env_xla_flags = std::getenv("XLA_FLAGS");
std::string xla_flags = absl::StrCat(env_xla_flags ? env_xla_flags : "",
absl::GetFlag(FLAGS_xla_flags));
tsl::setenv("XLA_FLAGS", xla_flags.data(), /*overwrite=*/1);
}

struct Softmax {
Shape input_shape;
std::vector<int64_t> reduction_dims;

PrimitiveType GetDType() const { return input_shape.element_type(); }

Shape GetReductionShape() const {
Shape reduction_shape = input_shape;
reduction_shape.DeleteDimensions(reduction_dims);
return reduction_shape;
}

std::string GetBenchmarkName() const {
return absl::StrCat("BM_Softmax/", input_shape.ToString(), "_{",
absl::StrJoin(reduction_dims, ","), "}");
}
};

Softmax ParseSoftmax(const Shape& s) {
Softmax softmax;
CHECK(s.IsTuple());
CHECK_EQ(s.tuple_shapes().size(), 2);

softmax.input_shape = s.tuple_shapes(0);

const Shape& dims_shape = s.tuple_shapes(1);
absl::Span<const int64_t> dims = dims_shape.dimensions();
softmax.reduction_dims.assign(dims.begin(), dims.end());

return softmax;
}

void BM_Softmax(benchmark::State& state, const Softmax& softmax) {
const std::string input_shape_str = softmax.input_shape.ToString();
const std::string reduction_dims_str =
absl::StrJoin(softmax.reduction_dims, ",");
const std::string dtype_str =
primitive_util::LowercasePrimitiveTypeName(softmax.GetDType());

Shape input_shape_f32 =
ShapeUtil::ChangeElementType(softmax.input_shape, F32);
const std::string input_shape_f32_str = input_shape_f32.ToString();

Shape reduction_shape_f32 =
ShapeUtil::ChangeElementType(softmax.GetReductionShape(), F32);
const std::string reduction_shape_f32_str = reduction_shape_f32.ToString();

std::vector<int64_t> kept_dims;
for (int i = 0; i < softmax.input_shape.dimensions().size(); ++i) {
bool is_reduced = false;
for (int64_t d : softmax.reduction_dims) {
if (i == d) {
is_reduced = true;
break;
}
}
if (!is_reduced) {
kept_dims.push_back(i);
}
}
const std::string kept_dims_str = absl::StrJoin(kept_dims, ",");

absl::string_view hlo_template = R"(
HloModule softmax

reducer_max {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT max = f32[] maximum(lhs, rhs)
}

reducer_add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}

ENTRY main {
input = $input_shape parameter(0)
input_f32 = $input_shape_f32 convert(input)

neg_inf = f32[] constant(-inf)
max_val = $reduction_shape_f32 reduce(input_f32, neg_inf),
dimensions={$reduction_dims}, to_apply=reducer_max
max_br = $input_shape_f32 broadcast(max_val), dimensions={$kept_dims}

input_centered = $input_shape_f32 subtract(input_f32, max_br)
input_exp = $input_shape_f32 exponential(input_centered)

zero = f32[] constant(0)
sum_exp = $reduction_shape_f32 reduce(input_exp, zero),
dimensions={$reduction_dims}, to_apply=reducer_add
sum_exp_br = $input_shape_f32 broadcast(sum_exp), dimensions={$kept_dims}

output_f32 = $input_shape_f32 divide(input_exp, sum_exp_br)
ROOT output = $input_shape convert(output_f32)
}
)";

std::string hlo_data = absl::StrReplaceAll(
hlo_template, {{"$input_shape", input_shape_str},
{"$input_shape_f32", input_shape_f32_str},
{"$reduction_shape_f32", reduction_shape_f32_str},
{"$reduction_dims", reduction_dims_str},
{"$kept_dims", kept_dims_str},
{"$dtype", dtype_str}});

HloBenchmarkOptions benchmark_options;
benchmark_options.num_executions = absl::GetFlag(FLAGS_num_executions);
benchmark_options.aot_options = absl::GetFlag(FLAGS_aot_compiled_execution)
? GetAotCompilationOptions()
: nullptr;

TF_ASSERT_OK_AND_ASSIGN(
auto module_and_iteration_literals,
LoadHloModuleAndMaybeIterationLiteralsFromString(hlo_data));

std::unique_ptr<HloModule> hlo_module =
std::move(module_and_iteration_literals.first);

std::vector<Literal> args;
args.reserve(module_and_iteration_literals.second->arguments_size());
for (const auto& arg : module_and_iteration_literals.second->arguments()) {
TF_ASSERT_OK_AND_ASSIGN(args.emplace_back(), Literal::CreateFromProto(arg));
}

std::vector<Literal*> arg_ptrs;
arg_ptrs.reserve(args.size());
for (auto& arg : args) {
arg_ptrs.push_back(&arg);
}

CHECK_OK(RunHloBenchmark(state, std::move(hlo_module), arg_ptrs,
benchmark_options));
}

void RegisterBenchmarks() {
std::vector<Shape> list = ParseShapeList(absl::GetFlag(FLAGS_shapes)).value();
for (const auto& s : list) {
Softmax softmax = ParseSoftmax(s);

benchmark::RegisterBenchmark(softmax.GetBenchmarkName(), BM_Softmax,
softmax)
->MeasureProcessCPUTime();
}
}

} // namespace

} // namespace xla::cpu

GTEST_API_ int main(int argc, char** argv) {
// Only run benchmarks if `--benchmark_filter` is set.
for (int i = 1; i < argc; ++i) {
if (absl::StartsWith(argv[i], "--benchmark_filter=")) {
tsl::testing::InstallStacktraceHandler();
::benchmark::Initialize(&argc, argv);
testing::InitGoogleTest(&argc, argv);
xla::cpu::Set_XLA_FLAGS();
xla::cpu::RegisterBenchmarks();
::benchmark::RunSpecifiedBenchmarks();
return 0;
}
}
}
Loading