Skip to content

Commit 7992878

Browse files
authored
Merge pull request #25 from greglucas/optimize-T-df
PERF: Optimize T_df computation
2 parents 1cb6205 + 69315ba commit 7992878

File tree

1 file changed

+68
-52
lines changed

1 file changed

+68
-52
lines changed

pysecs/secs.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -386,70 +386,44 @@ def T_df(obs_loc: np.ndarray, sec_loc: np.ndarray) -> np.ndarray:
386386
nobs = len(obs_loc)
387387
nsec = len(sec_loc)
388388

389-
obs_r = obs_loc[:, 2][:, np.newaxis]
390-
sec_r = sec_loc[:, 2][np.newaxis, :]
391-
392389
theta, alpha = _calc_angular_distance_and_bearing(obs_loc[:, :2], sec_loc[:, :2])
393390

394-
# magnetic permeability
395-
mu0 = 4 * np.pi * 1e-7
396-
397-
# simplify calculations by storing this ratio
398-
x = obs_r / sec_r
399-
400391
sin_theta = np.sin(theta)
401392
cos_theta = np.cos(theta)
402-
factor = 1.0 / np.sqrt(1 - 2 * x * cos_theta + x**2)
403393

404-
# Amm & Viljanen: Equation 9
405-
Br = mu0 / (4 * np.pi * obs_r) * (factor - 1)
394+
Br = np.empty((nobs, nsec))
395+
Btheta = np.empty((nobs, nsec))
396+
397+
# Over locations: obs_r > sec_r
398+
over_locs = obs_loc[:, 2][:, np.newaxis] > sec_loc[:, 2][np.newaxis, :]
399+
if np.any(over_locs):
400+
# We use np.where because we are broadcasting 1d arrays
401+
# over_locs is a 2d array of booleans
402+
over_indices = np.where(over_locs)
403+
obs_r = obs_loc[over_indices[0], 2]
404+
sec_r = sec_loc[over_indices[1], 2]
405+
Br[over_locs], Btheta[over_locs] = _calc_T_df_over(
406+
obs_r, sec_r, cos_theta[over_locs]
407+
)
408+
409+
# Under locations: obs_r <= sec_r
410+
under_locs = ~over_locs
411+
if np.any(under_locs):
412+
# We use np.where because we are broadcasting 1d arrays
413+
# over_locs is a 2d array of booleans
414+
under_indices = np.where(under_locs)
415+
obs_r = obs_loc[under_indices[0], 2]
416+
sec_r = sec_loc[under_indices[1], 2]
417+
Br[under_locs], Btheta[under_locs] = _calc_T_df_under(
418+
obs_r, sec_r, cos_theta[under_locs]
419+
)
406420

407-
# Amm & Viljanen: Equation 10 (transformed to try and eliminate trig operations and
408-
# divide by zeros)
409-
Btheta = -mu0 / (4 * np.pi * obs_r) * (factor * (x - cos_theta) + cos_theta)
410421
# If sin(theta) == 0: Btheta = 0
411422
# There is a possible 0/0 in the expansion when sec_loc == obs_loc
412423
Btheta = np.divide(
413424
Btheta, sin_theta, out=np.zeros_like(sin_theta), where=sin_theta != 0
414425
)
415426

416-
# When observation points radii are outside of the sec locations
417-
under_locs = sec_r < obs_r
418-
419-
# NOTE: If any SECs are below observations the math will be done on all points.
420-
# This could be updated to only work on the locations where this condition
421-
# occurs, but would make the code messier, with minimal performance gain
422-
# except for very large matrices.
423-
if np.any(under_locs):
424-
# Flipped from previous case
425-
x = sec_r / obs_r
426-
427-
# Amm & Viljanen: Equation A.7
428-
Br2 = (
429-
mu0
430-
* x
431-
/ (4 * np.pi * obs_r)
432-
* (1.0 / np.sqrt(1 - 2 * x * cos_theta + x**2) - 1)
433-
)
434-
435-
# Amm & Viljanen: Equation A.8
436-
Btheta2 = (
437-
-mu0
438-
/ (4 * np.pi * obs_r)
439-
* (
440-
(obs_r - sec_r * cos_theta)
441-
/ np.sqrt(obs_r**2 - 2 * obs_r * sec_r * cos_theta + sec_r**2)
442-
- 1
443-
)
444-
)
445-
Btheta2 = np.divide(
446-
Btheta2, sin_theta, out=np.zeros_like(sin_theta), where=sin_theta != 0
447-
)
448-
449-
# Update only the locations where secs are under observations
450-
Btheta[under_locs] = Btheta2[under_locs]
451-
Br[under_locs] = Br2[under_locs]
452-
453427
# Transform back to Bx, By, Bz at each local point
454428
T = np.empty((nobs, 3, nsec))
455429
# alpha == angle (from cartesian x-axis (By), going towards y-axis (Bx))
@@ -460,6 +434,48 @@ def T_df(obs_loc: np.ndarray, sec_loc: np.ndarray) -> np.ndarray:
460434
return T
461435

462436

437+
def _calc_T_df_under(
438+
obs_r: np.ndarray, sec_r: np.ndarray, cos_theta: np.ndarray
439+
) -> tuple[np.ndarray, np.ndarray]:
440+
"""T matrix for over locations (obs_r <= sec_r)."""
441+
mu0_over_4pi = 1e-7
442+
x = obs_r / sec_r
443+
factor = 1.0 / np.sqrt(1 - 2 * x * cos_theta + x**2)
444+
445+
# Amm & Viljanen: Equation 9
446+
Br = mu0_over_4pi / obs_r * (factor - 1)
447+
448+
# Amm & Viljanen: Equation 10
449+
Btheta = -mu0_over_4pi / obs_r * (factor * (x - cos_theta) + cos_theta)
450+
451+
return Br, Btheta
452+
453+
454+
def _calc_T_df_over(
455+
obs_r: np.ndarray, sec_r: np.ndarray, cos_theta: np.ndarray
456+
) -> tuple[np.ndarray, np.ndarray]:
457+
"""T matrix for over locations (obs_r > sec_r)."""
458+
mu0_over_4pi = 1e-7
459+
x = sec_r / obs_r
460+
factor = 1.0 / np.sqrt(1 - 2 * x * cos_theta + x**2)
461+
462+
# Amm & Viljanen: Equation A.7
463+
Br = mu0_over_4pi * x / obs_r * (factor - 1)
464+
465+
# Amm & Viljanen: Equation A.8
466+
Btheta = (
467+
-mu0_over_4pi
468+
/ obs_r
469+
* (
470+
(obs_r - sec_r * cos_theta)
471+
/ np.sqrt(obs_r**2 - 2 * obs_r * sec_r * cos_theta + sec_r**2)
472+
- 1
473+
)
474+
)
475+
476+
return Br, Btheta
477+
478+
463479
def T_cf(obs_loc: np.ndarray, sec_loc: np.ndarray) -> np.ndarray:
464480
"""Calculate the curl free magnetic field transfer function.
465481

0 commit comments

Comments
 (0)