@@ -238,7 +238,7 @@ def balance(SC, response_matrix, reference=None, cm_ords=None, bpm_ords=None, ep
238238 raise RuntimeError ('Balancing two turns: FAIL (maxsteps reached, unstable)' )
239239
240240
241- def correct (SC , response_matrix , reference = None , cm_ords = None , bpm_ords = None , eps = 1e-4 , target = 0 , maxsteps = 30 , scaleDisp = 0 , ** pinv_params ):
241+ def correct (SC , response_matrix , reference = None , cm_ords = None , bpm_ords = None , eps = 1e-4 , target = 0 , maxsteps = 30 , scaleDisp = 0 , Mplus = None , ** pinv_params ):
242242 """
243243 Iterative orbit/trajectory correction
244244
@@ -287,7 +287,19 @@ def correct(SC, response_matrix, reference=None, cm_ords=None, bpm_ords=None, ep
287287 bpm_ords , cm_ords , reference = _check_ords (SC , response_matrix [:, :- 1 ] if scaleDisp else response_matrix ,
288288 reference , bpm_ords , cm_ords )
289289 bpm_readings , transmission_history , rms_orbit_history = _bpm_reading_and_logging (SC , bpm_ords = bpm_ords ) # Inject ...
290- Mplus = sc_tools .pinv (response_matrix , ** pinv_params )
290+
291+ # Initial validation
292+ if Mplus is not None :
293+ expected_shape = (response_matrix .shape [1 ], response_matrix .shape [0 ])
294+ if Mplus .shape != expected_shape :
295+ raise ValueError (f"Invalid Mplus shape. Expected { expected_shape } , got { Mplus .shape } " )
296+
297+ # Pseudoinverse handling
298+ if Mplus is None :
299+ Mplus = sc_tools .pinv (response_matrix , ** pinv_params )
300+ LOGGER .debug (f'Computed pseudoinverse with params: { pinv_params } ' )
301+ else :
302+ LOGGER .debug ('Using user-provided pseudoinverse matrix' )
291303
292304 # Main loop
293305 for steps in range (maxsteps ):
0 commit comments