Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pip-build-lint-test-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:

permissions:
contents: read
pull-requests: read

concurrency:
group: ci-${{ github.ref }}
Expand Down Expand Up @@ -94,6 +95,7 @@ jobs:
needs: [ test ]
permissions:
contents: write
pull-requests: write
defaults:
run:
shell: bash -l {0}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"ray",
"spiceypy>=6.0.0",
"rebound>=4.4.10",
"timezonefinder==8.0.0",
]

[build-system]
Expand Down
202 changes: 177 additions & 25 deletions src/adam_assist/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
"""
Propagate the orbits to the specified times.
"""
# OPTIMIZATION: Fast path for single orbits
if len(orbits) == 1:
return self._propagate_single_orbit_optimized(orbits, times)

# The coordinate frame is the equatorial International Celestial Reference Frame (ICRF).
# This is also the native coordinate system for the JPL binary files.
# For units we use solar masses, astronomical units, and days.
Expand Down Expand Up @@ -188,6 +192,144 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp

return results

def _propagate_single_orbit_optimized(
self, orbit: OrbitType, times: TimestampType
) -> OrbitType:
"""
Optimized propagation for a single orbit, bypassing grouping overhead.
"""
# Validate assumption
if len(orbit) != 1:
raise ValueError(f"Expected exactly 1 orbit, got {len(orbit)}")

# Transform coordinates directly without grouping
transformed_coords = transform_coordinates(
orbit.coordinates,
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
frame_out="equatorial",
)
transformed_input_orbit_times = transformed_coords.time.rescale("tdb")
transformed_coords = transformed_coords.set_column(
"time", transformed_input_orbit_times
)
transformed_orbit = orbit.set_column("coordinates", transformed_coords)

return self._propagate_single_orbit_inner_optimized(transformed_orbit, times)

def _propagate_single_orbit_inner_optimized(
self, orbit: OrbitType, times: TimestampType
) -> OrbitType:
"""
Inner propagation optimized for exactly one orbit.
"""
# Setup ephemeris and simulation
ephem = assist.Ephem(planets_path=de440, asteroids_path=de441_n16)
sim = rebound.Simulation()

start_tdb_time = orbit.coordinates.time.jd().to_numpy()[0]
sim.t = start_tdb_time - ephem.jd_ref

# Handle particle ID creation (optimized for single orbit)
is_variant = isinstance(orbit, VariantOrbits)
if is_variant:
orbit_id = str(orbit.orbit_id.to_numpy(zero_copy_only=False)[0])
variant_id = str(orbit.variant_id.to_numpy(zero_copy_only=False)[0])
particle_hash = uint32_hash(f"{orbit_id}\x1f{variant_id}")
else:
orbit_id = str(orbit.orbit_id.to_numpy(zero_copy_only=False)[0])
particle_hash = uint32_hash(orbit_id)

assist.Extras(sim, ephem)

# Add single particle
coords = orbit.coordinates
position_arrays = coords.r
velocity_arrays = coords.v

sim.add(
x=position_arrays[0, 0],
y=position_arrays[0, 1],
z=position_arrays[0, 2],
vx=velocity_arrays[0, 0],
vy=velocity_arrays[0, 1],
vz=velocity_arrays[0, 2],
hash=particle_hash,
)

# Set integrator parameters
sim.dt = self.initial_dt
sim.ri_ias15.min_dt = self.min_dt
sim.ri_ias15.adaptive_mode = self.adaptive_mode
sim.ri_ias15.epsilon = self.epsilon

# Prepare integration times (numpy only)
integrator_times = times.rescale("tdb").jd().to_numpy()
integrator_times = integrator_times - ephem.jd_ref

# Integration loop (preallocate state array)
N = len(integrator_times)
if N == 0:
return VariantOrbits.empty() if is_variant else Orbits.empty()

xyzvxvyvz = np.zeros((N, 6), dtype="float64")
scratch = np.zeros((1, 6), dtype="float64")

for i in range(N):
sim.integrate(integrator_times[i])
scratch.fill(0.0)
sim.serialize_particle_data(xyzvxvyvz=scratch)
xyzvxvyvz[i, :] = scratch[0, :]

# Build results
jd_times = integrator_times + ephem.jd_ref
times_out = Timestamp.from_jd(jd_times, scale="tdb")
origin_codes = Origin.from_kwargs(
code=pa.repeat("SOLAR_SYSTEM_BARYCENTER", xyzvxvyvz.shape[0])
)

if is_variant:
orbit_ids_out = [orbit_id] * N
variant_ids_out = [variant_id] * N
object_id_out = np.tile(orbit.object_id.to_numpy(zero_copy_only=False), N)

