Skip to content

Commit 642b9fb

Browse files
committed
Improve multiplication table construction for duals
1 parent 271760e commit 642b9fb

1 file changed

Lines changed: 74 additions & 8 deletions

File tree

lib/numerica/src/domains/dual.rs

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
3737
use std::sync::Arc;
3838

39+
use ahash::HashMap;
40+
3941
use crate::domains::{
4042
float::{Constructible, FloatLike, Real, SingleFloat},
4143
rational::Rational,
@@ -225,11 +227,33 @@ const fn is_single_derivative_component<const N: usize>(
225227

226228
/// Get the size of the multiplication table.
227229
pub const fn get_mult_table_size<const N: usize, const C: usize>(r: &[[usize; N]; C]) -> usize {
230+
let mut max_single_pow = [0; N];
228231
let mut i = 0;
232+
while i < r.len() {
233+
let mut j = 0;
234+
while j < N {
235+
if r[i][j] > max_single_pow[j] {
236+
max_single_pow[j] = r[i][j];
237+
}
238+
j += 1;
239+
}
240+
i += 1;
241+
}
242+
229243
let mut ri = 0;
244+
i = 0;
230245
while i < r.len() {
231246
let mut j = 1; // skip first entry
232-
while j < r.len() {
247+
'next_inner: while j < r.len() {
248+
let mut k = 0;
249+
while k < N {
250+
if r[i][k] + r[j][k] > max_single_pow[k] {
251+
j += 1;
252+
continue 'next_inner;
253+
}
254+
k += 1;
255+
}
256+
233257
if get_multiplication_index::<N, C>(r, i, j).is_some() {
234258
ri += 1;
235259
}
@@ -247,11 +271,33 @@ pub const fn get_mult_table<const N: usize, const C: usize, const T: usize>(
247271
) -> [(usize, usize, usize); T] {
248272
let mut res = [(0, 0, 0); T];
249273

250-
let mut ri = 0;
274+
let mut max_single_pow = [0; N];
251275
let mut i = 0;
276+
while i < r.len() {
277+
let mut j = 0;
278+
while j < N {
279+
if r[i][j] > max_single_pow[j] {
280+
max_single_pow[j] = r[i][j];
281+
}
282+
j += 1;
283+
}
284+
i += 1;
285+
}
286+
287+
let mut ri = 0;
288+
i = 0;
252289
while i < r.len() {
253290
let mut j = 1; // skip first entry
254-
while j < r.len() {
291+
'next_inner: while j < r.len() {
292+
let mut k = 0;
293+
while k < N {
294+
if r[i][k] + r[j][k] > max_single_pow[k] {
295+
j += 1;
296+
continue 'next_inner;
297+
}
298+
k += 1;
299+
}
300+
255301
if let Some(index) = get_multiplication_index::<N, C>(r, i, j) {
256302
res[ri] = (i, j, index);
257303
ri += 1;
@@ -358,6 +404,10 @@ pub trait DualNumberStructure {
358404
#[macro_export]
359405
macro_rules! create_hyperdual_from_components {
360406
($t: ident, $var: expr) => {
407+
#[allow(unused_imports)]
408+
use $crate::domains::float::FloatLike as _;
409+
410+
#[allow(long_running_const_eval)]
361411
const _: () = assert!(
362412
$crate::domains::dual::is_dual_shape_ancestor_closed(&$var),
363413
"Dual shape is not ancestor-closed"
@@ -391,6 +441,7 @@ macro_rules! create_hyperdual_from_components {
391441
}
392442
max_pow
393443
};
444+
#[allow(long_running_const_eval)]
394445
const MULT_TABLE: [(usize, usize, usize); {
395446
$crate::domains::dual::get_mult_table_size(&$var)
396447
}] = $crate::domains::dual::get_mult_table(&$var);
@@ -900,7 +951,7 @@ macro_rules! create_hyperdual_from_components {
900951
}
901952

902953
#[inline(always)]
903-
fn from_rational(&self, rat: &Rational) -> Self {
954+
fn from_rational(&self, rat: &$crate::domains::rational::Rational) -> Self {
904955
let mut res = self.zero();
905956
res.values[0] = self.values[0].from_rational(rat);
906957
res
@@ -1302,17 +1353,32 @@ impl<T> HyperDual<T> {
13021353
.max()
13031354
.unwrap_or(0);
13041355

1356+
let max_single_pow: Vec<_> = (0..shape[0].len())
1357+
.map(|i| shape.iter().map(|s| s[i]).max().unwrap())
1358+
.collect();
1359+
13051360
let mut mult_table = vec![];
13061361

1362+
let entries = shape
1363+
.iter()
1364+
.enumerate()
1365+
.map(|(i, s)| (s.clone(), i))
1366+
.collect::<HashMap<_, _>>();
1367+
13071368
let mut sum = vec![0; shape[0].len()];
13081369
for (i, vi) in shape.iter().enumerate() {
1309-
for (j, vj) in shape.iter().enumerate().skip(1) {
1310-
for (s, (vii, vjj)) in sum.iter_mut().zip(vi.iter().zip(vj.iter())) {
1370+
'next_inner: for (j, vj) in shape.iter().enumerate().skip(1) {
1371+
for (r, (s, (vii, vjj))) in sum.iter_mut().zip(vi.iter().zip(vj.iter())).enumerate()
1372+
{
13111373
*s = vii + vjj;
1374+
1375+
if *s > max_single_pow[r] {
1376+
continue 'next_inner;
1377+
}
13121378
}
13131379

1314-
if let Some(p) = shape.iter().position(|s| s == &sum) {
1315-
mult_table.push((i, j, p));
1380+
if let Some(p) = entries.get(&sum) {
1381+
mult_table.push((i, j, *p));
13161382
}
13171383
}
13181384
}

0 commit comments

Comments
 (0)