@@ -54,6 +54,14 @@ def __init__(
5454 self .sec_amps = np .empty ((0 , self .nsec ))
5555 self .sec_amps_var = np .empty ((0 , self .nsec ))
5656
57+ # Keep some values around for cache lookups
58+ self ._obs_loc = None
59+ self ._T_obs_flat = None
60+ self ._pred_loc_B = None
61+ self ._T_pred_B = None
62+ self ._pred_loc_J = None
63+ self ._T_pred_J = None
64+
5765 @property
5866 def has_df (self ) -> bool :
5967 """Whether this system has any divergence free currents."""
@@ -69,6 +77,58 @@ def nsec(self) -> int:
6977 """The number of elementary currents in this system."""
7078 return len (self .sec_df_loc ) + len (self .sec_cf_loc )
7179
80+ @staticmethod
81+ def _compute_VWU (
82+ T_obs_flat : np .ndarray , std_flat : np .ndarray , epsilon : float , mode : str
83+ ) -> np .ndarray :
84+ """Compute the VWU matrix from the SVD of the transfer function.
85+
86+ This function computes the VWU matrix from the SVD of the transfer function
87+ and filters the singular values based on the specified mode. It is broken out
88+ to allow for easier branching logic in the fit() function.
89+
90+ Parameters
91+ ----------
92+ T_obs_flat : ndarray
93+ The flattened transfer function matrix.
94+ std_flat : ndarray
95+ The flattened standard deviation matrix.
96+ epsilon : float
97+ The threshold for filtering singular values.
98+ mode : str
99+ The mode for filtering singular values.
100+ Options are 'relative' or 'variance'.
101+
102+ Returns
103+ -------
104+ ndarray
105+ The VWU matrix.
106+ """
107+ # Weight the design matrix
108+ weighted_T = T_obs_flat / std_flat [:, np .newaxis ]
109+
110+ # SVD
111+ U , S , Vh = np .linalg .svd (weighted_T , full_matrices = False )
112+
113+ # Filter components
114+ if mode == "relative" :
115+ valid = S >= epsilon * S .max ()
116+ elif mode == "variance" :
117+ energy = np .cumsum (S ** 2 )
118+ total = energy [- 1 ]
119+ threshold = np .searchsorted (energy / total , 1 - epsilon ) + 1
120+ valid = np .arange (len (S )) < threshold
121+ else :
122+ raise ValueError (f"Unknown SVD filtering mode: '{ mode } '" )
123+
124+ # Truncate and build VWU
125+ U = U [:, valid ]
126+ S = S [valid ]
127+ Vh = Vh [valid , :]
128+ W = 1.0 / S
129+
130+ return Vh .T @ (np .diag (W ) @ U .T )
131+
72132 def fit (
73133 self ,
74134 obs_loc : np .ndarray ,
@@ -123,65 +183,50 @@ def fit(
123183
124184 # Assume unit standard error of all measurements
125185 if obs_std is None :
126- obs_std = np .ones (obs_B . shape )
186+ obs_std = np .ones_like (obs_B )
127187
128188 ntimes = len (obs_B )
189+ # Flatten the components to do the math with shape (ntimes, nvariables)
190+ obs_B_flat = obs_B .reshape (ntimes , - 1 )
191+ obs_std_flat = obs_std .reshape (ntimes , - 1 )
129192
130- # Calculate the transfer functions
131- T_obs = self ._calc_T (obs_loc )
193+ # Calculate the transfer functions, using cached values if possible
194+ if not np .array_equal (obs_loc , self ._obs_loc ):
195+ self ._T_obs_flat = self ._calc_T (obs_loc ).reshape (- 1 , self .nsec )
196+ self ._obs_loc = obs_loc
132197
133198 # Store the fit sec_amps in the object
134199 self .sec_amps = np .empty ((ntimes , self .nsec ))
135200 self .sec_amps_var = np .empty ((ntimes , self .nsec ))
136201
137- # Calculate the singular value decomposition (SVD)
138- # NOTE: T_obs has shape (nobs, 3, nsec), we reshape it
139- # to (nobs*3, nsec); obs_std has shape (ntimes, nobs, 3),
140- # we reshape it to (ntimes, nobs*3), then loop over ntimes
141- # to solve using (potentially) time-dependent observation
142- # standard errors to weight the observations
143- for i in range (ntimes ):
144- # Only (re-)calculate SVD when necessary
145- if i == 0 or not np .all (obs_std [i ] == obs_std [i - 1 ]):
146- # Weight T_obs with obs_std
147- svd_in = (
148- T_obs .reshape (- 1 , self .nsec ) / obs_std [i ].ravel ()[:, np .newaxis ]
149- )
150-
151- # Find singular value decompostion
152- U , S , Vh = np .linalg .svd (svd_in , full_matrices = False )
153-
154- if mode == "relative" :
155- valid = S >= epsilon * S .max ()
156- elif mode == "variance" :
157- cumulative_energy = np .cumsum (S ** 2 )
158- total_energy = cumulative_energy [- 1 ]
159- energy_ratio = cumulative_energy / total_energy
160- n_components = np .searchsorted (energy_ratio , 1 - epsilon ) + 1
161- valid = np .arange (len (S )) < n_components
162- else :
163- raise ValueError (f"Unknown SVD filtering mode: '{ mode } '" )
164-
165- # Apply truncation
166- U = U [:, valid ]
167- S = S [valid ]
168- Vh = Vh [valid , :]
169-
170- # Compute VWU
171- W = 1.0 / S
172- VWU = Vh .T @ (np .diag (W ) @ U .T )
173-
174- # Solve for SEC amplitudes and error variances
175- # shape: (ntimes, nsec)
176- self .sec_amps [i , :] = (VWU @ (obs_B [i ] / obs_std [i ]).reshape (- 1 ).T ).T
177-
178- # Maybe we want the variance of the predictions sometime later...?
179- # shape: (ntimes, nsec)
180- valid = np .isfinite (obs_std [i ].reshape (- 1 ))
181- self .sec_amps_var [i , :] = np .sum (
182- (VWU [:, valid ] * obs_std [i ].reshape (- 1 )[valid ]) ** 2 , axis = 1
183- )
202+ if np .allclose (obs_std_flat , obs_std_flat [0 ]):
203+ # The SVD is the same for all time steps, so we can calculate it once
204+ # and broadcast it to all time steps avoiding the for-loop below
205+ VWU = self ._compute_VWU (self ._T_obs_flat , obs_std_flat [0 ], epsilon , mode )
206+ self .sec_amps [:] = (obs_B_flat / obs_std_flat ) @ VWU .T
184207
208+ valid = np .isfinite (obs_std_flat [0 ])
209+ VWU_masked = VWU [:, valid ]
210+ std_masked = obs_std_flat [0 , valid ]
211+ self .sec_amps_var [:] = np .sum ((VWU_masked * std_masked ) ** 2 , axis = 1 )
212+ else :
213+ prev_std = None
214+ VWU = None
215+ for i in range (ntimes ):
216+ if prev_std is None or not np .allclose (
217+ obs_std_flat [i ], prev_std , atol = 1e-12 , rtol = 1e-12
218+ ):
219+ VWU = self ._compute_VWU (
220+ self ._T_obs_flat , obs_std_flat [i ], epsilon , mode
221+ )
222+ prev_std = obs_std_flat [i ]
223+
224+ self .sec_amps [i ] = VWU @ (obs_B_flat [i ] / obs_std_flat [i ])
225+
226+ valid = np .isfinite (obs_std_flat [i ])
227+ VWU_masked = VWU [:, valid ]
228+ std_masked = obs_std_flat [i , valid ]
229+ self .sec_amps_var [i ] = np .sum ((VWU_masked * std_masked ) ** 2 , axis = 1 )
185230 return self
186231
187232 def fit_unit_currents (self ) -> "SECS" :
@@ -225,16 +270,16 @@ def predict(self, pred_loc: np.ndarray, J: bool = False) -> np.ndarray:
225270 # sec_amps shape: (ntimes, nsec)
226271 if J :
227272 # Predicting currents
228- T_pred = self ._calc_J (pred_loc )
273+ if not np .array_equal (pred_loc , self ._pred_loc_J ):
274+ self ._T_pred_J = self ._calc_J (pred_loc )
275+ self ._pred_loc_J = pred_loc
276+ T_pred = self ._T_pred_J
229277 else :
230278 # Predicting magnetic fields
231- T_pred = self ._calc_T (pred_loc )
232-
233- # NOTE: dot product is slow on multi-dimensional arrays (i.e. > 2 dimensions)
234- # Therefore this is implemented as tensordot, and the arguments are
235- # arranged to eliminate needs of transposing things later.
236- # The dot product is done over the SEC locations, so the final output
237- # is of shape: (ntimes, npred, 3)
279+ if not np .array_equal (pred_loc , self ._pred_loc_B ):
280+ self ._T_pred_B = self ._calc_T (pred_loc )
281+ self ._pred_loc_B = pred_loc
282+ T_pred = self ._T_pred_B
238283
239284 return np .squeeze (np .tensordot (self .sec_amps , T_pred , (1 , 2 )))
240285
0 commit comments