@@ -21,6 +21,7 @@ def _drift(coll, particles, length):
2121def track (coll , particles ):
2222 import xcoll as xc
2323
24+
2425 # Initialize ionisation loss accumulation variable
2526 if coll ._acc_ionisation_loss < 0 :
2627 coll ._acc_ionisation_loss = 0.
@@ -52,30 +53,214 @@ def track(coll, particles):
5253 + "with a value large enough to accommodate secondaries outside of "
5354 + "FLUKA.\n In any case, please stop and restart the FlukaEngine now." )
5455
55- _drift (coll , particles , - coll .length_front )
56+ import time
57+
58+ import xpart as xp
59+
60+ start = time .time ()
61+
62+ coll_everest = xc .EverestCollimator (length = coll .length , material = xc .materials .MolybdenumGraphite )
63+ coll_everest .jaw = coll .jaw
64+
65+ part_everest = particles .copy ()
66+ coll_everest .track (part_everest )
67+
68+ mask_hit_base = (
69+ (abs (particles .px - part_everest .px ) > 1e-12 ) |
70+ (abs (particles .py - part_everest .py ) > 1e-12 ) |
71+ (particles .pdg_id == - 999999999 )
72+ )
73+
74+ end = time .time ()
75+ print (f"Tracking Everest collimator took { end - start :.4f} seconds" )
76+
77+
78+ start = time .time ()
79+ mask_lost = (particles .state != 1 ) & (particles .pdg_id != - 999999999 )
80+
81+ # Final hit mask includes both
82+ mask_hit = mask_hit_base # | mask_lost
83+
84+
85+ # mask_hit = ((abs(particles.px - part_everest.px) > 1e-12) | (abs(particles.py - part_everest.py) >1e-12) | (particles.pdg_id == -999999999))
86+ #
87+ # mask_lost = ((particles.state != 1) & (particles.pdg_id != -999999999))
88+ mask_miss_no_lost = ~ ((abs (particles .px - part_everest .px ) > 1e-12 ) | (abs (particles .py - part_everest .py ) > 1e-12 )) & (particles .pdg_id != - 999999999 )
89+ mask_miss = mask_miss_no_lost | mask_lost
90+
91+
92+ hit_id = particles .particle_id [mask_hit ]
93+ miss_id = particles .particle_id [mask_miss ]
94+ is_hit = np .isin (particles .particle_id , hit_id )
95+ is_miss = np .isin (particles .particle_id , miss_id )
96+
97+ if mask_lost .any ():
98+ particles_lost = particles .filter (mask_lost )
99+ else :
100+ particles_lost = None
101+
102+ particles_hit = 0
103+ particles_miss = 0
104+ particles_hit = particles .filter (is_hit )
105+ end = time .time ()
106+ print (f"Tracking Everest collimator took { end - start :.4f} seconds" )
107+
108+ print (f"Tracking { particles ._num_active_particles } particles (FLUKA)... " , end = '' )
109+ print (f"Hit: { particles_hit ._num_active_particles } " )
110+ print (f"Collimator: { coll .name } " )
111+
112+
113+ # if ((particles_hit._num_active_particles == 0) & (particles_miss == 0)):
114+ # # track_core(coll, particles)
115+ # _drift(coll, particles, coll.length)
116+ # return
117+
118+
119+ # check if ~is_hit has any true value
120+ if np .any (~ is_hit ):
121+ particles_miss = particles .filter (~ is_hit )
122+ elif (particles_hit ._num_active_particles == 0 ) and (particles ._num_active_particles > 0 ):
123+ # import pdb; pdb.set_trace()
124+ print ("Could not find any particles that hit the collimator, but some particles are still active. " )
125+ print ("This is likely due to a bug in the collimator tracking code. Please report this issue." )
126+
127+ # particles_miss = particles.filter(~is_hit) if np.any(is_miss) else 0
128+ # particles_miss = particles.filter(is_miss) if np.any(is_miss) else 0
129+ # particles_miss = particles_miss.filter(particles_miss.state == 1) if np.any(is_miss) else 0
130+
131+ # missing_capacity = 0
132+ # if particles_miss:
133+ # miss_part = particles_miss._num_active_particles
134+ # more_capacity = particles.filter([False] * (particles._capacity - miss_part) + [True] * miss_part)
135+ start = time .time ()
136+
137+
138+ if particles_miss :
139+ _drift (coll , particles_miss , coll .length )
140+ if particles_hit ._num_active_particles > 0 :
141+ _drift (coll , particles_hit , - coll .length_front )
56142 if coll .co is not None : # FLUKA collimators are centered; need to shift
57143 dx = coll .co [1 ][0 ]
58144 dy = coll .co [1 ][1 ]
59- particles .x -= dx
60- particles .y -= dy
61- track_core (coll , particles )
62- if coll .co is not None :
63- particles .x += dx
64- particles .y += dy
65- _drift (coll , particles , - coll .length_back )
66-
145+ particles_hit .x -= dx
146+ particles_hit .y -= dy
147+
148+ if particles_hit ._num_active_particles > 0 :
149+ track_core (coll , particles_hit )
150+ else :
151+ #_drift(coll, particles_hit, coll.length_front+coll.length_back+coll.length)
152+ _drift (coll , particles , coll .length )
153+ end = time .time ()
154+ print (f"No particles hit the collimator, skipping FLUKA tracking." )
155+ print (f"Time taken: { end - start :.4f} seconds" )
156+ return
67157
68- def _expand (arr , dtype = float ):
158+ if coll .co is not None :
159+ particles_hit .x += dx
160+ particles_hit .y += dy
161+
162+ _drift (coll , particles_hit , - coll .length_back )
163+
164+ end = time .time ()
165+ print (f"Tracking FLUKA collimator took { end - start :.4f} seconds" )
166+ print ("" )
167+
168+
169+ # new_particles=particles_hit
170+ # _drift(coll, particles, -coll.length_front)
171+ # if coll.co is not None: # FLUKA collimators are centered; need to shift
172+ # dx = coll.co[1][0]
173+ # dy = coll.co[1][1]
174+ # particles.x -= dx
175+ # particles.y -= dy
176+ # track_core(coll, particles)
177+ # if coll.co is not None:
178+ # particles.x += dx
179+ # particles.y += dy
180+ # _drift(coll, particles, -coll.length_back)
181+
182+
183+ particles_to_merge = []
184+
185+ # particles_to_merge.append(particles_hit)
186+ # if particles_miss:
187+ # particles_to_merge.append(particles_miss)
188+ # new_particles._capacity = particles._capacity
189+
190+ # new_particles.add_particles(particles_hit)
191+ # if particles_lost:
192+ # new_particles.add_particles(particles_lost)
193+ # if particles_miss:
194+ # particles_hit.add_particles(particles_miss)
195+ # # particles_lost.add_particles(particles_miss)
196+ # # particles_lost.add_particles(particles_hit)
197+ # particles_merged = xp.Particles.merge([particles_hit, more_capacity])
198+ # else:
199+ # # particles_lost.add_particles(particles_hit)
200+ # particles_merged = particles_hit
201+
202+
203+ start = time .time ()
204+
205+ if particles_miss :
206+ total_new = particles_hit ._capacity + particles_miss ._capacity
207+ else :
208+ total_new = particles_hit ._capacity
209+
210+ # for field in particles_merged._fields:
211+ # if field == "_capacity": continue
212+ # val_new = getattr(particles_merged, field)
213+ # setattr(particles, field, val_new.copy())
214+
215+ for field in particles ._fields :
216+ # check if field is array, if not continue
217+ if field == "_capacity" : continue
218+ if field == "_num_active_particles" :
219+ setattr (particles , field , particles_hit ._num_active_particles + particles_miss ._num_active_particles if particles_miss else particles_hit ._num_active_particles )
220+
221+ if not isinstance (getattr (particles , field ), np .ndarray ):
222+ continue
223+
224+ if hasattr (particles_hit , field ):
225+ f_hit = getattr (particles_hit , field )
226+ if particles_miss :
227+ f_miss = getattr (particles_miss , field )
228+ combined = np .concatenate ([f_hit , f_miss ])
229+ else :
230+ combined = f_hit
231+ setattr (particles , field , combined )
232+ # getattr(particles, field)[:total_new] = combined
233+
234+ end = time .time ()
235+
236+
237+ particles .reorganize ()
238+ xc .fluka .engine ._max_particle_id = particles .particle_id .max ()
239+ print (f"Copying fields took { end - start :.4f} seconds" )
240+
241+
242+ # def _expand(arr, dtype=float):
243+ # import xcoll as xc
244+ # max_part = xc.fluka.engine.capacity
245+ # return np.concatenate((arr, np.zeros(max_part-arr.size, dtype=dtype)))
246+
247+ def _expand (arr , dtype = np .float64 ):
69248 import xcoll as xc
70249 max_part = xc .fluka .engine .capacity
71- return np .concatenate ((arr , np .zeros (max_part - arr .size , dtype = dtype )))
250+ arr = np .asarray (arr , dtype = dtype , order = 'F' ) # Fortran order
251+
252+ if arr .size == max_part :
253+ return arr
254+ expanded = np .zeros (max_part , dtype = dtype , order = 'F' )
255+ expanded [:arr .size ] = arr
256+ return expanded
72257
73258
74259def track_core (coll , part ):
75260 import xcoll as xc
76261 npart = part ._num_active_particles
77262 try :
78- from pyflukaf import track_fluka
263+ from pyflukaf import track_fluka , track_fluka_batch
79264 except (ModuleNotFoundError , ImportError ) as error :
80265 xc .fluka .engine ._warn_pyfluka (error )
81266 return
@@ -136,7 +321,41 @@ def track_core(coll, part):
136321 turn_in = part .at_turn [0 ]
137322 start = part .start_tracking_at_element # TODO: is this needed?
138323
324+ # def combine_data_to_2d(data):
325+ # keys = ['x', 'xp', 'y', 'yp', 'zeta', 'e', 'm', 'q', 'A', 'Z', 'pdg_id',
326+ # 'pid', 'ppid', 'weight', 'spin_x', 'spin_y', 'spin_z']
327+
328+ # arrays = []
329+ # for k in keys:
330+ # arr = data[k]
331+ # # Ensure integer arrays are cast to float64 (fortran compatibility)
332+ # if np.issubdtype(arr.dtype, np.integer):
333+ # arr = arr.astype(np.float64)
334+ # arrays.append(arr)
335+
336+ # combined = np.vstack(arrays).T # shape (npart, nfields)
337+ # # Make sure it's Fortran contiguous for f2py
338+ # combined = np.asfortranarray(combined)
339+ # return combined
340+
341+ # combined = combine_data_to_2d(data)
342+
343+ # import time
344+ # start = time.time()
345+ # track_fluka_batch(turn=turn_in+1,
346+ # fluka_id=coll.fluka_id,
347+ # length=coll.length + coll.length_front + coll.length_back,
348+ # alive_part=npart,
349+ # max_part=max_part,
350+ # particle_data=combined)
351+ # end = time.time()
352+ # print(f"2D combined version: {end - start:.4f} seconds")
353+
139354 # send to fluka
355+ import time
356+ # for key in ['x', 'xp', 'y', 'yp', 'zeta', 'e', 'm', 'q', 'A', 'Z', 'pdg_id', 'pid', 'ppid', 'weight', 'spin_x', 'spin_y', 'spin_z']:
357+ # data[key] = np.asfortranarray(data[key])
358+
140359 track_fluka (turn = turn_in + 1 , # Turn indexing start from 1 with FLUKA IO (start from 0 with xpart)
141360 fluka_id = coll .fluka_id ,
142361 length = coll .length + coll .length_front + coll .length_back ,
@@ -161,6 +380,8 @@ def track_core(coll, part):
161380 spin_z_part = data ['spin_z' ]
162381 )
163382
383+
384+
164385 # Careful with all the masking!
165386 # Double-mask assignment does not work, e.g. part.state[mask1][mask2] = 1 will do nothing...
166387
@@ -187,10 +408,50 @@ def track_core(coll, part):
187408 # ===============================================================================================
188409 mask_existing = new_pid <= max_id
189410
411+ def original_lookup (part_particle_id , new_pid , mask_existing ):
412+ idx_old = np .array ([np .where (part_particle_id == pid )[0 ][0 ]
413+ for pid in new_pid [mask_existing ]])
414+ return idx_old
415+
416+ def safe_lookup (part_particle_id , new_pid , mask_existing ):
417+
418+ idx_old = []
419+ target_pids = new_pid [mask_existing ]
420+
421+ for pid in target_pids :
422+ match = np .where (part_particle_id == pid )[0 ]
423+ if match .size > 0 :
424+ idx_old .append (match [0 ])
425+ # Optional: log or warn if a match is not found
426+ # else:
427+ # print(f"[warning] pid {pid} not found in part_particle_id")
428+
429+ return np .array (idx_old , dtype = int )
430+
431+
432+ def dict_lookup (part_particle_id , new_pid , mask_existing ):
433+ id_to_index = {pid : i for i , pid in enumerate (part_particle_id )}
434+ idx_old = np .array ([id_to_index [pid ] for pid in new_pid [mask_existing ]])
435+ return idx_old
436+
437+ def searchsorted_lookup (part_particle_id , new_pid , mask_existing ):
438+ sorted_ids = np .sort (part_particle_id )
439+ idx_old = np .searchsorted (sorted_ids , new_pid [mask_existing ])
440+ return idx_old
441+
190442 if np .any (mask_existing ):
191443 # TODO: this is slooooow
192- idx_old = np .array ([np .where (part .particle_id [alive_at_entry ]== idx )[0 ][0 ]
193- for idx in new_pid [mask_existing ]]) # list of indices
444+ import time
445+ # for fn in [original_lookup, dict_lookup, searchsorted_lookup]:
446+ # start = time.time()
447+ # idxs = fn(part.particle_id, new_pid, mask_existing)
448+ # end = time.time()
449+ # print(f"{fn.__name__} took {end - start:.4f} seconds, found {len(idxs)} indices")
450+
451+ # idx_old = np.array([np.where(part.particle_id[alive_at_entry]==idx)[0][0]
452+ # for idx in new_pid[mask_existing]]) # list of indices
453+ idx_old = original_lookup (part .particle_id [alive_at_entry ], new_pid , mask_existing )
454+ # idx_old = safe_lookup(part.particle_id[alive_at_entry], new_pid, mask_existing)
194455
195456 # Sanity check
196457 assert np .all (part .particle_id [idx_old ] == new_pid [mask_existing ])
@@ -201,7 +462,12 @@ def track_core(coll, part):
201462 E_diff = np .zeros (len (part .x ))
202463 E_diff [idx_old ] = part .energy [idx_old ] - data ['e' ][:npart ][mask_existing ] * 1.e6
203464 if np .any (E_diff < - precision ):
204- raise ValueError (f"FLUKA returned particle with energy higher than incoming particle!" )
465+ # import pdb; pdb.set_trace()
466+ # raise ValueError(f"FLUKA returned particle with energy higher than incoming particle!")
467+ # just a warning for now
468+ print (f"[warning] FLUKA returned particle with energy higher than incoming particle! "
469+ + f"Energy difference: { E_diff [idx_old ][E_diff [idx_old ] < - precision ]} " )
470+
205471 E_diff [E_diff < precision ] = 0. # Lower cut on energy loss
206472 part .add_to_energy (- E_diff )
207473 part .weight [idx_old ] = data ['weight' ][:npart ][mask_existing ]
0 commit comments