Skip to content

Commit cf23102

Browse files
committed
Switch Tensor variants to ndarray ArcArray for cheap clones
1 parent 005fe52 commit cf23102

2 files changed

Lines changed: 118 additions & 73 deletions

File tree

crates/providers/src/store.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ mod tests {
8989
#[test]
9090
fn test_store_output_types_2d() {
9191
use ndarray::arr2;
92-
let data = DataTree::new_leaf(Tensor::F64(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn()));
92+
let data = DataTree::new_leaf(Tensor::F64(
93+
arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn().into_shared(),
94+
));
9395
let store = Store::new(data);
9496
let DataTree::Leaf(tt) = store.output_types() else {
9597
panic!("expected leaf output type");

crates/providers/src/tensor.rs

Lines changed: 115 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
// copyright notice, and modified files need to carry a notice indicating
1111
// that they have been altered from the originals.
1212

13-
use ndarray::{ArrayD, IxDyn, Zip};
13+
use ndarray::{ArcArray, ArrayD, IxDyn, Zip};
1414
use num_complex::{Complex32, Complex64};
1515
use std::fmt;
1616
use thiserror::Error;
1717

18+
/// Dynamic-dimensional [`ArcArray`]; the storage type for every [`Tensor`] variant.
19+
type ArcArrayD<T> = ArcArray<T, IxDyn>;
20+
1821
/// Errors returned by [`Tensor`] operations.
1922
#[derive(Debug, Clone, PartialEq, Eq, Error)]
2023
pub enum TensorError {
@@ -234,50 +237,67 @@ impl TensorType {
234237
}
235238

236239
/// A tensor of one of the supported dtypes.
240+
///
241+
/// Each variant wraps a reference-counted dynamic ndarray ([`ArcArray`]) so that
242+
/// [`Tensor::clone`] is a cheap atomic refcount bump rather than a deep buffer
243+
/// copy. Mutating the underlying buffer in place (via ndarray methods that
244+
/// require `DataMut`) clones-on-write when the buffer is shared.
237245
#[derive(Debug, Clone)]
238246
pub enum Tensor {
239-
C64(ArrayD<Complex32>), // complex
240-
C128(ArrayD<Complex64>),
241-
F32(ArrayD<f32>), // real
242-
F64(ArrayD<f64>),
243-
I8(ArrayD<i8>), // signed integer
244-
I16(ArrayD<i16>),
245-
I32(ArrayD<i32>),
246-
I64(ArrayD<i64>),
247-
U8(ArrayD<u8>), // unsigned integer
248-
U16(ArrayD<u16>),
249-
U32(ArrayD<u32>),
250-
U64(ArrayD<u64>),
251-
Bit(ArrayD<u8>), // bool
247+
C64(ArcArrayD<Complex32>), // complex
248+
C128(ArcArrayD<Complex64>),
249+
F32(ArcArrayD<f32>), // real
250+
F64(ArcArrayD<f64>),
251+
I8(ArcArrayD<i8>), // signed integer
252+
I16(ArcArrayD<i16>),
253+
I32(ArcArrayD<i32>),
254+
I64(ArcArrayD<i64>),
255+
U8(ArcArrayD<u8>), // unsigned integer
256+
U16(ArcArrayD<u16>),
257+
U32(ArcArrayD<u32>),
258+
U64(ArcArrayD<u64>),
259+
Bit(ArcArrayD<u8>), // bool
252260
}
253261

254-
/// Cast an `ArrayD` of a real numeric type to any supported dtype.
262+
/// Cast an array of a real numeric type to any supported dtype.
255263
macro_rules! cast_real {
256264
($arr:expr, $src:ty, $target:expr) => {
257265
match $target {
258-
DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)),
259-
DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)),
260-
DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)),
261-
DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)),
262-
DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)),
263-
DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)),
264-
DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)),
265-
DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)),
266-
DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)),
267-
DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)),
268-
DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)),
269-
DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex32::new(x as f32, 0.0))),
270-
DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex64::new(x as f64, 0.0))),
266+
DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8).into_shared()),
267+
DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8).into_shared()),
268+
DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16).into_shared()),
269+
DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32).into_shared()),
270+
DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64).into_shared()),
271+
DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8).into_shared()),
272+
DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16).into_shared()),
273+
DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32).into_shared()),
274+
DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64).into_shared()),
275+
DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32).into_shared()),
276+
DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64).into_shared()),
277+
DType::C64 => Tensor::C64(
278+
$arr.mapv(|x: $src| Complex32::new(x as f32, 0.0))
279+
.into_shared(),
280+
),
281+
DType::C128 => Tensor::C128(
282+
$arr.mapv(|x: $src| Complex64::new(x as f64, 0.0))
283+
.into_shared(),
284+
),
271285
}
272286
};
273287
}
274288

