Skip to content

Commit 104137f

Browse files
committed
Added transpose_2d_kernel
1 parent 8723ca7 commit 104137f

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

psydac/linalg/stencil_transpose_kernels.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,111 @@ def transpose_1d_kernel(mat: 'float[:, :]',
2424

2525
matT[p_out + i1_loc, d1] = mat[p_in + j1_loc, p_out + i1 - j1]
2626

27+
def transpose_2d_kernel(mat: 'float[:, :, :, :, :]',
28+
matT: 'float[:, :, :, :, :]',
29+
s_in: 'int[:]', # refers to matT
30+
p_in: 'int[:]',
31+
add: 'int[:]',
32+
s_out: 'int[:]',
33+
e_out: 'int[:]',
34+
p_out: 'int[:]'):
35+
36+
#####################################
37+
#####################################
38+
# without last row in 1st direction #
39+
#####################################
40+
#####################################
41+
for i1 in range(s_out[0], e_out[0]): # global row index of matT = global column index of mat
42+
i1_loc = i1 - s_out[0] # local row index of matT
43+
44+
#####################################
45+
# without last row in 2nd direction #
46+
#####################################
47+
for i2 in range(s_out[1], e_out[1]): # global row index of matT = global column index of mat
48+
i2_loc = i2 - s_out[1] # local row index of matT
49+
50+
for d1 in range(2 * p_in[0] + 1):
51+
j1 = i1 - p_in[0] + d1 # global column index of matT
52+
j1_loc = j1 - s_in[0] # local column index of matT = local row index of mat
53+
for d2 in range(2 * p_in[1] + 1):
54+
j2 = i2 - p_in[1] + d2 # global column index of matT
55+
j2_loc = j2 - s_in[1] # local column index of matT = local row index of mat
56+
57+
matT[p_out[0] + i1_loc,
58+
p_out[1] + i2_loc,
59+
d1, d2] = mat[p_in[0] + j1_loc,
60+
p_in[1] + j2_loc,
61+
p_out[0] + i1 - j1,
62+
p_out[1] + i2 - j2]
63+
64+
##############################################
65+
# treat last row in 2nd direction separately #
66+
##############################################
67+
i2 = e_out[1]
68+
i2_loc = i2 - s_out[1] # local row index of matT
69+
70+
for d1 in range(2 * p_in[0] + 1):
71+
j1 = i1 - p_in[0] + d1 # global column index of matT
72+
j1_loc = j1 - s_in[0] # local column index of matT = local row index of mat
73+
for d2 in range(2 * p_in[1] + add[1]):
74+
j2 = i2 - p_in[1] + d2 # global column index of matT
75+
j2_loc = j2 - s_in[1] # local column index of matT = local row index of mat
76+
77+
matT[p_out[0] + i1_loc,
78+
p_out[1] + i2_loc,
79+
d1, d2] = mat[p_in[0] + j1_loc,
80+
p_in[1] + j2_loc,
81+
p_out[0] + i1 - j1,
82+
p_out[1] + i2 - j2]
83+
84+
##############################################
85+
##############################################
86+
# treat last row in 1st direction separately #
87+
##############################################
88+
##############################################
89+
i1 = e_out[0]
90+
i1_loc = i1 - s_out[0] # local row index of matT
91+
92+
#####################################
93+
# without last row in 2nd direction #
94+
#####################################
95+
for i2 in range(s_out[1], e_out[1]):
96+
i2_loc = i2 - s_out[1] # local row index of matT
97+
98+
for d1 in range(2 * p_in[0] + add[0]):
99+
j1 = i1 - p_in[0] + d1 # global column index of matT
100+
j1_loc = j1 - s_in[0] # local column index of matT = local row index of mat
101+
for d2 in range(2 * p_in[1] + 1):
102+
j2 = i2 - p_in[1] + d2 # global column index of matT
103+
j2_loc = j2 - s_in[1] # local column index of matT = local row index of mat
104+
105+
matT[p_out[0] + i1_loc,
106+
p_out[1] + i2_loc,
107+
d1, d2] = mat[p_in[0] + j1_loc,
108+
p_in[1] + j2_loc,
109+
p_out[0] + i1 - j1,
110+
p_out[1] + i2 - j2]
111+
112+
##############################################
113+
# treat last row in 2nd direction separately #
114+
##############################################
115+
i2 = e_out[1]
116+
i2_loc = i2 - s_out[1] # local row index of matT
117+
118+
for d1 in range(2 * p_in[0] + add[0]):
119+
j1 = i1 - p_in[0] + d1 # global column index of matT
120+
j1_loc = j1 - s_in[0] # local column index of matT = local row index of mat
121+
for d2 in range(2 * p_in[1] + add[1]):
122+
j2 = i2 - p_in[1] + d2 # global column index of matT
123+
j2_loc = j2 - s_in[1] # local column index of matT = local row index of mat
124+
125+
matT[p_out[0] + i1_loc,
126+
p_out[1] + i2_loc,
127+
d1, d2] = mat[p_in[0] + j1_loc,
128+
p_in[1] + j2_loc,
129+
p_out[0] + i1 - j1,
130+
p_out[1] + i2 - j2]
131+
27132

28133
def transpose_3d_kernel(mat: 'float[:, :, :, :, :, :]',
29134
matT: 'float[:, :, :, :, :, :]',

0 commit comments

Comments
 (0)