Skip to content

Commit f519940

Browse files
committed
torch
1 parent f453748 commit f519940

File tree

23 files changed

+147
-55
lines changed

23 files changed

+147
-55
lines changed

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
108108
B::bool_repeat_dim(tensor, dim, times)
109109
}
110110

111-
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
111+
fn bool_unfold(
112+
tensor: BoolTensor<Self>,
113+
dim: usize,
114+
size: usize,
115+
step: usize,
116+
) -> BoolTensor<Self> {
112117
B::bool_unfold(tensor, dim, size, step)
113118
}
114119
}

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
378378
B::int_cast(tensor, dtype)
379379
}
380380

381-
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
381+
fn int_unfold(
382+
tensor: IntTensor<Self>,
383+
dim: usize,
384+
size: usize,
385+
step: usize,
386+
) -> IntTensor<Self> {
382387
B::int_unfold(tensor, dim, size, step)
383388
}
384389
}

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2593,7 +2593,12 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
25932593
// TODO: Implement float_prod and float_sum
25942594
// https://github.com/tracel-ai/burn/issues/1458
25952595

2596-
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2596+
fn float_unfold(
2597+
tensor: FloatTensor<Self>,
2598+
dim: usize,
2599+
size: usize,
2600+
step: usize,
2601+
) -> FloatTensor<Self> {
25972602
AutodiffTensor::new(B::float_unfold(tensor.primitive, dim, size, step))
25982603
}
25992604
}

crates/burn-cubecl/src/ops/base.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
use std::cmp::max;
21
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
32
use burn_common::tensor::{ReshapeAction, reshape_action};
4-
use burn_tensor::{Shape, TensorData, quantization::{QTensorPrimitive, QuantLevel}};
3+
use burn_tensor::{
4+
Shape, TensorData,
5+
quantization::{QTensorPrimitive, QuantLevel},
6+
};
57
use cubecl::{server::CopyDescriptor, tensor_vectorization_factor};
8+
use std::cmp::max;
69

710
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
811
let shape: Shape = (&data.shape).into();

crates/burn-cubecl/src/ops/bool_ops.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,12 @@ where
142142
kernel::flip::<R, BT, BT>(tensor, axes)
143143
}
144144

145-
fn bool_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
145+
fn bool_unfold(
146+
tensor: FloatTensor<Self>,
147+
dim: usize,
148+
size: usize,
149+
step: usize,
150+
) -> FloatTensor<Self> {
146151
unfold(tensor, dim, size, step)
147152
}
148153
}

crates/burn-cubecl/src/ops/float_ops.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,12 @@ where
703703
}
704704
}
705705

706-
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
706+
fn float_unfold(
707+
tensor: FloatTensor<Self>,
708+
dim: usize,
709+
size: usize,
710+
step: usize,
711+
) -> FloatTensor<Self> {
707712
unfold(tensor, dim, size, step)
708713
}
709714
}

crates/burn-cubecl/src/ops/int_ops.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,12 @@ where
685685
)
686686
}
687687

688-
fn int_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
688+
fn int_unfold(
689+
tensor: FloatTensor<Self>,
690+
dim: usize,
691+
size: usize,
692+
step: usize,
693+
) -> FloatTensor<Self> {
689694
unfold(tensor, dim, size, step)
690695
}
691696
}

crates/burn-fusion/src/ops/boolean.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
use std::cmp::max;
2-
use burn_ir::{BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer, InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr};
1+
use burn_ir::{
2+
BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
3+
InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
4+
SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
5+
};
36
use burn_tensor::{
47
Device, Element, Shape, Slice, TensorData, TensorMetadata,
58
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
69
};
10+
use std::cmp::max;
711
use std::marker::PhantomData;
812

913
use crate::{
@@ -742,7 +746,12 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
742746
out
743747
}
744748

745-
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
749+
fn bool_unfold(
750+
tensor: BoolTensor<Self>,
751+
dim: usize,
752+
size: usize,
753+
step: usize,
754+
) -> BoolTensor<Self> {
746755
#[derive(new, Debug)]
747756
struct UnfoldOps<B: FusionBackend> {
748757
desc: UnfoldOpIr,
@@ -752,11 +761,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
752761
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
753762
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
754763
let input = handles.get_bool_tensor::<B>(&self.desc.input);
755-
let output = B::bool_unfold(
756-
input,
757-
self.desc.dim,
758-
self.desc.size,
759-
self.desc.step);
764+
let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
760765

761766
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
762767
}

crates/burn-fusion/src/ops/float.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::NoOp;
12
use crate::{
23
Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops,
34
client::FusionClient,
@@ -12,9 +13,8 @@ use burn_tensor::{
1213
Device, Distribution, Element, FloatDType, Shape, Slice, TensorData, TensorMetadata,
1314
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor, binary_ops_shape},
1415
};
15-
use std::{marker::PhantomData, ops::Range};
1616
use std::cmp::max;
17-
use super::NoOp;
17+
use std::{marker::PhantomData, ops::Range};
1818

1919
impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
2020
fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
@@ -2260,7 +2260,12 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22602260
out
22612261
}
22622262

2263-
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2263+
fn float_unfold(
2264+
tensor: FloatTensor<Self>,
2265+
dim: usize,
2266+
size: usize,
2267+
step: usize,
2268+
) -> FloatTensor<Self> {
22642269
#[derive(new, Debug)]
22652270
struct UnfoldOps<B: FusionBackend> {
22662271
desc: UnfoldOpIr,
@@ -2270,11 +2275,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22702275
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
22712276
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
22722277
let input = handles.get_float_tensor::<B>(&self.desc.input);
2273-
let output = B::float_unfold(
2274-
input,
2275-
self.desc.dim,
2276-
self.desc.size,
2277-
self.desc.step);
2278+
let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
22782279

22792280
handles.register_float_tensor::<B>(&self.desc.out.id, output);
22802281
}

crates/burn-fusion/src/ops/int.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,7 +2173,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
21732173
out
21742174
}
21752175

2176-
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
2176+
fn int_unfold(
2177+
tensor: IntTensor<Self>,
2178+
dim: usize,
2179+
size: usize,
2180+
step: usize,
2181+
) -> IntTensor<Self> {
21772182
#[derive(new, Debug)]
21782183
struct UnfoldOps<B: FusionBackend> {
21792184
desc: UnfoldOpIr,
@@ -2183,11 +2188,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
21832188
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
21842189
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
21852190
let input = handles.get_int_tensor::<B>(&self.desc.input);
2186-
let output = B::int_unfold(
2187-
input,
2188-
self.desc.dim,
2189-
self.desc.size,
2190-
self.desc.step);
2191+
let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
21912192

21922193
handles.register_int_tensor::<B>(&self.desc.out.id, output);
21932194
}

0 commit comments

Comments
 (0)