@@ -57,6 +57,53 @@ import ../backend/[global_config, memory_optimization_hints],
57
57
# coord[k] = 0
58
58
# iter_pos -= backstrides[k]
59
59
60
+ type TensorForm = object
61
+ shape: Metadata
62
+ strides: Metadata
63
+
64
+ proc rank (t: TensorForm ): range [0 .. LASER_MAXRANK ] {.inline .} =
65
+ t.shape.len
66
+
67
+ func size (t: TensorForm ): int {.inline .} =
68
+ result = 1
69
+ for i in 0 ..< t.rank:
70
+ result *= t.shape[i]
71
+
72
+ func reduceRank (t: TensorForm ): TensorForm =
73
+ result = t
74
+
75
+ var i = 0
76
+ result .shape[0 ] = t.shape[0 ]
77
+ result .strides[0 ] = t.strides[0 ]
78
+ for j in 1 ..< t.rank:
79
+ # spurious axis
80
+ if t.shape[j] == 1 :
81
+ continue
82
+
83
+ # current axis is spurious
84
+ if result .shape[i] == 1 :
85
+ result .shape[i] = t.shape[j]
86
+ result .strides[i] = t.strides[j]
87
+ continue
88
+
89
+ # axes can be coalesced
90
+ if result .strides[i] == t.shape[j]* t.strides[j]:
91
+ result .shape[i] = result .shape[i]* t.shape[j]
92
+ result .strides[i] = t.strides[j]
93
+ continue
94
+
95
+ i += 1
96
+ result .shape[i] = t.shape[j]
97
+ result .strides[i] = t.strides[j]
98
+ result .shape.len = i + 1
99
+ result .strides.len = i + 1
100
+
101
+ func floor (x: int , divisor: int ): int {.inline .} =
102
+ return divisor* (x div divisor)
103
+
104
+ func ceil (x: int , divisor: int ): int {.inline .} =
105
+ return divisor* (((x - 1 ) div divisor) + 1 )
106
+
60
107
proc getIndex * [T](t: Tensor [T], idx: varargs [int ]): int {.noSideEffect ,inline .} =
61
108
# # Convert [i, j, k, l ...] to the proper index.
62
109
when compileOption (" boundChecks" ):
@@ -166,25 +213,145 @@ template stridedIterationYield*(strider: IterKind, data, i, iter_pos: typed) =
166
213
elif strider == IterKind .Iter_Values: yield (i, data[iter_pos])
167
214
elif strider == IterKind .Offset_Values: yield (iter_pos, data[iter_pos]) # # TODO: remove workaround for C++ backend
168
215
216
+ template stridedIterationLoop * (strider: IterKind , data, t, iter_offset, iter_size, prev_d, last_d: typed ) =
217
+ # # We break up the tensor in 5 parts and iterate over each using for loops.
218
+ # # We do this because the loop ranges and nestedness are different for each part.
219
+ # # The part boundaries are calculated and stored in the `bp1`, `bp2`, `bp3`
220
+ # # and `bp4` variables. The `(iter_offset, bp1)` segment is a rank-1 tensor
221
+ # # of size `<last_d`. The `(bp1, bp2)` segment is a rank-2 tensor with first
222
+ # # axis smaller than `prev_d`. The `(bp2, bp3)` segment is the main body, an
223
+ # # rank-n tensor with last axes sizes `prev_d` and `last_d`. The `(bp3, bp4)`
224
+ # # segment is a rank-2 tensor, and the `(bp4, iter_offset + iter_size)` segment
225
+ # # is a rank-1 tensor.
226
+ assert t.rank > 1
227
+
228
+ let prev_s = t.strides[^ 2 ]
229
+ let last_s = t.strides[^ 1 ]
230
+ let rank = t.rank
231
+ let size = t.size
232
+
233
+ assert iter_offset >= 0
234
+ assert iter_size <= size - iter_offset
235
+ assert prev_d > 0 and last_d > 0
236
+ assert size mod prev_d* last_d == 0
237
+
238
+ initStridedIteration (coord, backstrides, iter_pos, t, iter_offset, iter_size)
239
+
240
+ let bp1 =
241
+ if iter_offset == 0 :
242
+ 0
243
+ else :
244
+ min (iter_offset + iter_size, ceil (iter_offset, last_d))
245
+ let bp2 =
246
+ if iter_offset == 0 :
247
+ 0
248
+ else :
249
+ max (bp1, min (floor (iter_offset + iter_size, prev_d* last_d), ceil (iter_offset, prev_d* last_d)))
250
+ let bp3 =
251
+ if iter_size == size:
252
+ size
253
+ else :
254
+ max (bp2, floor (iter_offset + iter_size, prev_d* last_d))
255
+ let bp4 =
256
+ if iter_size == size:
257
+ size
258
+ else :
259
+ max (bp3, floor (iter_offset + iter_size, last_d))
260
+
261
+ assert iter_offset <= bp1 and bp1 <= bp2 and bp2 <= bp3 and bp3 <= bp4 and bp4 <= iter_offset + iter_size
262
+ assert bp1 - iter_offset < last_d and (bp1 mod last_d == 0 or bp1 == iter_offset + iter_size)
263
+ assert bp2 == bp1 or (bp2 mod prev_d* last_d == 0 and bp2 - bp1 < prev_d* last_d)
264
+ assert bp3 == bp2 or bp3 mod prev_d* last_d == 0
265
+ assert bp4 == bp3 or (bp4 mod last_d == 0 and bp4 - bp3 < prev_d* last_d)
266
+ assert iter_offset + iter_size - bp4 < last_d
267
+
268
+ var i = iter_offset
269
+
270
+ if bp1 > iter_offset:
271
+ coord[rank - 1 ] += bp1 - i - 1
272
+ while i < bp1:
273
+ stridedIterationYield (strider, data, i, iter_pos)
274
+ iter_pos += last_s
275
+ i += 1
276
+ iter_pos -= last_s
277
+ advanceStridedIteration (coord, backstrides, iter_pos, t, iter_offset, iter_size)
278
+
279
+ if bp2 > bp1:
280
+ coord[rank - 2 ] += ((bp2 - i) div last_d) - 1
281
+ coord[rank - 1 ] = last_d - 1
282
+ while i < bp2:
283
+ for _ in 0 ..< last_d:
284
+ stridedIterationYield (strider, data, i, iter_pos)
285
+ iter_pos += last_s
286
+ i += 1
287
+ iter_pos += prev_s - last_s* last_d
288
+ iter_pos += last_s* (last_d - 1 ) - prev_s
289
+ advanceStridedIteration (coord, backstrides, iter_pos, t, iter_offset, iter_size)
290
+
291
+ while i < bp3:
292
+ for _ in 0 ..< prev_d:
293
+ for _ in 0 ..< last_d:
294
+ stridedIterationYield (strider, data, i, iter_pos)
295
+ iter_pos += last_s
296
+ i += 1
297
+ iter_pos += prev_s - last_s* last_d
298
+ iter_pos -= prev_s* prev_d
299
+
300
+ for k in countdown (rank - 3 , 0 ):
301
+ if coord[k] < t.shape[k] - 1 :
302
+ coord[k] += 1
303
+ iter_pos += t.strides[k]
304
+ break
305
+ else :
306
+ coord[k] = 0
307
+ iter_pos -= backstrides[k]
308
+
309
+ if bp4 > bp3:
310
+ coord[rank - 2 ] += ((bp4 - i) div last_d) - 1
311
+ coord[rank - 1 ] = last_d - 1
312
+ while i < bp4:
313
+ for _ in 0 ..< last_d:
314
+ stridedIterationYield (strider, data, i, iter_pos)
315
+ iter_pos += last_s
316
+ i += 1
317
+ iter_pos += prev_s - last_s* last_d
318
+ iter_pos += last_s* (last_d - 1 ) - prev_s
319
+ advanceStridedIteration (coord, backstrides, iter_pos, t, iter_offset, iter_size)
320
+
321
+ while i < iter_offset + iter_size:
322
+ stridedIterationYield (strider, data, i, iter_pos)
323
+ iter_pos += last_s
324
+ i += 1
325
+
169
326
template stridedIteration * (strider: IterKind , t, iter_offset, iter_size: typed ): untyped =
170
327
# # Iterate over a Tensor, displaying data as in C order, whatever the strides.
171
328
172
329
# Get tensor data address with offset builtin
173
330
# only reading here, pointer access is safe even for ref types
331
+
174
332
when getSubType (type (t)) is KnownSupportsCopyMem :
175
333
let data = t.unsafe_raw_offset ()
176
334
else :
177
335
template data : untyped {.gensym .} = t
178
336
179
- # Optimize for loops in contiguous cases
180
- if t.is_C_contiguous:
337
+ let tf = reduceRank (TensorForm (shape: t.shape, strides: t.strides))
338
+
339
+ assert tf.rank >= 1
340
+ if tf.rank == 1 :
341
+ let s = tf.strides[^ 1 ]
181
342
for i in iter_offset..< (iter_offset+ iter_size):
182
- stridedIterationYield (strider, data, i, i)
343
+ stridedIterationYield (strider, data, i, i* s )
183
344
else :
184
- initStridedIteration (coord, backstrides, iter_pos, t, iter_offset, iter_size)
185
- for i in iter_offset..< (iter_offset+ iter_size):
186
- stridedIterationYield (strider, data, i, iter_pos)
187
- advanceStridedIteration (coord, backstrides, iter_pos, t, iter_offset, iter_size)
345
+ let prev_d = tf.shape[^ 2 ]
346
+ let last_d = tf.shape[^ 1 ]
347
+ if prev_d == 2 and last_d == 2 :
348
+ stridedIterationLoop (strider, data, tf, iter_offset, iter_size, 2 , 2 )
349
+ elif last_d == 2 :
350
+ stridedIterationLoop (strider, data, tf, iter_offset, iter_size, prev_d, 2 )
351
+ elif last_d == 3 :
352
+ stridedIterationLoop (strider, data, tf, iter_offset, iter_size, prev_d, 3 )
353
+ else :
354
+ stridedIterationLoop (strider, data, tf, iter_offset, iter_size, prev_d, last_d)
188
355
189
356
template stridedCoordsIteration * (t, iter_offset, iter_size: typed ): untyped =
190
357
# # Iterate over a Tensor, displaying data as in C order, whatever the strides. (coords)
0 commit comments