Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions xtrack/loss_location_refinement/loss_location_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,37 @@ def check_for_active_shifts_and_rotations(line, i_aper_0, i_aper_1):
break
return presence_shifts_rotations

def fields_equal(a, b, atol=1e-15):
# Check if exactly the same object
if a is b:
return True

# Check for type mismatch
if type(a) is not type(b):
return False

# Numpy array checks
if isinstance(a, np.ndarray):
if a.shape != b.shape:
return False
return np.allclose(a, b, rtol=0, atol=atol)

# Scalar check
if np.isscalar(a):
return abs(a - b) <= atol

# List/tuple check
if isinstance(a, (list, tuple)):
if len(a) != len(b):
return False
try:
return np.allclose(a, b, rtol=0, atol=atol)
except Exception:
return all(fields_equal(x, y, atol) for x, y in zip(a, b))

# Fallback check
return a == b


def apertures_are_identical(aper1, aper2, line):

Expand All @@ -231,9 +262,7 @@ def apertures_are_identical(aper1, aper2, line):

identical = True
for ff in aper1._fields:
tt = np.allclose(getattr(aper1, ff), getattr(aper2, ff),
rtol=0, atol=1e-15)
if not tt:
if not fields_equal(getattr(aper1, ff), getattr(aper2, ff)):
identical = False
break
return identical
Expand Down