Skip to content

Commit 1cb6205

Browse files
authored
Merge pull request #24 from greglucas/optimize-fit-predict
Optimize fit predict
2 parents 4950446 + 172a715 commit 1cb6205

File tree

2 files changed

+153
-58
lines changed

2 files changed

+153
-58
lines changed

pysecs/secs.py

Lines changed: 103 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def __init__(
5454
self.sec_amps = np.empty((0, self.nsec))
5555
self.sec_amps_var = np.empty((0, self.nsec))
5656

57+
# Keep some values around for cache lookups
58+
self._obs_loc = None
59+
self._T_obs_flat = None
60+
self._pred_loc_B = None
61+
self._T_pred_B = None
62+
self._pred_loc_J = None
63+
self._T_pred_J = None
64+
5765
@property
5866
def has_df(self) -> bool:
5967
"""Whether this system has any divergence free currents."""
@@ -69,6 +77,58 @@ def nsec(self) -> int:
6977
"""The number of elementary currents in this system."""
7078
return len(self.sec_df_loc) + len(self.sec_cf_loc)
7179

80+
@staticmethod
81+
def _compute_VWU(
82+
T_obs_flat: np.ndarray, std_flat: np.ndarray, epsilon: float, mode: str
83+
) -> np.ndarray:
84+
"""Compute the VWU matrix from the SVD of the transfer function.
85+
86+
This function computes the VWU matrix from the SVD of the transfer function
87+
and filters the singular values based on the specified mode. It is broken out
88+
to allow for easier branching logic in the fit() function.
89+
90+
Parameters
91+
----------
92+
T_obs_flat : ndarray
93+
The flattened transfer function matrix.
94+
std_flat : ndarray
95+
The flattened standard deviation matrix.
96+
epsilon : float
97+
The threshold for filtering singular values.
98+
mode : str
99+
The mode for filtering singular values.
100+
Options are 'relative' or 'variance'.
101+
102+
Returns
103+
-------
104+
ndarray
105+
The VWU matrix.
106+
"""
107+
# Weight the design matrix
108+
weighted_T = T_obs_flat / std_flat[:, np.newaxis]
109+
110+
# SVD
111+
U, S, Vh = np.linalg.svd(weighted_T, full_matrices=False)
112+
113+
# Filter components
114+
if mode == "relative":
115+
valid = S >= epsilon * S.max()
116+
elif mode == "variance":
117+
energy = np.cumsum(S**2)
118+
total = energy[-1]
119+
threshold = np.searchsorted(energy / total, 1 - epsilon) + 1
120+
valid = np.arange(len(S)) < threshold
121+
else:
122+
raise ValueError(f"Unknown SVD filtering mode: '{mode}'")
123+
124+
# Truncate and build VWU
125+
U = U[:, valid]
126+
S = S[valid]
127+
Vh = Vh[valid, :]
128+
W = 1.0 / S
129+
130+
return Vh.T @ (np.diag(W) @ U.T)
131+
72132
def fit(
73133
self,
74134
obs_loc: np.ndarray,
@@ -123,65 +183,50 @@ def fit(
123183

124184
# Assume unit standard error of all measurements
125185
if obs_std is None:
126-
obs_std = np.ones(obs_B.shape)
186+
obs_std = np.ones_like(obs_B)
127187

128188
ntimes = len(obs_B)
189+
# Flatten the components to do the math with shape (ntimes, nvariables)
190+
obs_B_flat = obs_B.reshape(ntimes, -1)
191+
obs_std_flat = obs_std.reshape(ntimes, -1)
129192

130-
# Calculate the transfer functions
131-
T_obs = self._calc_T(obs_loc)
193+
# Calculate the transfer functions, using cached values if possible
194+
if not np.array_equal(obs_loc, self._obs_loc):
195+
self._T_obs_flat = self._calc_T(obs_loc).reshape(-1, self.nsec)
196+
self._obs_loc = obs_loc
132197

133198
# Store the fit sec_amps in the object
134199
self.sec_amps = np.empty((ntimes, self.nsec))
135200
self.sec_amps_var = np.empty((ntimes, self.nsec))
136201

137-
# Calculate the singular value decomposition (SVD)
138-
# NOTE: T_obs has shape (nobs, 3, nsec), we reshape it
139-
# to (nobs*3, nsec); obs_std has shape (ntimes, nobs, 3),
140-
# we reshape it to (ntimes, nobs*3), then loop over ntimes
141-
# to solve using (potentially) time-dependent observation
142-
# standard errors to weight the observations
143-
for i in range(ntimes):
144-
# Only (re-)calculate SVD when necessary
145-
if i == 0 or not np.all(obs_std[i] == obs_std[i - 1]):
146-
# Weight T_obs with obs_std
147-
svd_in = (
148-
T_obs.reshape(-1, self.nsec) / obs_std[i].ravel()[:, np.newaxis]
149-
)
150-
151-
# Find singular value decompostion
152-
U, S, Vh = np.linalg.svd(svd_in, full_matrices=False)
153-
154-
if mode == "relative":
155-
valid = S >= epsilon * S.max()
156-
elif mode == "variance":
157-
cumulative_energy = np.cumsum(S**2)
158-
total_energy = cumulative_energy[-1]
159-
energy_ratio = cumulative_energy / total_energy
160-
n_components = np.searchsorted(energy_ratio, 1 - epsilon) + 1
161-
valid = np.arange(len(S)) < n_components
162-
else:
163-
raise ValueError(f"Unknown SVD filtering mode: '{mode}'")
164-
165-
# Apply truncation
166-
U = U[:, valid]
167-
S = S[valid]
168-
Vh = Vh[valid, :]
169-
170-
# Compute VWU
171-
W = 1.0 / S
172-
VWU = Vh.T @ (np.diag(W) @ U.T)
173-
174-
# Solve for SEC amplitudes and error variances
175-
# shape: (ntimes, nsec)
176-
self.sec_amps[i, :] = (VWU @ (obs_B[i] / obs_std[i]).reshape(-1).T).T
177-
178-
# Maybe we want the variance of the predictions sometime later...?
179-
# shape: (ntimes, nsec)
180-
valid = np.isfinite(obs_std[i].reshape(-1))
181-
self.sec_amps_var[i, :] = np.sum(
182-
(VWU[:, valid] * obs_std[i].reshape(-1)[valid]) ** 2, axis=1
183-
)
202+
if np.allclose(obs_std_flat, obs_std_flat[0]):
203+
# The SVD is the same for all time steps, so we can calculate it once
204+
# and broadcast it to all time steps avoiding the for-loop below
205+
VWU = self._compute_VWU(self._T_obs_flat, obs_std_flat[0], epsilon, mode)
206+
self.sec_amps[:] = (obs_B_flat / obs_std_flat) @ VWU.T
184207

208+
valid = np.isfinite(obs_std_flat[0])
209+
VWU_masked = VWU[:, valid]
210+
std_masked = obs_std_flat[0, valid]
211+
self.sec_amps_var[:] = np.sum((VWU_masked * std_masked) ** 2, axis=1)
212+
else:
213+
prev_std = None
214+
VWU = None
215+
for i in range(ntimes):
216+
if prev_std is None or not np.allclose(
217+
obs_std_flat[i], prev_std, atol=1e-12, rtol=1e-12
218+
):
219+
VWU = self._compute_VWU(
220+
self._T_obs_flat, obs_std_flat[i], epsilon, mode
221+
)
222+
prev_std = obs_std_flat[i]
223+
224+
self.sec_amps[i] = VWU @ (obs_B_flat[i] / obs_std_flat[i])
225+
226+
valid = np.isfinite(obs_std_flat[i])
227+
VWU_masked = VWU[:, valid]
228+
std_masked = obs_std_flat[i, valid]
229+
self.sec_amps_var[i] = np.sum((VWU_masked * std_masked) ** 2, axis=1)
185230
return self
186231

187232
def fit_unit_currents(self) -> "SECS":
@@ -225,16 +270,16 @@ def predict(self, pred_loc: np.ndarray, J: bool = False) -> np.ndarray:
225270
# sec_amps shape: (ntimes, nsec)
226271
if J:
227272
# Predicting currents
228-
T_pred = self._calc_J(pred_loc)
273+
if not np.array_equal(pred_loc, self._pred_loc_J):
274+
self._T_pred_J = self._calc_J(pred_loc)
275+
self._pred_loc_J = pred_loc
276+
T_pred = self._T_pred_J
229277
else:
230278
# Predicting magnetic fields
231-
T_pred = self._calc_T(pred_loc)
232-
233-
# NOTE: dot product is slow on multi-dimensional arrays (i.e. > 2 dimensions)
234-
# Therefore this is implemented as tensordot, and the arguments are
235-
# arranged to eliminate needs of transposing things later.
236-
# The dot product is done over the SEC locations, so the final output
237-
# is of shape: (ntimes, npred, 3)
279+
if not np.array_equal(pred_loc, self._pred_loc_B):
280+
self._T_pred_B = self._calc_T(pred_loc)
281+
self._pred_loc_B = pred_loc
282+
T_pred = self._T_pred_B
238283

239284
return np.squeeze(np.tensordot(self.sec_amps, T_pred, (1, 2)))
240285

tests/test_secs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,3 +575,53 @@ def test_predictJ_cf_df():
575575

576576
# Use the predict_J function call directly
577577
assert_allclose(secs.predict_J(pred_loc), secs.predict(pred_loc, J=True))
578+
579+
580+
def test_multidim_shapes():
581+
"""Test multidimensional prediction."""
582+
np.random.seed(0)
583+
nsec = 100
584+
nobs = 10
585+
ntimes = 75
586+
npred = 133
587+
sec_locs = np.random.rand(nsec, 3) * 100
588+
obs_locs = np.random.rand(nobs, 3) * 100
589+
obs_B = np.random.rand(ntimes, nobs, 3) * 10000
590+
pred_locs = np.random.rand(npred, 3) * 100
591+
592+
secs = pysecs.SECS(sec_df_loc=sec_locs)
593+
assert secs.nsec == nsec
594+
595+
secs.fit(obs_locs, obs_B=obs_B)
596+
assert secs.sec_amps.shape == (ntimes, nsec)
597+
assert secs.sec_amps_var.shape == (ntimes, nsec)
598+
599+
pred = secs.predict(pred_locs)
600+
assert pred.shape == (ntimes, npred, 3)
601+
602+
603+
def test_changing_obs_shape():
604+
# If we change the shape of the obs data, we don't want to
605+
# cache the old data and use that, we want to recompute with a new shape
606+
np.random.seed(0)
607+
nsec = 100
608+
nobs = 10
609+
ntimes = 75
610+
npred = 133
611+
sec_locs = np.random.rand(nsec, 3) * 100
612+
obs_locs = np.random.rand(nobs, 3) * 100
613+
obs_B = np.random.rand(ntimes, nobs, 3) * 10000
614+
pred_locs = np.random.rand(npred, 3) * 100
615+
616+
secs = pysecs.SECS(sec_df_loc=sec_locs)
617+
618+
secs.fit(obs_locs, obs_B=obs_B)
619+
# Now change the shape of the obs data by removing the last observation
620+
secs.fit(obs_locs[:-1], obs_B=obs_B[:, :-1, :])
621+
assert secs.sec_amps.shape == (ntimes, nsec)
622+
assert secs.sec_amps_var.shape == (ntimes, nsec)
623+
624+
pred = secs.predict(pred_locs)
625+
assert pred.shape == (ntimes, npred, 3)
626+
pred = secs.predict(pred_locs[:-1])
627+
assert pred.shape == (ntimes, npred - 1, 3)

0 commit comments

Comments
 (0)