@@ -53,12 +53,22 @@ impl GraphTensor {
5353
5454 /// add a new dimension of size 1 at the specified place
5555 pub fn unsqueeze ( mut self , dim : usize ) -> GraphTensor {
56- // Insert contiguous call
5756 assert ! ( self . shape. len( ) < 10 , "Shape is maxed out at 10 dimensions" ) ;
5857 self . shape . expand_dim ( dim, 1 ) ;
5958 self
6059 }
6160
61+ /// remove a dimension of size 1
62+ pub fn squeeze ( mut self , axis : usize ) -> GraphTensor {
63+ assert_eq ! (
64+ self . dims( ) [ axis] ,
65+ Expression :: from( 1 ) ,
66+ "Only dimensions of size 1 can be squeezed!"
67+ ) ;
68+ self . shape . remove_dim ( axis) ;
69+ self
70+ }
71+
6272 pub fn gather ( self , indexes : GraphTensor ) -> GraphTensor {
6373 assert_eq ! (
6474 indexes. dtype,
@@ -78,18 +88,32 @@ impl GraphTensor {
7888 /// x = [3, 2, 4, 1, 5, 0]
7989 /// inv_perm(x) = [5, 3, 1, 0, 2, 4]
8090 pub fn inverse_permutation ( self , axis : usize ) -> GraphTensor {
81- // TODO: this is very inefficient because we need to do O(n^2) comparisons and allocations for the one-hot and then sum reduce. Need a better way!
82- let ax_size = self . dims ( ) [ axis] ;
83- let mut arange = self . graph ( ) . arange ( ax_size) ;
91+ // TODO: this is super inefficient because it requires materializing a large (n^2) one-hot tensor
92+ assert_eq ! ( self . dtype, DType :: Int ) ;
93+ let dims = self . dims ( ) ;
94+ let ax_size = dims[ axis] ;
95+ let mut dims2 = dims. clone ( ) ;
96+ dims2. insert ( axis, ax_size) ;
97+ // candidate: varies along candidate dim (axis), broadcast elsewhere.
98+ let mut candidate = self . graph ( ) . arange ( ax_size) ;
8499 for i in 0 ..axis {
85- arange = arange. expand_dim ( i, self . dims ( ) [ i] ) ;
100+ candidate = candidate. expand_dim ( i, dims2[ i] ) ;
101+ }
102+ for i in axis + 1 ..dims2. len ( ) {
103+ candidate = candidate. expand_dim ( i, dims2[ i] ) ;
86104 }
87- for i in axis + 1 ..self . dims ( ) . len ( ) {
88- arange = arange. expand_dim ( i, self . dims ( ) [ i] ) ;
105+ // position: varies along position dim (axis+1), broadcast elsewhere.
106+ let mut position = self . graph ( ) . arange ( ax_size) ;
107+ for i in 0 ..( axis + 1 ) {
108+ position = position. expand_dim ( i, dims2[ i] ) ;
89109 }
90- arange = arange. expand_dim ( axis + 1 , ax_size) ;
91- let one_hot = self . expand_dim ( axis + 1 , ax_size) . eq ( arange) ;
92- ( one_hot * arange) . sum ( axis + 1 )
110+ for i in ( axis + 2 ) ..dims2. len ( ) {
111+ position = position. expand_dim ( i, dims2[ i] ) ;
112+ }
113+ // one_hot[candidate, ..., position, ...] = (self[position, ...] == candidate)
114+ let one_hot = self . expand_dim ( axis, ax_size) . eq ( candidate) ;
115+ // inv[candidate, ...] = Σ_pos one_hot * position
116+ ( one_hot * position) . sum ( axis + 1 )
93117 }
94118
95119 /// Extracts sliding local windows from an input tensor.
@@ -101,6 +125,7 @@ impl GraphTensor {
101125 ) -> GraphTensor {
102126 let ( kernel, strides, dilation) =
103127 ( kernel. to_shape ( ) , strides. to_shape ( ) , dilation. to_shape ( ) ) ;
128+
104129 assert_eq ! (
105130 self . shape. len( ) ,
106131 kernel. len( ) ,
@@ -117,7 +142,7 @@ impl GraphTensor {
117142 "Dilation must be same number of dimensions as tensor!"
118143 ) ;
119144
120- // Compute input strides
145+ // Compute input strides (row-major contiguous)
121146 let dims = self . dims ( ) ;
122147 let n = dims. len ( ) ;
123148 let mut in_strides = vec ! [ Expression :: from( 1 ) ; n] ;
@@ -128,25 +153,27 @@ impl GraphTensor {
128153 }
129154
130155 // Per-dim window counts
131- let mut win = Vec :: with_capacity ( dims. len ( ) ) ;
132- for ( ( ( dim, kernel) , stride) , dilation) in
133- dims. into_iter ( ) . zip ( & kernel) . zip ( & strides) . zip ( & dilation)
134- {
135- let effective_window = * dilation * ( * kernel - 1 ) + 1 ;
136- win. push ( ( ( dim - effective_window) / stride) + 1 ) ;
156+ let mut win = Vec :: with_capacity ( n) ;
157+ for ( ( ( dim, k) , s) , d) in dims. iter ( ) . zip ( & kernel) . zip ( & strides) . zip ( & dilation) {
158+ let effective_window = * d * ( * k - 1 ) + 1 ;
159+ win. push ( ( ( * dim - effective_window) / s) + 1 ) ;
137160 }
138161
139- // final_shape = [kernel ..., win ...]
140- let mut final_shape = kernel . clone ( ) ;
141- final_shape. extend ( win . into_iter ( ) . map ( |e| e . simplify ( ) ) ) ;
162+ // [win ..., kernel ...]
163+ let mut final_shape: Vec < Expression > = win . into_iter ( ) . map ( |e| e . simplify ( ) ) . collect ( ) ;
164+ final_shape. extend ( kernel . iter ( ) . copied ( ) ) ;
142165
143- // strides exprs over axes [k0..kN-1, w0..wN-1]
166+ // Axis exprs must match final_shape axis order: first w axes, then k axes.
167+ // idx = Σ_d (w_d * stride_d + k_d * dilation_d) * in_strides[d]
144168 let mut axis_exprs = Vec :: with_capacity ( 2 * n) ;
169+
170+ // w axes
145171 for i in 0 ..n {
146- axis_exprs. push ( Expression :: from ( 'z' ) * dilation [ i] * in_strides[ i] ) ;
172+ axis_exprs. push ( Expression :: from ( 'z' ) * strides [ i] * in_strides[ i] ) ;
147173 }
174+ // k axes
148175 for i in 0 ..n {
149- axis_exprs. push ( Expression :: from ( 'z' ) * strides [ i] * in_strides[ i] ) ;
176+ axis_exprs. push ( Expression :: from ( 'z' ) * dilation [ i] * in_strides[ i] ) ;
150177 }
151178
152179 let index_expression = flatten_strides ( & final_shape, & axis_exprs) . simplify ( ) ;
@@ -266,6 +293,7 @@ impl GraphTensor {
266293 new_tensor * mask
267294 }
268295
296+ /// Pad along an existing dimension
269297 pub fn pad_along (
270298 self ,
271299 left : impl Into < Expression > ,
@@ -338,10 +366,170 @@ mod tests {
338366
339367 #[ test]
340368 fn test_unfold ( ) {
341- let mut cx = Graph :: new ( ) ;
369+ // Need all this code because candle doesnt do unfold
370+ pub fn unfold_nd_f32 (
371+ x : & [ f32 ] ,
372+ shape : & [ usize ] ,
373+ strides : & [ usize ] ,
374+ kernel : & [ usize ] ,
375+ step : & [ usize ] ,
376+ dilation : & [ usize ] ,
377+ pad_before : & [ usize ] ,
378+ pad_after : & [ usize ] ,
379+ ) -> Vec < f32 > {
380+ let n = shape. len ( ) ;
381+ assert ! ( n > 0 ) ;
382+ assert_eq ! ( strides. len( ) , n) ;
383+ assert_eq ! ( kernel. len( ) , n) ;
384+ assert_eq ! ( step. len( ) , n) ;
385+ assert_eq ! ( dilation. len( ) , n) ;
386+ assert_eq ! ( pad_before. len( ) , n) ;
387+ assert_eq ! ( pad_after. len( ) , n) ;
388+
389+ for d in 0 ..n {
390+ assert ! ( kernel[ d] > 0 ) ;
391+ assert ! ( step[ d] > 0 ) ;
392+ assert ! ( dilation[ d] > 0 ) ;
393+ assert ! ( shape[ d] > 0 ) ;
394+ }
395+
396+ // Effective kernel size per dim: (K-1)*d + 1
397+ let eff_kernel: Vec < usize > =
398+ ( 0 ..n) . map ( |d| ( kernel[ d] - 1 ) * dilation[ d] + 1 ) . collect ( ) ;
399+
400+ // Output spatial shape (number of windows) per dim
401+ let mut out_shape = vec ! [ 0usize ; n] ;
402+ for d in 0 ..n {
403+ let padded = shape[ d] + pad_before[ d] + pad_after[ d] ;
404+ if padded < eff_kernel[ d] {
405+ return Vec :: new ( ) ;
406+ }
407+ out_shape[ d] = ( padded - eff_kernel[ d] ) / step[ d] + 1 ;
408+ }
409+
410+ let windows = prod ( & out_shape) ;
411+ let window_elems = prod ( kernel) ;
412+ let mut out = vec ! [ 0.0f32 ; windows * window_elems] ;
413+
414+ // Precompute helpers
415+ let k_mul = row_major_multipliers ( kernel) ;
416+
417+ // Current output window position (row-major)
418+ let mut out_pos = vec ! [ 0usize ; n] ;
419+
420+ for w in 0 ..windows {
421+ if w > 0 {
422+ incr_row_major ( & mut out_pos, & out_shape) ;
423+ }
424+
425+ // Window start in padded coordinates
426+ let start_padded: Vec < usize > = ( 0 ..n) . map ( |d| out_pos[ d] * step[ d] ) . collect ( ) ;
427+
428+ let base_out = w * window_elems;
429+
430+ // Iterate kernel elements (flattened)
431+ for ke in 0 ..window_elems {
432+ let k_idx = unravel_row_major ( ke, kernel, & k_mul) ;
433+
434+ let mut flat: isize = 0 ;
435+ let mut in_bounds = true ;
342436
343- let inp = cx. tensor ( ( 5 , ) ) ;
344- let _pooled = inp. unfold ( ( 3 , ) , ( 1 , ) , ( 1 , ) ) ;
437+ for d in 0 ..n {
438+ let p = start_padded[ d] + k_idx[ d] * dilation[ d] ;
439+ let logical = p as isize - pad_before[ d] as isize ;
440+
441+ if logical < 0 || logical >= shape[ d] as isize {
442+ in_bounds = false ;
443+ break ;
444+ }
445+ flat += logical * strides[ d] as isize ;
446+ }
447+
448+ let out_idx = base_out + ke;
449+ out[ out_idx] = if in_bounds { x[ flat as usize ] } else { 0.0 } ;
450+ }
451+ }
452+
453+ out
454+ }
455+
456+ // -------- helpers --------
457+
458+ fn prod ( xs : & [ usize ] ) -> usize {
459+ xs. iter ( ) . copied ( ) . product ( )
460+ }
461+
462+ fn row_major_multipliers ( shape : & [ usize ] ) -> Vec < usize > {
463+ let n = shape. len ( ) ;
464+ let mut mul = vec ! [ 1usize ; n] ;
465+ let mut acc = 1usize ;
466+ for d in ( 0 ..n) . rev ( ) {
467+ mul[ d] = acc;
468+ acc *= shape[ d] ;
469+ }
470+ mul
471+ }
472+
473+ fn unravel_row_major ( mut idx : usize , shape : & [ usize ] , mul : & [ usize ] ) -> Vec < usize > {
474+ let n = shape. len ( ) ;
475+ let mut coords = vec ! [ 0usize ; n] ;
476+ for d in 0 ..n {
477+ coords[ d] = idx / mul[ d] ;
478+ idx %= mul[ d] ;
479+ }
480+ coords
481+ }
482+
483+ fn incr_row_major ( pos : & mut [ usize ] , shape : & [ usize ] ) {
484+ for d in ( 0 ..pos. len ( ) ) . rev ( ) {
485+ pos[ d] += 1 ;
486+ if pos[ d] < shape[ d] {
487+ return ;
488+ }
489+ pos[ d] = 0 ;
490+ }
491+ }
492+
493+ test_unary (
494+ 5 ,
495+ |a| a. unfold ( 3 , 1 , 1 ) ,
496+ |a| {
497+ Tensor :: new (
498+ unfold_nd_f32 (
499+ & a. flatten_all ( ) . unwrap ( ) . to_vec1 :: < f32 > ( ) . unwrap ( ) ,
500+ a. dims ( ) ,
501+ a. stride ( ) ,
502+ & [ 3 ] ,
503+ & [ 1 ] ,
504+ & [ 1 ] ,
505+ & [ 0 ] ,
506+ & [ 0 ] ,
507+ ) ,
508+ a. device ( ) ,
509+ )
510+ . unwrap ( )
511+ } ,
512+ ) ;
513+ test_unary (
514+ ( 8 , 10 ) ,
515+ |a| a. pad ( ( ( 0 , 2 ) , ( 4 , 4 ) ) ) . unfold ( ( 2 , 3 ) , ( 1 , 2 ) , ( 2 , 1 ) ) ,
516+ |a| {
517+ Tensor :: new (
518+ unfold_nd_f32 (
519+ & a. flatten_all ( ) . unwrap ( ) . to_vec1 :: < f32 > ( ) . unwrap ( ) ,
520+ a. dims ( ) ,
521+ a. stride ( ) ,
522+ & [ 2 , 3 ] ,
523+ & [ 1 , 2 ] ,
524+ & [ 2 , 1 ] ,
525+ & [ 0 , 4 ] ,
526+ & [ 2 , 3 ] ,
527+ ) ,
528+ a. device ( ) ,
529+ )
530+ . unwrap ( )
531+ } ,
532+ ) ;
345533 }
346534
347535 #[ test]
0 commit comments