Skip to content

Commit 7b97ea8

Browse files
committed
to clean up, version with hit miss particles
1 parent 92f2d1c commit 7b97ea8

File tree

1 file changed

+281
-15
lines changed

1 file changed

+281
-15
lines changed

xcoll/scattering_routines/fluka/track.py

Lines changed: 281 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def _drift(coll, particles, length):
2121
def track(coll, particles):
2222
import xcoll as xc
2323

24+
2425
# Initialize ionisation loss accumulation variable
2526
if coll._acc_ionisation_loss < 0:
2627
coll._acc_ionisation_loss = 0.
@@ -52,30 +53,214 @@ def track(coll, particles):
5253
+ "with a value large enough to accommodate secondaries outside of "
5354
+ "FLUKA.\nIn any case, please stop and restart the FlukaEngine now.")
5455

55-
_drift(coll, particles, -coll.length_front)
56+
import time
57+
58+
import xpart as xp
59+
60+
start = time.time()
61+
62+
coll_everest = xc.EverestCollimator(length=coll.length, material=xc.materials.MolybdenumGraphite)
63+
coll_everest.jaw = coll.jaw
64+
65+
part_everest = particles.copy()
66+
coll_everest.track(part_everest)
67+
68+
mask_hit_base = (
69+
(abs(particles.px - part_everest.px) > 1e-12) |
70+
(abs(particles.py - part_everest.py) > 1e-12) |
71+
(particles.pdg_id == -999999999)
72+
)
73+
74+
end = time.time()
75+
print(f"Tracking Everest collimator took {end - start:.4f} seconds")
76+
77+
78+
start = time.time()
79+
mask_lost = (particles.state != 1) & (particles.pdg_id != -999999999)
80+
81+
# Final hit mask includes both
82+
mask_hit = mask_hit_base # | mask_lost
83+
84+
85+
# mask_hit = ((abs(particles.px - part_everest.px) > 1e-12) | (abs(particles.py - part_everest.py) >1e-12) | (particles.pdg_id == -999999999))
86+
#
87+
# mask_lost = ((particles.state != 1) & (particles.pdg_id != -999999999))
88+
mask_miss_no_lost = ~((abs(particles.px - part_everest.px) > 1e-12) | (abs(particles.py - part_everest.py) > 1e-12)) & (particles.pdg_id != -999999999)
89+
mask_miss = mask_miss_no_lost | mask_lost
90+
91+
92+
hit_id = particles.particle_id[mask_hit]
93+
miss_id = particles.particle_id[mask_miss]
94+
is_hit = np.isin(particles.particle_id, hit_id)
95+
is_miss = np.isin(particles.particle_id, miss_id)
96+
97+
if mask_lost.any():
98+
particles_lost = particles.filter(mask_lost)
99+
else:
100+
particles_lost = None
101+
102+
particles_hit = 0
103+
particles_miss = 0
104+
particles_hit = particles.filter(is_hit)
105+
end = time.time()
106+
print(f"Tracking Everest collimator took {end - start:.4f} seconds")
107+
108+
print(f"Tracking {particles._num_active_particles} particles (FLUKA)... ", end='')
109+
print(f"Hit: {particles_hit._num_active_particles}")
110+
print(f"Collimator: {coll.name}")
111+
112+
113+
# if ((particles_hit._num_active_particles == 0) & (particles_miss == 0)):
114+
# # track_core(coll, particles)
115+
# _drift(coll, particles, coll.length)
116+
# return
117+
118+
119+
# check if ~is_hit has any true value
120+
if np.any(~is_hit):
121+
particles_miss = particles.filter(~is_hit)
122+
elif (particles_hit._num_active_particles == 0) and (particles._num_active_particles > 0):
123+
# import pdb; pdb.set_trace()
124+
print("Could not find any particles that hit the collimator, but some particles are still active. ")
125+
print("This is likely due to a bug in the collimator tracking code. Please report this issue.")
126+
127+
# particles_miss = particles.filter(~is_hit) if np.any(is_miss) else 0
128+
# particles_miss = particles.filter(is_miss) if np.any(is_miss) else 0
129+
# particles_miss = particles_miss.filter(particles_miss.state == 1) if np.any(is_miss) else 0
130+
131+
# missing_capacity = 0
132+
# if particles_miss:
133+
# miss_part = particles_miss._num_active_particles
134+
# more_capacity = particles.filter([False] * (particles._capacity - miss_part) + [True] * miss_part)
135+
start = time.time()
136+
137+
138+
if particles_miss:
139+
_drift(coll, particles_miss, coll.length)
140+
if particles_hit._num_active_particles > 0:
141+
_drift(coll, particles_hit, -coll.length_front)
56142
if coll.co is not None: # FLUKA collimators are centered; need to shift
57143
dx = coll.co[1][0]
58144
dy = coll.co[1][1]
59-
particles.x -= dx
60-
particles.y -= dy
61-
track_core(coll, particles)
62-
if coll.co is not None:
63-
particles.x += dx
64-
particles.y += dy
65-
_drift(coll, particles, -coll.length_back)
66-
145+
particles_hit.x -= dx
146+
particles_hit.y -= dy
147+
148+
if particles_hit._num_active_particles > 0:
149+
track_core(coll, particles_hit)
150+
else:
151+
#_drift(coll, particles_hit, coll.length_front+coll.length_back+coll.length)
152+
_drift(coll, particles, coll.length)
153+
end = time.time()
154+
print(f"No particles hit the collimator, skipping FLUKA tracking.")
155+
print(f"Time taken: {end - start:.4f} seconds")
156+
return
67157

