11import hashlib
22import random
33from ctypes import c_uint32
4- from typing import Any , Dict , List , Tuple , Union , Optional
4+ from typing import Any , Dict , List , Optional , Tuple , Union
55
66import assist
77import numpy as np
@@ -143,7 +143,6 @@ def __init__(
143143 self .adaptive_mode = adaptive_mode
144144 self .epsilon = epsilon
145145
146-
147146 def __getstate__ (self ) -> Dict [str , Any ]:
148147 state = self .__dict__ .copy ()
149148 state .pop ("_last_simulation" , None )
@@ -159,7 +158,7 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
159158 # OPTIMIZATION: Fast path for single orbits
160159 if len (orbits ) == 1 :
161160 return self ._propagate_single_orbit_optimized (orbits , times )
162-
161+
163162 # The coordinate frame is the equatorial International Celestial Reference Frame (ICRF).
164163 # This is also the native coordinate system for the JPL binary files.
165164 # For units we use solar masses, astronomical units, and days.
@@ -193,37 +192,43 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
193192
194193 return results
195194
196- def _propagate_single_orbit_optimized (self , orbit : OrbitType , times : TimestampType ) -> OrbitType :
195+ def _propagate_single_orbit_optimized (
196+ self , orbit : OrbitType , times : TimestampType
197+ ) -> OrbitType :
197198 """
198199 Optimized propagation for a single orbit, bypassing grouping overhead.
199200 """
200201 # Validate assumption
201202 if len (orbit ) != 1 :
202203 raise ValueError (f"Expected exactly 1 orbit, got { len (orbit )} " )
203-
204+
204205 # Transform coordinates directly without grouping
205206 transformed_coords = transform_coordinates (
206207 orbit .coordinates ,
207208 origin_out = OriginCodes .SOLAR_SYSTEM_BARYCENTER ,
208209 frame_out = "equatorial" ,
209210 )
210211 transformed_input_orbit_times = transformed_coords .time .rescale ("tdb" )
211- transformed_coords = transformed_coords .set_column ("time" , transformed_input_orbit_times )
212+ transformed_coords = transformed_coords .set_column (
213+ "time" , transformed_input_orbit_times
214+ )
212215 transformed_orbit = orbit .set_column ("coordinates" , transformed_coords )
213-
216+
214217 return self ._propagate_single_orbit_inner_optimized (transformed_orbit , times )
215218
216- def _propagate_single_orbit_inner_optimized (self , orbit : OrbitType , times : TimestampType ) -> OrbitType :
219+ def _propagate_single_orbit_inner_optimized (
220+ self , orbit : OrbitType , times : TimestampType
221+ ) -> OrbitType :
217222 """
218223 Inner propagation optimized for exactly one orbit.
219224 """
220225 # Setup ephemeris and simulation
221226 ephem = assist .Ephem (planets_path = de440 , asteroids_path = de441_n16 )
222227 sim = rebound .Simulation ()
223-
228+
224229 start_tdb_time = orbit .coordinates .time .jd ().to_numpy ()[0 ]
225230 sim .t = start_tdb_time - ephem .jd_ref
226-
231+
227232 # Handle particle ID creation (optimized for single orbit)
228233 is_variant = isinstance (orbit , VariantOrbits )
229234 if is_variant :
@@ -233,60 +238,60 @@ def _propagate_single_orbit_inner_optimized(self, orbit: OrbitType, times: Times
233238 else :
234239 orbit_id = str (orbit .orbit_id .to_numpy (zero_copy_only = False )[0 ])
235240 particle_hash = uint32_hash (orbit_id )
236-
241+
237242 assist .Extras (sim , ephem )
238-
243+
239244 # Add single particle
240245 coords = orbit .coordinates
241246 position_arrays = coords .r
242247 velocity_arrays = coords .v
243-
248+
244249 sim .add (
245250 x = position_arrays [0 , 0 ],
246- y = position_arrays [0 , 1 ],
251+ y = position_arrays [0 , 1 ],
247252 z = position_arrays [0 , 2 ],
248253 vx = velocity_arrays [0 , 0 ],
249254 vy = velocity_arrays [0 , 1 ],
250255 vz = velocity_arrays [0 , 2 ],
251256 hash = particle_hash ,
252257 )
253-
258+
254259 # Set integrator parameters
255260 sim .dt = self .initial_dt
256261 sim .ri_ias15 .min_dt = self .min_dt
257262 sim .ri_ias15 .adaptive_mode = self .adaptive_mode
258263 sim .ri_ias15 .epsilon = self .epsilon
259-
264+
260265 # Prepare integration times (numpy only)
261266 integrator_times = times .rescale ("tdb" ).jd ().to_numpy ()
262267 integrator_times = integrator_times - ephem .jd_ref
263-
268+
264269 # Integration loop (preallocate state array)
265270 N = len (integrator_times )
266271 if N == 0 :
267272 return VariantOrbits .empty () if is_variant else Orbits .empty ()
268-
273+
269274 xyzvxvyvz = np .zeros ((N , 6 ), dtype = "float64" )
270275 scratch = np .zeros ((1 , 6 ), dtype = "float64" )
271-
276+
272277 for i in range (N ):
273278 sim .integrate (integrator_times [i ])
274279 scratch .fill (0.0 )
275280 sim .serialize_particle_data (xyzvxvyvz = scratch )
276281 xyzvxvyvz [i , :] = scratch [0 , :]
277-
282+
278283 # Build results
279284 jd_times = integrator_times + ephem .jd_ref
280285 times_out = Timestamp .from_jd (jd_times , scale = "tdb" )
281286 origin_codes = Origin .from_kwargs (
282287 code = pa .repeat ("SOLAR_SYSTEM_BARYCENTER" , xyzvxvyvz .shape [0 ])
283288 )
284-
289+
285290 if is_variant :
286291 orbit_ids_out = [orbit_id ] * N
287292 variant_ids_out = [variant_id ] * N
288293 object_id_out = np .tile (orbit .object_id .to_numpy (zero_copy_only = False ), N )
289-
294+
290295 return VariantOrbits .from_kwargs (
291296 orbit_id = orbit_ids_out ,
292297 variant_id = variant_ids_out ,
@@ -308,7 +313,7 @@ def _propagate_single_orbit_inner_optimized(self, orbit: OrbitType, times: Times
308313 else :
309314 orbit_ids_out = [orbit_id ] * N
310315 object_id_out = np .tile (orbit .object_id .to_numpy (zero_copy_only = False ), N )
311-
316+
312317 return Orbits .from_kwargs (
313318 coordinates = CartesianCoordinates .from_kwargs (
314319 x = xyzvxvyvz [:, 0 ],
@@ -412,11 +417,19 @@ def _propagate_orbits_inner(
412417 step_states .append (step_xyzvxvyvz )
413418
414419 if is_variant :
415- indices = np .fromiter ((hash_to_index [h ] for h in orbit_id_hashes ), dtype = np .int64 , count = sim .N )
420+ indices = np .fromiter (
421+ (hash_to_index [h ] for h in orbit_id_hashes ),
422+ dtype = np .int64 ,
423+ count = sim .N ,
424+ )
416425 step_orbit_ids .append (orbit_ids [indices ])
417426 step_variant_ids .append (variant_ids [indices ])
418427 else :
419- indices = np .fromiter ((hash_to_index [h ] for h in orbit_id_hashes ), dtype = np .int64 , count = sim .N )
428+ indices = np .fromiter (
429+ (hash_to_index [h ] for h in orbit_id_hashes ),
430+ dtype = np .int64 ,
431+ count = sim .N ,
432+ )
420433 step_orbit_ids .append (particle_ids [indices ])
421434
422435 # Build a single result table
0 commit comments