Skip to content

Commit c2dc479

Browse files
committed
[AVSLab#311] Refactor OpportunityProperties to use objects instead of dicts
1 parent 7fe35cc commit c2dc479

File tree

3 files changed

+145
-73
lines changed

3 files changed

+145
-73
lines changed

src/bsk_rl/obs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class MyObservationSatellite(Satellite):
3939
Eclipse,
4040
Observation,
4141
OpportunityProperties,
42+
TargetOpportunityProperty,
4243
ResourceRewardWeight,
4344
SatProperties,
4445
Time,

src/bsk_rl/obs/observations.py

Lines changed: 131 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
import logging
44
from abc import ABC, abstractmethod
55
from copy import deepcopy
6-
from typing import TYPE_CHECKING, Any, Union
6+
from typing import TYPE_CHECKING, Any, Union, Callable
77

88
import numpy as np
99
from gymnasium import spaces
1010

1111
from bsk_rl.utils.functional import Resetable, vectorize_nested_dict
1212
from bsk_rl.utils.orbital import rv2HN
1313

14+
from enum import Enum
15+
1416
if TYPE_CHECKING: # pragma: no cover
1517
from bsk_rl.sats import Satellite
1618
from bsk_rl.sim import Simulator
@@ -253,7 +255,8 @@ def reset_post_sim_init(self) -> None:
253255
obs_property["module"] = module
254256
break
255257
else:
256-
raise AttributeError(f"Property {obs_property['prop']} not found")
258+
raise AttributeError(
259+
f"Property {obs_property['prop']} not found")
257260

258261
def get_obs(self) -> dict[str, Any]:
259262
"""Return the observation.
@@ -324,7 +327,8 @@ def _target_angle_rate(sat, opp):
324327
omega_BP_P = sat.dynamics.omega_BP_P
325328
omega_CP_ref = (
326329
omega_BP_P
327-
- np.cross(v_BN_P, r_LP_P - r_BN_P) / np.linalg.norm(r_LP_P - r_BN_P) ** 2
330+
- np.cross(v_BN_P, r_LP_P - r_BN_P) /
331+
np.linalg.norm(r_LP_P - r_BN_P) ** 2
328332
)
329333
return np.linalg.norm(omega_CP_ref)
330334

@@ -337,60 +341,133 @@ def _r_LB_H(sat, opp):
337341
return HN @ r_TB_N
338342

339343

340-
class OpportunityProperties(Observation):
341-
_fn_map = {
342-
"priority": lambda sat, opp: opp["object"].priority,
343-
"r_LP_P": lambda sat, opp: opp["r_LP_P"],
344-
"r_LB_H": _r_LB_H,
345-
"opportunity_open": lambda sat, opp: opp["window"][0] - sat.simulator.sim_time,
346-
"opportunity_mid": lambda sat, opp: sum(opp["window"]) / 2
347-
- sat.simulator.sim_time,
348-
"opportunity_close": lambda sat, opp: opp["window"][1] - sat.simulator.sim_time,
349-
"target_angle": _target_angle,
350-
"target_angle_rate": _target_angle_rate,
351-
}
344+
class TargetOpportunityProperty():
345+
class Property(Enum):
346+
"""Enumeration of opportunity properties used in observations.
347+
348+
Each member stores a function that computes the property value
349+
given a satellite (`sat`) and an opportunity (`opp`).
350+
351+
Members:
352+
priority: Priority of the target.
353+
r_LP_P: Location of the target in the planet-fixed frame.
354+
r_LB_H: Location of the target in the Hill frame.
355+
opportunity_open: Time until the opportunity opens.
356+
opportunity_mid: Time until the opportunity midpoint.
357+
opportunity_close: Time until the opportunity closes.
358+
target_angle: Angle between the target and the satellite instrument direction.
359+
target_angle_rate: Rate difference between the target pointing frame and the body frame.
360+
"""
361+
priority = lambda sat, opp: opp["object"].priority
362+
r_LP_P = lambda sat, opp: opp["r_LP_P"]
363+
r_LB_H = _r_LB_H
364+
opportunity_open = lambda sat, opp: opp["window"][0] - sat.simulator.sim_time
365+
opportunity_mid = lambda sat, opp: sum(opp["window"]) / 2 - sat.simulator.sim_time
366+
opportunity_close = lambda sat, opp: opp["window"][1] - sat.simulator.sim_time
367+
target_angle = _target_angle
368+
target_angle_rate = _target_angle_rate
369+
370+
@classmethod
371+
def from_key(cls, key: str) -> "TargetOpportunityProperty.Property":
372+
try:
373+
return cls[key]
374+
except KeyError as e:
375+
raise KeyError(f"Unknown Property key: {key!r}") from e
376+
377+
378+
def __init__(
379+
self,
380+
name: str = None,
381+
fn: Callable[[Any, Any], Any] = None,
382+
prop: Union[Property, str] = None,
383+
norm: float = 1.0,
384+
i: int = 0,
385+
):
386+
"""Property derived from an access opportunity to append to the observation.
387+
388+
Property that is a function of the opportunity to be appended to the the
389+
observation. Properties are optionally normalized by some factor.
390+
391+
Args:
392+
name: Explicit name for this observation element.
393+
fn: Callable to compute the value, with signature fn(satellite, opportunity) -> Any``.
394+
prop: If fn is not provided, this key will be used to look up a preset function:
395+
norm: Scalar used to normalize the computed value. Defaults to ``1.0``.
396+
i: Index used only for auto-naming when neither name nor prop is provided.
397+
"""
398+
if isinstance(prop, str):
399+
prop = self.Property.from_key(prop) # from_key method handles errors
400+
401+
# Choose fn
402+
if fn is None:
403+
if prop is None:
404+
raise ValueError("Either `fn` or `prop` must be provided.")
405+
fn = prop.value
406+
else:
407+
if prop is not None:
408+
logger.warning(
409+
f"Ignoring default function for `{prop}` when `fn` is provided."
410+
)
411+
412+
# Determine best name
413+
if name is None:
414+
if isinstance(prop, TargetOpportunityProperty.Property):
415+
name = prop.name
416+
else:
417+
name = f"prop_{i}"
418+
if norm != 1.0:
419+
name += "_normd"
420+
421+
self.name = name
422+
self.fn = fn
423+
self.norm = norm
424+
425+
@classmethod
426+
def from_dict(cls, spec: dict[str, Any], i: int = 0) -> "TargetOpportunityProperty":
427+
"""Initialize from a legacy dict spec with keys {prop, fn, name, norm}."""
428+
for key in spec:
429+
if key not in ["fn", "norm", "name", "prop"]:
430+
raise ValueError(f"Invalid property key: {key}")
431+
432+
name = spec.get("name")
433+
fn = spec.get("fn")
434+
prop = spec.get("prop")
435+
norm = spec.get("norm", 1.0)
436+
437+
return cls(name=name, fn=fn, prop=prop, norm=norm, i=i)
438+
439+
def __getitem__(self, key):
440+
return getattr(self, key)
441+
442+
def __setitem__(self, key, value):
443+
setattr(self, key, value)
444+
_ = list(TargetOpportunityProperty.Property) # loader
352445

446+
class OpportunityProperties(Observation):
353447
def __init__(
354448
self,
355-
*target_properties: dict[str, Any],
449+
*target_properties: dict[str, Any] | TargetOpportunityProperty,
356450
n_ahead_observe: int,
357451
type="target",
358452
name=None,
359453
):
360-
"""Include information about upcoming access opportunities in the observation..
454+
"""Include information about upcoming access opportunities in the observation.
361455
362-
For each desired property, a dictionary specifying the property name and settings
456+
For each desired property, a TargetOpportunityProperty object specifying the property name and settings
363457
is passed. These can include preset properties or arbitrary functions of the satellite
364458
and opportunity.
365459
366460
.. code-block:: python
367461
368462
OpportunityProperties(
369-
dict(prop="r_LP_P", norm=REQ_EARTH * 1e3),
370-
dict(prop="double_priority", fn=lambda sat, opp: opp["target"].priority * 2.0),
463+
TargetOpportunityProperty(prop="r_LP_P", norm=REQ_EARTH * 1e3),
464+
TargetOpportunityProperty(name="double_priority",
465+
fn=lambda sat, opp: opp["object"].priority * 2.0),
371466
n_ahead_observe=16,
372467
)
373468
374469
Args:
375-
target_properties: Property that is a function of the opportunity to be appended
376-
to the the observation. Properties are optionally normalized by some factor.
377-
Each observation is a dictionary with the keys:
378-
379-
* ``name`` `optional`: Name of the observation element.
380-
* ``fn`` `optional`: Function to calculate property, in the form ``fn(satellite, opportunity)``.
381-
If not provided, the key ``prop`` will be used to look up a preset function:
382-
383-
* ``priority``: Priority of the target.
384-
* ``r_LP_P``: Location of the target in the planet-fixed frame.
385-
* ``r_LB_H``: Location of the target in the Hill frame.
386-
* ``opportunity_open``: Time until the opportunity opens.
387-
* ``opportunity_mid``: Time until the opportunity midpoint.
388-
* ``opportunity_close``: Time until the opportunity closes.
389-
* ``target_angle``: Angle between the target and the satellite instrument direction.
390-
* ``target_angle_rate``: Rate difference between the target pointing frame and the body frame.
391-
392-
* ``norm`` `optional`: Value to normalize property by. Defaults to 1.0.
393-
470+
target_properties: Property that is a function of the opportunity to be appended to the the observation.
394471
n_ahead_observe: Number of upcoming targets to consider.
395472
type: The type of opportunity to consider. Can be ``target``, ``ground_station``,
396473
or any other type of opportunity that has been added via
@@ -401,38 +478,19 @@ def __init__(
401478
name = type
402479
super().__init__(name=name)
403480
self.type = type
404-
self.target_properties = target_properties
405-
for i, prop_spec in enumerate(self.target_properties):
406-
for key in prop_spec:
407-
if key not in ["fn", "norm", "name", "prop"]:
408-
raise ValueError(f"Invalid property key: {key}")
409481

410-
if "norm" not in prop_spec:
411-
prop_spec["norm"] = 1.0
412-
413-
# Determine observation function
414-
if "fn" not in prop_spec:
415-
try:
416-
prop_spec["fn"] = self._fn_map[prop_spec["prop"]]
417-
except KeyError:
418-
raise ValueError(
419-
f"Property prop={prop_spec['prop']} is not predefined and no `fn` was provided."
420-
)
482+
normalized = []
483+
for i, prop in enumerate(target_properties):
484+
if isinstance(prop, TargetOpportunityProperty):
485+
normalized.append(prop)
486+
elif isinstance(prop, dict):
487+
normalized.append(TargetOpportunityProperty.from_dict(prop, i=i))
421488
else:
422-
if "prop" in prop_spec and prop_spec["prop"] in self._fn_map:
423-
logger.warning(
424-
f"Ignoring default function for `{prop_spec['prop']}` when `fn` is provided."
425-
)
426-
427-
# Determine best name
428-
if "name" not in prop_spec:
429-
if "prop" in prop_spec:
430-
prop_spec["name"] = prop_spec["prop"]
431-
else:
432-
prop_spec["name"] = f"prop_{i}"
433-
434-
if prop_spec["norm"] != 1.0:
435-
prop_spec["name"] += "_normd"
489+
raise TypeError(
490+
"Each target property must be a dict or a TargetOpportunityProperty"
491+
f"got {type(prop).__name__}."
492+
)
493+
self.target_properties = normalized
436494

437495
self.n_ahead_observe = int(n_ahead_observe)
438496

@@ -457,10 +515,10 @@ def get_obs(self):
457515
)
458516
):
459517
props = {}
460-
for prop_spec in self.target_properties:
461-
name = prop_spec["name"]
462-
norm = prop_spec["norm"]
463-
value = prop_spec["fn"](self.satellite, opportunity)
518+
for prop in self.target_properties:
519+
name = prop.name
520+
norm = prop.norm
521+
value = prop.fn(self.satellite, opportunity)
464522
props[name] = value / norm
465523
obs[f"{self.name}_{i}"] = props
466524
return obs

tests/unittest/obs/test_observations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ def test_init(self):
192192
assert ob.target_properties[0]["name"] == "r_LP_P_normd"
193193
assert ob.target_properties[1]["norm"] == 1.0
194194

195+
def test_init_new_api(self):
196+
ob = obs.OpportunityProperties(
197+
obs.TargetOpportunityProperty(prop="r_LP_P", norm=2.0),
198+
obs.TargetOpportunityProperty(prop="r_LP_P", norm=2.0),
199+
obs.TargetOpportunityProperty(
200+
prop="double_priority", fn=lambda sat, opp: opp["target"].priority * 2.0
201+
),
202+
n_ahead_observe=2,
203+
)
204+
assert ob.target_properties[0].fn
205+
assert ob.target_properties[0].name == "r_LP_P_normd"
206+
assert ob.target_properties[1].norm == 1.0
207+
195208
def test_get_obs(self):
196209
ob = obs.OpportunityProperties(
197210
dict(prop="priority", norm=2.0),

0 commit comments

Comments
 (0)