Skip to content

Commit 7d48c32

Browse files
author
The paxml Authors
committed
Update the flops computation by counting add and multiply separately.
PiperOrigin-RevId: 534193569
1 parent cacfd9f commit 7d48c32

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paxml/tools/model_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
3131
##### bert.BertAdamL4H128 #####
3232
33-
GFLOPS = 23.79
33+
GFLOPS = 47.58
3434
3535
3636
**************
@@ -174,7 +174,7 @@ def get_flops(self, fprop_func: str = 'compute_predictions'):
174174
client = jax.lib.xla_bridge.get_backend()
175175
m = jax.xla_computation(model_fprop)(datum).as_hlo_module()
176176
analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m)
177-
flops = analysis['flops'] / 2.0
177+
flops = analysis['flops']
178178
gflops = flops / 1e9
179179

180180
print('\n' + '#' * 5 + ' ' + self.exp_name + ' ' + '#' * 5)

0 commit comments

Comments
 (0)