Skip to content

Commit 90c96fb

Browse files
committed
cumulative ops added back and topk fixed
1 parent 6788b98 commit 90c96fb

File tree

4 files changed

+377
-167
lines changed

4 files changed

+377
-167
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pretty-duration = "0.1.1"
3838
[dev-dependencies]
3939
candle-core = "0.9.1"
4040
candle-nn = "0.9.1"
41+
ordered-float = "5.1.0"
4142

4243
[workspace]
4344
members = [

src/hl_ops/movement.rs

Lines changed: 214 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,22 @@ impl GraphTensor {
5353

5454
/// add a new dimension of size 1 at the specified place
5555
pub fn unsqueeze(mut self, dim: usize) -> GraphTensor {
56-
// Insert contiguous call
5756
assert!(self.shape.len() < 10, "Shape is maxed out at 10 dimensions");
5857
self.shape.expand_dim(dim, 1);
5958
self
6059
}
6160

61+
/// remove a dimension of size 1
62+
pub fn squeeze(mut self, axis: usize) -> GraphTensor {
63+
assert_eq!(
64+
self.dims()[axis],
65+
Expression::from(1),
66+
"Only dimensions of size 1 can be squeezed!"
67+
);
68+
self.shape.remove_dim(axis);
69+
self
70+
}
71+
6272
pub fn gather(self, indexes: GraphTensor) -> GraphTensor {
6373
assert_eq!(
6474
indexes.dtype,
@@ -78,18 +88,32 @@ impl GraphTensor {
7888
/// x = [3, 2, 4, 1, 5, 0]
7989
/// inv_perm(x) = [5, 3, 1, 0, 2, 4]
8090
pub fn inverse_permutation(self, axis: usize) -> GraphTensor {
81-
// TODO: this is very inefficient because we need to do O(n^2) comparisons and allocations for the one-hot and then sum reduce. Need a better way!
82-
let ax_size = self.dims()[axis];
83-
let mut arange = self.graph().arange(ax_size);
91+
// TODO: this is super inefficient because it requires materializing a large (n^2) one-hot tensor
92+
assert_eq!(self.dtype, DType::Int);
93+
let dims = self.dims();
94+
let ax_size = dims[axis];
95+
let mut dims2 = dims.clone();
96+
dims2.insert(axis, ax_size);
97+
// candidate: varies along candidate dim (axis), broadcast elsewhere.
98+
let mut candidate = self.graph().arange(ax_size);
8499
for i in 0..axis {
85-
arange = arange.expand_dim(i, self.dims()[i]);
100+
candidate = candidate.expand_dim(i, dims2[i]);
101+
}
102+
for i in axis + 1..dims2.len() {
103+
candidate = candidate.expand_dim(i, dims2[i]);
86104
}
87-
for i in axis + 1..self.dims().len() {
88-
arange = arange.expand_dim(i, self.dims()[i]);
105+
// position: varies along position dim (axis+1), broadcast elsewhere.
106+
let mut position = self.graph().arange(ax_size);
107+
for i in 0..(axis + 1) {
108+
position = position.expand_dim(i, dims2[i]);
89109
}
90-
arange = arange.expand_dim(axis + 1, ax_size);
91-
let one_hot = self.expand_dim(axis + 1, ax_size).eq(arange);
92-
(one_hot * arange).sum(axis + 1)
110+
for i in (axis + 2)..dims2.len() {
111+
position = position.expand_dim(i, dims2[i]);
112+
}
113+
// one_hot[candidate, ..., position, ...] = (self[position, ...] == candidate)
114+
let one_hot = self.expand_dim(axis, ax_size).eq(candidate);
115+
// inv[candidate, ...] = Σ_pos one_hot * position
116+
(one_hot * position).sum(axis + 1)
93117
}
94118

95119
/// Extracts sliding local windows from an input tensor.
@@ -101,6 +125,7 @@ impl GraphTensor {
101125
) -> GraphTensor {
102126
let (kernel, strides, dilation) =
103127
(kernel.to_shape(), strides.to_shape(), dilation.to_shape());
128+
104129
assert_eq!(
105130
self.shape.len(),
106131
kernel.len(),
@@ -117,7 +142,7 @@ impl GraphTensor {
117142
"Dilation must be same number of dimensions as tensor!"
118143
);
119144

120-
// Compute input strides
145+
// Compute input strides (row-major contiguous)
121146
let dims = self.dims();
122147
let n = dims.len();
123148
let mut in_strides = vec![Expression::from(1); n];
@@ -128,25 +153,27 @@ impl GraphTensor {
128153
}
129154

130155
// Per-dim window counts
131-
let mut win = Vec::with_capacity(dims.len());
132-
for (((dim, kernel), stride), dilation) in
133-
dims.into_iter().zip(&kernel).zip(&strides).zip(&dilation)
134-
{
135-
let effective_window = *dilation * (*kernel - 1) + 1;
136-
win.push(((dim - effective_window) / stride) + 1);
156+
let mut win = Vec::with_capacity(n);
157+
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
158+
let effective_window = *d * (*k - 1) + 1;
159+
win.push(((*dim - effective_window) / s) + 1);
137160
}
138161

139-
// final_shape = [kernel..., win...]
140-
let mut final_shape = kernel.clone();
141-
final_shape.extend(win.into_iter().map(|e| e.simplify()));
162+
// [win..., kernel...]
163+
let mut final_shape: Vec<Expression> = win.into_iter().map(|e| e.simplify()).collect();
164+
final_shape.extend(kernel.iter().copied());
142165

143-
// strides exprs over axes [k0..kN-1, w0..wN-1]
166+
// Axis exprs must match final_shape axis order: first w axes, then k axes.
167+
// idx = Σ_d (w_d * stride_d + k_d * dilation_d) * in_strides[d]
144168
let mut axis_exprs = Vec::with_capacity(2 * n);
169+
170+
// w axes
145171
for i in 0..n {
146-
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
172+
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
147173
}
174+
// k axes
148175
for i in 0..n {
149-
axis_exprs.push(Expression::from('z') * strides[i] * in_strides[i]);
176+
axis_exprs.push(Expression::from('z') * dilation[i] * in_strides[i]);
150177
}
151178

152179
let index_expression = flatten_strides(&final_shape, &axis_exprs).simplify();
@@ -266,6 +293,7 @@ impl GraphTensor {
266293
new_tensor * mask
267294
}
268295

296+
/// Pad along an existing dimension
269297
pub fn pad_along(
270298
self,
271299
left: impl Into<Expression>,
@@ -338,10 +366,170 @@ mod tests {
338366

339367
#[test]
340368
fn test_unfold() {
341-
let mut cx = Graph::new();
369+
// Need all this code because candle doesnt do unfold
370+
pub fn unfold_nd_f32(
371+
x: &[f32],
372+
shape: &[usize],
373+
strides: &[usize],
374+
kernel: &[usize],
375+
step: &[usize],
376+
dilation: &[usize],
377+
pad_before: &[usize],
378+
pad_after: &[usize],
379+
) -> Vec<f32> {
380+
let n = shape.len();
381+
assert!(n > 0);
382+
assert_eq!(strides.len(), n);
383+
assert_eq!(kernel.len(), n);
384+
assert_eq!(step.len(), n);
385+
assert_eq!(dilation.len(), n);
386+
assert_eq!(pad_before.len(), n);
387+
assert_eq!(pad_after.len(), n);
388+
389+
for d in 0..n {
390+
assert!(kernel[d] > 0);
391+
assert!(step[d] > 0);
392+
assert!(dilation[d] > 0);
393+
assert!(shape[d] > 0);
394+
}
395+
396+
// Effective kernel size per dim: (K-1)*d + 1
397+
let eff_kernel: Vec<usize> =
398+
(0..n).map(|d| (kernel[d] - 1) * dilation[d] + 1).collect();
399+
400+
// Output spatial shape (number of windows) per dim
401+
let mut out_shape = vec![0usize; n];
402+
for d in 0..n {
403+
let padded = shape[d] + pad_before[d] + pad_after[d];
404+
if padded < eff_kernel[d] {
405+
return Vec::new();
406+
}
407+
out_shape[d] = (padded - eff_kernel[d]) / step[d] + 1;
408+
}
409+
410+
let windows = prod(&out_shape);
411+
let window_elems = prod(kernel);
412+
let mut out = vec![0.0f32; windows * window_elems];
413+
414+
// Precompute helpers
415+
let k_mul = row_major_multipliers(kernel);
416+
417+
// Current output window position (row-major)
418+
let mut out_pos = vec![0usize; n];
419+
420+
for w in 0..windows {
421+
if w > 0 {
422+
incr_row_major(&mut out_pos, &out_shape);
423+
}
424+
425+
// Window start in padded coordinates
426+
let start_padded: Vec<usize> = (0..n).map(|d| out_pos[d] * step[d]).collect();
427+
428+
let base_out = w * window_elems;
429+
430+
// Iterate kernel elements (flattened)
431+
for ke in 0..window_elems {
432+
let k_idx = unravel_row_major(ke, kernel, &k_mul);
433+
434+
let mut flat: isize = 0;
435+
let mut in_bounds = true;
342436

343-
let inp = cx.tensor((5,));
344-
let _pooled = inp.unfold((3,), (1,), (1,));
437+
for d in 0..n {
438+
let p = start_padded[d] + k_idx[d] * dilation[d];
439+
let logical = p as isize - pad_before[d] as isize;
440+
441+
if logical < 0 || logical >= shape[d] as isize {
442+
in_bounds = false;
443+
break;
444+
}
445+
flat += logical * strides[d] as isize;
446+
}
447+
448+
let out_idx = base_out + ke;
449+
out[out_idx] = if in_bounds { x[flat as usize] } else { 0.0 };
450+
}
451+
}
452+
453+
out
454+
}
455+
456+
// -------- helpers --------
457+
458+
fn prod(xs: &[usize]) -> usize {
459+
xs.iter().copied().product()
460+
}
461+
462+
fn row_major_multipliers(shape: &[usize]) -> Vec<usize> {
463+
let n = shape.len();
464+
let mut mul = vec![1usize; n];
465+
let mut acc = 1usize;
466+
for d in (0..n).rev() {
467+
mul[d] = acc;
468+
acc *= shape[d];
469+
}
470+
mul
471+
}
472+
473+
fn unravel_row_major(mut idx: usize, shape: &[usize], mul: &[usize]) -> Vec<usize> {
474+
let n = shape.len();
475+
let mut coords = vec![0usize; n];
476+
for d in 0..n {
477+
coords[d] = idx / mul[d];
478+
idx %= mul[d];
479+
}
480+
coords
481+
}
482+
483+
fn incr_row_major(pos: &mut [usize], shape: &[usize]) {
484+
for d in (0..pos.len()).rev() {
485+
pos[d] += 1;
486+
if pos[d] < shape[d] {
487+
return;
488+
}
489+
pos[d] = 0;
490+
}
491+
}
492+
493+
test_unary(
494+
5,
495+
|a| a.unfold(3, 1, 1),
496+
|a| {
497+
Tensor::new(
498+
unfold_nd_f32(
499+
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
500+
a.dims(),
501+
a.stride(),
502+
&[3],
503+
&[1],
504+
&[1],
505+
&[0],
506+
&[0],
507+
),
508+
a.device(),
509+
)
510+
.unwrap()
511+
},
512+
);
513+
test_unary(
514+
(8, 10),
515+
|a| a.pad(((0, 2), (4, 4))).unfold((2, 3), (1, 2), (2, 1)),
516+
|a| {
517+
Tensor::new(
518+
unfold_nd_f32(
519+
&a.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
520+
a.dims(),
521+
a.stride(),
522+
&[2, 3],
523+
&[1, 2],
524+
&[2, 1],
525+
&[0, 4],
526+
&[2, 3],
527+
),
528+
a.device(),
529+
)
530+
.unwrap()
531+
},
532+
);
345533
}
346534

347535
#[test]

0 commit comments

Comments
 (0)