return VariantOrbits.from_kwargs(
orbit_id=orbit_ids_out,
variant_id=variant_ids_out,
object_id=object_id_out,
weights=orbit.weights,
weights_cov=orbit.weights_cov,
coordinates=CartesianCoordinates.from_kwargs(
x=xyzvxvyvz[:, 0],
y=xyzvxvyvz[:, 1],
z=xyzvxvyvz[:, 2],
vx=xyzvxvyvz[:, 3],
vy=xyzvxvyvz[:, 4],
vz=xyzvxvyvz[:, 5],
time=times_out,
origin=origin_codes,
frame="equatorial",
),
)
else:
orbit_ids_out = [orbit_id] * N
object_id_out = np.tile(orbit.object_id.to_numpy(zero_copy_only=False), N)

return Orbits.from_kwargs(
coordinates=CartesianCoordinates.from_kwargs(
x=xyzvxvyvz[:, 0],
y=xyzvxvyvz[:, 1],
z=xyzvxvyvz[:, 2],
vx=xyzvxvyvz[:, 3],
vy=xyzvxvyvz[:, 4],
vz=xyzvxvyvz[:, 5],
time=times_out,
origin=origin_codes,
frame="equatorial",
),
orbit_id=orbit_ids_out,
object_id=object_id_out,
)

def _propagate_orbits_inner(
self, orbits: OrbitType, times: TimestampType
) -> OrbitType:
Expand Down Expand Up @@ -225,20 +367,24 @@ def _propagate_orbits_inner(
particle_ids = np.array(particle_ids, dtype="object")

orbit_id_mapping, uint_orbit_ids = hash_orbit_ids_to_uint32(particle_ids)
hash_to_index = {uint_orbit_ids[i].value: i for i in range(len(uint_orbit_ids))}

# Add the orbits as particles to the simulation
coords_df = orbits.coordinates.to_dataframe()
# OPTIMIZED: Use direct array access instead of DataFrame conversion
coords = orbits.coordinates
position_arrays = coords.r # x, y, z columns
velocity_arrays = coords.v # vx, vy, vz columns

assist.Extras(sim, ephem)

for i in range(len(coords_df)):
for i in range(len(position_arrays)):
sim.add(
x=coords_df.x[i],
y=coords_df.y[i],
z=coords_df.z[i],
vx=coords_df.vx[i],
vy=coords_df.vy[i],
vz=coords_df.vz[i],
x=position_arrays[i, 0],
y=position_arrays[i, 1],
z=position_arrays[i, 2],
vx=velocity_arrays[i, 0],
vy=velocity_arrays[i, 1],
vz=velocity_arrays[i, 2],
hash=uint_orbit_ids[i],
)

Expand Down Expand Up @@ -271,18 +417,20 @@ def _propagate_orbits_inner(
step_states.append(step_xyzvxvyvz)

if is_variant:
particle_ids = [orbit_id_mapping[h] for h in orbit_id_hashes]
orbit_ids, variant_ids = zip(
*[pid.split(separator) for pid in particle_ids]
indices = np.fromiter(
(hash_to_index[h] for h in orbit_id_hashes),
dtype=np.int64,
count=sim.N,
)
step_orbit_ids.append(np.asarray(orbit_ids, dtype=object))
step_variant_ids.append(np.asarray(variant_ids, dtype=object))
step_orbit_ids.append(orbit_ids[indices])
step_variant_ids.append(variant_ids[indices])
else:
step_orbit_ids.append(
np.asarray(
[orbit_id_mapping[h] for h in orbit_id_hashes], dtype=object
)
indices = np.fromiter(
(hash_to_index[h] for h in orbit_id_hashes),
dtype=np.int64,
count=sim.N,
)
step_orbit_ids.append(particle_ids[indices])

# Build a single result table
if len(step_states) == 0:
Expand Down Expand Up @@ -409,20 +557,24 @@ def _detect_collisions(
particle_ids = np.array(particle_ids, dtype="object")

orbit_id_mapping, uint_orbit_ids = hash_orbit_ids_to_uint32(particle_ids)
{uint_orbit_ids[i].value: i for i in range(len(uint_orbit_ids))}

# Add the orbits as particles to the simulation
coords_df = orbits.coordinates.to_dataframe()
# OPTIMIZED: Use direct array access instead of DataFrame conversion
coords = orbits.coordinates
position_arrays = coords.r # x, y, z columns
velocity_arrays = coords.v # vx, vy, vz columns

assist.Extras(sim, ephem)

for i in range(len(coords_df)):
for i in range(len(position_arrays)):
sim.add(
x=coords_df.x[i],
y=coords_df.y[i],
z=coords_df.z[i],
vx=coords_df.vx[i],
vy=coords_df.vy[i],
vz=coords_df.vz[i],
x=position_arrays[i, 0],
y=position_arrays[i, 1],
z=position_arrays[i, 2],
vx=velocity_arrays[i, 0],
vy=velocity_arrays[i, 1],
vz=velocity_arrays[i, 2],
hash=uint_orbit_ids[i],
)

Expand Down