@@ -24,6 +24,111 @@ def transpose_1d_kernel(mat: 'float[:, :]',
2424
2525 matT [p_out + i1_loc , d1 ] = mat [p_in + j1_loc , p_out + i1 - j1 ]
2626
27+ def transpose_2d_kernel (mat : 'float[:, :, :, :, :]' ,
28+ matT : 'float[:, :, :, :, :]' ,
29+ s_in : 'int[:]' , # refers to matT
30+ p_in : 'int[:]' ,
31+ add : 'int[:]' ,
32+ s_out : 'int[:]' ,
33+ e_out : 'int[:]' ,
34+ p_out : 'int[:]' ):
35+
36+ #####################################
37+ #####################################
38+ # without last row in 1st direction #
39+ #####################################
40+ #####################################
41+ for i1 in range (s_out [0 ], e_out [0 ]): # global row index of matT = global column index of mat
42+ i1_loc = i1 - s_out [0 ] # local row index of matT
43+
44+ #####################################
45+ # without last row in 2nd direction #
46+ #####################################
47+ for i2 in range (s_out [1 ], e_out [1 ]): # global row index of matT = global column index of mat
48+ i2_loc = i2 - s_out [1 ] # local row index of matT
49+
50+ for d1 in range (2 * p_in [0 ] + 1 ):
51+ j1 = i1 - p_in [0 ] + d1 # global column index of matT
52+ j1_loc = j1 - s_in [0 ] # local column index of matT = local row index of mat
53+ for d2 in range (2 * p_in [1 ] + 1 ):
54+ j2 = i2 - p_in [1 ] + d2 # global column index of matT
55+ j2_loc = j2 - s_in [1 ] # local column index of matT = local row index of mat
56+
57+ matT [p_out [0 ] + i1_loc ,
58+ p_out [1 ] + i2_loc ,
59+ d1 , d2 ] = mat [p_in [0 ] + j1_loc ,
60+ p_in [1 ] + j2_loc ,
61+ p_out [0 ] + i1 - j1 ,
62+ p_out [1 ] + i2 - j2 ]
63+
64+ ##############################################
65+ # treat last row in 2nd direction separately #
66+ ##############################################
67+ i2 = e_out [1 ]
68+ i2_loc = i2 - s_out [1 ] # local row index of matT
69+
70+ for d1 in range (2 * p_in [0 ] + 1 ):
71+ j1 = i1 - p_in [0 ] + d1 # global column index of matT
72+ j1_loc = j1 - s_in [0 ] # local column index of matT = local row index of mat
73+ for d2 in range (2 * p_in [1 ] + add [1 ]):
74+ j2 = i2 - p_in [1 ] + d2 # global column index of matT
75+ j2_loc = j2 - s_in [1 ] # local column index of matT = local row index of mat
76+
77+ matT [p_out [0 ] + i1_loc ,
78+ p_out [1 ] + i2_loc ,
79+ d1 , d2 ] = mat [p_in [0 ] + j1_loc ,
80+ p_in [1 ] + j2_loc ,
81+ p_out [0 ] + i1 - j1 ,
82+ p_out [1 ] + i2 - j2 ]
83+
84+ ##############################################
85+ ##############################################
86+ # treat last row in 1st direction separately #
87+ ##############################################
88+ ##############################################
89+ i1 = e_out [0 ]
90+ i1_loc = i1 - s_out [0 ] # local row index of matT
91+
92+ #####################################
93+ # without last row in 2nd direction #
94+ #####################################
95+ for i2 in range (s_out [1 ], e_out [1 ]):
96+ i2_loc = i2 - s_out [1 ] # local row index of matT
97+
98+ for d1 in range (2 * p_in [0 ] + add [0 ]):
99+ j1 = i1 - p_in [0 ] + d1 # global column index of matT
100+ j1_loc = j1 - s_in [0 ] # local column index of matT = local row index of mat
101+ for d2 in range (2 * p_in [1 ] + 1 ):
102+ j2 = i2 - p_in [1 ] + d2 # global column index of matT
103+ j2_loc = j2 - s_in [1 ] # local column index of matT = local row index of mat
104+
105+ matT [p_out [0 ] + i1_loc ,
106+ p_out [1 ] + i2_loc ,
107+ d1 , d2 ] = mat [p_in [0 ] + j1_loc ,
108+ p_in [1 ] + j2_loc ,
109+ p_out [0 ] + i1 - j1 ,
110+ p_out [1 ] + i2 - j2 ]
111+
112+ ##############################################
113+ # treat last row in 2nd direction separately #
114+ ##############################################
115+ i2 = e_out [1 ]
116+ i2_loc = i2 - s_out [1 ] # local row index of matT
117+
118+ for d1 in range (2 * p_in [0 ] + add [0 ]):
119+ j1 = i1 - p_in [0 ] + d1 # global column index of matT
120+ j1_loc = j1 - s_in [0 ] # local column index of matT = local row index of mat
121+ for d2 in range (2 * p_in [1 ] + add [1 ]):
122+ j2 = i2 - p_in [1 ] + d2 # global column index of matT
123+ j2_loc = j2 - s_in [1 ] # local column index of matT = local row index of mat
124+
125+ matT [p_out [0 ] + i1_loc ,
126+ p_out [1 ] + i2_loc ,
127+ d1 , d2 ] = mat [p_in [0 ] + j1_loc ,
128+ p_in [1 ] + j2_loc ,
129+ p_out [0 ] + i1 - j1 ,
130+ p_out [1 ] + i2 - j2 ]
131+
27132
28133def transpose_3d_kernel (mat : 'float[:, :, :, :, :, :]' ,
29134 matT : 'float[:, :, :, :, :, :]' ,
0 commit comments