Skip to content

Commit aed5d53

Browse files
committed
Merge branch 'avoid_division_by_zero' into 'main'
Guard against division by zero in forward and back-substitution See merge request omniverse/warp!1332
2 parents e8e7733 + 5b861be commit aed5d53

1 file changed

Lines changed: 20 additions & 8 deletions

File tree

warp/native/tile.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,7 +2374,8 @@ inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX&
23742374
for (int j=0; j < i; ++j)
23752375
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
23762376

2377-
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2377+
T diag = L.data(tile_coord(i, i));
2378+
X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
23782379
}
23792380
}
23802381
else if constexpr (TileY::Layout::Shape::N == 2)
@@ -2391,7 +2392,8 @@ inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX&
23912392
for (int j=0; j < i; ++j)
23922393
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j,k));
23932394

2394-
X.data(tile_coord(i,k)) = s / L.data(tile_coord(i, i));
2395+
T diag = L.data(tile_coord(i, i));
2396+
X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
23952397
}
23962398
}
23972399
}
@@ -2414,7 +2416,8 @@ inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
24142416
for (int j=i+1; j < n; ++j)
24152417
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
24162418

2417-
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2419+
T diag = L.data(tile_coord(i, i));
2420+
X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
24182421
}
24192422
}
24202423
else if constexpr (TileX::Layout::Shape::N == 2)
@@ -2431,7 +2434,8 @@ inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
24312434
for (int j=i+1; j < n; ++j)
24322435
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j,k));
24332436

2434-
X.data(tile_coord(i,k)) = s / L.data(tile_coord(i, i));
2437+
T diag = L.data(tile_coord(i, i));
2438+
X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
24352439
}
24362440
}
24372441
}
@@ -2713,7 +2717,9 @@ TileZ& tile_lower_solve(TileL& L, TileY& y, TileZ& z)
27132717
// Divide the diagonal element (only one batch)
27142718
if (WP_TILE_THREAD_IDX == 0)
27152719
{
2716-
z.data(tile_coord(i)) /= L.data(tile_coord(i, i));
2720+
T diag = L.data(tile_coord(i, i));
2721+
if (diag != T(0.0f))
2722+
z.data(tile_coord(i)) /= diag;
27172723
}
27182724
WP_TILE_SYNC();
27192725

@@ -2753,7 +2759,9 @@ TileZ& tile_lower_solve(TileL& L, TileY& y, TileZ& z)
27532759
// Divide the diagonal element for all batches in parallel
27542760
for (int batchId = WP_TILE_THREAD_IDX; batchId < m; batchId += num_threads)
27552761
{
2756-
z.data(tile_coord(i, batchId)) /= L.data(tile_coord(i, i));
2762+
T diag = L.data(tile_coord(i, i));
2763+
if (diag != T(0.0f))
2764+
z.data(tile_coord(i, batchId)) /= diag;
27572765
}
27582766
WP_TILE_SYNC();
27592767

@@ -2851,7 +2859,9 @@ TileX& tile_upper_solve(TileU& U, TileZ& z, TileX& x)
28512859
// Divide the diagonal element for all batches in parallel (only one batch)
28522860
if (WP_TILE_THREAD_IDX == 0)
28532861
{
2854-
x.data(tile_coord(i)) /= U.data(tile_coord(i, i));
2862+
T diag = U.data(tile_coord(i, i));
2863+
if (diag != T(0.0f))
2864+
x.data(tile_coord(i)) /= diag;
28552865
}
28562866
WP_TILE_SYNC();
28572867

@@ -2891,7 +2901,9 @@ TileX& tile_upper_solve(TileU& U, TileZ& z, TileX& x)
28912901
// Divide the diagonal element for all batches in parallel
28922902
for (int batchId = WP_TILE_THREAD_IDX; batchId < m; batchId += num_threads)
28932903
{
2894-
x.data(tile_coord(i, batchId)) /= U.data(tile_coord(i, i));
2904+
T diag = U.data(tile_coord(i, i));
2905+
if (diag != T(0.0f))
2906+
x.data(tile_coord(i, batchId)) /= diag;
28952907
}
28962908
WP_TILE_SYNC();
28972909

0 commit comments

Comments
 (0)