275-
/// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets).
289+
/// Cast an array of a complex type to a complex dtype (panics for real targets).
276290
macro_rules! cast_complex {
277291
($arr:expr, $target:expr) => {
278292
match $target {
279-
DType::C64 => Tensor::C64($arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32))),
280-
DType::C128 => Tensor::C128($arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64))),
293+
DType::C64 => Tensor::C64(
294+
$arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32))
295+
.into_shared(),
296+
),
297+
DType::C128 => Tensor::C128(
298+
$arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64))
299+
.into_shared(),
300+
),
281301
_ => panic!("cannot cast complex tensor to a real dtype"),
282302
}
283303
};
@@ -318,10 +338,10 @@ fn broadcast_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, TensorError>
318338
/// this helper is needed for operations without a Rust operator (e.g. `pow`). Returns
319339
/// [`TensorError::ShapeMismatch`] if the operand shapes are not broadcast-compatible.
320340
fn broadcast_elementwise<T, F>(
321-
a: &ArrayD<T>,
322-
b: &ArrayD<T>,
341+
a: &ArcArrayD<T>,
342+
b: &ArcArrayD<T>,
323343
op: F,
324-
) -> Result<ArrayD<T>, TensorError>
344+
) -> Result<ArcArrayD<T>, TensorError>
325345
where
326346
T: Clone,
327347
F: Fn(&T, &T) -> T,
@@ -330,7 +350,7 @@ where
330350
let out_ix = IxDyn(&out_shape);
331351
let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed");
332352
let b_bc = b.broadcast(out_ix).expect("broadcast failed");
333-
Ok(Zip::from(a_bc).and(b_bc).map_collect(op))
353+
Ok(Zip::from(a_bc).and(b_bc).map_collect(op).into_shared())
334354
}
335355

336356
impl Tensor {
@@ -455,21 +475,27 @@ impl Tensor {
455475
}
456476
}
457477

