@@ -28,7 +28,7 @@ def unpack_indices(dual_object):
2828
2929@flow .unittest .skip_unless_1n1d () 
3030class  TestMaxPooling (flow .unittest .TestCase ):
31-     @autotest (n = 100 ,  auto_backward = False ) 
31+     @autotest (auto_backward = False ) 
3232    def  test_maxpool1d_with_random_data (test_case ):
3333        return_indices  =  random ().to (bool ).value ()
3434        m  =  torch .nn .MaxPool1d (
@@ -50,7 +50,7 @@ def test_maxpool1d_with_random_data(test_case):
5050        else :
5151            return  y , y .sum ().backward ()
5252
53-     @autotest (n = 100 ,  auto_backward = False ) 
53+     @autotest (auto_backward = False ) 
5454    def  test_maxpool2d_with_random_data (test_case ):
5555        return_indices  =  random ().to (bool ).value ()
5656        m  =  torch .nn .MaxPool2d (
@@ -74,7 +74,7 @@ def test_maxpool2d_with_random_data(test_case):
7474        else :
7575            return  y , y .sum ().backward ()
7676
77-     @autotest (n = 100 ,  auto_backward = False ) 
77+     @autotest (auto_backward = False ) 
7878    def  test_maxpool3d_with_random_data (test_case ):
7979        return_indices  =  random ().to (bool ).value ()
8080        m  =  torch .nn .MaxPool3d (
@@ -99,5 +99,72 @@ def test_maxpool3d_with_random_data(test_case):
9999            return  y , y .sum ().backward ()
100100
101101
102+ @flow .unittest .skip_unless_1n1d () 
103+ class  TestMaxPoolingFunctional (flow .unittest .TestCase ):
104+     @autotest (auto_backward = False ) 
105+     def  test_maxpool1d_with_random_data (test_case ):
106+         return_indices  =  random ().to (bool ).value ()
107+         device  =  random_device ()
108+         x  =  random_pytorch_tensor (ndim = 3 , dim2 = random (20 , 22 )).to (device )
109+         y  =  torch .nn .functional .max_pool1d (
110+             x ,
111+             kernel_size = random (4 , 6 ).to (int ),
112+             stride = random (1 , 3 ).to (int ) |  nothing (),
113+             padding = random (1 , 3 ).to (int ) |  nothing (),
114+             dilation = random (2 , 4 ).to (int ) |  nothing (),
115+             ceil_mode = random ().to (bool ),
116+             return_indices = return_indices ,
117+         )
118+ 
119+         if  return_indices :
120+             return  unpack_indices (y )
121+         else :
122+             return  y , y .sum ().backward ()
123+ 
124+     @autotest (auto_backward = False ) 
125+     def  test_maxpool2d_with_random_data (test_case ):
126+         return_indices  =  random ().to (bool ).value ()
127+         device  =  random_device ()
128+         x  =  random_pytorch_tensor (ndim = 4 , dim2 = random (20 , 22 ), dim3 = random (20 , 22 )).to (
129+             device 
130+         )
131+         y  =  torch .nn .functional .max_pool2d (
132+             x ,
133+             kernel_size = random (4 , 6 ).to (int ),
134+             stride = random (1 , 3 ).to (int ) |  nothing (),
135+             padding = random (1 , 3 ).to (int ) |  nothing (),
136+             dilation = random (2 , 4 ).to (int ) |  nothing (),
137+             ceil_mode = random ().to (bool ),
138+             return_indices = return_indices ,
139+         )
140+ 
141+         if  return_indices :
142+             return  unpack_indices (y )
143+         else :
144+             return  y , y .sum ().backward ()
145+ 
146+     @autotest (auto_backward = False ) 
147+     def  test_maxpool3d_with_random_data (test_case ):
148+         return_indices  =  random ().to (bool ).value ()
149+         device  =  random_device ()
150+         x  =  random_pytorch_tensor (
151+             ndim = 5 , dim2 = random (20 , 22 ), dim3 = random (20 , 22 ), dim4 = random (20 , 22 )
152+         ).to (device )
153+         y  =  torch .nn .functional .max_pool3d (
154+             x ,
155+             kernel_size = random (4 , 6 ).to (int ),
156+             stride = random (1 , 3 ).to (int ) |  nothing (),
157+             padding = random (1 , 3 ).to (int ) |  nothing (),
158+             dilation = random (2 , 4 ).to (int ) |  nothing (),
159+             ceil_mode = random ().to (bool ),
160+             return_indices = return_indices ,
161+         )
162+ 
163+         if  return_indices :
164+             return  unpack_indices (y )
165+         else :
166+             return  y , y .sum ().backward ()
167+ 
168+ 
102169if  __name__  ==  "__main__" :
103170    unittest .main ()
0 commit comments