68-
def _expand(arr, dtype=float):
158+
if coll.co is not None:
159+
particles_hit.x += dx
160+
particles_hit.y += dy
161+
162+
_drift(coll, particles_hit, -coll.length_back)
163+
164+
end = time.time()
165+
print(f"Tracking FLUKA collimator took {end - start:.4f} seconds")
166+
print("")
167+
168+
169+
# new_particles=particles_hit
170+
# _drift(coll, particles, -coll.length_front)
171+
# if coll.co is not None: # FLUKA collimators are centered; need to shift
172+
# dx = coll.co[1][0]
173+
# dy = coll.co[1][1]
174+
# particles.x -= dx
175+
# particles.y -= dy
176+
# track_core(coll, particles)
177+
# if coll.co is not None:
178+
# particles.x += dx
179+
# particles.y += dy
180+
# _drift(coll, particles, -coll.length_back)
181+
182+
183+
particles_to_merge = []
184+
185+
# particles_to_merge.append(particles_hit)
186+
# if particles_miss:
187+
# particles_to_merge.append(particles_miss)
188+
# new_particles._capacity = particles._capacity
189+
190+
# new_particles.add_particles(particles_hit)
191+
# if particles_lost:
192+
# new_particles.add_particles(particles_lost)
193+
# if particles_miss:
194+
# particles_hit.add_particles(particles_miss)
195+
# # particles_lost.add_particles(particles_miss)
196+
# # particles_lost.add_particles(particles_hit)
197+
# particles_merged = xp.Particles.merge([particles_hit, more_capacity])
198+
# else:
199+
# # particles_lost.add_particles(particles_hit)
200+
# particles_merged = particles_hit
201+
202+
203+
start = time.time()
204+
205+
if particles_miss:
206+
total_new = particles_hit._capacity + particles_miss._capacity
207+
else:
208+
total_new = particles_hit._capacity
209+
210+
# for field in particles_merged._fields:
211+
# if field == "_capacity": continue
212+
# val_new = getattr(particles_merged, field)
213+
# setattr(particles, field, val_new.copy())
214+
215+
for field in particles._fields:
216+
# check if field is array, if not continue
217+
if field == "_capacity": continue
218+
if field == "_num_active_particles":
219+
setattr(particles, field, particles_hit._num_active_particles + particles_miss._num_active_particles if particles_miss else particles_hit._num_active_particles)
220+
221+
if not isinstance(getattr(particles, field), np.ndarray):
222+
continue
223+
224+
if hasattr(particles_hit, field):
225+
f_hit = getattr(particles_hit, field)
226+
if particles_miss:
227+
f_miss = getattr(particles_miss, field)
228+
combined = np.concatenate([f_hit, f_miss])
229+
else:
230+
combined = f_hit
231+
setattr(particles, field, combined)
232+
# getattr(particles, field)[:total_new] = combined
233+
234+
end = time.time()
235+
236+
237+
particles.reorganize()
238+
xc.fluka.engine._max_particle_id = particles.particle_id.max()
239+
print(f"Copying fields took {end - start:.4f} seconds")
240+
241+
242+
# def _expand(arr, dtype=float):
243+
# import xcoll as xc
244+
# max_part = xc.fluka.engine.capacity
245+
# return np.concatenate((arr, np.zeros(max_part-arr.size, dtype=dtype)))
246+
247+
def _expand(arr, dtype=np.float64):
69248
import xcoll as xc
70249
max_part = xc.fluka.engine.capacity
71-
return np.concatenate((arr, np.zeros(max_part-arr.size, dtype=dtype)))
250+
arr = np.asarray(arr, dtype=dtype, order='F') # Fortran order
251+
252+
if arr.size == max_part:
253+
return arr
254+
expanded = np.zeros(max_part, dtype=dtype, order='F')
255+
expanded[:arr.size] = arr
256+
return expanded
72257