458-
/// Implement `From<&[T]>`, `From<&[T; N]>`, and `From<ArrayD<T>>` for a given `Tensor` variant.
478+
/// Implement `From<&[T]>`, `From<&[T; N]>`, `From<ArrayD<T>>`, and
479+
/// `From<ArcArrayD<T>>` for a given `Tensor` variant.
459480
macro_rules! impl_tensor_from {
460481
($variant:ident, $t:ty) => {
461482
impl From<&[$t]> for Tensor {
462483
fn from(data: &[$t]) -> Self {
463-
Tensor::$variant(ndarray::arr1(data).into_dyn())
484+
Tensor::$variant(ndarray::arr1(data).into_dyn().into_shared())
464485
}
465486
}
466487
impl<const N: usize> From<[$t; N]> for Tensor {
467488
fn from(data: [$t; N]) -> Self {
468-
Tensor::$variant(ndarray::arr1(&data).into_dyn())
489+
Tensor::$variant(ndarray::arr1(&data).into_dyn().into_shared())
469490
}
470491
}
471492
impl From<ArrayD<$t>> for Tensor {
472493
fn from(data: ArrayD<$t>) -> Self {
494+
Tensor::$variant(data.into_shared())
495+
}
496+
}
497+
impl From<ArcArrayD<$t>> for Tensor {
498+
fn from(data: ArcArrayD<$t>) -> Self {
473499
Tensor::$variant(data)
474500
}
475501
}
@@ -508,18 +534,18 @@ macro_rules! impl_tensor_binop {
508534
pub fn $tensor_method(&self, rhs: &Tensor) -> Result<Tensor, TensorError> {
509535
broadcast_shape(self.shape(), rhs.shape())?;
510536
match (self, rhs) {
511-
(Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128(a $op b)),
512-
(Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(a $op b)),
513-
(Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a $op b)),
514-
(Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a $op b)),
515-
(Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a $op b)),
516-
(Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a $op b)),
517-
(Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a $op b)),
518-
(Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a $op b)),
519-
(Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a $op b)),
520-
(Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a $op b)),
521-
(Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a $op b)),
522-
(Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a $op b)),
537+
(Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128((a $op b).into_shared())),
538+
(Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64((a $op b).into_shared())),
539+
(Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a $op b).into_shared())),
540+
(Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a $op b).into_shared())),
541+
(Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a $op b).into_shared())),
542+
(Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a $op b).into_shared())),
543+
(Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a $op b).into_shared())),
544+
(Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a $op b).into_shared())),
545+
(Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a $op b).into_shared())),
546+
(Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a $op b).into_shared())),
547+
(Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a $op b).into_shared())),
548+
(Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a $op b).into_shared())),
523549
_ => Err(TensorError::DTypeMismatch {
524550
op: $op_name,
525551
lhs: self.dtype(),
@@ -557,16 +583,16 @@ impl Tensor {
557583
pub fn rem_tensor(&self, rhs: &Tensor) -> Result<Tensor, TensorError> {
558584
broadcast_shape(self.shape(), rhs.shape())?;
559585
match (self, rhs) {
560-
(Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a % b)),
561-
(Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a % b)),
562-
(Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a % b)),
563-
(Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a % b)),
564-
(Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a % b)),
565-
(Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a % b)),
566-
(Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a % b)),
567-
(Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a % b)),
568-
(Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a % b)),
569-
(Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a % b)),
586+
(Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a % b).into_shared())),
587+
(Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a % b).into_shared())),
588+
(Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a % b).into_shared())),
589+
(Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a % b).into_shared())),
590+
(Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a % b).into_shared())),
591+
(Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a % b).into_shared())),
592+
(Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a % b).into_shared())),
593+
(Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a % b).into_shared())),
594+
(Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a % b).into_shared())),
595+
(Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a % b).into_shared())),
570596
_ => Err(TensorError::DTypeMismatch {
571597
op: "rem",
572598
lhs: self.dtype(),
@@ -770,6 +796,22 @@ mod test {
770796
assert_eq!(t.shape(), &[4]);
771797
}
772798

799+
#[test]
800+
fn test_clone_shares_buffer() {
801+
// ArcArray storage means Tensor::clone() is a refcount bump, not a deep
802+
// copy. Verify by comparing the underlying buffer pointer between the
803+
// original and a clone.
804+
let t = Tensor::from([1.0_f64, 2.0, 3.0]);
805+
let cloned = t.clone();
806+
let Tensor::F64(orig) = &t else {
807+
panic!("expected F64 tensor")
808+
};
809+
let Tensor::F64(copy) = &cloned else {
810+
panic!("expected F64 tensor")
811+
};
812+
assert_eq!(orig.as_ptr(), copy.as_ptr());
813+
}
814+
773815
#[test]
774816
fn test_from_arrayd() {
775817
let arr = ndarray::Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0f64; 6]).unwrap();
@@ -1390,17 +1432,17 @@ mod test {
13901432
DType::C128,
13911433
];
13921434
let sources = [
1393-
Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)),
1394-
Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)),
1395-
Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16)),
1396-
Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32)),
1397-
Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64)),
1398-
Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8)),
1399-
Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16)),
1400-
Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32)),
1401-
Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64)),
1402-
Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32)),
1403-
Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64)),
1435+
Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()),
1436+
Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()),
1437+
Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16).into_shared()),
1438+
Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32).into_shared()),
1439+
Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64).into_shared()),
1440+
Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8).into_shared()),
1441+
Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16).into_shared()),
1442+
Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32).into_shared()),
1443+
Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64).into_shared()),
1444+
Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32).into_shared()),
1445+
Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64).into_shared()),
14041446
];
14051447
for src in sources {
14061448
let src_dtype = src.dtype();
@@ -1425,7 +1467,8 @@ mod test {
14251467
}
14261468

14271469
// Spot-check a numeric value (Bit(1) -> F64 -> 1.0).
1428-
let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)).cast(DType::F64);
1470+
let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared())
1471+
.cast(DType::F64);
14291472
if let Tensor::F64(arr) = bit_to_f64 {
14301473
assert_eq!(arr.as_slice().unwrap(), &[1.0_f64, 1.0]);
14311474
} else {

0 commit comments

Comments
 (0)