Skip to content

Commit 38c34d6

Browse files
committed
Add more general function to test if too fields are equal
1 parent 3f77462 commit 38c34d6

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

xtrack/loss_location_refinement/loss_location_refinement.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,37 @@ def check_for_active_shifts_and_rotations(line, i_aper_0, i_aper_1):
217217
break
218218
return presence_shifts_rotations
219219

220+
def fields_equal(a, b, atol=1e-15):
221+
# Check if exactly the same object
222+
if a is b:
223+
return True
224+
225+
# Check for type mismatch
226+
if type(a) is not type(b):
227+
return False
228+
229+
# Numpy array checks
230+
if isinstance(a, np.ndarray):
231+
if a.shape != b.shape:
232+
return False
233+
return np.allclose(a, b, rtol=0, atol=atol)
234+
235+
# Scalar check
236+
if np.isscalar(a):
237+
return abs(a - b) <= atol
238+
239+
# List/tuple check
240+
if isinstance(a, (list, tuple)):
241+
if len(a) != len(b):
242+
return False
243+
try:
244+
return np.allclose(a, b, rtol=0, atol=atol)
245+
except Exception:
246+
return all(fields_equal(x, y, atol) for x, y in zip(a, b))
247+
248+
# Fallback check
249+
return a == b
250+
220251

221252
def apertures_are_identical(aper1, aper2, line):
222253

@@ -231,9 +262,7 @@ def apertures_are_identical(aper1, aper2, line):
231262

232263
identical = True
233264
for ff in aper1._fields:
234-
tt = np.allclose(getattr(aper1, ff), getattr(aper2, ff),
235-
rtol=0, atol=1e-15)
236-
if not tt:
265+
if not fields_equal(getattr(aper1, ff), getattr(aper2, ff)):
237266
identical = False
238267
break
239268
return identical

0 commit comments

Comments
 (0)