diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 0312449be42b..a1d36395d417 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -11,6 +11,7 @@ // that they have been altered from the originals. mod data_tree; +pub mod math_nodes; mod program_node; mod store; pub mod tensor; diff --git a/crates/providers/src/math_nodes/binary.rs b/crates/providers/src/math_nodes/binary.rs new file mode 100644 index 000000000000..b2a347e0eb24 --- /dev/null +++ b/crates/providers/src/math_nodes/binary.rs @@ -0,0 +1,247 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DTypeLike, Tensor, TensorType, promotion}; +use std::sync::LazyLock; + +/// Shared input type spec for all elementwise binary nodes: two broadcastable tensors `x` and `y`. +static INPUT_TYPES: LazyLock> = LazyLock::new(|| { + let mut types = DataTree::with_capacity(2); + types.insert_leaf( + "x", + TensorType { + dtype: DTypeLike::Var("x".into()), + shape: vec![], + broadcastable: true, + }, + ); + types.insert_leaf( + "y", + TensorType { + dtype: DTypeLike::Var("y".into()), + shape: vec![], + broadcastable: true, + }, + ); + types +}); + +/// Shared output type spec for all elementwise binary nodes: a single tensor of the promoted dtype. +static OUTPUT_TYPES: LazyLock> = LazyLock::new(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Promotion( + vec![DTypeLike::Var("x".into()), DTypeLike::Var("y".into())].into(), + ), + shape: vec![], + broadcastable: true, + }) +}); + +/// Generate a [`ProgramNode`] struct for an elementwise binary operation. +macro_rules! elementwise_binary_node { + ($name:ident, $node_name:literal, $call_fn:expr) => { + #[doc = concat!("Elementwise `", $node_name, "` of two broadcastable tensors.")] + pub struct $name; + + impl ProgramNode for $name { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + $node_name + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &INPUT_TYPES + } + fn output_types(&self) -> &DataTree { + &OUTPUT_TYPES + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x, y] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let out_dtype = promotion(x.dtype(), y.dtype()); + Ok(vec![$call_fn( + &x.cast_ref(out_dtype), + &y.cast_ref(out_dtype), + )?]) + } + } + }; +} + +elementwise_binary_node!(Add, "add", Tensor::add_tensor); +elementwise_binary_node!(Subtract, "subtract", Tensor::sub_tensor); +elementwise_binary_node!(Multiply, "multiply", Tensor::mul_tensor); +elementwise_binary_node!(Divide, "divide", Tensor::div_tensor); +elementwise_binary_node!(Remainder, "remainder", Tensor::rem_tensor); +elementwise_binary_node!(Power, "power", Tensor::pow); + +#[cfg(test)] +mod tests { + use super::*; + use crate::math_nodes::MathNodeError; + use crate::program_node::{CallError, CallInputError, ProgramNodeExt}; + use crate::tensor::{DType, Tensor}; + + #[test] + fn test_add_same_dtype() { + let result = Add + .call_flat(&[ + Tensor::from([1.0_f64, 2.0, 3.0]), + Tensor::from([4.0_f64, 5.0, 6.0]), + ]) + .unwrap(); + assert_eq!(result.len(), 1); + let Tensor::F64(arr) = &result[0] else { + panic!("expected f64") + }; + assert_eq!(arr.as_slice().unwrap(), &[5.0, 7.0, 9.0]); + } + + #[test] + fn test_add_promotes_dtype() { + let result = Add + .call_flat(&[Tensor::from([1.0_f32, 2.0]), Tensor::from([3.0_f64, 4.0])]) + .unwrap(); + assert_eq!(result[0].dtype(), DType::F64); + let Tensor::F64(arr) = &result[0] else { + panic!("expected f64") + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 6.0]); + } + + #[test] + fn test_add_broadcasts_2d_with_1d() { + use ndarray::arr2; + let x = Tensor::F64( + arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]) + .into_dyn() + .into_shared(), + ); + let y = Tensor::from([10.0_f64, 20.0, 30.0]); + let result = Add.call_flat(&[x, y]).unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!("expected f64") + }; + let expected = arr2(&[[11.0_f64, 22.0, 33.0], [14.0, 25.0, 36.0]]) + .into_dyn() + .into_shared(); + assert_eq!(arr, &expected); + } + + #[test] + fn test_subtract() { + let result = Subtract + .call_flat(&[ + Tensor::from([5.0_f64, 6.0, 7.0]), + Tensor::from([1.0_f64, 2.0, 3.0]), + ]) + .unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 4.0, 4.0]); + } + + #[test] + fn test_multiply() { + let result = Multiply + .call_flat(&[ + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([10.0_f64, 10.0, 10.0]), + ]) + .unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[20.0, 30.0, 40.0]); + } + + #[test] + fn test_divide() { + let result = Divide + .call_flat(&[ + Tensor::from([10.0_f64, 9.0, 8.0]), + Tensor::from([2.0_f64, 3.0, 4.0]), + ]) + .unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[5.0, 3.0, 2.0]); + } + + #[test] + fn test_remainder() { + let result = Remainder + .call_flat(&[ + Tensor::from([7.0_f64, 8.0, 9.0]), + Tensor::from([3.0_f64, 3.0, 3.0]), + ]) + .unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 0.0]); + } + + #[test] + fn test_power() { + let result = Power + .call_flat(&[ + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([3.0_f64, 2.0, 1.0]), + ]) + .unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + for (a, b) in arr.as_slice().unwrap().iter().zip(&[8.0_f64, 9.0, 4.0]) { + assert!(approx::abs_diff_eq!(a, b, epsilon = 1e-12)); + } + } + + #[test] + fn test_call_missing_input_errors() { + let mut tree = DataTree::new(); + tree.insert_leaf("x", Tensor::from([1.0_f64])); + let err = Add.call(&tree).unwrap_err(); + assert!(matches!( + err, + CallError::::Input(CallInputError::MissingInput { + ref key, + }) if key == "y" + )); + } + + #[test] + fn test_call_branch_where_leaf_expected_errors() { + let mut tree = DataTree::new(); + tree.insert_leaf("x", Tensor::from([1.0_f64])); + tree.insert_branch("y", DataTree::new()); + let err = Add.call(&tree).unwrap_err(); + assert!(matches!( + err, + CallError::::Input(CallInputError::ExpectedLeaf { + ref key, + }) if key == "y" + )); + } +} diff --git a/crates/providers/src/math_nodes/bitwise.rs b/crates/providers/src/math_nodes/bitwise.rs new file mode 100644 index 000000000000..1a2e1f93c9a8 --- /dev/null +++ b/crates/providers/src/math_nodes/bitwise.rs @@ -0,0 +1,291 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::{CallInputError, ProgramNode}; +use crate::tensor::{DType, DTypeLike, Tensor, TensorType}; +use ndarray::Axis; +use std::sync::LazyLock; + +/// Shared input type spec for binary bitwise nodes +static INPUT_TYPES: LazyLock> = LazyLock::new(|| { + let mut types = DataTree::with_capacity(2); + types.insert_leaf( + "x", + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }, + ); + types.insert_leaf( + "y", + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }, + ); + types +}); + +/// A single broadcastable `Bit` leaf — used for unary inputs and all bitwise outputs. +static LEAF_TYPE: LazyLock> = LazyLock::new(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }) +}); + +/// Construct an `UnexpectedDType` error for a slice element that did not match +/// the schema's required dtype. +fn unexpected_dtype(key: &str, actual: &Tensor) -> CallInputError { + CallInputError::UnexpectedDType { + key: key.into(), + expected: DType::Bit.to_string(), + actual: actual.dtype(), + } +} + +/// Generate a [`ProgramNode`] struct for an elementwise binary bitwise operation on `Bit` tensors. +macro_rules! bitwise_binary_node { + ($name:ident, $node_name:literal, $call_fn:expr) => { + #[doc = concat!("Elementwise `", $node_name, "` of two broadcastable `Bit` tensors.")] + pub struct $name; + + impl ProgramNode for $name { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + $node_name + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &INPUT_TYPES + } + fn output_types(&self) -> &DataTree { + &LEAF_TYPE + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x, y] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let Tensor::Bit(x_arr) = x else { + return Err(unexpected_dtype("x", x).into()); + }; + let Tensor::Bit(y_arr) = y else { + return Err(unexpected_dtype("y", y).into()); + }; + // TODO: I think this call will panic on bad broadcast? we need an error + Ok(vec![Tensor::Bit($call_fn(x_arr, y_arr).into_shared())]) + } + } + }; +} + +bitwise_binary_node!(BitwiseAnd, "bitwise_and", |x, y| x & y); +bitwise_binary_node!(BitwiseOr, "bitwise_or", |x, y| x | y); +bitwise_binary_node!(BitwiseXor, "bitwise_xor", |x, y| x ^ y); + +/// Elementwise bitwise NOT of a broadcastable `Bit` tensor. +pub struct BitwiseNot; + +impl ProgramNode for BitwiseNot { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + "bitwise_not" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &LEAF_TYPE + } + fn output_types(&self) -> &DataTree { + &LEAF_TYPE + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let Tensor::Bit(arr) = x else { + return Err(unexpected_dtype("", x).into()); + }; + // TODO: I think this call will panic on bad broadcast? we need an error + Ok(vec![Tensor::Bit(arr.mapv(|b| b ^ 1).into_shared())]) + } +} + +/// XOR-reduction of a `Bit` tensor along a specified axis, removing that axis. +/// +/// The parity of a sequence of bits is 1 if an odd number of bits are 1, and 0 otherwise, +/// which is equivalent to XOR-folding the sequence. The output has one fewer dimension than +/// the input, with the reduction axis removed. +pub struct Parity { + axis: usize, +} + +impl Parity { + /// Construct a `Parity` node that reduces along `axis`. + pub fn new(axis: usize) -> Self { + Self { axis } + } +} + +impl ProgramNode for Parity { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + "parity" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &LEAF_TYPE + } + fn output_types(&self) -> &DataTree { + &LEAF_TYPE + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let Tensor::Bit(arr) = x else { + return Err(unexpected_dtype("", x).into()); + }; + Ok(vec![Tensor::Bit( + arr.fold_axis(Axis(self.axis), 0u8, |&acc, &b| acc ^ b) + .into_shared(), + )]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::math_nodes::MathNodeError; + use crate::program_node::{CallError, CallInputError, ProgramNodeExt}; + use ndarray::{arr1, arr2}; + + fn bit(data: &[u8]) -> Tensor { + Tensor::Bit(arr1(data).into_dyn().into_shared()) + } + + #[test] + fn test_bitwise_and() { + let result = BitwiseAnd + .call_flat(&[bit(&[1, 0, 1, 1]), bit(&[1, 1, 0, 1])]) + .unwrap(); + let Tensor::Bit(arr) = &result[0] else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 0, 0, 1]); + } + + #[test] + fn test_bitwise_or() { + let result = BitwiseOr + .call_flat(&[bit(&[1, 0, 1, 0]), bit(&[0, 1, 0, 1])]) + .unwrap(); + let Tensor::Bit(arr) = &result[0] else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 1, 1, 1]); + } + + #[test] + fn test_bitwise_xor() { + let result = BitwiseXor + .call_flat(&[bit(&[1, 0, 1, 1]), bit(&[1, 1, 0, 1])]) + .unwrap(); + let Tensor::Bit(arr) = &result[0] else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 1, 1, 0]); + } + + #[test] + fn test_bitwise_and_broadcasts() { + // shape [3] & shape [1] -> shape [3] + let result = BitwiseAnd.call_flat(&[bit(&[1, 0, 1]), bit(&[1])]).unwrap(); + let Tensor::Bit(arr) = &result[0] else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 0, 1]); + } + + #[test] + fn test_bitwise_not() { + let result = BitwiseNot.call_flat(&[bit(&[1, 0, 1, 0])]).unwrap(); + let Tensor::Bit(arr) = &result[0] else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 1, 0, 1]); + } + + #[test] + fn test_parity_axis0() { + // [[1,0,1],[0,1,1],[0,0,0]] axis 0 → [1, 1, 0] + let x = Tensor::Bit( + arr2(&[[1u8, 0, 1], [0, 1, 1], [0, 0, 0]]) + .into_dyn() + .into_shared(), + ); + let result = Parity::new(0).call_flat(&[x]).unwrap(); + let Tensor::Bit(arr) = &result[0] else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 1, 0]); + } + + #[test] + fn test_bitwise_and_wrong_dtype_errors() { + let err = BitwiseAnd + .call_flat(&[Tensor::from([1.0_f64]), bit(&[1])]) + .unwrap_err(); + assert_eq!( + err, + MathNodeError::Input(CallInputError::UnexpectedDType { + key: "x".to_string(), + expected: "Bit".to_string(), + actual: DType::F64, + }) + ); + } + + #[test] + fn test_call_branch_where_leaf_expected_errors() { + let mut tree = DataTree::new(); + tree.insert_leaf("x", bit(&[1, 0])); + let err = BitwiseNot.call(&tree).unwrap_err(); + assert!(matches!( + err, + CallError::::Input(CallInputError::ExpectedLeaf { + ref key, + }) if key.is_empty() + )); + } +} diff --git a/crates/providers/src/math_nodes/mod.rs b/crates/providers/src/math_nodes/mod.rs new file mode 100644 index 000000000000..e721a5703e7d --- /dev/null +++ b/crates/providers/src/math_nodes/mod.rs @@ -0,0 +1,30 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +pub mod binary; +pub mod bitwise; +pub mod reduction; + +use crate::program_node::CallInputError; +use crate::tensor::TensorError; +use thiserror::Error; + +/// Errors returned by [`crate::program_node::ProgramNode`] implementations in this module. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum MathNodeError { + /// The input tree did not match the contract declared by `input_types`. + #[error(transparent)] + Input(#[from] CallInputError), + /// A tensor operation failed (dtype or shape mismatch). + #[error(transparent)] + Tensor(#[from] TensorError), +} diff --git a/crates/providers/src/math_nodes/reduction.rs b/crates/providers/src/math_nodes/reduction.rs new file mode 100644 index 000000000000..67cad596f3dd --- /dev/null +++ b/crates/providers/src/math_nodes/reduction.rs @@ -0,0 +1,406 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DType, DTypeLike, Tensor, TensorType}; +use ndarray::Axis; +use num_complex::Complex; +use std::sync::LazyLock; + +/// Shared input type spec for reduction nodes: a single broadcastable tensor of any dtype. +static INPUT_TYPES: LazyLock> = LazyLock::new(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Var("x".into()), + shape: vec![], + broadcastable: true, + }) +}); + +/// Shared output type spec for reduction nodes: a single broadcastable tensor of any dtype. +static OUTPUT_TYPES: LazyLock> = LazyLock::new(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Var("out".into()), + shape: vec![], + broadcastable: true, + }) +}); + +/// Mean of a tensor along a specified axis, removing that axis. +/// +/// Integer inputs are cast to `F64` before computing the mean. `F32` inputs +/// produce `F32` output; all other float and integer types produce `F64`. +/// Complex inputs (`C64`, `C128`) preserve their complex dtype. +pub struct Mean { + axis: usize, +} + +impl Mean { + /// Construct a `Mean` node that reduces along `axis`. + pub fn new(axis: usize) -> Self { + Self { axis } + } +} + +impl ProgramNode for Mean { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + "mean" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &INPUT_TYPES + } + fn output_types(&self) -> &DataTree { + &OUTPUT_TYPES + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.mean_axis(Axis(self.axis)).unwrap().into_shared()), + Tensor::F64(a) => Tensor::F64(a.mean_axis(Axis(self.axis)).unwrap().into_shared()), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + Tensor::C64((a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)).into_shared()) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + Tensor::C128((a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)).into_shared()) + } + other => { + let Tensor::F64(a) = other.cast_ref(DType::F64).into_owned() else { + unreachable!() + }; + Tensor::F64(a.mean_axis(Axis(self.axis)).unwrap().into_shared()) + } + }; + Ok(vec![result]) + } +} + +/// Variance of a tensor along a specified axis, removing that axis. +/// +/// The `ddof` (delta degrees of freedom) parameter adjusts the divisor: the result +/// is divided by `n - ddof` where `n` is the number of elements along the axis. +/// Use `ddof=0` for population variance and `ddof=1` for sample variance. +/// +/// Integer inputs are cast to `F64`. `F32` produces `F32`; all other real types +/// produce `F64`. Complex inputs (`C64`, `C128`) produce real output (`F32`, `F64` +/// respectively), computed as the mean squared modulus of the deviations. +pub struct Variance { + axis: usize, + ddof: f64, +} + +impl Variance { + /// Construct a `Variance` node that reduces along `axis` with degrees-of-freedom + /// correction `ddof`. + pub fn new(axis: usize, ddof: f64) -> Self { + Self { axis, ddof } + } +} + +impl ProgramNode for Variance { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + "variance" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &INPUT_TYPES + } + fn output_types(&self) -> &DataTree { + &OUTPUT_TYPES + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let result = match x { + Tensor::F32(a) => { + Tensor::F32(a.var_axis(Axis(self.axis), self.ddof as f32).into_shared()) + } + Tensor::F64(a) => Tensor::F64(a.var_axis(Axis(self.axis), self.ddof).into_shared()), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F32( + (sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof as f32)).into_shared(), + ) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F64((sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof)).into_shared()) + } + other => { + let Tensor::F64(a) = other.cast_ref(DType::F64).into_owned() else { + unreachable!() + }; + Tensor::F64(a.var_axis(Axis(self.axis), self.ddof).into_shared()) + } + }; + Ok(vec![result]) + } +} + +/// Standard deviation of a tensor along a specified axis, removing that axis. +/// +/// This is the square root of [`Variance`]. See that type for details on `ddof`, +/// output dtypes, and complex handling. +pub struct Std { + axis: usize, + ddof: f64, +} + +impl Std { + /// Construct a `Std` node that reduces along `axis` with degrees-of-freedom + /// correction `ddof`. + pub fn new(axis: usize, ddof: f64) -> Self { + Self { axis, ddof } + } +} + +impl ProgramNode for Std { + type CallError = super::MathNodeError; + + fn name(&self) -> &'static str { + "std" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + &INPUT_TYPES + } + fn output_types(&self) -> &DataTree { + &OUTPUT_TYPES + } + fn implements_call(&self) -> bool { + true + } + fn call_flat(&self, args: &[Tensor]) -> Result, Self::CallError> { + let [x] = args else { + unreachable!("input arity is fixed by input_types"); + }; + let result = match x { + Tensor::F32(a) => { + Tensor::F32(a.std_axis(Axis(self.axis), self.ddof as f32).into_shared()) + } + Tensor::F64(a) => Tensor::F64(a.std_axis(Axis(self.axis), self.ddof).into_shared()), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F32( + (sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof as f32)) + .mapv(f32::sqrt) + .into_shared(), + ) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F64( + (sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof)) + .mapv(f64::sqrt) + .into_shared(), + ) + } + other => { + let Tensor::F64(a) = other.cast_ref(DType::F64).into_owned() else { + unreachable!() + }; + Tensor::F64(a.std_axis(Axis(self.axis), self.ddof).into_shared()) + } + }; + Ok(vec![result]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::math_nodes::MathNodeError; + use crate::program_node::{CallError, CallInputError, ProgramNodeExt}; + use crate::tensor::{DType, Tensor}; + use ndarray::arr2; + + fn approx_eq_slice(a: &[f64], b: &[f64]) { + assert_eq!(a.len(), b.len(), "slice lengths differ"); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < 1e-10, "{x} != {y}"); + } + } + + // --- Mean tests --- + + #[test] + fn test_mean_f64_axis0() { + // [[1,2,3],[4,5,6]] along axis 0 → [2.5, 3.5, 4.5] + let x = Tensor::F64( + arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]) + .into_dyn() + .into_shared(), + ); + let result = Mean::new(0).call_flat(&[x]).unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.5, 3.5, 4.5]); + } + + #[test] + fn test_mean_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4]); + let result = Mean::new(0).call_flat(&[x]).unwrap(); + assert_eq!( + result[0].dtype(), + DType::F64, + "integer input should produce F64 mean" + ); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.5]); + } + + #[test] + fn test_mean_c128() { + use num_complex::Complex; + let data: Vec> = vec![ + Complex::new(1.0, 2.0), + Complex::new(3.0, 4.0), + Complex::new(5.0, 6.0), + ]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn().into_shared()); + let result = Mean::new(0).call_flat(&[x]).unwrap(); + let Tensor::C128(arr) = &result[0] else { + panic!("expected C128 leaf"); + }; + let v = arr.as_slice().unwrap()[0]; + assert!((v.re - 3.0).abs() < 1e-10); + assert!((v.im - 4.0).abs() < 1e-10); + } + + // --- Variance tests --- + + #[test] + fn test_variance_f64_ddof0() { + // [2, 4, 4, 4, 5, 5, 7, 9] — classic example, population variance = 4.0 + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 0.0).call_flat(&[x]).unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[4.0]); + } + + #[test] + fn test_variance_f64_ddof1() { + // Sample variance (ddof=1) of the same sequence + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 1.0).call_flat(&[x]).unwrap(); + let Tensor::F64(arr) = &result[0] else { + panic!("expected F64 leaf"); + }; + // sample variance = population variance * n / (n-1) = 4.0 * 8/7 + approx_eq_slice(arr.as_slice().unwrap(), &[4.0 * 8.0 / 7.0]); + } + + #[test] + fn test_variance_c128_returns_real() { + use num_complex::Complex; + // [1+1i, 3+3i] — mean = 2+2i, deviations = [−1−i, 1+i], |.|^2 = [2, 2], var = 2.0 + let data: Vec> = vec![Complex::new(1.0, 1.0), Complex::new(3.0, 3.0)]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn().into_shared()); + let result = Variance::new(0, 0.0).call_flat(&[x]).unwrap(); + assert_eq!( + result[0].dtype(), + DType::F64, + "C128 variance should return F64" + ); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + // --- Std tests --- + + #[test] + fn test_std_matches_sqrt_of_variance() { + // Verify std = sqrt(variance) numerically + let x = Tensor::from([1.0_f64, 3.0, 5.0, 7.0, 9.0]); + let var_result = Variance::new(0, 0.0).call_flat(&[x.clone()]).unwrap(); + let std_result = Std::new(0, 0.0).call_flat(&[x]).unwrap(); + + let Tensor::F64(var_arr) = &var_result[0] else { + panic!() + }; + let Tensor::F64(std_arr) = &std_result[0] else { + panic!() + }; + + let var_val = var_arr.as_slice().unwrap()[0]; + let std_val = std_arr.as_slice().unwrap()[0]; + assert!((std_val - var_val.sqrt()).abs() < 1e-10); + } + + #[test] + fn test_std_c128_returns_real() { + use num_complex::Complex; + let data: Vec> = vec![Complex::new(1.0, 1.0), Complex::new(3.0, 3.0)]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn().into_shared()); + let result = Std::new(0, 0.0).call_flat(&[x]).unwrap(); + assert_eq!(result[0].dtype(), DType::F64, "C128 std should return F64"); + let Tensor::F64(arr) = &result[0] else { + panic!() + }; + // std = sqrt(2.0) + approx_eq_slice(arr.as_slice().unwrap(), &[2.0_f64.sqrt()]); + } + + #[test] + fn test_call_branch_where_leaf_expected_errors() { + let mut tree = DataTree::new(); + tree.insert_leaf("x", Tensor::from([1.0_f64, 2.0])); + let err = Mean::new(0).call(&tree).unwrap_err(); + assert!(matches!( + err, + CallError::::Input(CallInputError::ExpectedLeaf { + ref key, + }) if key.is_empty() + )); + } +} diff --git a/crates/providers/src/store.rs b/crates/providers/src/store.rs index b115efc12f29..57ea160a4a6a 100644 --- a/crates/providers/src/store.rs +++ b/crates/providers/src/store.rs @@ -89,7 +89,9 @@ mod tests { #[test] fn test_store_output_types_2d() { use ndarray::arr2; - let data = DataTree::new_leaf(Tensor::F64(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn())); + let data = DataTree::new_leaf(Tensor::F64( + arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn().into_shared(), + )); let store = Store::new(data); let DataTree::Leaf(tt) = store.output_types() else { panic!("expected leaf output type"); diff --git a/crates/providers/src/tensor.rs b/crates/providers/src/tensor.rs index e631bac1607b..5480dfd34cf4 100644 --- a/crates/providers/src/tensor.rs +++ b/crates/providers/src/tensor.rs @@ -10,11 +10,15 @@ // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. -use ndarray::{ArrayD, IxDyn, Zip}; +use ndarray::{ArcArray, ArrayD, IxDyn, Zip}; use num_complex::{Complex32, Complex64}; +use std::borrow::Cow; use std::fmt; use thiserror::Error; +/// Dynamic-dimensional [`ArcArray`]; the storage type for every [`Tensor`] variant. +type ArcArrayD = ArcArray; + /// Errors returned by [`Tensor`] operations. #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum TensorError { @@ -234,50 +238,68 @@ impl TensorType { } /// A tensor of one of the supported dtypes. +/// +/// Each variant wraps a reference-counted dynamic ndarray ([`ArcArray`]). +/// +/// This allows [`Tensor::clone`] to cause a refcount bump rather than a copy of +/// underlying data. Note that mutating the underlying buffer in place (via ndarray +/// methods that require `DataMut`) clones-on-write when the buffer is shared. #[derive(Debug, Clone)] pub enum Tensor { - C64(ArrayD), // complex - C128(ArrayD), - F32(ArrayD), // real - F64(ArrayD), - I8(ArrayD), // signed integer - I16(ArrayD), - I32(ArrayD), - I64(ArrayD), - U8(ArrayD), // unsigned integer - U16(ArrayD), - U32(ArrayD), - U64(ArrayD), - Bit(ArrayD), // bool + C64(ArcArrayD), // complex + C128(ArcArrayD), + F32(ArcArrayD), // real + F64(ArcArrayD), + I8(ArcArrayD), // signed integer + I16(ArcArrayD), + I32(ArcArrayD), + I64(ArcArrayD), + U8(ArcArrayD), // unsigned integer + U16(ArcArrayD), + U32(ArcArrayD), + U64(ArcArrayD), + Bit(ArcArrayD), // bool } -/// Cast an `ArrayD` of a real numeric type to any supported dtype. +/// Cast an array of a real numeric type to any supported dtype. macro_rules! cast_real { ($arr:expr, $src:ty, $target:expr) => { match $target { - DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)), - DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)), - DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)), - DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)), - DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)), - DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)), - DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)), - DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)), - DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)), - DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)), - DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)), - DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex32::new(x as f32, 0.0))), - DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex64::new(x as f64, 0.0))), + DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8).into_shared()), + DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8).into_shared()), + DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16).into_shared()), + DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32).into_shared()), + DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64).into_shared()), + DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8).into_shared()), + DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16).into_shared()), + DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32).into_shared()), + DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64).into_shared()), + DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32).into_shared()), + DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64).into_shared()), + DType::C64 => Tensor::C64( + $arr.mapv(|x: $src| Complex32::new(x as f32, 0.0)) + .into_shared(), + ), + DType::C128 => Tensor::C128( + $arr.mapv(|x: $src| Complex64::new(x as f64, 0.0)) + .into_shared(), + ), } }; } -/// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets). +/// Cast an array of a complex type to a complex dtype (panics for real targets). macro_rules! cast_complex { ($arr:expr, $target:expr) => { match $target { - DType::C64 => Tensor::C64($arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32))), - DType::C128 => Tensor::C128($arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64))), + DType::C64 => Tensor::C64( + $arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32)) + .into_shared(), + ), + DType::C128 => Tensor::C128( + $arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64)) + .into_shared(), + ), _ => panic!("cannot cast complex tensor to a real dtype"), } }; @@ -318,10 +340,10 @@ fn broadcast_shape(a: &[usize], b: &[usize]) -> Result, TensorError> /// this helper is needed for operations without a Rust operator (e.g. `pow`). Returns /// [`TensorError::ShapeMismatch`] if the operand shapes are not broadcast-compatible. fn broadcast_elementwise( - a: &ArrayD, - b: &ArrayD, + a: &ArcArrayD, + b: &ArcArrayD, op: F, -) -> Result, TensorError> +) -> Result, TensorError> where T: Clone, F: Fn(&T, &T) -> T, @@ -330,7 +352,7 @@ where let out_ix = IxDyn(&out_shape); let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed"); let b_bc = b.broadcast(out_ix).expect("broadcast failed"); - Ok(Zip::from(a_bc).and(b_bc).map_collect(op)) + Ok(Zip::from(a_bc).and(b_bc).map_collect(op).into_shared()) } impl Tensor { @@ -453,23 +475,42 @@ impl Tensor { Tensor::C128(a) => cast_complex!(a, target), } } + + /// Cast this tensor to the `target` dtype, borrowing if it is already that dtype. + /// + /// Returns `Cow::Borrowed(self)` when no conversion is needed, otherwise + /// `Cow::Owned` of the cast result. Useful when promoting a `&Tensor` into + /// a common dtype without paying for a clone in the common no-op case. + pub fn cast_ref(&self, target: DType) -> Cow<'_, Tensor> { + if self.dtype() == target { + Cow::Borrowed(self) + } else { + Cow::Owned(self.clone().cast(target)) + } + } } -/// Implement `From<&[T]>`, `From<&[T; N]>`, and `From>` for a given `Tensor` variant. +/// Implement `From<&[T]>`, `From<&[T; N]>`, `From>`, and +/// `From>` for a given `Tensor` variant. macro_rules! impl_tensor_from { ($variant:ident, $t:ty) => { impl From<&[$t]> for Tensor { fn from(data: &[$t]) -> Self { - Tensor::$variant(ndarray::arr1(data).into_dyn()) + Tensor::$variant(ndarray::arr1(data).into_dyn().into_shared()) } } impl From<[$t; N]> for Tensor { fn from(data: [$t; N]) -> Self { - Tensor::$variant(ndarray::arr1(&data).into_dyn()) + Tensor::$variant(ndarray::arr1(&data).into_dyn().into_shared()) } } impl From> for Tensor { fn from(data: ArrayD<$t>) -> Self { + Tensor::$variant(data.into_shared()) + } + } + impl From> for Tensor { + fn from(data: ArcArrayD<$t>) -> Self { Tensor::$variant(data) } } @@ -508,18 +549,18 @@ macro_rules! impl_tensor_binop { pub fn $tensor_method(&self, rhs: &Tensor) -> Result { broadcast_shape(self.shape(), rhs.shape())?; match (self, rhs) { - (Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128(a $op b)), - (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(a $op b)), - (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a $op b)), - (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a $op b)), - (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a $op b)), - (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a $op b)), - (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a $op b)), - (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a $op b)), - (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a $op b)), - (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a $op b)), - (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a $op b)), - (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a $op b)), + (Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128((a $op b).into_shared())), + (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64((a $op b).into_shared())), + (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a $op b).into_shared())), + (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a $op b).into_shared())), + (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a $op b).into_shared())), + (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a $op b).into_shared())), + (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a $op b).into_shared())), + (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a $op b).into_shared())), + (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a $op b).into_shared())), + (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a $op b).into_shared())), + (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a $op b).into_shared())), + (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a $op b).into_shared())), _ => Err(TensorError::DTypeMismatch { op: $op_name, lhs: self.dtype(), @@ -557,16 +598,16 @@ impl Tensor { pub fn rem_tensor(&self, rhs: &Tensor) -> Result { broadcast_shape(self.shape(), rhs.shape())?; match (self, rhs) { - (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a % b)), - (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a % b)), - (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a % b)), - (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a % b)), - (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a % b)), - (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a % b)), - (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a % b)), - (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a % b)), - (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a % b)), - (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a % b)), + (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a % b).into_shared())), + (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a % b).into_shared())), + (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a % b).into_shared())), + (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a % b).into_shared())), + (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a % b).into_shared())), + (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a % b).into_shared())), + (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a % b).into_shared())), + (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a % b).into_shared())), + (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a % b).into_shared())), + (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a % b).into_shared())), _ => Err(TensorError::DTypeMismatch { op: "rem", lhs: self.dtype(), @@ -770,6 +811,22 @@ mod test { assert_eq!(t.shape(), &[4]); } + #[test] + fn test_clone_shares_buffer() { + // ArcArray storage means Tensor::clone() is a refcount bump, not a deep + // copy. Verify by comparing the underlying buffer pointer between the + // original and a clone. + let t = Tensor::from([1.0_f64, 2.0, 3.0]); + let cloned = t.clone(); + let Tensor::F64(orig) = &t else { + panic!("expected F64 tensor") + }; + let Tensor::F64(copy) = &cloned else { + panic!("expected F64 tensor") + }; + assert_eq!(orig.as_ptr(), copy.as_ptr()); + } + #[test] fn test_from_arrayd() { let arr = ndarray::Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0f64; 6]).unwrap(); @@ -1390,17 +1447,17 @@ mod test { DType::C128, ]; let sources = [ - Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)), - Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)), - Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16)), - Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32)), - Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64)), - Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8)), - Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16)), - Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32)), - Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64)), - Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32)), - Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64)), + Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()), + Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()), + Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16).into_shared()), + Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32).into_shared()), + Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64).into_shared()), + Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8).into_shared()), + Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16).into_shared()), + Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32).into_shared()), + Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64).into_shared()), + Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32).into_shared()), + Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64).into_shared()), ]; for src in sources { let src_dtype = src.dtype(); @@ -1425,7 +1482,8 @@ mod test { } // Spot-check a numeric value (Bit(1) -> F64 -> 1.0). - let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)).cast(DType::F64); + let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()) + .cast(DType::F64); if let Tensor::F64(arr) = bit_to_f64 { assert_eq!(arr.as_slice().unwrap(), &[1.0_f64, 1.0]); } else { @@ -1434,4 +1492,35 @@ mod test { assert_eq!(fails, Vec::::new(), "cast failures: {fails:?}"); } + + #[test] + fn test_cast_ref_borrows_when_dtype_matches() { + let t = Tensor::from([1.0_f64, 2.0, 3.0]); + let cow = t.cast_ref(DType::F64); + assert!( + matches!(cow, Cow::Borrowed(_)), + "expected Cow::Borrowed when dtype matches" + ); + // The borrowed tensor still points at the original data. + assert_eq!(cow.dtype(), DType::F64); + let Tensor::F64(arr) = cow.as_ref() else { + panic!("expected F64"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0_f64, 2.0, 3.0]); + } + + #[test] + fn test_cast_ref_owns_when_dtype_differs() { + let t = Tensor::from([1.0_f32, 2.0, 3.0]); + let cow = t.cast_ref(DType::F64); + assert!( + matches!(cow, Cow::Owned(_)), + "expected Cow::Owned when dtype differs" + ); + assert_eq!(cow.dtype(), DType::F64); + let Tensor::F64(arr) = cow.into_owned() else { + panic!("expected F64"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0_f64, 2.0, 3.0]); + } }