3636
3737use std:: sync:: Arc ;
3838
39+ use ahash:: HashMap ;
40+
3941use crate :: domains:: {
4042 float:: { Constructible , FloatLike , Real , SingleFloat } ,
4143 rational:: Rational ,
@@ -225,11 +227,33 @@ const fn is_single_derivative_component<const N: usize>(
225227
226228/// Get the size of the multiplication table.
227229pub const fn get_mult_table_size < const N : usize , const C : usize > ( r : & [ [ usize ; N ] ; C ] ) -> usize {
230+ let mut max_single_pow = [ 0 ; N ] ;
228231 let mut i = 0 ;
232+ while i < r. len ( ) {
233+ let mut j = 0 ;
234+ while j < N {
235+ if r[ i] [ j] > max_single_pow[ j] {
236+ max_single_pow[ j] = r[ i] [ j] ;
237+ }
238+ j += 1 ;
239+ }
240+ i += 1 ;
241+ }
242+
229243 let mut ri = 0 ;
244+ i = 0 ;
230245 while i < r. len ( ) {
231246 let mut j = 1 ; // skip first entry
232- while j < r. len ( ) {
247+ ' next_inner: while j < r. len ( ) {
248+ let mut k = 0 ;
249+ while k < N {
250+ if r[ i] [ k] + r[ j] [ k] > max_single_pow[ k] {
251+ j += 1 ;
252+ continue ' next_inner;
253+ }
254+ k += 1 ;
255+ }
256+
233257 if get_multiplication_index :: < N , C > ( r, i, j) . is_some ( ) {
234258 ri += 1 ;
235259 }
@@ -247,11 +271,33 @@ pub const fn get_mult_table<const N: usize, const C: usize, const T: usize>(
247271) -> [ ( usize , usize , usize ) ; T ] {
248272 let mut res = [ ( 0 , 0 , 0 ) ; T ] ;
249273
250- let mut ri = 0 ;
274+ let mut max_single_pow = [ 0 ; N ] ;
251275 let mut i = 0 ;
276+ while i < r. len ( ) {
277+ let mut j = 0 ;
278+ while j < N {
279+ if r[ i] [ j] > max_single_pow[ j] {
280+ max_single_pow[ j] = r[ i] [ j] ;
281+ }
282+ j += 1 ;
283+ }
284+ i += 1 ;
285+ }
286+
287+ let mut ri = 0 ;
288+ i = 0 ;
252289 while i < r. len ( ) {
253290 let mut j = 1 ; // skip first entry
254- while j < r. len ( ) {
291+ ' next_inner: while j < r. len ( ) {
292+ let mut k = 0 ;
293+ while k < N {
294+ if r[ i] [ k] + r[ j] [ k] > max_single_pow[ k] {
295+ j += 1 ;
296+ continue ' next_inner;
297+ }
298+ k += 1 ;
299+ }
300+
255301 if let Some ( index) = get_multiplication_index :: < N , C > ( r, i, j) {
256302 res[ ri] = ( i, j, index) ;
257303 ri += 1 ;
@@ -358,6 +404,10 @@ pub trait DualNumberStructure {
358404#[ macro_export]
359405macro_rules! create_hyperdual_from_components {
360406 ( $t: ident, $var: expr) => {
407+ #[ allow( unused_imports) ]
408+ use $crate:: domains:: float:: FloatLike as _;
409+
410+ #[ allow( long_running_const_eval) ]
361411 const _: ( ) = assert!(
362412 $crate:: domains:: dual:: is_dual_shape_ancestor_closed( & $var) ,
363413 "Dual shape is not ancestor-closed"
@@ -391,6 +441,7 @@ macro_rules! create_hyperdual_from_components {
391441 }
392442 max_pow
393443 } ;
444+ #[ allow( long_running_const_eval) ]
394445 const MULT_TABLE : [ ( usize , usize , usize ) ; {
395446 $crate:: domains:: dual:: get_mult_table_size( & $var)
396447 } ] = $crate:: domains:: dual:: get_mult_table( & $var) ;
@@ -900,7 +951,7 @@ macro_rules! create_hyperdual_from_components {
900951 }
901952
902953 #[ inline( always) ]
903- fn from_rational( & self , rat: & Rational ) -> Self {
954+ fn from_rational( & self , rat: & $crate :: domains :: rational :: Rational ) -> Self {
904955 let mut res = self . zero( ) ;
905956 res. values[ 0 ] = self . values[ 0 ] . from_rational( rat) ;
906957 res
@@ -1302,17 +1353,32 @@ impl<T> HyperDual<T> {
13021353 . max ( )
13031354 . unwrap_or ( 0 ) ;
13041355
1356+ let max_single_pow: Vec < _ > = ( 0 ..shape[ 0 ] . len ( ) )
1357+ . map ( |i| shape. iter ( ) . map ( |s| s[ i] ) . max ( ) . unwrap ( ) )
1358+ . collect ( ) ;
1359+
13051360 let mut mult_table = vec ! [ ] ;
13061361
1362+ let entries = shape
1363+ . iter ( )
1364+ . enumerate ( )
1365+ . map ( |( i, s) | ( s. clone ( ) , i) )
1366+ . collect :: < HashMap < _ , _ > > ( ) ;
1367+
13071368 let mut sum = vec ! [ 0 ; shape[ 0 ] . len( ) ] ;
13081369 for ( i, vi) in shape. iter ( ) . enumerate ( ) {
1309- for ( j, vj) in shape. iter ( ) . enumerate ( ) . skip ( 1 ) {
1310- for ( s, ( vii, vjj) ) in sum. iter_mut ( ) . zip ( vi. iter ( ) . zip ( vj. iter ( ) ) ) {
1370+ ' next_inner: for ( j, vj) in shape. iter ( ) . enumerate ( ) . skip ( 1 ) {
1371+ for ( r, ( s, ( vii, vjj) ) ) in sum. iter_mut ( ) . zip ( vi. iter ( ) . zip ( vj. iter ( ) ) ) . enumerate ( )
1372+ {
13111373 * s = vii + vjj;
1374+
1375+ if * s > max_single_pow[ r] {
1376+ continue ' next_inner;
1377+ }
13121378 }
13131379
1314- if let Some ( p) = shape . iter ( ) . position ( |s| s == & sum) {
1315- mult_table. push ( ( i, j, p) ) ;
1380+ if let Some ( p) = entries . get ( & sum) {
1381+ mult_table. push ( ( i, j, * p) ) ;
13161382 }
13171383 }
13181384 }
0 commit comments