@@ -94,60 +94,75 @@ def _init_test_cases(self):
94
94
def test_gmm (self ):
95
95
met .clear_all ()
96
96
jax .config .update ('jax_default_matmul_precision' , "highest" )
97
+ compiled_gmm = torch .compile (torch .ops .xla .gmm , backend = "openxla" )
97
98
gmm_funcs = [
98
- gmm , torch .ops .xla .gmm ,
99
- torch .compile (torch .ops .xla .gmm , backend = "openxla" )
99
+ gmm ,
100
+ torch .ops .xla .gmm ,
101
+ compiled_gmm ,
100
102
]
101
103
102
104
self ._init_test_cases ()
103
- for gmm_func in gmm_funcs :
104
- for test_case in self .tests_cases :
105
- num_groups = test_case ['num_groups' ]
106
- k = test_case ['k' ]
107
- m = test_case ['m' ]
108
- n = test_case ['n' ]
109
- lhs_dtype = rhs_dtype = test_case ['dtype' ]
110
-
111
- lhs = torch .rand (m , k , dtype = lhs_dtype )
112
- rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype )
113
- group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
114
- ref_out = self ._reference_gmm (lhs , rhs , group_sizes )
115
-
116
- out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
117
- self .assertTrue (torch .allclose (ref_out , out .cpu ()))
105
+ for test_cache in [False , True ]:
106
+ for gmm_func in gmm_funcs :
107
+ for test_case in self .tests_cases :
108
+ num_groups = test_case ['num_groups' ]
109
+ k = test_case ['k' ]
110
+ m = test_case ['m' ]
111
+ n = test_case ['n' ]
112
+ lhs_dtype = rhs_dtype = test_case ['dtype' ]
113
+
114
+ lhs = torch .rand (m , k , dtype = lhs_dtype )
115
+ rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype )
116
+ group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
117
+ ref_out = self ._reference_gmm (lhs , rhs , group_sizes )
118
+
119
+ out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
120
+ # torch.compiled version of the gmm will cache the payload in dynamo layer
121
+ # hence won't trigger the trace_pallas cache
122
+ if test_cache and gmm_func != compiled_gmm :
123
+ met .clear_counters ()
124
+ # execute the same gmm func, expected to hit the cache
125
+ out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
126
+ self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
127
+ self .assertTrue (torch .allclose (ref_out , out .cpu ()))
118
128
119
129
# Make sure gmm doesn't fallback.
120
- self .assertNotIn ( "aten::" , met . short_metrics_report () )
130
+ self .assertEqual ( len ( torch_xla . _XLAC . _get_executed_fallback_ops ()), 0 )
121
131
jax .config .update ('jax_default_matmul_precision' , "default" )
122
132
123
133
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
124
134
def test_gmm_bf16 (self ):
125
135
met .clear_all ()
126
136
127
- gmm_funcs = [
128
- gmm , torch .ops .xla .gmm ,
129
- torch .compile (torch .ops .xla .gmm , backend = "openxla" )
130
- ]
137
+ compiled_gmm = torch .compile (torch .ops .xla .gmm , backend = "openxla" )
138
+ gmm_funcs = [gmm , torch .ops .xla .gmm , compiled_gmm ]
131
139
self ._init_test_cases ()
132
- for gmm_func in gmm_funcs :
133
- for test_case in self .tests_cases :
134
- num_groups = test_case ['num_groups' ]
135
- k = test_case ['k' ]
136
- m = test_case ['m' ]
137
- n = test_case ['n' ]
138
- lhs_dtype = rhs_dtype = torch .bfloat16
139
-
140
- lhs = torch .rand (m , k , dtype = lhs_dtype )
141
- rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype )
142
- group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
143
- ref_out = self ._reference_gmm (lhs , rhs , group_sizes )
144
-
145
- out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
146
-
147
- self .assertTrue (torch .allclose (ref_out , out .cpu ()))
140
+ for test_cache in [False , True ]:
141
+ for gmm_func in gmm_funcs :
142
+ for test_case in self .tests_cases :
143
+ num_groups = test_case ['num_groups' ]
144
+ k = test_case ['k' ]
145
+ m = test_case ['m' ]
146
+ n = test_case ['n' ]
147
+ lhs_dtype = rhs_dtype = torch .bfloat16
148
+
149
+ lhs = torch .rand (m , k , dtype = lhs_dtype )
150
+ rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype )
151
+ group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
152
+ ref_out = self ._reference_gmm (lhs , rhs , group_sizes )
153
+
154
+ out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
155
+ # torch.compiled version of the gmm will cache the payload in dynamo layer
156
+ # hence won't trigger the trace_pallas cache
157
+ if test_cache and gmm_func != compiled_gmm :
158
+ met .clear_counters ()
159
+ # execute the same gmm func, expected to hit the cache
160
+ out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
161
+ self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
162
+ self .assertTrue (torch .allclose (ref_out , out .cpu ()))
148
163
149
164
# Make sure gmm doesn't fallback.
150
- self .assertNotIn ( "aten::" , met . short_metrics_report () )
165
+ self .assertEqual ( len ( torch_xla . _XLAC . _get_executed_fallback_ops ()), 0 )
151
166
152
167
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
153
168
def test_make_group_metadata (self ):
@@ -313,47 +328,59 @@ def test_tgmm(self):
313
328
jax .config .update ('jax_default_matmul_precision' , "highest" )
314
329
315
330
self ._init_test_cases ()
316
- for test_case in self .tests_cases :
317
- num_groups = test_case ['num_groups' ]
318
- k = test_case ['k' ]
319
- m = test_case ['m' ]
320
- n = test_case ['n' ]
321
- lhs_dtype = rhs_dtype = test_case ['dtype' ]
322
-
323
- lhs = torch .rand (k , m , dtype = lhs_dtype )
324
- rhs = torch .rand (m , n , dtype = rhs_dtype )
325
- group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
326
- ref_out = self ._reference_tgmm (lhs , rhs , group_sizes )
331
+ for test_cache in [False , True ]:
332
+ for test_case in self .tests_cases :
333
+ num_groups = test_case ['num_groups' ]
334
+ k = test_case ['k' ]
335
+ m = test_case ['m' ]
336
+ n = test_case ['n' ]
337
+ lhs_dtype = rhs_dtype = test_case ['dtype' ]
327
338
328
- out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
329
- self .assertTrue (torch .allclose (ref_out , out .cpu ()))
339
+ lhs = torch .rand (k , m , dtype = lhs_dtype )
340
+ rhs = torch .rand (m , n , dtype = rhs_dtype )
341
+ group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
342
+ ref_out = self ._reference_tgmm (lhs , rhs , group_sizes )
343
+
344
+ out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
345
+ if test_cache :
346
+ met .clear_counters ()
347
+ # execute the same gmm func, expected to hit the cache
348
+ out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
349
+ self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
350
+ self .assertTrue (torch .allclose (ref_out , out .cpu ()))
330
351
331
352
# Make sure tgmm doesn't fallback.
332
- self .assertNotIn ( "aten::" , met . short_metrics_report () )
353
+ self .assertEqual ( len ( torch_xla . _XLAC . _get_executed_fallback_ops ()), 0 )
333
354
jax .config .update ('jax_default_matmul_precision' , "default" )
334
355
335
356
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
336
357
def test_tgmm_bf16 (self ):
337
358
met .clear_all ()
338
359
339
360
self ._init_test_cases ()
340
- for test_case in self .tests_cases :
341
- num_groups = test_case ['num_groups' ]
342
- k = test_case ['k' ]
343
- m = test_case ['m' ]
344
- n = test_case ['n' ]
345
- lhs_dtype = rhs_dtype = torch .bfloat16
346
-
347
- lhs = torch .rand (k , m , dtype = lhs_dtype )
348
- rhs = torch .rand (m , n , dtype = rhs_dtype )
349
- group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
350
- ref_out = self ._reference_tgmm (lhs , rhs , group_sizes )
361
+ for test_cache in [False , True ]:
362
+ for test_case in self .tests_cases :
363
+ num_groups = test_case ['num_groups' ]
364
+ k = test_case ['k' ]
365
+ m = test_case ['m' ]
366
+ n = test_case ['n' ]
367
+ lhs_dtype = rhs_dtype = torch .bfloat16
351
368
352
- out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
353
- self .assertTrue (torch .allclose (ref_out , out .cpu ()))
369
+ lhs = torch .rand (k , m , dtype = lhs_dtype )
370
+ rhs = torch .rand (m , n , dtype = rhs_dtype )
371
+ group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
372
+ ref_out = self ._reference_tgmm (lhs , rhs , group_sizes )
373
+
374
+ out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
375
+ if test_cache :
376
+ met .clear_counters ()
377
+ # execute the same gmm func, expected to hit the cache
378
+ out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
379
+ self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
380
+ self .assertTrue (torch .allclose (ref_out , out .cpu ()))
354
381
355
382
# Make sure tgmm doesn't fallback.
356
- self .assertNotIn ( "aten::" , met . short_metrics_report () )
383
+ self .assertEqual ( len ( torch_xla . _XLAC . _get_executed_fallback_ops ()), 0 )
357
384
358
385
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
359
386
def test_gmm_backward (self ):
@@ -365,25 +392,31 @@ def test_gmm_backward(self):
365
392
n = test_case ['n' ]
366
393
lhs_dtype = rhs_dtype = torch .bfloat16
367
394
368
- lhs = torch .rand (m , k , dtype = lhs_dtype , requires_grad = True )
369
- rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype , requires_grad = True )
370
- group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
371
- lhs .retain_grad ()
372
- rhs .retain_grad ()
395
+ for test_cache in [False , True ]:
396
+ met .clear_all ()
397
+ lhs = torch .rand (m , k , dtype = lhs_dtype , requires_grad = True )
398
+ rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype , requires_grad = True )
399
+ group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
400
+ lhs .retain_grad ()
401
+ rhs .retain_grad ()
373
402
374
- ref_out = self ._reference_gmm (lhs , rhs , group_sizes )
375
- ref_out .sum ().backward ()
403
+ ref_out = self ._reference_gmm (lhs , rhs , group_sizes )
404
+ ref_out .sum ().backward ()
376
405
377
- ref_out_backward = torch .ones_like (ref_out )
378
- grad_lhs , grad_rhs = gmm_backward (
379
- ref_out_backward .to ("xla" ), lhs .to ("xla" ), rhs .to ("xla" ),
380
- group_sizes .to ("xla" ))
406
+ ref_out_backward = torch .ones_like (ref_out )
407
+ grad_lhs , grad_rhs = gmm_backward (
408
+ ref_out_backward .to ("xla" ), lhs .to ("xla" ), rhs .to ("xla" ),
409
+ group_sizes .to ("xla" ))
410
+ # same gmm/tgmm was run for the `test_cache=False` case so the
411
+ # cache should be populated now
412
+ if test_cache :
413
+ self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 2 )
381
414
382
- self .assertTrue (torch .allclose (lhs .grad , grad_lhs .cpu ()))
383
- self .assertTrue (torch .allclose (rhs .grad , grad_rhs .cpu ()))
415
+ self .assertTrue (torch .allclose (lhs .grad , grad_lhs .cpu ()))
416
+ self .assertTrue (torch .allclose (rhs .grad , grad_rhs .cpu ()))
384
417
385
418
# Make sure gmm doesn't fallback.
386
- self .assertNotIn ( "aten::" , met . short_metrics_report () )
419
+ self .assertEqual ( len ( torch_xla . _XLAC . _get_executed_fallback_ops ()), 0 )
387
420
388
421
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
389
422
def test_gmm_backward_2 (self ):
@@ -420,7 +453,7 @@ def test_gmm_backward_2(self):
420
453
self .assertTrue (torch .allclose (rhs .grad , rhs_xla .grad .cpu ()))
421
454
422
455
# Make sure gmm doesn't fallback.
423
- self .assertNotIn ( "aten::" , met . short_metrics_report () )
456
+ self .assertEqual ( len ( torch_xla . _XLAC . _get_executed_fallback_ops ()), 0 )
424
457
425
458
@unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
426
459
def test_gmm_backward_3 (self ):
@@ -458,7 +491,32 @@ def test_gmm_backward_3(self):
458
491
self .assertTrue (torch .allclose (rhs .grad , rhs_xla .grad .cpu ()))
459
492
460
493
# Make sure gmm doesn't fallback.
461
- self .assertNotIn ("aten::" , met .short_metrics_report ())
494
+ self .assertEqual (len (torch_xla ._XLAC ._get_executed_fallback_ops ()), 0 )
495
+
496
+ @unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
497
+ def test_gmm_cache_miss (self ):
498
+ met .clear_all ()
499
+ jax .config .update ('jax_default_matmul_precision' , "highest" )
500
+
501
+ self ._init_test_cases ()
502
+ test_case = self .tests_cases [- 1 ]
503
+ # make sure that cache miss for different input shapes and dtype
504
+ met .clear_all ()
505
+ for mul_factor in [[2 , 1 , 1 , 1 ], [1 , 2 , 1 , 1 ], [2 , 1 , 2 , 1 ], [2 , 1 , 1 , 2 ]]:
506
+ for dtype in [torch .float32 , torch .bfloat16 ]:
507
+ for tiling in [(128 , 128 , 128 ), (256 , 256 , 256 )]:
508
+ num_groups = test_case ['num_groups' ] * mul_factor [0 ]
509
+ k = test_case ['k' ] * mul_factor [1 ]
510
+ m = test_case ['m' ] * mul_factor [2 ]
511
+ n = test_case ['n' ] * mul_factor [3 ]
512
+ lhs_dtype = rhs_dtype = dtype
513
+
514
+ lhs = torch .rand (m , k , dtype = lhs_dtype )
515
+ rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype )
516
+ group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
517
+
518
+ out = gmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ), tiling )
519
+ self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), None )
462
520
463
521
464
522
if __name__ == '__main__' :
0 commit comments