Skip to content

Commit a5480f9

Browse files
committed
Add real_if_args_real option to complex evaluator
1 parent 5c5a6d3 commit a5480f9

3 files changed

Lines changed: 35 additions & 13 deletions

File tree

src/api/python/evaluator.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -872,26 +872,28 @@ impl PythonExpressionEvaluator {
872872
/// assembly output that uses real arithmetic instead of complex arithmetic
873873
/// where possible.
874874
///
875-
/// You can also set if all encountered sqrt, log, and powf operations with real
876-
/// arguments are expected to yield real results.
875+
/// You can also set if all encountered sqrt, log, powf, and custom evaluator
876+
/// operations with real arguments are expected to yield real results.
877877
///
878878
/// Must be called after all optimization functions and merging are performed
879879
/// on the evaluator and before the first call to `evaluate`, or the registration will be lost.
880-
#[pyo3(signature = (real_params, sqrt_real = false, log_real = false, powf_real = false, verbose = false))]
880+
#[pyo3(signature = (real_params, sqrt_real = false, log_real = false, powf_real = false, real_if_args_real = false, verbose = false))]
881881
fn set_real_params(
882882
&mut self,
883883
real_params: Vec<usize>,
884884
sqrt_real: bool,
885885
log_real: bool,
886886
powf_real: bool,
887+
real_if_args_real: bool,
887888
verbose: bool,
888889
) -> PyResult<()> {
889890
self.jit_complex = None; // force a recompilation
891+
let mut settings = ComplexEvaluatorSettings::new(sqrt_real, log_real, powf_real, verbose);
892+
if real_if_args_real {
893+
settings = settings.real_if_args_real();
894+
}
890895
self.eval_complex
891-
.set_real_params(
892-
&real_params,
893-
ComplexEvaluatorSettings::new(sqrt_real, log_real, powf_real, verbose),
894-
)
896+
.set_real_params(&real_params, settings)
895897
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))
896898
}
897899

src/evaluate/optimize.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ pub struct ComplexEvaluatorSettings {
99
pub(crate) log_real: bool,
1010
/// Whether powf with real arguments yields real results.
1111
pub(crate) powf_real: bool,
12+
/// Whether custom evaluator functions with real arguments yield real results.
13+
pub(crate) real_if_args_real: bool,
1214
/// Report on the number of converted operations.
1315
pub(crate) verbose: bool,
1416
}
@@ -20,6 +22,7 @@ impl ComplexEvaluatorSettings {
2022
sqrt_real,
2123
log_real,
2224
powf_real,
25+
real_if_args_real: false,
2326
verbose,
2427
}
2528
}
@@ -42,6 +45,12 @@ impl ComplexEvaluatorSettings {
4245
self
4346
}
4447

48+
/// Set that all custom evaluator functions with real arguments yield real results.
49+
pub fn real_if_args_real(mut self) -> Self {
50+
self.real_if_args_real = true;
51+
self
52+
}
53+
4554
/// Set verbose reporting.
4655
pub fn verbose(mut self) -> Self {
4756
self.verbose = true;
@@ -56,6 +65,7 @@ impl Default for ComplexEvaluatorSettings {
5665
sqrt_real: false,
5766
log_real: false,
5867
powf_real: false,
68+
real_if_args_real: false,
5969
verbose: false,
6070
}
6171
}
@@ -66,8 +76,8 @@ impl<T: Default + PartialEq> ExpressionEvaluator<Complex<T>> {
6676
/// assembly output that uses real arithmetic instead of complex arithmetic
6777
/// where possible.
6878
///
69-
/// You can also set if all encountered sqrt, log, and powf operations with real
70-
/// arguments are expected to yield real results.
79+
/// You can also set if all encountered sqrt, log, powf, and custom evaluator
80+
/// operations with real arguments are expected to yield real results.
7181
///
7282
/// Must be called after all optimization functions and merging are performed
7383
/// on the evaluator, or the registration will be lost.
@@ -192,8 +202,15 @@ impl<T: Default + PartialEq> ExpressionEvaluator<Complex<T>> {
192202
}
193203
subcomponents[*r] = *sc;
194204
}
195-
Instr::ExternalFun(r, ..) => {
196-
*sc = ComplexPhase::Any;
205+
Instr::ExternalFun(r, _, a) => {
206+
if settings.real_if_args_real
207+
&& !a.is_empty()
208+
&& a.iter().all(|x| subcomponents[*x] == ComplexPhase::Real)
209+
{
210+
*sc = ComplexPhase::Real;
211+
} else {
212+
*sc = ComplexPhase::Any;
213+
}
197214
subcomponents[*r] = *sc;
198215
}
199216
Instr::IfElse(..) | Instr::Goto(..) | Instr::Label(..) => {

symbolica.pyi

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8576,15 +8576,16 @@ class Evaluator:
85768576
sqrt_real=False,
85778577
log_real=False,
85788578
powf_real=False,
8579+
real_if_args_real=False,
85798580
verbose=False,
85808581
) -> None:
85818582
"""
85828583
Set which parameters are fully real. This allows for more optimal
85838584
assembly output that uses real arithmetic instead of complex arithmetic
85848585
where possible.
85858586
8586-
You can also set if all encountered sqrt, log, and powf operations with real
8587-
arguments are expected to yield real results.
8587+
You can also set if all encountered sqrt, log, powf, and custom evaluator
8588+
operations with real arguments are expected to yield real results.
85888589
85898590
Must be called after all optimization functions and merging are performed
85908591
on the evaluator, or the registration will be lost.
@@ -8599,6 +8600,8 @@ class Evaluator:
85998600
Whether logarithms should be assumed real.
86008601
powf_real: Any
86018602
Whether fractional powers should be assumed real.
8603+
real_if_args_real: Any
8604+
Whether custom evaluators should yield real results for real arguments.
86028605
verbose: Any
86038606
Whether verbose output should be enabled.
86048607
"""

0 commit comments

Comments
 (0)