Skip to content

Commit 72d58cd

Browse files
committed
Formatting
1 parent 6f12ebf commit 72d58cd

2 files changed

Lines changed: 2 additions & 246 deletions

File tree

candle-core/src/npy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl Header {
106106
let mut parts: Vec<String> = vec![];
107107
let mut start_index = 0usize;
108108
let mut cnt_parenthesis = 0i64;
109-
for (index, c) in header.chars().enumerate() {
109+
for (index, c) in header.char_indices() {
110110
match c {
111111
'(' => cnt_parenthesis += 1,
112112
')' => cnt_parenthesis -= 1,

candle-core/src/tensor_indexing.rs

Lines changed: 1 addition & 245 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
22

3-
use crate::{
4-
bail,
5-
op::{BackpropOp, Op},
6-
shape::Dim,
7-
tensor::from_storage,
8-
DType, Error, Result, Tensor,
9-
};
3+
use crate::{bail, DType, Error, Result, Tensor};
104

115
/// Specialization of `std::ops::RangeBounds` for `usize` to allow trait objects.
126
pub trait RangeBound {
@@ -138,242 +132,4 @@ impl Tensor {
138132
}
139133
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
140134
}
141-
142-
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
143-
let dim = dim.to_index(self.shape(), "scatter-add")?;
144-
let source_dims = source.dims();
145-
let self_dims = self.dims();
146-
let mismatch = if source_dims.len() != self_dims.len() {
147-
true
148-
} else {
149-
let mut mismatch = false;
150-
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
151-
if i != dim && d1 != d2 {
152-
mismatch = true;
153-
break;
154-
}
155-
}
156-
mismatch
157-
};
158-
if mismatch {
159-
Err(Error::ShapeMismatchBinaryOp {
160-
op: "scatter-add (self, src)",
161-
lhs: self.shape().clone(),
162-
rhs: source.shape().clone(),
163-
}
164-
.bt())?
165-
}
166-
if indexes.dims() != source.dims() {
167-
Err(Error::ShapeMismatchBinaryOp {
168-
op: "scatter-add (indexes, src)",
169-
lhs: indexes.shape().clone(),
170-
rhs: source.shape().clone(),
171-
}
172-
.bt())?
173-
}
174-
let storage = self.storage().scatter_add(
175-
self.layout(),
176-
&indexes.storage(),
177-
indexes.layout(),
178-
&source.storage(),
179-
source.layout(),
180-
dim,
181-
)?;
182-
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
183-
Op::ScatterAdd(t1, t2, t3, dim)
184-
});
185-
Ok(from_storage(storage, self.shape(), op, false))
186-
}
187-
188-
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
189-
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
190-
let dim = dim.to_index(self.shape(), "slice-scatter")?;
191-
if dim == 0 {
192-
self.slice_scatter0(src, start)
193-
} else {
194-
// TODO: Maybe we want to add a more efficient implementation at some point.
195-
self.transpose(0, dim)?
196-
.slice_scatter0(&src.transpose(0, dim)?, start)?
197-
.transpose(0, dim)
198-
}
199-
}
200-
201-
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
202-
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
203-
if self.dtype() != src.dtype() {
204-
Err(Error::DTypeMismatchBinaryOp {
205-
lhs: self.dtype(),
206-
rhs: src.dtype(),
207-
op: "slice-scatter",
208-
}
209-
.bt())?
210-
}
211-
if self.device().location() != src.device().location() {
212-
Err(Error::DeviceMismatchBinaryOp {
213-
lhs: self.device().location(),
214-
rhs: src.device().location(),
215-
op: "slice-scatter",
216-
}
217-
.bt())?
218-
}
219-
if self.rank() != src.rank() {
220-
Err(Error::UnexpectedNumberOfDims {
221-
expected: self.rank(),
222-
got: src.rank(),
223-
shape: src.shape().clone(),
224-
}
225-
.bt())?
226-
}
227-
let shape_ok =
228-
self.dims()
229-
.iter()
230-
.zip(src.dims().iter())
231-
.enumerate()
232-
.all(|(dim_idx, (&d1, &d2))| {
233-
if 0 == dim_idx {
234-
d2 + start <= d1
235-
} else {
236-
d1 == d2
237-
}
238-
});
239-
if !shape_ok {
240-
Err(Error::ShapeMismatchBinaryOp {
241-
op: "slice-scatter (self, src)",
242-
lhs: self.shape().clone(),
243-
rhs: src.shape().clone(),
244-
}
245-
.bt())?
246-
}
247-
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
248-
self.storage()
249-
.copy_strided_src(&mut storage, 0, self.layout())?;
250-
let offset = start * src.dims()[1..].iter().product::<usize>();
251-
src.storage()
252-
.copy_strided_src(&mut storage, offset, src.layout())?;
253-
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
254-
Ok(from_storage(storage, self.shape(), op, false))
255-
}
256-
257-
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
258-
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
259-
let dim = dim.to_index(self.shape(), "index-add")?;
260-
let source_dims = source.dims();
261-
let self_dims = self.dims();
262-
let mismatch = if source_dims.len() != self_dims.len() {
263-
true
264-
} else {
265-
let mut mismatch = false;
266-
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
267-
if i != dim && d1 != d2 {
268-
mismatch = true;
269-
break;
270-
}
271-
}
272-
mismatch
273-
};
274-
if mismatch {
275-
Err(Error::ShapeMismatchBinaryOp {
276-
op: "index-add (self, source)",
277-
lhs: self.shape().clone(),
278-
rhs: source.shape().clone(),
279-
}
280-
.bt())?
281-
}
282-
// The number of element in indexes must match the dimension on which the add is
283-
// performed on the source tensor (and the index values from `indexes` are taken from
284-
// the target tensor self)
285-
let indexes_len = indexes.dims1()?;
286-
if source_dims[dim] != indexes_len {
287-
Err(Error::ShapeMismatchBinaryOp {
288-
op: "index-add (ids, source))",
289-
lhs: indexes.shape().clone(),
290-
rhs: source.shape().clone(),
291-
}
292-
.bt())?
293-
}
294-
let storage = self.storage().index_add(
295-
self.layout(),
296-
&indexes.storage(),
297-
indexes.layout(),
298-
&source.storage(),
299-
source.layout(),
300-
dim,
301-
)?;
302-
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
303-
Op::IndexAdd(t1, t2, t3, dim)
304-
});
305-
Ok(from_storage(storage, self.shape(), op, false))
306-
}
307-
308-
/// Gather values across the target dimension.
309-
///
310-
/// # Arguments
311-
///
312-
/// * `self` - The input tensor.
313-
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
314-
/// but can have a different number of elements on the target dimension.
315-
/// * `dim` - the target dimension.
316-
///
317-
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
318-
/// dimension `dim` by the values in `indexes`.
319-
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
320-
let dim = dim.to_index(self.shape(), "gather")?;
321-
let self_dims = self.dims();
322-
let indexes_dims = indexes.dims();
323-
let mismatch = if indexes_dims.len() != self_dims.len() {
324-
true
325-
} else {
326-
let mut mismatch = false;
327-
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
328-
if i != dim && d1 != d2 {
329-
mismatch = true;
330-
break;
331-
}
332-
}
333-
mismatch
334-
};
335-
if mismatch {
336-
Err(Error::ShapeMismatchBinaryOp {
337-
op: "gather",
338-
lhs: self.shape().clone(),
339-
rhs: indexes.shape().clone(),
340-
}
341-
.bt())?
342-
}
343-
let storage =
344-
self.storage()
345-
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
346-
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
347-
Ok(from_storage(storage, indexes.shape(), op, false))
348-
}
349-
350-
/// Select values for the input tensor at the target indexes across the specified dimension.
351-
///
352-
/// The `indexes` is argument is an int tensor with a single dimension.
353-
/// The output has the same number of dimension as the `self` input. The target dimension of
354-
/// the output has length the length of `indexes` and the values are taken from `self` using
355-
/// the index from `indexes`. Other dimensions have the same number of elements as the input
356-
/// tensor.
357-
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
358-
let dim = dim.to_index(self.shape(), "index-select")?;
359-
let indexes_len = match indexes.dims() {
360-
[l] => *l,
361-
_ => Err(Error::ShapeMismatchBinaryOp {
362-
lhs: self.shape().clone(),
363-
rhs: indexes.shape().clone(),
364-
op: "index-select",
365-
}
366-
.bt())?,
367-
};
368-
let storage = self.storage().index_select(
369-
&indexes.storage(),
370-
self.layout(),
371-
indexes.layout(),
372-
dim,
373-
)?;
374-
let mut dims = self.dims().to_vec();
375-
dims[dim] = indexes_len;
376-
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
377-
Ok(from_storage(storage, dims, op, false))
378-
}
379135
}

0 commit comments

Comments
 (0)