Skip to content

Commit 768edb2

Browse files
committed
Add impls of ProgramNode for various math operations
1 parent 9cec0dc commit 768edb2

7 files changed

Lines changed: 1212 additions & 2 deletions

File tree

crates/providers/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
// that they have been altered from the originals.
1212

1313
mod data_tree;
14+
pub mod math_nodes;
1415
mod program_node;
1516
mod store;
1617
pub mod tensor;
1718

1819
pub use data_tree::{DataTree, PathEntry};
19-
pub use program_node::ProgramNode;
20+
pub use program_node::{ProgramNode, ProgramNodeError, require_leaf_arg, require_typed_leaf_arg};
2021
pub use store::Store;
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
// This code is part of Qiskit.
2+
//
3+
// (C) Copyright IBM 2026
4+
//
5+
// This code is licensed under the Apache License, Version 2.0. You may
6+
// obtain a copy of this license in the LICENSE.txt file in the root directory
7+
// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8+
//
9+
// Any modifications or derivative works of this code must retain this
10+
// copyright notice, and modified files need to carry a notice indicating
11+
// that they have been altered from the originals.
12+
13+
use crate::data_tree::DataTree;
14+
use crate::program_node::{ProgramNode, require_leaf_arg};
15+
use crate::tensor::{DTypeLike, Tensor, TensorType, promotion};
16+
use std::sync::LazyLock;
17+
18+
/// Shared input type spec for all elementwise binary nodes: two broadcastable tensors `x` and `y`.
19+
static INPUT_TYPES: LazyLock<DataTree<TensorType>> = LazyLock::new(|| {
20+
let mut types = DataTree::with_capacity(2);
21+
types.insert_leaf(
22+
"x",
23+
TensorType {
24+
dtype: DTypeLike::Var("x".into()),
25+
shape: vec![],
26+
broadcastable: true,
27+
},
28+
);
29+
types.insert_leaf(
30+
"y",
31+
TensorType {
32+
dtype: DTypeLike::Var("y".into()),
33+
shape: vec![],
34+
broadcastable: true,
35+
},
36+
);
37+
types
38+
});
39+
40+
/// Shared output type spec for all elementwise binary nodes: a single tensor of the promoted dtype.
41+
static OUTPUT_TYPES: LazyLock<DataTree<TensorType>> = LazyLock::new(|| {
42+
DataTree::new_leaf(TensorType {
43+
dtype: DTypeLike::Promotion(
44+
vec![DTypeLike::Var("x".into()), DTypeLike::Var("y".into())].into(),
45+
),
46+
shape: vec![],
47+
broadcastable: true,
48+
})
49+
});
50+
51+
/// Generate a [`ProgramNode`] struct for an elementwise binary operation.
52+
macro_rules! elementwise_binary_node {
53+
($name:ident, $node_name:literal, $call_fn:expr) => {
54+
#[doc = concat!("Elementwise `", $node_name, "` of two broadcastable tensors.")]
55+
pub struct $name;
56+
57+
impl ProgramNode for $name {
58+
type CallError = super::MathNodeError;
59+
60+
fn name(&self) -> &'static str {
61+
$node_name
62+
}
63+
fn namespace(&self) -> &'static str {
64+
"math"
65+
}
66+
fn input_types(&self) -> &DataTree<TensorType> {
67+
&INPUT_TYPES
68+
}
69+
fn output_types(&self) -> &DataTree<TensorType> {
70+
&OUTPUT_TYPES
71+
}
72+
fn implements_call(&self) -> bool {
73+
true
74+
}
75+
fn call(&self, args: &DataTree<Tensor>) -> Result<DataTree<Tensor>, Self::CallError> {
76+
let x = require_leaf_arg(args, "x")?;
77+
let y = require_leaf_arg(args, "y")?;
78+
let out_dtype = promotion(x.dtype(), y.dtype());
79+
Ok(DataTree::new_leaf($call_fn(
80+
&x.cast_ref(out_dtype),
81+
&y.cast_ref(out_dtype),
82+
)?))
83+
}
84+
}
85+
};
86+
}
87+
88+
elementwise_binary_node!(Add, "add", Tensor::add_tensor);
89+
elementwise_binary_node!(Subtract, "subtract", Tensor::sub_tensor);
90+
elementwise_binary_node!(Multiply, "multiply", Tensor::mul_tensor);
91+
elementwise_binary_node!(Divide, "divide", Tensor::div_tensor);
92+
elementwise_binary_node!(Remainder, "remainder", Tensor::rem_tensor);
93+
elementwise_binary_node!(Power, "power", Tensor::pow);
94+
95+
#[cfg(test)]
96+
mod tests {
97+
use super::*;
98+
use crate::math_nodes::MathNodeError;
99+
use crate::program_node::ProgramNodeError;
100+
use crate::tensor::{DType, Tensor};
101+
102+
fn args(x: Tensor, y: Tensor) -> DataTree<Tensor> {
103+
let mut tree = DataTree::new();
104+
tree.insert_leaf("x", x);
105+
tree.insert_leaf("y", y);
106+
tree
107+
}
108+
109+
#[test]
110+
fn test_add_same_dtype() {
111+
let result = Add
112+
.call(&args(
113+
Tensor::from([1.0_f64, 2.0, 3.0]),
114+
Tensor::from([4.0_f64, 5.0, 6.0]),
115+
))
116+
.unwrap();
117+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
118+
panic!("expected f64 leaf")
119+
};
120+
assert_eq!(arr.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
121+
}
122+
123+
#[test]
124+
fn test_add_promotes_dtype() {
125+
let result = Add
126+
.call(&args(
127+
Tensor::from([1.0_f32, 2.0]),
128+
Tensor::from([3.0_f64, 4.0]),
129+
))
130+
.unwrap();
131+
let DataTree::Leaf(tensor) = result else {
132+
panic!("expected leaf")
133+
};
134+
assert_eq!(tensor.dtype(), DType::F64);
135+
let Tensor::F64(arr) = tensor else {
136+
panic!("expected f64")
137+
};
138+
assert_eq!(arr.as_slice().unwrap(), &[4.0, 6.0]);
139+
}
140+
141+
#[test]
142+
fn test_add_broadcasts_1d_scalar() {
143+
// shape [3] + shape [1] -> shape [3]
144+
let result = Add
145+
.call(&args(
146+
Tensor::from([1.0_f64, 2.0, 3.0]),
147+
Tensor::from([10.0_f64]),
148+
))
149+
.unwrap();
150+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
151+
panic!("expected f64 leaf")
152+
};
153+
assert_eq!(arr.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
154+
}
155+
156+
#[test]
157+
fn test_add_broadcasts_2d_with_1d() {
158+
// shape [2, 3] + shape [3] -> shape [2, 3]
159+
use ndarray::arr2;
160+
let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn());
161+
let y = Tensor::from([10.0_f64, 20.0, 30.0]);
162+
let result = Add.call(&args(x, y)).unwrap();
163+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
164+
panic!("expected f64 leaf")
165+
};
166+
let expected = arr2(&[[11.0_f64, 22.0, 33.0], [14.0, 25.0, 36.0]]).into_dyn();
167+
assert_eq!(arr, expected);
168+
}
169+
170+
#[test]
171+
fn test_subtract() {
172+
let result = Subtract
173+
.call(&args(
174+
Tensor::from([5.0_f64, 6.0, 7.0]),
175+
Tensor::from([1.0_f64, 2.0, 3.0]),
176+
))
177+
.unwrap();
178+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
179+
panic!()
180+
};
181+
assert_eq!(arr.as_slice().unwrap(), &[4.0, 4.0, 4.0]);
182+
}
183+
184+
#[test]
185+
fn test_multiply() {
186+
let result = Multiply
187+
.call(&args(
188+
Tensor::from([2.0_f64, 3.0, 4.0]),
189+
Tensor::from([10.0_f64, 10.0, 10.0]),
190+
))
191+
.unwrap();
192+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
193+
panic!()
194+
};
195+
assert_eq!(arr.as_slice().unwrap(), &[20.0, 30.0, 40.0]);
196+
}
197+
198+
#[test]
199+
fn test_divide() {
200+
let result = Divide
201+
.call(&args(
202+
Tensor::from([10.0_f64, 9.0, 8.0]),
203+
Tensor::from([2.0_f64, 3.0, 4.0]),
204+
))
205+
.unwrap();
206+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
207+
panic!()
208+
};
209+
assert_eq!(arr.as_slice().unwrap(), &[5.0, 3.0, 2.0]);
210+
}
211+
212+
#[test]
213+
fn test_remainder() {
214+
let result = Remainder
215+
.call(&args(
216+
Tensor::from([7.0_f64, 8.0, 9.0]),
217+
Tensor::from([3.0_f64, 3.0, 3.0]),
218+
))
219+
.unwrap();
220+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
221+
panic!()
222+
};
223+
assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 0.0]);
224+
}
225+
226+
#[test]
227+
fn test_power() {
228+
let result = Power
229+
.call(&args(
230+
Tensor::from([2.0_f64, 3.0, 4.0]),
231+
Tensor::from([3.0_f64, 2.0, 1.0]),
232+
))
233+
.unwrap();
234+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
235+
panic!()
236+
};
237+
assert_eq!(arr.as_slice().unwrap(), &[8.0, 9.0, 4.0]);
238+
}
239+
240+
#[test]
241+
fn test_missing_input_returns_error() {
242+
let mut tree = DataTree::new();
243+
tree.insert_leaf("x", Tensor::from([1.0_f64]));
244+
// No "y" input.
245+
let err = Add.call(&tree).unwrap_err();
246+
assert_eq!(
247+
err,
248+
MathNodeError::Input(ProgramNodeError::MissingInput {
249+
key: "y".to_string(),
250+
})
251+
);
252+
}
253+
254+
#[test]
255+
fn test_branch_where_leaf_expected_returns_error() {
256+
let mut tree = DataTree::new();
257+
tree.insert_leaf("x", Tensor::from([1.0_f64]));
258+
// "y" is a branch, not a leaf.
259+
tree.insert_branch("y", DataTree::new());
260+
let err = Add.call(&tree).unwrap_err();
261+
assert_eq!(
262+
err,
263+
MathNodeError::Input(ProgramNodeError::ExpectedLeaf {
264+
key: "y".to_string(),
265+
})
266+
);
267+
}
268+
269+
#[test]
270+
fn test_power_broadcasts() {
271+
// shape [3] ** shape [1] -> shape [3]
272+
let result = Power
273+
.call(&args(
274+
Tensor::from([2.0_f64, 3.0, 4.0]),
275+
Tensor::from([2.0_f64]),
276+
))
277+
.unwrap();
278+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
279+
panic!()
280+
};
281+
assert_eq!(arr.as_slice().unwrap(), &[4.0, 9.0, 16.0]);
282+
}
283+
}

0 commit comments

Comments
 (0)