|
1 | 1 | use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; |
2 | 2 |
|
3 | | -use crate::{ |
4 | | - bail, |
5 | | - op::{BackpropOp, Op}, |
6 | | - shape::Dim, |
7 | | - tensor::from_storage, |
8 | | - DType, Error, Result, Tensor, |
9 | | -}; |
| 3 | +use crate::{bail, DType, Error, Result, Tensor}; |
10 | 4 |
|
11 | 5 | /// Specialization of `std::ops::RangeBounds` for `usize` to allow trait objects. |
12 | 6 | pub trait RangeBound { |
@@ -138,242 +132,4 @@ impl Tensor { |
138 | 132 | } |
139 | 133 | mask.where_cond(/* on_true= */ &src, /* on_false= */ self) |
140 | 134 | } |
141 | | - |
142 | | - pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { |
143 | | - let dim = dim.to_index(self.shape(), "scatter-add")?; |
144 | | - let source_dims = source.dims(); |
145 | | - let self_dims = self.dims(); |
146 | | - let mismatch = if source_dims.len() != self_dims.len() { |
147 | | - true |
148 | | - } else { |
149 | | - let mut mismatch = false; |
150 | | - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { |
151 | | - if i != dim && d1 != d2 { |
152 | | - mismatch = true; |
153 | | - break; |
154 | | - } |
155 | | - } |
156 | | - mismatch |
157 | | - }; |
158 | | - if mismatch { |
159 | | - Err(Error::ShapeMismatchBinaryOp { |
160 | | - op: "scatter-add (self, src)", |
161 | | - lhs: self.shape().clone(), |
162 | | - rhs: source.shape().clone(), |
163 | | - } |
164 | | - .bt())? |
165 | | - } |
166 | | - if indexes.dims() != source.dims() { |
167 | | - Err(Error::ShapeMismatchBinaryOp { |
168 | | - op: "scatter-add (indexes, src)", |
169 | | - lhs: indexes.shape().clone(), |
170 | | - rhs: source.shape().clone(), |
171 | | - } |
172 | | - .bt())? |
173 | | - } |
174 | | - let storage = self.storage().scatter_add( |
175 | | - self.layout(), |
176 | | - &indexes.storage(), |
177 | | - indexes.layout(), |
178 | | - &source.storage(), |
179 | | - source.layout(), |
180 | | - dim, |
181 | | - )?; |
182 | | - let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { |
183 | | - Op::ScatterAdd(t1, t2, t3, dim) |
184 | | - }); |
185 | | - Ok(from_storage(storage, self.shape(), op, false)) |
186 | | - } |
187 | | - |
188 | | - /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. |
189 | | - pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> { |
190 | | - let dim = dim.to_index(self.shape(), "slice-scatter")?; |
191 | | - if dim == 0 { |
192 | | - self.slice_scatter0(src, start) |
193 | | - } else { |
194 | | - // TODO: Maybe we want to add a more efficient implementation at some point. |
195 | | - self.transpose(0, dim)? |
196 | | - .slice_scatter0(&src.transpose(0, dim)?, start)? |
197 | | - .transpose(0, dim) |
198 | | - } |
199 | | - } |
200 | | - |
201 | | - /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. |
202 | | - pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> { |
203 | | - if self.dtype() != src.dtype() { |
204 | | - Err(Error::DTypeMismatchBinaryOp { |
205 | | - lhs: self.dtype(), |
206 | | - rhs: src.dtype(), |
207 | | - op: "slice-scatter", |
208 | | - } |
209 | | - .bt())? |
210 | | - } |
211 | | - if self.device().location() != src.device().location() { |
212 | | - Err(Error::DeviceMismatchBinaryOp { |
213 | | - lhs: self.device().location(), |
214 | | - rhs: src.device().location(), |
215 | | - op: "slice-scatter", |
216 | | - } |
217 | | - .bt())? |
218 | | - } |
219 | | - if self.rank() != src.rank() { |
220 | | - Err(Error::UnexpectedNumberOfDims { |
221 | | - expected: self.rank(), |
222 | | - got: src.rank(), |
223 | | - shape: src.shape().clone(), |
224 | | - } |
225 | | - .bt())? |
226 | | - } |
227 | | - let shape_ok = |
228 | | - self.dims() |
229 | | - .iter() |
230 | | - .zip(src.dims().iter()) |
231 | | - .enumerate() |
232 | | - .all(|(dim_idx, (&d1, &d2))| { |
233 | | - if 0 == dim_idx { |
234 | | - d2 + start <= d1 |
235 | | - } else { |
236 | | - d1 == d2 |
237 | | - } |
238 | | - }); |
239 | | - if !shape_ok { |
240 | | - Err(Error::ShapeMismatchBinaryOp { |
241 | | - op: "slice-scatter (self, src)", |
242 | | - lhs: self.shape().clone(), |
243 | | - rhs: src.shape().clone(), |
244 | | - } |
245 | | - .bt())? |
246 | | - } |
247 | | - let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; |
248 | | - self.storage() |
249 | | - .copy_strided_src(&mut storage, 0, self.layout())?; |
250 | | - let offset = start * src.dims()[1..].iter().product::<usize>(); |
251 | | - src.storage() |
252 | | - .copy_strided_src(&mut storage, offset, src.layout())?; |
253 | | - let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); |
254 | | - Ok(from_storage(storage, self.shape(), op, false)) |
255 | | - } |
256 | | - |
257 | | - /// Accumulate element from `source` at indexes `indexes` and add them to `self`. |
258 | | - pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { |
259 | | - let dim = dim.to_index(self.shape(), "index-add")?; |
260 | | - let source_dims = source.dims(); |
261 | | - let self_dims = self.dims(); |
262 | | - let mismatch = if source_dims.len() != self_dims.len() { |
263 | | - true |
264 | | - } else { |
265 | | - let mut mismatch = false; |
266 | | - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { |
267 | | - if i != dim && d1 != d2 { |
268 | | - mismatch = true; |
269 | | - break; |
270 | | - } |
271 | | - } |
272 | | - mismatch |
273 | | - }; |
274 | | - if mismatch { |
275 | | - Err(Error::ShapeMismatchBinaryOp { |
276 | | - op: "index-add (self, source)", |
277 | | - lhs: self.shape().clone(), |
278 | | - rhs: source.shape().clone(), |
279 | | - } |
280 | | - .bt())? |
281 | | - } |
282 | | - // The number of element in indexes must match the dimension on which the add is |
283 | | - // performed on the source tensor (and the index values from `indexes` are taken from |
284 | | - // the target tensor self) |
285 | | - let indexes_len = indexes.dims1()?; |
286 | | - if source_dims[dim] != indexes_len { |
287 | | - Err(Error::ShapeMismatchBinaryOp { |
288 | | - op: "index-add (ids, source))", |
289 | | - lhs: indexes.shape().clone(), |
290 | | - rhs: source.shape().clone(), |
291 | | - } |
292 | | - .bt())? |
293 | | - } |
294 | | - let storage = self.storage().index_add( |
295 | | - self.layout(), |
296 | | - &indexes.storage(), |
297 | | - indexes.layout(), |
298 | | - &source.storage(), |
299 | | - source.layout(), |
300 | | - dim, |
301 | | - )?; |
302 | | - let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { |
303 | | - Op::IndexAdd(t1, t2, t3, dim) |
304 | | - }); |
305 | | - Ok(from_storage(storage, self.shape(), op, false)) |
306 | | - } |
307 | | - |
308 | | - /// Gather values across the target dimension. |
309 | | - /// |
310 | | - /// # Arguments |
311 | | - /// |
312 | | - /// * `self` - The input tensor. |
313 | | - /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` |
314 | | - /// but can have a different number of elements on the target dimension. |
315 | | - /// * `dim` - the target dimension. |
316 | | - /// |
317 | | - /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on |
318 | | - /// dimension `dim` by the values in `indexes`. |
319 | | - pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> { |
320 | | - let dim = dim.to_index(self.shape(), "gather")?; |
321 | | - let self_dims = self.dims(); |
322 | | - let indexes_dims = indexes.dims(); |
323 | | - let mismatch = if indexes_dims.len() != self_dims.len() { |
324 | | - true |
325 | | - } else { |
326 | | - let mut mismatch = false; |
327 | | - for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { |
328 | | - if i != dim && d1 != d2 { |
329 | | - mismatch = true; |
330 | | - break; |
331 | | - } |
332 | | - } |
333 | | - mismatch |
334 | | - }; |
335 | | - if mismatch { |
336 | | - Err(Error::ShapeMismatchBinaryOp { |
337 | | - op: "gather", |
338 | | - lhs: self.shape().clone(), |
339 | | - rhs: indexes.shape().clone(), |
340 | | - } |
341 | | - .bt())? |
342 | | - } |
343 | | - let storage = |
344 | | - self.storage() |
345 | | - .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; |
346 | | - let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); |
347 | | - Ok(from_storage(storage, indexes.shape(), op, false)) |
348 | | - } |
349 | | - |
350 | | - /// Select values for the input tensor at the target indexes across the specified dimension. |
351 | | - /// |
352 | | - /// The `indexes` is argument is an int tensor with a single dimension. |
353 | | - /// The output has the same number of dimension as the `self` input. The target dimension of |
354 | | - /// the output has length the length of `indexes` and the values are taken from `self` using |
355 | | - /// the index from `indexes`. Other dimensions have the same number of elements as the input |
356 | | - /// tensor. |
357 | | - pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> { |
358 | | - let dim = dim.to_index(self.shape(), "index-select")?; |
359 | | - let indexes_len = match indexes.dims() { |
360 | | - [l] => *l, |
361 | | - _ => Err(Error::ShapeMismatchBinaryOp { |
362 | | - lhs: self.shape().clone(), |
363 | | - rhs: indexes.shape().clone(), |
364 | | - op: "index-select", |
365 | | - } |
366 | | - .bt())?, |
367 | | - }; |
368 | | - let storage = self.storage().index_select( |
369 | | - &indexes.storage(), |
370 | | - self.layout(), |
371 | | - indexes.layout(), |
372 | | - dim, |
373 | | - )?; |
374 | | - let mut dims = self.dims().to_vec(); |
375 | | - dims[dim] = indexes_len; |
376 | | - let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); |
377 | | - Ok(from_storage(storage, dims, op, false)) |
378 | | - } |
379 | 135 | } |
0 commit comments