1+ void X (destroy_sparse )(X (sparse ) * A ) {
2+ free (A -> p );
3+ free (A -> q );
4+ free (A -> v );
5+ free (A );
6+ }
7+
18void X (destroy_banded )(X (banded ) * A ) {
29 free (A -> data );
310 free (A );
@@ -23,6 +30,7 @@ void X(destroy_tb_eigen_FMM)(X(tb_eigen_FMM) * F) {
2330 X (destroy_hierarchicalmatrix )(F -> F0 );
2431 X (destroy_tb_eigen_FMM )(F -> F1 );
2532 X (destroy_tb_eigen_FMM )(F -> F2 );
33+ X (destroy_sparse )(F -> S );
2634 free (F -> X );
2735 free (F -> Y );
2836 free (F -> t1 );
@@ -75,6 +83,28 @@ size_t X(summary_size_tb_eigen_ADI)(X(tb_eigen_ADI) * F) {
7583 return S ;
7684}
7785
86+ X (sparse ) * X (malloc_sparse )(const int m , const int n , const int nnz ) {
87+ X (sparse ) * A = malloc (sizeof (X (sparse )));
88+ A -> p = malloc (nnz * sizeof (int ));
89+ A -> q = malloc (nnz * sizeof (int ));
90+ A -> v = malloc (nnz * sizeof (FLT ));
91+ A -> m = m ;
92+ A -> n = n ;
93+ A -> nnz = nnz ;
94+ return A ;
95+ }
96+
97+ X (sparse ) * X (calloc_sparse )(const int m , const int n , const int nnz ) {
98+ X (sparse ) * A = malloc (sizeof (X (sparse )));
99+ A -> p = calloc (nnz , sizeof (int ));
100+ A -> q = calloc (nnz , sizeof (int ));
101+ A -> v = calloc (nnz , sizeof (FLT ));
102+ A -> m = m ;
103+ A -> n = n ;
104+ A -> nnz = nnz ;
105+ return A ;
106+ }
107+
78108X (banded ) * X (malloc_banded )(const int m , const int n , const int l , const int u ) {
79109 FLT * data = malloc (n * (l + u + 1 )* sizeof (FLT ));
80110 X (banded ) * A = malloc (sizeof (X (banded )));
@@ -488,7 +518,7 @@ void X(triangular_banded_eigenvectors)(X(triangular_banded) * A, X(triangular_ba
488518 }
489519 d = lam * X (get_triangular_banded_index )(B , i , i ) - X (get_triangular_banded_index )(A , i , i );
490520 kd = Y (fabs )(lam * X (get_triangular_banded_index )(B , i , i )) + Y (fabs )(X (get_triangular_banded_index )(A , i , i ));
491- if (Y (fabs )(d ) < 4 * kd * Y (eps )() && Y (fabs )(t ) < 4 * kt * Y (eps )())
521+ if (Y (fabs )(d ) < 4 * kd * Y (eps )() || Y (fabs )(t ) < 4 * kt * Y (eps )())
492522 V [i + j * n ] = 0 ;
493523 else
494524 V [i + j * n ] = t /d ;
@@ -533,6 +563,83 @@ void X(triangular_banded_quadratic_eigenvectors)(X(triangular_banded) * A, X(tri
533563 }
534564}
535565
566+ // Assumptions: x, y are non-decreasing.
567+ static inline int X (count_intersections )(const int m , const FLT * x , const int n , const FLT * y , const FLT epsilon ) {
568+ int istart = 0 , idx = 0 ;
569+ for (int j = 0 ; j < n ; j ++ ) {
570+ int i = istart ;
571+ int thefirst = 1 ;
572+ while (i < m ) {
573+ if (Y (fabs )(x [i ] - y [j ]) < epsilon * MAX (Y (fabs )(x [i ]), Y (fabs )(y [j ]))) {
574+ idx ++ ;
575+ if (thefirst ) {
576+ istart = i ;
577+ thefirst -- ;
578+ }
579+ }
580+ else if (x [i ] > y [j ])
581+ break ;
582+ i ++ ;
583+ }
584+ }
585+ return idx ;
586+ }
587+
588+ // Assumptions: p and q have been malloc'ed with `idx` integers.
589+ static inline void X (produce_intersection_indices )(const int m , const FLT * x , const int n , const FLT * y , const FLT epsilon , int * p , int * q ) {
590+ int istart = 0 , idx = 0 ;
591+ for (int j = 0 ; j < n ; j ++ ) {
592+ int i = istart ;
593+ int thefirst = 1 ;
594+ while (i < m ) {
595+ if (Y (fabs )(x [i ] - y [j ]) < epsilon * MAX (Y (fabs )(x [i ]), Y (fabs )(y [j ]))) {
596+ p [idx ] = i ;
597+ q [idx ] = j ;
598+ idx ++ ;
599+ if (thefirst ) {
600+ istart = i ;
601+ thefirst -- ;
602+ }
603+ }
604+ else if (x [i ] > y [j ])
605+ break ;
606+ i ++ ;
607+ }
608+ }
609+ }
610+
611+ static inline X (sparse ) * X (get_sparse_from_eigenvectors )(X (tb_eigen_FMM ) * F1 , X (triangular_banded ) * A , X (triangular_banded ) * B , FLT * D , int * p1 , int * p2 , int * p3 , int * p4 , int n , int s , int b , int idx ) {
612+ X (sparse ) * S = X (malloc_sparse )(s , n - s , idx );
613+ FLT * V = calloc (n , sizeof (FLT ));
614+ for (int l = 0 ; l < idx ; l ++ ) {
615+ int j = p2 [p4 [l ]]+ s ;
616+ for (int i = 0 ; i < n ; i ++ )
617+ V [i ] = 0 ;
618+ V [j ] = D [j ];
619+ FLT t , kt , d , kd , lam ;
620+ lam = X (get_triangular_banded_index )(A , j , j )/X (get_triangular_banded_index )(B , j , j );
621+ for (int i = j - 1 ; i >= 0 ; i -- ) {
622+ t = kt = 0 ;
623+ for (int k = i + 1 ; k < MIN (i + b + 1 , n ); k ++ ) {
624+ t += (X (get_triangular_banded_index )(A , i , k ) - lam * X (get_triangular_banded_index )(B , i , k ))* V [k ];
625+ kt += (Y (fabs )(X (get_triangular_banded_index )(A , i , k )) + Y (fabs )(lam * X (get_triangular_banded_index )(B , i , k )))* Y (fabs )(V [k ]);
626+ }
627+ d = lam * X (get_triangular_banded_index )(B , i , i ) - X (get_triangular_banded_index )(A , i , i );
628+ kd = Y (fabs )(lam * X (get_triangular_banded_index )(B , i , i )) + Y (fabs )(X (get_triangular_banded_index )(A , i , i ));
629+ if (Y (fabs )(d ) < 4 * kd * Y (eps )() || Y (fabs )(t ) < 4 * kt * Y (eps )())
630+ V [i ] = 0 ;
631+ else
632+ V [i ] = t /d ;
633+ }
634+ X (bfsv )('N' , F1 , V );
635+ S -> p [l ] = p1 [p3 [l ]];
636+ S -> q [l ] = p2 [p4 [l ]];
637+ S -> v [l ] = V [p1 [p3 [l ]]];
638+ }
639+ free (V );
640+ return S ;
641+ }
642+
536643X (tb_eigen_FMM ) * X (tb_eig_FMM )(X (triangular_banded ) * A , X (triangular_banded ) * B , FLT * D ) {
537644 int n = A -> n , b1 = A -> b , b2 = B -> b ;
538645 int b = MAX (b1 , b2 );
@@ -599,9 +706,18 @@ X(tb_eigen_FMM) * X(tb_eig_FMM)(X(triangular_banded) * A, X(triangular_banded) *
599706 p2 [i ] = i ;
600707 X (quicksort_1arg )(lambda + s , p2 , 0 , n - s - 1 , X (lt ));
601708
602- F -> F0 = X (sample_hierarchicalmatrix )(X (cauchykernel ), lambda , lambda , i , j , 'G' );
709+ int idx = X (count_intersections )(s , lambda , n - s , lambda + s , 16 * Y (sqrt )(Y (eps )()));
710+ int * p3 = malloc (idx * sizeof (int ));
711+ int * p4 = malloc (idx * sizeof (int ));
712+ X (produce_intersection_indices )(s , lambda , n - s , lambda + s , 16 * Y (sqrt )(Y (eps )()), p3 , p4 );
713+ X (sparse ) * S = X (get_sparse_from_eigenvectors )(F -> F1 , A , B , D , p1 , p2 , p3 , p4 , n , s , b , idx );
714+ free (p3 );
715+ free (p4 );
716+
717+ F -> F0 = X (sample_hierarchicalmatrix )(X (thresholded_cauchykernel ), lambda , lambda , i , j , 'G' );
603718 F -> X = X ;
604719 F -> Y = Y ;
720+ F -> S = S ;
605721 F -> t1 = calloc (s * FT_GET_MAX_THREADS (), sizeof (FLT ));
606722 F -> t2 = calloc ((n - s )* FT_GET_MAX_THREADS (), sizeof (FLT ));
607723 X (perm )('T' , lambda , p1 , s );
@@ -860,6 +976,7 @@ void X(bfmv)(char TRANS, X(tb_eigen_FMM) * F, FLT * x) {
860976 int s = n >>1 , b = F -> b ;
861977 FLT * t1 = F -> t1 + s * FT_GET_THREAD_NUM (), * t2 = F -> t2 + (n - s )* FT_GET_THREAD_NUM ();
862978 int * p1 = F -> p1 , * p2 = F -> p2 ;
979+ X (sparse ) * S = F -> S ;
863980 if (TRANS == 'N' ) {
864981 // C(Λ₁, Λ₂) ∘ (-XYᵀ)
865982 for (int k = 0 ; k < b ; k ++ ) {
@@ -869,6 +986,8 @@ void X(bfmv)(char TRANS, X(tb_eigen_FMM) * F, FLT * x) {
869986 for (int i = 0 ; i < s ; i ++ )
870987 x [p1 [i ]] += t1 [i ]* F -> X [p1 [i ]+ k * s ];
871988 }
989+ for (int l = 0 ; l < S -> nnz ; l ++ )
990+ x [S -> p [l ]] += S -> v [l ]* x [S -> q [l ]+ s ];
872991 X (bfmv )(TRANS , F -> F1 , x );
873992 X (bfmv )(TRANS , F -> F2 , x + s );
874993 }
@@ -883,6 +1002,8 @@ void X(bfmv)(char TRANS, X(tb_eigen_FMM) * F, FLT * x) {
8831002 for (int i = 0 ; i < n - s ; i ++ )
8841003 x [p2 [i ]+ s ] += t2 [i ]* F -> Y [p2 [i ]+ k * (n - s )];
8851004 }
1005+ for (int l = 0 ; l < S -> nnz ; l ++ )
1006+ x [S -> q [l ]+ s ] += S -> v [l ]* x [S -> p [l ]];
8861007 }
8871008 }
8881009}
@@ -916,6 +1037,7 @@ void X(bfsv)(char TRANS, X(tb_eigen_FMM) * F, FLT * x) {
9161037 int s = n >>1 , b = F -> b ;
9171038 FLT * t1 = F -> t1 + s * FT_GET_THREAD_NUM (), * t2 = F -> t2 + (n - s )* FT_GET_THREAD_NUM ();
9181039 int * p1 = F -> p1 , * p2 = F -> p2 ;
1040+ X (sparse ) * S = F -> S ;
9191041 if (TRANS == 'N' ) {
9201042 X (bfsv )(TRANS , F -> F1 , x );
9211043 X (bfsv )(TRANS , F -> F2 , x + s );
@@ -927,6 +1049,8 @@ void X(bfsv)(char TRANS, X(tb_eigen_FMM) * F, FLT * x) {
9271049 for (int i = 0 ; i < s ; i ++ )
9281050 x [p1 [i ]] += t1 [i ]* F -> X [p1 [i ]+ k * s ];
9291051 }
1052+ for (int l = 0 ; l < S -> nnz ; l ++ )
1053+ x [S -> p [l ]] -= S -> v [l ]* x [S -> q [l ]+ s ];
9301054 }
9311055 else if (TRANS == 'T' ) {
9321056 // C(Λ₁, Λ₂) ∘ (-XYᵀ)
@@ -937,6 +1061,8 @@ void X(bfsv)(char TRANS, X(tb_eigen_FMM) * F, FLT * x) {
9371061 for (int i = 0 ; i < n - s ; i ++ )
9381062 x [p2 [i ]+ s ] += t2 [i ]* F -> Y [p2 [i ]+ k * (n - s )];
9391063 }
1064+ for (int l = 0 ; l < S -> nnz ; l ++ )
1065+ x [S -> q [l ]+ s ] -= S -> v [l ]* x [S -> p [l ]];
9401066 X (bfsv )(TRANS , F -> F1 , x );
9411067 X (bfsv )(TRANS , F -> F2 , x + s );
9421068 }
0 commit comments