1515import pytest
1616import tilus
1717import torch
18- from tilus import int32
18+ from tilus import boolean , int32
1919from tilus .ir .layout import RegisterLayout , register_layout
2020from tilus .ir .layout .ops import spatial
2121
@@ -41,6 +41,19 @@ def __call__(self, out_ptr: ~int32):
4141 self .store_global (g_out , b , offsets = [0 , 0 ], dims = [0 , 1 ])
4242
4343
44+ class AnyAllInstExample (tilus .Script ):
45+ def __call__ (self , x_ptr : ~ int32 , y_ptr : ~ boolean ):
46+ self .attrs .blocks = 1
47+ self .attrs .warps = 1
48+
49+ g_x = self .global_view (ptr = x_ptr , dtype = int32 , shape = (32 , 32 ))
50+ g_y = self .global_view (ptr = y_ptr , dtype = boolean , shape = [2 ])
51+ r_x = self .load_global (g_x , offsets = [0 , 0 ], shape = [32 , 32 ])
52+
53+ self .store_global (g_y , src = self .any (r_x != 0 ), offsets = [0 ], dims = [])
54+ self .store_global (g_y , src = self .all (r_x != 0 ), offsets = [1 ], dims = [])
55+
56+
4457@pytest .mark .parametrize ("dim" , [0 , 1 ])
4558@pytest .mark .parametrize (
4659 "layout" ,
@@ -69,3 +82,19 @@ def test_reduce_instruction(dim: int, layout: RegisterLayout):
6982 demo = ReduceKernelExample (layout , dim = dim )
7083 demo (actual )
7184 assert torch .allclose (actual , expected ), f"Failed for layout { layout } and dim { dim } "
85+
86+
87+ def test_any_all_reduce_instruction ():
88+ kernel = AnyAllInstExample ()
89+ x0 = torch .zeros ((32 , 32 ), dtype = torch .int32 ).cuda ()
90+ y0 = torch .asarray ([False , False ], dtype = torch .bool ).cuda ()
91+ x1 = torch .ones ((32 , 32 ), dtype = torch .int32 ).cuda ()
92+ y1 = torch .asarray ([True , True ], dtype = torch .bool ).cuda ()
93+ x2 = torch .randint (0 , 2 , size = (32 , 32 ), dtype = torch .int32 ).cuda ()
94+ x2 [0 , 0 ] = 1
95+ x2 [0 , 1 ] = 0
96+ y2 = torch .asarray ([True , False ], dtype = torch .bool ).cuda ()
97+ for x , y in zip ([x0 , x1 , x2 ], [y0 , y1 , y2 ]):
98+ y_actual = torch .empty_like (y )
99+ kernel (x , y_actual )
100+ assert torch .allclose (y_actual , y ), f"Failed for x={ x } and y={ y } , y_actual={ y_actual } "
0 commit comments