Skip to content

Commit e5e336f

Browse files
authored
Merge pull request #34 from astroautomata/cleaner-operator-syntax
feat!: cleaner operator syntax
2 parents 813f73a + e17c8cc commit e5e336f

26 files changed

Lines changed: 922 additions & 1381 deletions

README.md

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ fn main() {
6464

6565
let dataset = Dataset::new(x, y);
6666

67-
let operators = Operators::<D>::from_names_by_arity::<BuiltinOpsF32>(&["cos", "exp", "sin"], &["+", "-", "*", "/"], &[])
68-
.unwrap();
67+
let operators = BuiltinOpsF32::from_names(["cos", "exp", "sin", "+", "sub", "*", "/"]).unwrap();
6968

7069
let options = Options::<f32, D> {
7170
operators,
@@ -99,42 +98,46 @@ fn main() {
9998

10099
## Custom operators
101100

102-
Define light-weight operator sets with inline evaluation and derivatives using the `custom_opset!` macro (re-exported from `symbolic_regression`):
101+
Define custom operators with `op!`, then build an operator set with `opset!`:
103102

104103
```rust
105-
use symbolic_regression::custom_opset;
106104
use symbolic_regression::prelude::*;
107105

108-
custom_opset! {
109-
pub struct CustomOps<f64> {
110-
1 {
111-
square {
112-
eval(args) { args[0] * args[0] },
113-
partial(args, _idx) { 2.0 * args[0] },
114-
}
115-
exp {
116-
eval(args) { args[0].exp() },
117-
partial(args, _idx) { args[0].exp() },
118-
}
119-
}
120-
2 {
121-
add {
122-
eval(args) { args[0] + args[1] },
123-
partial(_args, _idx) { 1.0 },
124-
}
125-
sub {
126-
infix: "-", // optional
127-
complexity: 2, // optional
128-
eval(args) { args[0] - args[1] },
129-
partial(_args, idx) {
130-
if idx == 0 { 1.0 } else { -1.0 }
131-
},
132-
}
133-
}
106+
op!(Square for f64 {
107+
eval: |[x]| { x * x },
108+
partial: |[x], _idx| { 2.0 * x },
109+
});
110+
111+
op!(Exp for f64 {
112+
eval: |[x]| { x.exp() },
113+
partial: |[x], _idx| { x.exp() },
114+
});
115+
116+
op!(Add for f64 {
117+
infix: "+",
118+
commutative: true,
119+
associative: true,
120+
eval: |[x, y]| { x + y },
121+
partial: |[_x, _y], _idx| { 1.0 },
122+
});
123+
124+
op!(Sub for f64 {
125+
infix: "-", // optional
126+
complexity: 2, // optional
127+
eval: |[x, y]| { x - y },
128+
partial: |[_x, _y], idx| { if idx == 0 { 1.0 } else { -1.0 } },
129+
});
130+
131+
opset! {
132+
pub CustomOps for f64 {
133+
Square,
134+
Exp,
135+
Add,
136+
Sub,
134137
}
135138
}
136139

137-
let operators = CustomOps::from_names_by_arity(&["square", "exp"], &["add", "sub"], &[]).unwrap();
140+
let operators = CustomOps::from_names(["square", "exp", "add", "sub"]).unwrap();
138141
let options = Options::<f64, _> { operators, ..Default::default() };
139142
```
140143

dynamic_expressions/benches/eval.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main};
22
use dynamic_expressions::evaluate::EvalOptions;
33
use dynamic_expressions::expression::PostfixExpr;
44
use dynamic_expressions::node::PNode;
5+
use dynamic_expressions::operator_enum::builtin::*;
56
use dynamic_expressions::{OperatorSet, opset};
67
use ndarray::Array2;
78
use num_traits::Float;
@@ -14,19 +15,11 @@ const N_TREES: usize = 100;
1415
const N_ROWS: usize = 1_000;
1516

1617
opset! {
17-
pub struct BenchOpsF32<f32>;
18-
ops {
19-
(1, UnaryF32) { Cos, Exp, }
20-
(2, BinaryF32) { Add, Sub, Mul, Div, }
21-
}
18+
pub BenchOpsF32 for f32 { Cos, Exp, Add, Sub, Mul, Div }
2219
}
2320

2421
opset! {
25-
pub struct BenchOpsF64<f64>;
26-
ops {
27-
(1, UnaryF64) { Cos, Exp, }
28-
(2, BinaryF64) { Add, Sub, Mul, Div, }
29-
}
22+
pub BenchOpsF64 for f64 { Cos, Exp, Add, Sub, Mul, Div }
3023
}
3124

3225
fn random_leaf<T: Float, R: Rng>(rng: &mut R, n_features: usize, consts: &mut Vec<T>) -> PNode {

dynamic_expressions/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub mod operator_enum;
1212
pub mod operators;
1313
#[cfg(feature = "proptest-utils")]
1414
pub mod proptest_utils;
15+
pub mod select;
1516
pub mod simplify;
1617
pub mod strings;
1718
pub mod traits;
@@ -31,6 +32,7 @@ pub use crate::node_utils::{
3132
count_constant_nodes, count_depth, count_nodes, has_constants, has_operators, subtree_range, subtree_sizes,
3233
tree_mapreduce,
3334
};
35+
pub use crate::select::{OperatorSelectError, Operators};
3436
pub use crate::simplify::{combine_operators_in_place, simplify_in_place, simplify_tree_in_place};
3537
pub use crate::strings::{StringTreeOptions, print_tree, string_tree};
3638
pub use crate::traits::{HasOp, LookupError, OpId, OpMeta, OpTag, Operator, OperatorSet};

0 commit comments

Comments
 (0)