@@ -16,6 +16,7 @@ limitations under the License.
1616#include < cstdint>
1717#include < cstdlib>
1818#include < memory>
19+ #include < random>
1920#include < string>
2021#include < utility>
2122#include < vector>
@@ -35,6 +36,7 @@ limitations under the License.
3536#include " xla/hlo/ir/hlo_module.h"
3637#include " xla/hlo/parser/hlo_parser.h"
3738#include " xla/literal.h"
39+ #include " xla/literal_util.h"
3840#include " xla/primitive_util.h"
3941#include " xla/service/hlo.pb.h"
4042#include " xla/shape.h"
@@ -72,7 +74,7 @@ void Set_XLA_FLAGS() {
7274 tsl::setenv (" XLA_FLAGS" , xla_flags.data (), /* overwrite=*/ 1 );
7375}
7476
75- struct RmsNorm {
77+ struct NormShape {
7678 Shape input_shape;
7779 std::vector<int64_t > reduction_dims;
7880
@@ -83,53 +85,63 @@ struct RmsNorm {
8385 reduction_shape.DeleteDimensions (reduction_dims);
8486 return reduction_shape;
8587 }
86-
87- std::string GetBenchmarkName () const {
88- return absl::StrCat (" BM_RmsNorm/" , input_shape.ToString (), " _{" ,
89- absl::StrJoin (reduction_dims, " ," ), " }" );
90- }
9188};
9289
93- RmsNorm ParseRmsNorm (const Shape& s) {
94- RmsNorm rms_norm ;
90+ NormShape ParseShape (const Shape& s) {
91+ NormShape result ;
9592 CHECK (s.IsTuple ());
9693 CHECK_EQ (s.tuple_shapes ().size (), 2 );
9794
98- rms_norm .input_shape = s.tuple_shapes (0 );
95+ result .input_shape = s.tuple_shapes (0 );
9996
10097 const Shape& dims_shape = s.tuple_shapes (1 );
10198 absl::Span<const int64_t > dims = dims_shape.dimensions ();
102- rms_norm.reduction_dims .assign (dims.begin (), dims.end ());
99+ result.reduction_dims .assign (dims.begin (), dims.end ());
100+
101+ return result;
102+ }
103103
104- return rms_norm;
104+ Literal GetRandomLiteral (const Shape& shape) {
105+ double mean = 1 .0f ;
106+ double stddev = 0 .1f ;
107+ std::minstd_rand0 engine;
108+ PrimitiveType dtype = shape.element_type ();
109+ switch (dtype) {
110+ case F32:
111+ return *LiteralUtil::CreateRandomLiteral<F32>(shape, &engine, mean,
112+ stddev);
113+ case BF16:
114+ return *LiteralUtil::CreateRandomLiteral<BF16>(shape, &engine, mean,
115+ stddev);
116+ default :
117+ LOG (FATAL) << " Add dtype to the if-else block before use: " << dtype;
118+ }
105119}
106120
107- void BM_RmsNorm (benchmark::State& state, const RmsNorm& rms_norm ) {
108- const std::string input_shape_str = rms_norm .input_shape .ToString ();
121+ void BM_RmsNorm (benchmark::State& state, const NormShape& shape ) {
122+ const std::string input_shape_str = shape .input_shape .ToString ();
109123 const std::string reduction_dims_str =
110- absl::StrJoin (rms_norm .reduction_dims , " ," );
124+ absl::StrJoin (shape .reduction_dims , " ," );
111125 const std::string dtype_str =
112- primitive_util::LowercasePrimitiveTypeName (rms_norm.GetDType ());
113- const std::string reduction_shape_str =
114- rms_norm.GetReductionShape ().ToString ();
126+ primitive_util::LowercasePrimitiveTypeName (shape.GetDType ());
127+ const std::string reduction_shape_str = shape.GetReductionShape ().ToString ();
115128
116- Shape input_shape_f32 =
117- ShapeUtil::ChangeElementType (rms_norm.input_shape , F32);
129+ Shape input_shape_f32 = ShapeUtil::ChangeElementType (shape.input_shape , F32);
118130 const std::string input_shape_f32_str = input_shape_f32.ToString ();
119131
120132 Shape reduction_shape_f32 =
121- ShapeUtil::ChangeElementType (rms_norm .GetReductionShape (), F32);
133+ ShapeUtil::ChangeElementType (shape .GetReductionShape (), F32);
122134 const std::string reduction_shape_f32_str = reduction_shape_f32.ToString ();
123135
124136 int64_t reduction_size = 1 ;
125- for (int64_t d : rms_norm .reduction_dims ) {
126- reduction_size *= rms_norm .input_shape .dimensions (d);
137+ for (int64_t d : shape .reduction_dims ) {
138+ reduction_size *= shape .input_shape .dimensions (d);
127139 }
128140
129141 std::vector<int64_t > kept_dims;
130- for (int64_t i = 0 ; i < rms_norm .input_shape .dimensions ().size (); ++i) {
142+ for (int64_t i = 0 ; i < shape .input_shape .dimensions ().size (); ++i) {
131143 bool is_reduced = false ;
132- for (int64_t d : rms_norm .reduction_dims ) {
144+ for (int64_t d : shape .reduction_dims ) {
133145 if (i == d) {
134146 is_reduced = true ;
135147 break ;
@@ -141,7 +153,7 @@ void BM_RmsNorm(benchmark::State& state, const RmsNorm& rms_norm) {
141153 }
142154 const std::string kept_dims_str = absl::StrJoin (kept_dims, " ," );
143155
144- absl::string_view hlo_template = R"(
156+ absl::string_view hlo = R"(
145157 reducer_add {
146158 lhs = f32[] parameter(0)
147159 rhs = f32[] parameter(1)
@@ -175,51 +187,122 @@ void BM_RmsNorm(benchmark::State& state, const RmsNorm& rms_norm) {
175187 }
176188 )" ;
177189
178- std::string hlo_data = absl::StrReplaceAll (
179- hlo_template, {{" $input_shape" , input_shape_str},
180- {" $input_shape_f32" , input_shape_f32_str},
181- {" $reduction_shape_f32" , reduction_shape_f32_str},
182- {" $reduction_dims" , reduction_dims_str},
183- {" $reduction_size" , absl::StrCat (reduction_size)},
184- {" $kept_dims" , kept_dims_str},
185- {" $dtype" , dtype_str}});
186-
187190 HloBenchmarkOptions benchmark_options;
188191 benchmark_options.num_executions = absl::GetFlag (FLAGS_num_executions);
189192 benchmark_options.aot_options = absl::GetFlag (FLAGS_aot_compiled_execution)
190193 ? GetAotCompilationOptions ()
191194 : nullptr ;
192195
193- TF_ASSERT_OK_AND_ASSIGN (
194- auto module_and_iteration_literals,
195- LoadHloModuleAndMaybeIterationLiteralsFromString (hlo_data));
196+ Literal input = GetRandomLiteral (shape.input_shape );
197+
198+ CHECK_OK (RunHloBenchmark (state, hlo, {&input},
199+ {{" $input_shape" , input_shape_str},
200+ {" $input_shape_f32" , input_shape_f32_str},
201+ {" $reduction_shape_f32" , reduction_shape_f32_str},
202+ {" $reduction_dims" , reduction_dims_str},
203+ {" $reduction_size" , absl::StrCat (reduction_size)},
204+ {" $kept_dims" , kept_dims_str},
205+ {" $dtype" , dtype_str}},
206+ benchmark_options));
207+ }
208+
209+ void BM_Softmax (benchmark::State& state, const NormShape& shape) {
210+ const std::string input_shape_str = shape.input_shape .ToString ();
211+ const std::string reduction_dims_str =
212+ absl::StrJoin (shape.reduction_dims , " ," );
213+ const std::string dtype_str =
214+ primitive_util::LowercasePrimitiveTypeName (shape.GetDType ());
215+
216+ Shape input_shape_f32 = ShapeUtil::ChangeElementType (shape.input_shape , F32);
217+ const std::string input_shape_f32_str = input_shape_f32.ToString ();
218+
219+ Shape reduction_shape_f32 =
220+ ShapeUtil::ChangeElementType (shape.GetReductionShape (), F32);
221+ const std::string reduction_shape_f32_str = reduction_shape_f32.ToString ();
222+
223+ std::vector<int64_t > kept_dims;
224+ for (int i = 0 ; i < shape.input_shape .dimensions ().size (); ++i) {
225+ bool is_reduced = false ;
226+ for (int64_t d : shape.reduction_dims ) {
227+ if (i == d) {
228+ is_reduced = true ;
229+ break ;
230+ }
231+ }
232+ if (!is_reduced) {
233+ kept_dims.push_back (i);
234+ }
235+ }
236+ const std::string kept_dims_str = absl::StrJoin (kept_dims, " ," );
237+
238+ absl::string_view hlo = R"(
239+ HloModule softmax
196240
197- std::unique_ptr<HloModule> hlo_module =
198- std::move (module_and_iteration_literals.first );
241+ reducer_max {
242+ lhs = f32[] parameter(0)
243+ rhs = f32[] parameter(1)
244+ ROOT max = f32[] maximum(lhs, rhs)
245+ }
199246
200- std::vector<Literal> args;
201- args. reserve (module_and_iteration_literals. second -> arguments_size ());
202- for ( const auto & arg : module_and_iteration_literals. second -> arguments ()) {
203- TF_ASSERT_OK_AND_ASSIGN (args. emplace_back (), Literal::CreateFromProto (arg));
247+ reducer_add {
248+ lhs = f32[] parameter(0)
249+ rhs = f32[] parameter(1)
250+ ROOT sum = f32[] add(lhs, rhs)
204251 }
205252
206- std::vector<Literal*> arg_ptrs;
207- arg_ptrs.reserve (args.size ());
208- for (auto & arg : args) {
209- arg_ptrs.push_back (&arg);
253+ ENTRY main {
254+ input = $input_shape parameter(0)
255+ input_f32 = $input_shape_f32 convert(input)
256+
257+ neg_inf = f32[] constant(-inf)
258+ max_val = $reduction_shape_f32 reduce(input_f32, neg_inf),
259+ dimensions={$reduction_dims}, to_apply=reducer_max
260+ max_br = $input_shape_f32 broadcast(max_val), dimensions={$kept_dims}
261+
262+ input_centered = $input_shape_f32 subtract(input_f32, max_br)
263+ input_exp = $input_shape_f32 exponential(input_centered)
264+
265+ zero = f32[] constant(0)
266+ sum_exp = $reduction_shape_f32 reduce(input_exp, zero),
267+ dimensions={$reduction_dims}, to_apply=reducer_add
268+ sum_exp_br = $input_shape_f32 broadcast(sum_exp), dimensions={$kept_dims}
269+
270+ output_f32 = $input_shape_f32 divide(input_exp, sum_exp_br)
271+ ROOT output = $input_shape convert(output_f32)
210272 }
273+ )" ;
274+
275+ HloBenchmarkOptions benchmark_options;
276+ benchmark_options.num_executions = absl::GetFlag (FLAGS_num_executions);
277+ benchmark_options.aot_options = absl::GetFlag (FLAGS_aot_compiled_execution)
278+ ? GetAotCompilationOptions ()
279+ : nullptr ;
280+
281+ Literal input = GetRandomLiteral (shape.input_shape );
211282
212- CHECK_OK (RunHloBenchmark (state, std::move (hlo_module), arg_ptrs,
283+ CHECK_OK (RunHloBenchmark (state, hlo, {&input},
284+ {{" $input_shape" , input_shape_str},
285+ {" $input_shape_f32" , input_shape_f32_str},
286+ {" $reduction_shape_f32" , reduction_shape_f32_str},
287+ {" $reduction_dims" , reduction_dims_str},
288+ {" $kept_dims" , kept_dims_str},
289+ {" $dtype" , dtype_str}},
213290 benchmark_options));
214291}
215292
216293void RegisterBenchmarks () {
217294 std::vector<Shape> list = ParseShapeList (absl::GetFlag (FLAGS_shapes)).value ();
218295 for (const auto & s : list) {
219- RmsNorm rms_norm = ParseRmsNorm (s);
296+ NormShape shape = ParseShape (s);
297+
298+ std::string shape_str =
299+ absl::StrCat (shape.input_shape .ToString (), " _{" ,
300+ absl::StrJoin (shape.reduction_dims , " ," ), " }" );
301+
302+ benchmark::RegisterBenchmark (" BM_RmsNorm/" + shape_str, BM_RmsNorm, shape)
303+ ->MeasureProcessCPUTime ();
220304
221- benchmark::RegisterBenchmark (rms_norm.GetBenchmarkName (), BM_RmsNorm,
222- rms_norm)
305+ benchmark::RegisterBenchmark (" BM_Softmax/" + shape_str, BM_Softmax, shape)
223306 ->MeasureProcessCPUTime ();
224307 }
225308}
0 commit comments