73258

74259
def track_core(coll, part):
75260
import xcoll as xc
76261
npart = part._num_active_particles
77262
try:
78-
from pyflukaf import track_fluka
263+
from pyflukaf import track_fluka, track_fluka_batch
79264
except (ModuleNotFoundError, ImportError) as error:
80265
xc.fluka.engine._warn_pyfluka(error)
81266
return
@@ -136,7 +321,41 @@ def track_core(coll, part):
136321
turn_in = part.at_turn[0]
137322
start = part.start_tracking_at_element # TODO: is this needed?
138323

324+
# def combine_data_to_2d(data):
325+
# keys = ['x', 'xp', 'y', 'yp', 'zeta', 'e', 'm', 'q', 'A', 'Z', 'pdg_id',
326+
# 'pid', 'ppid', 'weight', 'spin_x', 'spin_y', 'spin_z']
327+
328+
# arrays = []
329+
# for k in keys:
330+
# arr = data[k]
331+
# # Ensure integer arrays are cast to float64 (fortran compatibility)
332+
# if np.issubdtype(arr.dtype, np.integer):
333+
# arr = arr.astype(np.float64)
334+
# arrays.append(arr)
335+
336+
# combined = np.vstack(arrays).T # shape (npart, nfields)
337+
# # Make sure it's Fortran contiguous for f2py
338+
# combined = np.asfortranarray(combined)
339+
# return combined
340+
341+
# combined = combine_data_to_2d(data)
342+
343+
# import time
344+
# start = time.time()
345+
# track_fluka_batch(turn=turn_in+1,
346+
# fluka_id=coll.fluka_id,
347+
# length=coll.length + coll.length_front + coll.length_back,
348+
# alive_part=npart,
349+
# max_part=max_part,
350+
# particle_data=combined)
351+
# end = time.time()
352+
# print(f"2D combined version: {end - start:.4f} seconds")
353+
139354
# send to fluka
355+
import time
356+
# for key in ['x', 'xp', 'y', 'yp', 'zeta', 'e', 'm', 'q', 'A', 'Z', 'pdg_id', 'pid', 'ppid', 'weight', 'spin_x', 'spin_y', 'spin_z']:
357+
# data[key] = np.asfortranarray(data[key])
358+
140359
track_fluka(turn=turn_in+1, # Turn indexing start from 1 with FLUKA IO (start from 0 with xpart)
141360
fluka_id=coll.fluka_id,
142361
length=coll.length + coll.length_front + coll.length_back,
@@ -161,6 +380,8 @@ def track_core(coll, part):
161380
spin_z_part=data['spin_z']
162381
)
163382

