Skip to content

Commit d01d7f1

Browse files
dsharletgGoogle-ML-Automation
authored andcommitted
Rename rms_norm_benchmark_test to norm_benchmark_test
And add a softmax benchmark to it. PiperOrigin-RevId: 901328747
1 parent 0dff2b0 commit d01d7f1

2 files changed

Lines changed: 137 additions & 53 deletions

File tree

xla/backends/cpu/benchmarks/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,13 @@ xla_cc_test(
183183
)
184184

185185
xla_cc_test(
186-
name = "rms_norm_benchmark_test",
187-
srcs = ["rms_norm_benchmark_test.cc"],
186+
name = "norm_benchmark_test",
187+
srcs = ["norm_benchmark_test.cc"],
188188
deps = [
189189
":aot_benchmark_helper",
190190
":hlo_benchmark_runner",
191191
"//xla:literal",
192+
"//xla:literal_util",
192193
"//xla:shape_util",
193194
"//xla:xla_data_proto_cc",
194195
"//xla/hlo/ir:hlo",

xla/backends/cpu/benchmarks/rms_norm_benchmark_test.cc renamed to xla/backends/cpu/benchmarks/norm_benchmark_test.cc

Lines changed: 134 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include <cstdint>
1717
#include <cstdlib>
1818
#include <memory>
19+
#include <random>
1920
#include <string>
2021
#include <utility>
2122
#include <vector>
@@ -35,6 +36,7 @@ limitations under the License.
3536
#include "xla/hlo/ir/hlo_module.h"
3637
#include "xla/hlo/parser/hlo_parser.h"
3738
#include "xla/literal.h"
39+
#include "xla/literal_util.h"
3840
#include "xla/primitive_util.h"
3941
#include "xla/service/hlo.pb.h"
4042
#include "xla/shape.h"
@@ -72,7 +74,7 @@ void Set_XLA_FLAGS() {
7274
tsl::setenv("XLA_FLAGS", xla_flags.data(), /*overwrite=*/1);
7375
}
7476

75-
struct RmsNorm {
77+
struct NormShape {
7678
Shape input_shape;
7779
std::vector<int64_t> reduction_dims;
7880

@@ -83,53 +85,63 @@ struct RmsNorm {
8385
reduction_shape.DeleteDimensions(reduction_dims);
8486
return reduction_shape;
8587
}
86-
87-
std::string GetBenchmarkName() const {
88-
return absl::StrCat("BM_RmsNorm/", input_shape.ToString(), "_{",
89-
absl::StrJoin(reduction_dims, ","), "}");
90-
}
9188
};
9289

93-
RmsNorm ParseRmsNorm(const Shape& s) {
94-
RmsNorm rms_norm;
90+
NormShape ParseShape(const Shape& s) {
91+
NormShape result;
9592
CHECK(s.IsTuple());
9693
CHECK_EQ(s.tuple_shapes().size(), 2);
9794

98-
rms_norm.input_shape = s.tuple_shapes(0);
95+
result.input_shape = s.tuple_shapes(0);
9996

10097
const Shape& dims_shape = s.tuple_shapes(1);
10198
absl::Span<const int64_t> dims = dims_shape.dimensions();
102-
rms_norm.reduction_dims.assign(dims.begin(), dims.end());
99+
result.reduction_dims.assign(dims.begin(), dims.end());
100+
101+
return result;
102+
}
103103

104-
return rms_norm;
104+
Literal GetRandomLiteral(const Shape& shape) {
105+
double mean = 1.0f;
106+
double stddev = 0.1f;
107+
std::minstd_rand0 engine;
108+
PrimitiveType dtype = shape.element_type();
109+
switch (dtype) {
110+
case F32:
111+
return *LiteralUtil::CreateRandomLiteral<F32>(shape, &engine, mean,
112+
stddev);
113+
case BF16:
114+
return *LiteralUtil::CreateRandomLiteral<BF16>(shape, &engine, mean,
115+
stddev);
116+
default:
117+
LOG(FATAL) << "Add dtype to the if-else block before use: " << dtype;
118+
}
105119
}
106120

107-
void BM_RmsNorm(benchmark::State& state, const RmsNorm& rms_norm) {
108-
const std::string input_shape_str = rms_norm.input_shape.ToString();
121+
void BM_RmsNorm(benchmark::State& state, const NormShape& shape) {
122+
const std::string input_shape_str = shape.input_shape.ToString();
109123
const std::string reduction_dims_str =
110-
absl::StrJoin(rms_norm.reduction_dims, ",");
124+
absl::StrJoin(shape.reduction_dims, ",");
111125
const std::string dtype_str =
112-
primitive_util::LowercasePrimitiveTypeName(rms_norm.GetDType());
113-
const std::string reduction_shape_str =
114-
rms_norm.GetReductionShape().ToString();
126+
primitive_util::LowercasePrimitiveTypeName(shape.GetDType());
127+
const std::string reduction_shape_str = shape.GetReductionShape().ToString();
115128

116-
Shape input_shape_f32 =
117-
ShapeUtil::ChangeElementType(rms_norm.input_shape, F32);
129+
Shape input_shape_f32 = ShapeUtil::ChangeElementType(shape.input_shape, F32);
118130
const std::string input_shape_f32_str = input_shape_f32.ToString();
119131

120132
Shape reduction_shape_f32 =
121-
ShapeUtil::ChangeElementType(rms_norm.GetReductionShape(), F32);
133+
ShapeUtil::ChangeElementType(shape.GetReductionShape(), F32);
122134
const std::string reduction_shape_f32_str = reduction_shape_f32.ToString();
123135

124136
int64_t reduction_size = 1;
125-
for (int64_t d : rms_norm.reduction_dims) {
126-
reduction_size *= rms_norm.input_shape.dimensions(d);
137+
for (int64_t d : shape.reduction_dims) {
138+
reduction_size *= shape.input_shape.dimensions(d);
127139
}
128140

129141
std::vector<int64_t> kept_dims;
130-
for (int64_t i = 0; i < rms_norm.input_shape.dimensions().size(); ++i) {
142+
for (int64_t i = 0; i < shape.input_shape.dimensions().size(); ++i) {
131143
bool is_reduced = false;
132-
for (int64_t d : rms_norm.reduction_dims) {
144+
for (int64_t d : shape.reduction_dims) {
133145
if (i == d) {
134146
is_reduced = true;
135147
break;
@@ -141,7 +153,7 @@ void BM_RmsNorm(benchmark::State& state, const RmsNorm& rms_norm) {
141153
}
142154
const std::string kept_dims_str = absl::StrJoin(kept_dims, ",");
143155

144-
absl::string_view hlo_template = R"(
156+
absl::string_view hlo = R"(
145157
reducer_add {
146158
lhs = f32[] parameter(0)
147159
rhs = f32[] parameter(1)
@@ -175,51 +187,122 @@ void BM_RmsNorm(benchmark::State& state, const RmsNorm& rms_norm) {
175187
}
176188
)";
177189

178-
std::string hlo_data = absl::StrReplaceAll(
179-
hlo_template, {{"$input_shape", input_shape_str},
180-
{"$input_shape_f32", input_shape_f32_str},
181-
{"$reduction_shape_f32", reduction_shape_f32_str},
182-
{"$reduction_dims", reduction_dims_str},
183-
{"$reduction_size", absl::StrCat(reduction_size)},
184-
{"$kept_dims", kept_dims_str},
185-
{"$dtype", dtype_str}});
186-
187190
HloBenchmarkOptions benchmark_options;
188191
benchmark_options.num_executions = absl::GetFlag(FLAGS_num_executions);
189192
benchmark_options.aot_options = absl::GetFlag(FLAGS_aot_compiled_execution)
190193
? GetAotCompilationOptions()
191194
: nullptr;
192195

193-
TF_ASSERT_OK_AND_ASSIGN(
194-
auto module_and_iteration_literals,
195-
LoadHloModuleAndMaybeIterationLiteralsFromString(hlo_data));
196+
Literal input = GetRandomLiteral(shape.input_shape);
197+
198+
CHECK_OK(RunHloBenchmark(state, hlo, {&input},
199+
{{"$input_shape", input_shape_str},
200+
{"$input_shape_f32", input_shape_f32_str},
201+
{"$reduction_shape_f32", reduction_shape_f32_str},
202+
{"$reduction_dims", reduction_dims_str},
203+
{"$reduction_size", absl::StrCat(reduction_size)},
204+
{"$kept_dims", kept_dims_str},
205+
{"$dtype", dtype_str}},
206+
benchmark_options));
207+
}
208+
209+
void BM_Softmax(benchmark::State& state, const NormShape& shape) {
210+
const std::string input_shape_str = shape.input_shape.ToString();
211+
const std::string reduction_dims_str =
212+
absl::StrJoin(shape.reduction_dims, ",");
213+
const std::string dtype_str =
214+
primitive_util::LowercasePrimitiveTypeName(shape.GetDType());
215+
216+
Shape input_shape_f32 = ShapeUtil::ChangeElementType(shape.input_shape, F32);
217+
const std::string input_shape_f32_str = input_shape_f32.ToString();
218+
219+
Shape reduction_shape_f32 =
220+
ShapeUtil::ChangeElementType(shape.GetReductionShape(), F32);
221+
const std::string reduction_shape_f32_str = reduction_shape_f32.ToString();
222+
223+
std::vector<int64_t> kept_dims;
224+
for (int i = 0; i < shape.input_shape.dimensions().size(); ++i) {
225+
bool is_reduced = false;
226+
for (int64_t d : shape.reduction_dims) {
227+
if (i == d) {
228+
is_reduced = true;
229+
break;
230+
}
231+
}
232+
if (!is_reduced) {
233+
kept_dims.push_back(i);
234+
}
235+
}
236+
const std::string kept_dims_str = absl::StrJoin(kept_dims, ",");
237+
238+
absl::string_view hlo = R"(
239+
HloModule softmax
196240
197-
std::unique_ptr<HloModule> hlo_module =
198-
std::move(module_and_iteration_literals.first);
241+
reducer_max {
242+
lhs = f32[] parameter(0)
243+
rhs = f32[] parameter(1)
244+
ROOT max = f32[] maximum(lhs, rhs)
245+
}
199246
200-
std::vector<Literal> args;
201-
args.reserve(module_and_iteration_literals.second->arguments_size());
202-
for (const auto& arg : module_and_iteration_literals.second->arguments()) {
203-
TF_ASSERT_OK_AND_ASSIGN(args.emplace_back(), Literal::CreateFromProto(arg));
247+
reducer_add {
248+
lhs = f32[] parameter(0)
249+
rhs = f32[] parameter(1)
250+
ROOT sum = f32[] add(lhs, rhs)
204251
}
205252
206-
std::vector<Literal*> arg_ptrs;
207-
arg_ptrs.reserve(args.size());
208-
for (auto& arg : args) {
209-
arg_ptrs.push_back(&arg);
253+
ENTRY main {
254+
input = $input_shape parameter(0)
255+
input_f32 = $input_shape_f32 convert(input)
256+
257+
neg_inf = f32[] constant(-inf)
258+
max_val = $reduction_shape_f32 reduce(input_f32, neg_inf),
259+
dimensions={$reduction_dims}, to_apply=reducer_max
260+
max_br = $input_shape_f32 broadcast(max_val), dimensions={$kept_dims}
261+
262+
input_centered = $input_shape_f32 subtract(input_f32, max_br)
263+
input_exp = $input_shape_f32 exponential(input_centered)
264+
265+
zero = f32[] constant(0)
266+
sum_exp = $reduction_shape_f32 reduce(input_exp, zero),
267+
dimensions={$reduction_dims}, to_apply=reducer_add
268+
sum_exp_br = $input_shape_f32 broadcast(sum_exp), dimensions={$kept_dims}
269+
270+
output_f32 = $input_shape_f32 divide(input_exp, sum_exp_br)
271+
ROOT output = $input_shape convert(output_f32)
210272
}
273+
)";
274+
275+
HloBenchmarkOptions benchmark_options;
276+
benchmark_options.num_executions = absl::GetFlag(FLAGS_num_executions);
277+
benchmark_options.aot_options = absl::GetFlag(FLAGS_aot_compiled_execution)
278+
? GetAotCompilationOptions()
279+
: nullptr;
280+
281+
Literal input = GetRandomLiteral(shape.input_shape);
211282

212-
CHECK_OK(RunHloBenchmark(state, std::move(hlo_module), arg_ptrs,
283+
CHECK_OK(RunHloBenchmark(state, hlo, {&input},
284+
{{"$input_shape", input_shape_str},
285+
{"$input_shape_f32", input_shape_f32_str},
286+
{"$reduction_shape_f32", reduction_shape_f32_str},
287+
{"$reduction_dims", reduction_dims_str},
288+
{"$kept_dims", kept_dims_str},
289+
{"$dtype", dtype_str}},
213290
benchmark_options));
214291
}
215292

216293
void RegisterBenchmarks() {
217294
std::vector<Shape> list = ParseShapeList(absl::GetFlag(FLAGS_shapes)).value();
218295
for (const auto& s : list) {
219-
RmsNorm rms_norm = ParseRmsNorm(s);
296+
NormShape shape = ParseShape(s);
297+
298+
std::string shape_str =
299+
absl::StrCat(shape.input_shape.ToString(), "_{",
300+
absl::StrJoin(shape.reduction_dims, ","), "}");
301+
302+
benchmark::RegisterBenchmark("BM_RmsNorm/" + shape_str, BM_RmsNorm, shape)
303+
->MeasureProcessCPUTime();
220304

221-
benchmark::RegisterBenchmark(rms_norm.GetBenchmarkName(), BM_RmsNorm,
222-
rms_norm)
305+
benchmark::RegisterBenchmark("BM_Softmax/" + shape_str, BM_Softmax, shape)
223306
->MeasureProcessCPUTime();
224307
}
225308
}

0 commit comments

Comments
 (0)