@@ -2374,7 +2374,8 @@ inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX&
23742374 for (int j=0 ; j < i; ++j)
23752375 s -= L.data (tile_coord (i,j)) * X.data (tile_coord (j));
23762376
2377- X.data (tile_coord (i)) = s / L.data (tile_coord (i, i));
2377+ T diag = L.data (tile_coord (i, i));
2378+ X.data (tile_coord (i)) = (diag != T (0 .0f )) ? s / diag : s;
23782379 }
23792380 }
23802381 else if constexpr (TileY::Layout::Shape::N == 2 )
@@ -2391,7 +2392,8 @@ inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX&
23912392 for (int j=0 ; j < i; ++j)
23922393 s -= L.data (tile_coord (i,j)) * X.data (tile_coord (j,k));
23932394
2394- X.data (tile_coord (i,k)) = s / L.data (tile_coord (i, i));
2395+ T diag = L.data (tile_coord (i, i));
2396+ X.data (tile_coord (i,k)) = (diag != T (0 .0f )) ? s / diag : s;
23952397 }
23962398 }
23972399 }
@@ -2414,7 +2416,8 @@ inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
24142416 for (int j=i+1 ; j < n; ++j)
24152417 s -= L.data (tile_coord (j, i)) * X.data (tile_coord (j));
24162418
2417- X.data (tile_coord (i)) = s / L.data (tile_coord (i, i));
2419+ T diag = L.data (tile_coord (i, i));
2420+ X.data (tile_coord (i)) = (diag != T (0 .0f )) ? s / diag : s;
24182421 }
24192422 }
24202423 else if constexpr (TileX::Layout::Shape::N == 2 )
@@ -2431,7 +2434,8 @@ inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
24312434 for (int j=i+1 ; j < n; ++j)
24322435 s -= L.data (tile_coord (j, i)) * X.data (tile_coord (j,k));
24332436
2434- X.data (tile_coord (i,k)) = s / L.data (tile_coord (i, i));
2437+ T diag = L.data (tile_coord (i, i));
2438+ X.data (tile_coord (i,k)) = (diag != T (0 .0f )) ? s / diag : s;
24352439 }
24362440 }
24372441 }
@@ -2713,7 +2717,9 @@ TileZ& tile_lower_solve(TileL& L, TileY& y, TileZ& z)
27132717 // Divide the diagonal element (only one batch)
27142718 if (WP_TILE_THREAD_IDX == 0 )
27152719 {
2716- z.data (tile_coord (i)) /= L.data (tile_coord (i, i));
2720+ T diag = L.data (tile_coord (i, i));
2721+ if (diag != T (0 .0f ))
2722+ z.data (tile_coord (i)) /= diag;
27172723 }
27182724 WP_TILE_SYNC ();
27192725
@@ -2753,7 +2759,9 @@ TileZ& tile_lower_solve(TileL& L, TileY& y, TileZ& z)
27532759 // Divide the diagonal element for all batches in parallel
27542760 for (int batchId = WP_TILE_THREAD_IDX ; batchId < m; batchId += num_threads)
27552761 {
2756- z.data (tile_coord (i, batchId)) /= L.data (tile_coord (i, i));
2762+ T diag = L.data (tile_coord (i, i));
2763+ if (diag != T (0 .0f ))
2764+ z.data (tile_coord (i, batchId)) /= diag;
27572765 }
27582766 WP_TILE_SYNC ();
27592767
@@ -2851,7 +2859,9 @@ TileX& tile_upper_solve(TileU& U, TileZ& z, TileX& x)
28512859 // Divide the diagonal element for all batches in parallel (only one batch)
28522860 if (WP_TILE_THREAD_IDX == 0 )
28532861 {
2854- x.data (tile_coord (i)) /= U.data (tile_coord (i, i));
2862+ T diag = U.data (tile_coord (i, i));
2863+ if (diag != T (0 .0f ))
2864+ x.data (tile_coord (i)) /= diag;
28552865 }
28562866 WP_TILE_SYNC ();
28572867
@@ -2891,7 +2901,9 @@ TileX& tile_upper_solve(TileU& U, TileZ& z, TileX& x)
28912901 // Divide the diagonal element for all batches in parallel
28922902 for (int batchId = WP_TILE_THREAD_IDX ; batchId < m; batchId += num_threads)
28932903 {
2894- x.data (tile_coord (i, batchId)) /= U.data (tile_coord (i, i));
2904+ T diag = U.data (tile_coord (i, i));
2905+ if (diag != T (0 .0f ))
2906+ x.data (tile_coord (i, batchId)) /= diag;
28952907 }
28962908 WP_TILE_SYNC ();
28972909
0 commit comments