383+
384+
164385
# Careful with all the masking!
165386
# Double-mask assignment does not work, e.g. part.state[mask1][mask2] = 1 will do nothing...
166387

@@ -187,10 +408,50 @@ def track_core(coll, part):
187408
# ===============================================================================================
188409
mask_existing = new_pid <= max_id
189410

411+
def original_lookup(part_particle_id, new_pid, mask_existing):
412+
idx_old = np.array([np.where(part_particle_id == pid)[0][0]
413+
for pid in new_pid[mask_existing]])
414+
return idx_old
415+
416+
def safe_lookup(part_particle_id, new_pid, mask_existing):
417+
418+
idx_old = []
419+
target_pids = new_pid[mask_existing]
420+
421+
for pid in target_pids:
422+
match = np.where(part_particle_id == pid)[0]
423+
if match.size > 0:
424+
idx_old.append(match[0])
425+
# Optional: log or warn if a match is not found
426+
# else:
427+
# print(f"[warning] pid {pid} not found in part_particle_id")
428+
429+
return np.array(idx_old, dtype=int)
430+
431+
432+
def dict_lookup(part_particle_id, new_pid, mask_existing):
433+
id_to_index = {pid: i for i, pid in enumerate(part_particle_id)}
434+
idx_old = np.array([id_to_index[pid] for pid in new_pid[mask_existing]])
435+
return idx_old
436+
437+
def searchsorted_lookup(part_particle_id, new_pid, mask_existing):
438+
sorted_ids = np.sort(part_particle_id)
439+
idx_old = np.searchsorted(sorted_ids, new_pid[mask_existing])
440+
return idx_old
441+
190442
if np.any(mask_existing):
191443
# TODO: this is slooooow
192-
idx_old = np.array([np.where(part.particle_id[alive_at_entry]==idx)[0][0]
193-
for idx in new_pid[mask_existing]]) # list of indices
444+
import time
445+
# for fn in [original_lookup, dict_lookup, searchsorted_lookup]:
446+
# start = time.time()
447+
# idxs = fn(part.particle_id, new_pid, mask_existing)
448+
# end = time.time()
449+
# print(f"{fn.__name__} took {end - start:.4f} seconds, found {len(idxs)} indices")
450+
451+
# idx_old = np.array([np.where(part.particle_id[alive_at_entry]==idx)[0][0]
452+
# for idx in new_pid[mask_existing]]) # list of indices
453+
idx_old = original_lookup(part.particle_id[alive_at_entry], new_pid, mask_existing)
454+
# idx_old = safe_lookup(part.particle_id[alive_at_entry], new_pid, mask_existing)
194455

195456
# Sanity check
196457
assert np.all(part.particle_id[idx_old] == new_pid[mask_existing])
@@ -201,7 +462,12 @@ def track_core(coll, part):
201462
E_diff = np.zeros(len(part.x))
202463
E_diff[idx_old] = part.energy[idx_old] - data['e'][:npart][mask_existing] * 1.e6
203464
if np.any(E_diff < -precision):
204-
raise ValueError(f"FLUKA returned particle with energy higher than incoming particle!")
465+
# import pdb; pdb.set_trace()
466+
# raise ValueError(f"FLUKA returned particle with energy higher than incoming particle!")
467+
# just a warning for now
468+
print(f"[warning] FLUKA returned particle with energy higher than incoming particle! "
469+
+ f"Energy difference: {E_diff[idx_old][E_diff[idx_old] < -precision]}")
470+
205471
E_diff[E_diff < precision] = 0. # Lower cut on energy loss
206472
part.add_to_energy(-E_diff)
207473
part.weight[idx_old] = data['weight'][:npart][mask_existing]

0 commit comments

Comments
 (0)