Skip to content

Commit 8912573

Browse files
CopilotSierd
andcommitted
Simplify FFT shear edge case handling
Simplified the masked computation approach to a cleaner implementation: - Use np.where for safe division (replace zeros with 1.0 temporarily) - Compute formulas normally with safe arrays - Apply invalid_mask at the end to zero out problematic regions This achieves the same result with much simpler, more readable code. Co-authored-by: Sierd <[email protected]>
1 parent b38b349 commit 8912573

File tree

1 file changed

+15
-28
lines changed

1 file changed

+15
-28
lines changed

aeolis/shear.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -582,34 +582,21 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
582582
time_start_perturbation = time.time()
583583

584584
# Shear stress perturbation
585-
# Avoid division by zero and invalid values
586-
# When kx=0 or ky=0 or k=0, set perturbations to zero
587-
588-
# Create masks for valid computations
589-
valid_mask = (k > 0) & (np.abs(kx) > 0)
590-
591-
# Initialize perturbation arrays with zeros
592-
dtaux_t = np.zeros_like(hs, dtype=complex)
593-
dtauy_t = np.zeros_like(hs, dtype=complex)
594-
595-
# Only compute where we have valid frequencies
596-
if np.any(valid_mask):
597-
# Safe division for valid regions only
598-
k_valid = k[valid_mask]
599-
kx_valid = kx[valid_mask]
600-
ky_valid = ky[valid_mask]
601-
hs_valid = hs[valid_mask]
602-
sigma_valid = sigma[valid_mask]
603-
ul2_valid = ul**2
604-
605-
# Compute dtaux for valid regions
606-
dtaux_t[valid_mask] = hs_valid * kx_valid**2 / k_valid * 2 / ul2_valid * \
607-
(-1. + (2. * np.log(l/z0new) + k_valid**2/kx_valid**2) * sigma_valid * \
608-
sc_kv(1., 2. * sigma_valid) / sc_kv(0., 2. * sigma_valid))
609-
610-
# Compute dtauy for valid regions
611-
dtauy_t[valid_mask] = hs_valid * kx_valid * ky_valid / k_valid * 2 / ul2_valid * \
612-
2. * np.sqrt(2.) * sigma_valid * sc_kv(1., 2. * np.sqrt(2.) * sigma_valid) / sc_kv(0., 2. * np.sqrt(2.) * sigma_valid)
585+
# Use safe division to avoid zero/invalid values at kx=0 or k=0
586+
k_safe = np.where(k == 0, 1.0, k)
587+
kx_safe = np.where(kx == 0, 1.0, kx)
588+
589+
dtaux_t = hs * kx**2 / k_safe * 2 / ul**2 * \
590+
(-1. + (2. * np.log(l/z0new) + k**2/kx_safe**2) * sigma * \
591+
sc_kv(1., 2. * sigma) / sc_kv(0., 2. * sigma))
592+
593+
dtauy_t = hs * kx * ky / k_safe * 2 / ul**2 * \
594+
2. * np.sqrt(2.) * sigma * sc_kv(1., 2. * np.sqrt(2.) * sigma) / sc_kv(0., 2. * np.sqrt(2.) * sigma)
595+
596+
# Zero out invalid regions (kx=0 or k=0) where formulation is not valid
597+
invalid_mask = (k == 0) | (kx == 0)
598+
dtaux_t[invalid_mask] = 0.
599+
dtauy_t[invalid_mask] = 0.
613600

614601
gc['dtaux'] = np.real(np.fft.ifft2(dtaux_t))
615602
gc['dtauy'] = np.real(np.fft.ifft2(dtauy_t))

0 commit comments

Comments
 (0)