1010// copyright notice, and modified files need to carry a notice indicating
1111// that they have been altered from the originals.
1212
13- use ndarray:: { ArrayD , IxDyn , Zip } ;
13+ use ndarray:: { ArcArray , ArrayD , IxDyn , Zip } ;
1414use num_complex:: { Complex32 , Complex64 } ;
1515use std:: fmt;
1616use thiserror:: Error ;
1717
18+ /// Dynamic-dimensional [`ArcArray`]; the storage type for every [`Tensor`] variant.
19+ type ArcArrayD < T > = ArcArray < T , IxDyn > ;
20+
1821/// Errors returned by [`Tensor`] operations.
1922#[ derive( Debug , Clone , PartialEq , Eq , Error ) ]
2023pub enum TensorError {
@@ -234,50 +237,67 @@ impl TensorType {
234237}
235238
236239/// A tensor of one of the supported dtypes.
240+ ///
241+ /// Each variant wraps a reference-counted dynamic ndarray ([`ArcArray`]) so that
242+ /// [`Tensor::clone`] is a cheap atomic refcount bump rather than a deep buffer
243+ /// copy. Mutating the underlying buffer in place (via ndarray methods that
244+ /// require `DataMut`) clones-on-write when the buffer is shared.
237245#[ derive( Debug , Clone ) ]
238246pub enum Tensor {
239- C64 ( ArrayD < Complex32 > ) , // complex
240- C128 ( ArrayD < Complex64 > ) ,
241- F32 ( ArrayD < f32 > ) , // real
242- F64 ( ArrayD < f64 > ) ,
243- I8 ( ArrayD < i8 > ) , // signed integer
244- I16 ( ArrayD < i16 > ) ,
245- I32 ( ArrayD < i32 > ) ,
246- I64 ( ArrayD < i64 > ) ,
247- U8 ( ArrayD < u8 > ) , // unsigned integer
248- U16 ( ArrayD < u16 > ) ,
249- U32 ( ArrayD < u32 > ) ,
250- U64 ( ArrayD < u64 > ) ,
251- Bit ( ArrayD < u8 > ) , // bool
247+ C64 ( ArcArrayD < Complex32 > ) , // complex
248+ C128 ( ArcArrayD < Complex64 > ) ,
249+ F32 ( ArcArrayD < f32 > ) , // real
250+ F64 ( ArcArrayD < f64 > ) ,
251+ I8 ( ArcArrayD < i8 > ) , // signed integer
252+ I16 ( ArcArrayD < i16 > ) ,
253+ I32 ( ArcArrayD < i32 > ) ,
254+ I64 ( ArcArrayD < i64 > ) ,
255+ U8 ( ArcArrayD < u8 > ) , // unsigned integer
256+ U16 ( ArcArrayD < u16 > ) ,
257+ U32 ( ArcArrayD < u32 > ) ,
258+ U64 ( ArcArrayD < u64 > ) ,
259+ Bit ( ArcArrayD < u8 > ) , // bool
252260}
253261
254- /// Cast an `ArrayD` of a real numeric type to any supported dtype.
262+ /// Cast an array of a real numeric type to any supported dtype.
255263macro_rules! cast_real {
256264 ( $arr: expr, $src: ty, $target: expr) => {
257265 match $target {
258- DType :: Bit => Tensor :: Bit ( $arr. mapv( |x: $src| x as u8 ) ) ,
259- DType :: U8 => Tensor :: U8 ( $arr. mapv( |x: $src| x as u8 ) ) ,
260- DType :: U16 => Tensor :: U16 ( $arr. mapv( |x: $src| x as u16 ) ) ,
261- DType :: U32 => Tensor :: U32 ( $arr. mapv( |x: $src| x as u32 ) ) ,
262- DType :: U64 => Tensor :: U64 ( $arr. mapv( |x: $src| x as u64 ) ) ,
263- DType :: I8 => Tensor :: I8 ( $arr. mapv( |x: $src| x as i8 ) ) ,
264- DType :: I16 => Tensor :: I16 ( $arr. mapv( |x: $src| x as i16 ) ) ,
265- DType :: I32 => Tensor :: I32 ( $arr. mapv( |x: $src| x as i32 ) ) ,
266- DType :: I64 => Tensor :: I64 ( $arr. mapv( |x: $src| x as i64 ) ) ,
267- DType :: F32 => Tensor :: F32 ( $arr. mapv( |x: $src| x as f32 ) ) ,
268- DType :: F64 => Tensor :: F64 ( $arr. mapv( |x: $src| x as f64 ) ) ,
269- DType :: C64 => Tensor :: C64 ( $arr. mapv( |x: $src| Complex32 :: new( x as f32 , 0.0 ) ) ) ,
270- DType :: C128 => Tensor :: C128 ( $arr. mapv( |x: $src| Complex64 :: new( x as f64 , 0.0 ) ) ) ,
266+ DType :: Bit => Tensor :: Bit ( $arr. mapv( |x: $src| x as u8 ) . into_shared( ) ) ,
267+ DType :: U8 => Tensor :: U8 ( $arr. mapv( |x: $src| x as u8 ) . into_shared( ) ) ,
268+ DType :: U16 => Tensor :: U16 ( $arr. mapv( |x: $src| x as u16 ) . into_shared( ) ) ,
269+ DType :: U32 => Tensor :: U32 ( $arr. mapv( |x: $src| x as u32 ) . into_shared( ) ) ,
270+ DType :: U64 => Tensor :: U64 ( $arr. mapv( |x: $src| x as u64 ) . into_shared( ) ) ,
271+ DType :: I8 => Tensor :: I8 ( $arr. mapv( |x: $src| x as i8 ) . into_shared( ) ) ,
272+ DType :: I16 => Tensor :: I16 ( $arr. mapv( |x: $src| x as i16 ) . into_shared( ) ) ,
273+ DType :: I32 => Tensor :: I32 ( $arr. mapv( |x: $src| x as i32 ) . into_shared( ) ) ,
274+ DType :: I64 => Tensor :: I64 ( $arr. mapv( |x: $src| x as i64 ) . into_shared( ) ) ,
275+ DType :: F32 => Tensor :: F32 ( $arr. mapv( |x: $src| x as f32 ) . into_shared( ) ) ,
276+ DType :: F64 => Tensor :: F64 ( $arr. mapv( |x: $src| x as f64 ) . into_shared( ) ) ,
277+ DType :: C64 => Tensor :: C64 (
278+ $arr. mapv( |x: $src| Complex32 :: new( x as f32 , 0.0 ) )
279+ . into_shared( ) ,
280+ ) ,
281+ DType :: C128 => Tensor :: C128 (
282+ $arr. mapv( |x: $src| Complex64 :: new( x as f64 , 0.0 ) )
283+ . into_shared( ) ,
284+ ) ,
271285 }
272286 } ;
273287}
274288
275- /// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets).
289+ /// Cast an array of a complex type to a complex dtype (panics for real targets).
276290macro_rules! cast_complex {
277291 ( $arr: expr, $target: expr) => {
278292 match $target {
279- DType :: C64 => Tensor :: C64 ( $arr. mapv( |x| Complex32 :: new( x. re as f32 , x. im as f32 ) ) ) ,
280- DType :: C128 => Tensor :: C128 ( $arr. mapv( |x| Complex64 :: new( x. re as f64 , x. im as f64 ) ) ) ,
293+ DType :: C64 => Tensor :: C64 (
294+ $arr. mapv( |x| Complex32 :: new( x. re as f32 , x. im as f32 ) )
295+ . into_shared( ) ,
296+ ) ,
297+ DType :: C128 => Tensor :: C128 (
298+ $arr. mapv( |x| Complex64 :: new( x. re as f64 , x. im as f64 ) )
299+ . into_shared( ) ,
300+ ) ,
281301 _ => panic!( "cannot cast complex tensor to a real dtype" ) ,
282302 }
283303 } ;
@@ -318,10 +338,10 @@ fn broadcast_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, TensorError>
318338/// this helper is needed for operations without a Rust operator (e.g. `pow`). Returns
319339/// [`TensorError::ShapeMismatch`] if the operand shapes are not broadcast-compatible.
320340fn broadcast_elementwise < T , F > (
321- a : & ArrayD < T > ,
322- b : & ArrayD < T > ,
341+ a : & ArcArrayD < T > ,
342+ b : & ArcArrayD < T > ,
323343 op : F ,
324- ) -> Result < ArrayD < T > , TensorError >
344+ ) -> Result < ArcArrayD < T > , TensorError >
325345where
326346 T : Clone ,
327347 F : Fn ( & T , & T ) -> T ,
@@ -330,7 +350,7 @@ where
330350 let out_ix = IxDyn ( & out_shape) ;
331351 let a_bc = a. broadcast ( out_ix. clone ( ) ) . expect ( "broadcast failed" ) ;
332352 let b_bc = b. broadcast ( out_ix) . expect ( "broadcast failed" ) ;
333- Ok ( Zip :: from ( a_bc) . and ( b_bc) . map_collect ( op) )
353+ Ok ( Zip :: from ( a_bc) . and ( b_bc) . map_collect ( op) . into_shared ( ) )
334354}
335355
336356impl Tensor {
@@ -455,21 +475,27 @@ impl Tensor {
455475 }
456476}
457477
458- /// Implement `From<&[T]>`, `From<&[T; N]>`, and `From<ArrayD<T>>` for a given `Tensor` variant.
478+ /// Implement `From<&[T]>`, `From<&[T; N]>`, `From<ArrayD<T>>`, and
479+ /// `From<ArcArrayD<T>>` for a given `Tensor` variant.
459480macro_rules! impl_tensor_from {
460481 ( $variant: ident, $t: ty) => {
461482 impl From <& [ $t] > for Tensor {
462483 fn from( data: & [ $t] ) -> Self {
463- Tensor :: $variant( ndarray:: arr1( data) . into_dyn( ) )
484+ Tensor :: $variant( ndarray:: arr1( data) . into_dyn( ) . into_shared ( ) )
464485 }
465486 }
466487 impl <const N : usize > From <[ $t; N ] > for Tensor {
467488 fn from( data: [ $t; N ] ) -> Self {
468- Tensor :: $variant( ndarray:: arr1( & data) . into_dyn( ) )
489+ Tensor :: $variant( ndarray:: arr1( & data) . into_dyn( ) . into_shared ( ) )
469490 }
470491 }
471492 impl From <ArrayD <$t>> for Tensor {
472493 fn from( data: ArrayD <$t>) -> Self {
494+ Tensor :: $variant( data. into_shared( ) )
495+ }
496+ }
497+ impl From <ArcArrayD <$t>> for Tensor {
498+ fn from( data: ArcArrayD <$t>) -> Self {
473499 Tensor :: $variant( data)
474500 }
475501 }
@@ -508,18 +534,18 @@ macro_rules! impl_tensor_binop {
508534 pub fn $tensor_method( & self , rhs: & Tensor ) -> Result <Tensor , TensorError > {
509535 broadcast_shape( self . shape( ) , rhs. shape( ) ) ?;
510536 match ( self , rhs) {
511- ( Tensor :: C128 ( a) , Tensor :: C128 ( b) ) => Ok ( Tensor :: C128 ( a $op b) ) ,
512- ( Tensor :: C64 ( a) , Tensor :: C64 ( b) ) => Ok ( Tensor :: C64 ( a $op b) ) ,
513- ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => Ok ( Tensor :: F64 ( a $op b) ) ,
514- ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => Ok ( Tensor :: F32 ( a $op b) ) ,
515- ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => Ok ( Tensor :: I64 ( a $op b) ) ,
516- ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => Ok ( Tensor :: I32 ( a $op b) ) ,
517- ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => Ok ( Tensor :: I16 ( a $op b) ) ,
518- ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => Ok ( Tensor :: I8 ( a $op b) ) ,
519- ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => Ok ( Tensor :: U64 ( a $op b) ) ,
520- ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => Ok ( Tensor :: U32 ( a $op b) ) ,
521- ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => Ok ( Tensor :: U16 ( a $op b) ) ,
522- ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => Ok ( Tensor :: U8 ( a $op b) ) ,
537+ ( Tensor :: C128 ( a) , Tensor :: C128 ( b) ) => Ok ( Tensor :: C128 ( ( a $op b) . into_shared ( ) ) ) ,
538+ ( Tensor :: C64 ( a) , Tensor :: C64 ( b) ) => Ok ( Tensor :: C64 ( ( a $op b) . into_shared ( ) ) ) ,
539+ ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => Ok ( Tensor :: F64 ( ( a $op b) . into_shared ( ) ) ) ,
540+ ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => Ok ( Tensor :: F32 ( ( a $op b) . into_shared ( ) ) ) ,
541+ ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => Ok ( Tensor :: I64 ( ( a $op b) . into_shared ( ) ) ) ,
542+ ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => Ok ( Tensor :: I32 ( ( a $op b) . into_shared ( ) ) ) ,
543+ ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => Ok ( Tensor :: I16 ( ( a $op b) . into_shared ( ) ) ) ,
544+ ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => Ok ( Tensor :: I8 ( ( a $op b) . into_shared ( ) ) ) ,
545+ ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => Ok ( Tensor :: U64 ( ( a $op b) . into_shared ( ) ) ) ,
546+ ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => Ok ( Tensor :: U32 ( ( a $op b) . into_shared ( ) ) ) ,
547+ ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => Ok ( Tensor :: U16 ( ( a $op b) . into_shared ( ) ) ) ,
548+ ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => Ok ( Tensor :: U8 ( ( a $op b) . into_shared ( ) ) ) ,
523549 _ => Err ( TensorError :: DTypeMismatch {
524550 op: $op_name,
525551 lhs: self . dtype( ) ,
@@ -557,16 +583,16 @@ impl Tensor {
557583 pub fn rem_tensor ( & self , rhs : & Tensor ) -> Result < Tensor , TensorError > {
558584 broadcast_shape ( self . shape ( ) , rhs. shape ( ) ) ?;
559585 match ( self , rhs) {
560- ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => Ok ( Tensor :: F64 ( a % b) ) ,
561- ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => Ok ( Tensor :: F32 ( a % b) ) ,
562- ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => Ok ( Tensor :: I64 ( a % b) ) ,
563- ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => Ok ( Tensor :: I32 ( a % b) ) ,
564- ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => Ok ( Tensor :: I16 ( a % b) ) ,
565- ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => Ok ( Tensor :: I8 ( a % b) ) ,
566- ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => Ok ( Tensor :: U64 ( a % b) ) ,
567- ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => Ok ( Tensor :: U32 ( a % b) ) ,
568- ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => Ok ( Tensor :: U16 ( a % b) ) ,
569- ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => Ok ( Tensor :: U8 ( a % b) ) ,
586+ ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => Ok ( Tensor :: F64 ( ( a % b) . into_shared ( ) ) ) ,
587+ ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => Ok ( Tensor :: F32 ( ( a % b) . into_shared ( ) ) ) ,
588+ ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => Ok ( Tensor :: I64 ( ( a % b) . into_shared ( ) ) ) ,
589+ ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => Ok ( Tensor :: I32 ( ( a % b) . into_shared ( ) ) ) ,
590+ ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => Ok ( Tensor :: I16 ( ( a % b) . into_shared ( ) ) ) ,
591+ ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => Ok ( Tensor :: I8 ( ( a % b) . into_shared ( ) ) ) ,
592+ ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => Ok ( Tensor :: U64 ( ( a % b) . into_shared ( ) ) ) ,
593+ ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => Ok ( Tensor :: U32 ( ( a % b) . into_shared ( ) ) ) ,
594+ ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => Ok ( Tensor :: U16 ( ( a % b) . into_shared ( ) ) ) ,
595+ ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => Ok ( Tensor :: U8 ( ( a % b) . into_shared ( ) ) ) ,
570596 _ => Err ( TensorError :: DTypeMismatch {
571597 op : "rem" ,
572598 lhs : self . dtype ( ) ,
@@ -770,6 +796,22 @@ mod test {
770796 assert_eq ! ( t. shape( ) , & [ 4 ] ) ;
771797 }
772798
799+ #[ test]
800+ fn test_clone_shares_buffer ( ) {
801+ // ArcArray storage means Tensor::clone() is a refcount bump, not a deep
802+ // copy. Verify by comparing the underlying buffer pointer between the
803+ // original and a clone.
804+ let t = Tensor :: from ( [ 1.0_f64 , 2.0 , 3.0 ] ) ;
805+ let cloned = t. clone ( ) ;
806+ let Tensor :: F64 ( orig) = & t else {
807+ panic ! ( "expected F64 tensor" )
808+ } ;
809+ let Tensor :: F64 ( copy) = & cloned else {
810+ panic ! ( "expected F64 tensor" )
811+ } ;
812+ assert_eq ! ( orig. as_ptr( ) , copy. as_ptr( ) ) ;
813+ }
814+
773815 #[ test]
774816 fn test_from_arrayd ( ) {
775817 let arr = ndarray:: Array :: from_shape_vec ( IxDyn ( & [ 2 , 3 ] ) , vec ! [ 1.0f64 ; 6 ] ) . unwrap ( ) ;
@@ -1390,17 +1432,17 @@ mod test {
13901432 DType :: C128 ,
13911433 ] ;
13921434 let sources = [
1393- Tensor :: Bit ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u8 ) ) ,
1394- Tensor :: U8 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u8 ) ) ,
1395- Tensor :: U16 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u16 ) ) ,
1396- Tensor :: U32 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u32 ) ) ,
1397- Tensor :: U64 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u64 ) ) ,
1398- Tensor :: I8 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i8 ) ) ,
1399- Tensor :: I16 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i16 ) ) ,
1400- Tensor :: I32 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i32 ) ) ,
1401- Tensor :: I64 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i64 ) ) ,
1402- Tensor :: F32 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1.0f32 ) ) ,
1403- Tensor :: F64 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1.0f64 ) ) ,
1435+ Tensor :: Bit ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u8 ) . into_shared ( ) ) ,
1436+ Tensor :: U8 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u8 ) . into_shared ( ) ) ,
1437+ Tensor :: U16 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u16 ) . into_shared ( ) ) ,
1438+ Tensor :: U32 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u32 ) . into_shared ( ) ) ,
1439+ Tensor :: U64 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u64 ) . into_shared ( ) ) ,
1440+ Tensor :: I8 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i8 ) . into_shared ( ) ) ,
1441+ Tensor :: I16 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i16 ) . into_shared ( ) ) ,
1442+ Tensor :: I32 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i32 ) . into_shared ( ) ) ,
1443+ Tensor :: I64 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1i64 ) . into_shared ( ) ) ,
1444+ Tensor :: F32 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1.0f32 ) . into_shared ( ) ) ,
1445+ Tensor :: F64 ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1.0f64 ) . into_shared ( ) ) ,
14041446 ] ;
14051447 for src in sources {
14061448 let src_dtype = src. dtype ( ) ;
@@ -1425,7 +1467,8 @@ mod test {
14251467 }
14261468
14271469 // Spot-check a numeric value (Bit(1) -> F64 -> 1.0).
1428- let bit_to_f64 = Tensor :: Bit ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u8 ) ) . cast ( DType :: F64 ) ;
1470+ let bit_to_f64 = Tensor :: Bit ( ndarray:: ArrayD :: from_elem ( IxDyn ( & [ 2 ] ) , 1u8 ) . into_shared ( ) )
1471+ . cast ( DType :: F64 ) ;
14291472 if let Tensor :: F64 ( arr) = bit_to_f64 {
14301473 assert_eq ! ( arr. as_slice( ) . unwrap( ) , & [ 1.0_f64 , 1.0 ] ) ;
14311474 } else {
0 commit comments