Skip to content

Commit d96a7bf

Browse files
committed
Linting
1 parent c4e21a7 commit d96a7bf

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

src/adam_assist/propagator.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import hashlib
22
import random
33
from ctypes import c_uint32
4-
from typing import Any, Dict, List, Tuple, Union, Optional
4+
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import assist
77
import numpy as np
@@ -143,7 +143,6 @@ def __init__(
143143
self.adaptive_mode = adaptive_mode
144144
self.epsilon = epsilon
145145

146-
147146
def __getstate__(self) -> Dict[str, Any]:
148147
state = self.__dict__.copy()
149148
state.pop("_last_simulation", None)
@@ -159,7 +158,7 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
159158
# OPTIMIZATION: Fast path for single orbits
160159
if len(orbits) == 1:
161160
return self._propagate_single_orbit_optimized(orbits, times)
162-
161+
163162
# The coordinate frame is the equatorial International Celestial Reference Frame (ICRF).
164163
# This is also the native coordinate system for the JPL binary files.
165164
# For units we use solar masses, astronomical units, and days.
@@ -193,37 +192,43 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
193192

194193
return results
195194

196-
def _propagate_single_orbit_optimized(self, orbit: OrbitType, times: TimestampType) -> OrbitType:
195+
def _propagate_single_orbit_optimized(
196+
self, orbit: OrbitType, times: TimestampType
197+
) -> OrbitType:
197198
"""
198199
Optimized propagation for a single orbit, bypassing grouping overhead.
199200
"""
200201
# Validate assumption
201202
if len(orbit) != 1:
202203
raise ValueError(f"Expected exactly 1 orbit, got {len(orbit)}")
203-
204+
204205
# Transform coordinates directly without grouping
205206
transformed_coords = transform_coordinates(
206207
orbit.coordinates,
207208
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
208209
frame_out="equatorial",
209210
)
210211
transformed_input_orbit_times = transformed_coords.time.rescale("tdb")
211-
transformed_coords = transformed_coords.set_column("time", transformed_input_orbit_times)
212+
transformed_coords = transformed_coords.set_column(
213+
"time", transformed_input_orbit_times
214+
)
212215
transformed_orbit = orbit.set_column("coordinates", transformed_coords)
213-
216+
214217
return self._propagate_single_orbit_inner_optimized(transformed_orbit, times)
215218

216-
def _propagate_single_orbit_inner_optimized(self, orbit: OrbitType, times: TimestampType) -> OrbitType:
219+
def _propagate_single_orbit_inner_optimized(
220+
self, orbit: OrbitType, times: TimestampType
221+
) -> OrbitType:
217222
"""
218223
Inner propagation optimized for exactly one orbit.
219224
"""
220225
# Setup ephemeris and simulation
221226
ephem = assist.Ephem(planets_path=de440, asteroids_path=de441_n16)
222227
sim = rebound.Simulation()
223-
228+
224229
start_tdb_time = orbit.coordinates.time.jd().to_numpy()[0]
225230
sim.t = start_tdb_time - ephem.jd_ref
226-
231+
227232
# Handle particle ID creation (optimized for single orbit)
228233
is_variant = isinstance(orbit, VariantOrbits)
229234
if is_variant:
@@ -233,60 +238,60 @@ def _propagate_single_orbit_inner_optimized(self, orbit: OrbitType, times: Times
233238
else:
234239
orbit_id = str(orbit.orbit_id.to_numpy(zero_copy_only=False)[0])
235240
particle_hash = uint32_hash(orbit_id)
236-
241+
237242
assist.Extras(sim, ephem)
238-
243+
239244
# Add single particle
240245
coords = orbit.coordinates
241246
position_arrays = coords.r
242247
velocity_arrays = coords.v
243-
248+
244249
sim.add(
245250
x=position_arrays[0, 0],
246-
y=position_arrays[0, 1],
251+
y=position_arrays[0, 1],
247252
z=position_arrays[0, 2],
248253
vx=velocity_arrays[0, 0],
249254
vy=velocity_arrays[0, 1],
250255
vz=velocity_arrays[0, 2],
251256
hash=particle_hash,
252257
)
253-
258+
254259
# Set integrator parameters
255260
sim.dt = self.initial_dt
256261
sim.ri_ias15.min_dt = self.min_dt
257262
sim.ri_ias15.adaptive_mode = self.adaptive_mode
258263
sim.ri_ias15.epsilon = self.epsilon
259-
264+
260265
# Prepare integration times (numpy only)
261266
integrator_times = times.rescale("tdb").jd().to_numpy()
262267
integrator_times = integrator_times - ephem.jd_ref
263-
268+
264269
# Integration loop (preallocate state array)
265270
N = len(integrator_times)
266271
if N == 0:
267272
return VariantOrbits.empty() if is_variant else Orbits.empty()
268-
273+
269274
xyzvxvyvz = np.zeros((N, 6), dtype="float64")
270275
scratch = np.zeros((1, 6), dtype="float64")
271-
276+
272277
for i in range(N):
273278
sim.integrate(integrator_times[i])
274279
scratch.fill(0.0)
275280
sim.serialize_particle_data(xyzvxvyvz=scratch)
276281
xyzvxvyvz[i, :] = scratch[0, :]
277-
282+
278283
# Build results
279284
jd_times = integrator_times + ephem.jd_ref
280285
times_out = Timestamp.from_jd(jd_times, scale="tdb")
281286
origin_codes = Origin.from_kwargs(
282287
code=pa.repeat("SOLAR_SYSTEM_BARYCENTER", xyzvxvyvz.shape[0])
283288
)
284-
289+
285290
if is_variant:
286291
orbit_ids_out = [orbit_id] * N
287292
variant_ids_out = [variant_id] * N
288293
object_id_out = np.tile(orbit.object_id.to_numpy(zero_copy_only=False), N)
289-
294+
290295
return VariantOrbits.from_kwargs(
291296
orbit_id=orbit_ids_out,
292297
variant_id=variant_ids_out,
@@ -308,7 +313,7 @@ def _propagate_single_orbit_inner_optimized(self, orbit: OrbitType, times: Times
308313
else:
309314
orbit_ids_out = [orbit_id] * N
310315
object_id_out = np.tile(orbit.object_id.to_numpy(zero_copy_only=False), N)
311-
316+
312317
return Orbits.from_kwargs(
313318
coordinates=CartesianCoordinates.from_kwargs(
314319
x=xyzvxvyvz[:, 0],
@@ -412,11 +417,19 @@ def _propagate_orbits_inner(
412417
step_states.append(step_xyzvxvyvz)
413418

414419
if is_variant:
415-
indices = np.fromiter((hash_to_index[h] for h in orbit_id_hashes), dtype=np.int64, count=sim.N)
420+
indices = np.fromiter(
421+
(hash_to_index[h] for h in orbit_id_hashes),
422+
dtype=np.int64,
423+
count=sim.N,
424+
)
416425
step_orbit_ids.append(orbit_ids[indices])
417426
step_variant_ids.append(variant_ids[indices])
418427
else:
419-
indices = np.fromiter((hash_to_index[h] for h in orbit_id_hashes), dtype=np.int64, count=sim.N)
428+
indices = np.fromiter(
429+
(hash_to_index[h] for h in orbit_id_hashes),
430+
dtype=np.int64,
431+
count=sim.N,
432+
)
420433
step_orbit_ids.append(particle_ids[indices])
421434

422435
# Build a single result table

0 commit comments

Comments
 (0)