Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion candle-core/examples/cuda_sum_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use anyhow::Result;
use candle_core::{Device, Tensor};

fn cos_sin(n: usize, device: &Device) -> Result<Tensor> {
let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect();
let thetas: Vec<_> = (0..n).map(|i| i as f32 / n as f32).collect();
let xs: Vec<_> = thetas.iter().map(|t| t.cos().abs()).collect();
let ys: Vec<_> = thetas.iter().map(|t| t.sin().abs()).collect();
let xs = Tensor::from_vec(xs, (n, 1), device)?;
Expand Down
28 changes: 28 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Unfold(node, _, _, _)
| Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDevice(node)
Expand Down Expand Up @@ -495,6 +496,33 @@
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
}
&Op::Unfold(ref arg, dim, size, step) => {
assert!(false, "never runs");

Check failure on line 500 in candle-core/src/backprop.rs

View workflow job for this annotation

GitHub Actions / Clippy

`assert!(false, ..)` should probably be replaced

let arg_dims = arg.dims();
let node_dims = node.dims();

println!("arg.id {:?}", arg.id());
println!("node.id {:?}", node.id());

let sum_grad = grads.or_insert(arg)?;
let extra_dim = arg_dims.len();

let windows = node_dims[dim];
for widx in 0..windows {
let window_slice = grad
.get_on_dim(dim, widx)?
.unsqueeze(dim)?
.transpose(dim, extra_dim)?
.squeeze(extra_dim)?;

let start = widx * step;
let end = start + size;

let indexes = Tensor::arange(start as f32, end as f32, self.device())?;
*sum_grad = sum_grad.index_add(&indexes, &window_slice, dim)?;
}
}
Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let sum_grad = grads.or_insert(arg)?;
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub enum Op {
ToDType(Tensor),
Copy(Tensor),
Broadcast(Tensor),
Unfold(Tensor, usize, usize, usize),
Narrow(Tensor, usize, usize, usize),
SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
Expand Down
51 changes: 51 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]

use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::scalar::TensorOrScalar;
Expand Down Expand Up @@ -2269,6 +2270,56 @@ impl Tensor {
self.broadcast_as(shape)
}

/// Unfold windows along a dimension.
///
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
/// where windows are advanced by `step` at each index.
///
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
///
/// The new view will have the unfolded dimension replaced by two dimensions;
/// one in the position of the original dimension, with size equal to the number of windows,
/// and one appended to the right-most position, with size equal to `size`.
///
/// # Arguments
///
/// * `dim` - the dimension to unfold.
/// * `size` - the size of each unfolded window.
/// * `stride` - the step between each window.
///
/// # Returns
///
/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
pub fn unfold(&self, dim: usize, size: usize, step: usize) -> Result<Self> {
let mut shape = self.layout.shape().dims().to_vec();
let mut strides = self.layout.stride().to_vec();

let d_shape = shape[dim];
let d_stride = strides[dim];

let tmp = d_shape + step;
let windows = if tmp < size { 0 } else { (tmp - size) / step };

shape[dim] = windows;
shape.push(size);

strides[dim] = step * d_stride;
strides.push(d_stride);

let unfold_layout = Layout::new(shape.into(), strides, self.layout.start_offset());

let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: unfold_layout,
op: BackpropOp::new1(self, |arg| Op::Unfold(arg, dim, size, step)),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}

/// Casts the input tensor to the target `dtype`.
///
/// ```rust
Expand Down
16 changes: 16 additions & 0 deletions candle-core/tests/grad_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,21 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}

fn unfold_grad(device: &Device) -> Result<()> {
let data = &[[0f32, 1., 2., 3., 4.], [5f32, 6., 7., 8., 9.]];
let x = Tensor::new(data, device)?;
let unf = x.unfold(1, 3, 1)?;
let y = (&unf + 1.)?;
let grads = y.backward()?;
println!("grads: {:?}", grads);
println!("x.id: {:?}", x.id());
println!("unf.id: {:?}", unf.id());
println!("y.id: {:?}", x.id());
let _grad_tensor = grads.get(&x).context("no grad for tensor")?;

Ok(())
}

#[test]
fn test_flip_backprop() -> Result<()> {
let device = &Device::Cpu;
Expand Down Expand Up @@ -555,6 +570,7 @@ test_device!(
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(unfold_grad, unfold_grad_cpu, unfold_grad_gpu, unfold_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
Expand Down
15 changes: 15 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,20 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}

fn unfold(device: &Device) -> Result<()> {
let data = &[[0f32, 1., 2., 3., 4.], [5f32, 6., 7., 8., 9.]];
let tensor = Tensor::new(data, device)?;
let actual = tensor.unfold(1, 3, 2)?;
assert_eq!(
actual.to_vec3::<f32>()?,
&[
[[0f32, 1., 2.], [2f32, 3., 4.]],
[[5f32, 6., 7.], [7f32, 8., 9.]],
]
);
Ok(())
}

fn slice_set(device: &Device) -> Result<()> {
let (b, h, max_t, d) = (2, 4, 7, 3);
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
Expand Down Expand Up @@ -1655,6 +1669,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(unfold, unfold_cpu, unfold_gpu, unfold_metal);
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
Expand Down
Loading