Skip to content

Commit 5284fb5

Browse files
committed
ENH: support for air brakes, controller and sensors encoding.
1 parent 274d347 commit 5284fb5

File tree

15 files changed

+317
-25
lines changed

15 files changed

+317
-25
lines changed

rocketpy/_encoders.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,32 @@ def object_hook(self, obj):
107107

108108
try:
109109
class_ = get_class_from_signature(signature)
110+
hash_ = signature.get("hash", None)
110111

111112
if class_.__name__ == "Flight" and not self.resimulate:
112113
new_flight = class_.__new__(class_)
113114
new_flight.prints = _FlightPrints(new_flight)
114115
new_flight.plots = _FlightPlots(new_flight)
115116
set_minimal_flight_attributes(new_flight, obj)
117+
if hash_ is not None:
118+
setattr(new_flight, "__rpy_hash", hash_)
116119
return new_flight
117120
elif hasattr(class_, "from_dict"):
118-
return class_.from_dict(obj)
121+
new_obj = class_.from_dict(obj)
122+
if hash_ is not None:
123+
setattr(new_obj, "__rpy_hash", hash_)
124+
return new_obj
119125
else:
120126
# Filter keyword arguments
121127
kwargs = {
122128
key: value
123129
for key, value in obj.items()
124130
if key in class_.__init__.__code__.co_varnames
125131
}
126-
127-
return class_(**kwargs)
132+
new_obj = class_(**kwargs)
133+
if hash_ is not None:
134+
setattr(new_obj, "__rpy_hash", hash_)
135+
return new_obj
128136
except (ImportError, AttributeError):
129137
return obj
130138
else:
@@ -157,7 +165,6 @@ def set_minimal_flight_attributes(flight, obj):
157165
"x_impact",
158166
"y_impact",
159167
"t_final",
160-
"flight_phases",
161168
"ax",
162169
"ay",
163170
"az",
@@ -207,7 +214,14 @@ def get_class_signature(obj):
207214
class_ = obj.__class__
208215
name = getattr(class_, "__qualname__", class_.__name__)
209216

210-
return {"module": class_.__module__, "name": name}
217+
signature = {"module": class_.__module__, "name": name}
218+
219+
try:
220+
signature.update({"hash": hash(obj)})
221+
except TypeError:
222+
pass
223+
224+
return signature
211225

212226

213227
def get_class_from_signature(signature):

rocketpy/control/controller.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from inspect import signature
2+
from typing import Iterable
3+
from rocketpy.tools import from_hex_decode, to_hex_encode
24

35
from ..prints.controller_prints import _ControllerPrints
46

@@ -181,3 +183,46 @@ def info(self):
181183
def all_info(self):
182184
"""Prints out all information about the controller."""
183185
self.info()
186+
187+
def to_dict(self, **kwargs):
188+
allow_pickle = kwargs.get("allow_pickle", True)
189+
190+
if allow_pickle:
191+
controller_function = to_hex_encode(self.controller_function)
192+
else:
193+
controller_function = self.controller_function.__name__
194+
195+
return {
196+
"controller_function": controller_function,
197+
"sampling_rate": self.sampling_rate,
198+
"initial_observed_variables": self.initial_observed_variables,
199+
"name": self.name,
200+
"_interactive_objects_hash": hash(self.interactive_objects)
201+
if not isinstance(self.interactive_objects, Iterable)
202+
else [hash(obj) for obj in self.interactive_objects],
203+
}
204+
205+
@classmethod
206+
def from_dict(cls, data):
207+
interactive_objects = data.get("interactive_objects", [])
208+
controller_function = data.get("controller_function")
209+
sampling_rate = data.get("sampling_rate")
210+
initial_observed_variables = data.get("initial_observed_variables")
211+
name = data.get("name", "Controller")
212+
213+
try:
214+
controller_function = from_hex_decode(controller_function)
215+
except (TypeError, ValueError):
216+
pass
217+
218+
obj = cls(
219+
interactive_objects=interactive_objects,
220+
controller_function=controller_function,
221+
sampling_rate=sampling_rate,
222+
initial_observed_variables=initial_observed_variables,
223+
name=name,
224+
)
225+
setattr(
226+
obj, "_interactive_objects_hash", data.get("_interactive_objects_hash", [])
227+
)
228+
return obj

rocketpy/mathutils/function.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,7 @@ def short_time_fft(
15511551
... ylabel=f"Freq. $f$ in Hz)",
15521552
... xlim=(t_lo, t_hi)
15531553
... )
1554-
>>> _ = ax1.plot(t_x, f_i, 'r--', alpha=.5, label='$f_i(t)$')
1554+
# >>> _ = ax1.plot(t_x, f_i, 'r--', alpha=.5, label='$f_i(t)$')
15551555
>>> _ = fig1.colorbar(im1, label="Magnitude $|S_x(t, f)|$")
15561556
>>> # Shade areas where window slices stick out to the side
15571557
>>> for t0_, t1_ in [(t_lo, 1), (49, t_hi)]:
@@ -1856,8 +1856,7 @@ def plot_1d( # pylint: disable=too-many-statements
18561856
None
18571857
"""
18581858
# Define a mesh and y values at mesh nodes for plotting
1859-
fig = plt.figure()
1860-
ax = fig.axes
1859+
fig, ax = plt.subplots()
18611860
if self._source_type is SourceType.CALLABLE:
18621861
# Determine boundaries
18631862
domain = [0, 10]
@@ -1895,9 +1894,9 @@ def plot_1d( # pylint: disable=too-many-statements
18951894
plt.title(self.title)
18961895
plt.xlabel(self.__inputs__[0].title())
18971896
plt.ylabel(self.__outputs__[0].title())
1898-
show_or_save_plot(filename)
18991897
if return_object:
19001898
return fig, ax
1899+
show_or_save_plot(filename)
19011900

19021901
@deprecated(
19031902
reason="The `Function.plot2D` method is set to be deprecated and fully "

rocketpy/rocket/aero_surface/air_brakes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,24 @@ def all_info(self):
206206
"""
207207
self.info()
208208
self.plots.drag_coefficient_curve()
209+
210+
def to_dict(self, **kwargs):
211+
return {
212+
"drag_coefficient_curve": self.drag_coefficient,
213+
"reference_area": self.reference_area,
214+
"clamp": self.clamp,
215+
"override_rocket_drag": self.override_rocket_drag,
216+
"deployment_level": self.initial_deployment_level,
217+
"name": self.name,
218+
}
219+
220+
@classmethod
221+
def from_dict(cls, data):
222+
return cls(
223+
drag_coefficient_curve=data.get("drag_coefficient_curve"),
224+
reference_area=data.get("reference_area"),
225+
clamp=data.get("clamp"),
226+
override_rocket_drag=data.get("override_rocket_drag"),
227+
deployment_level=data.get("deployment_level"),
228+
name=data.get("name"),
229+
)

rocketpy/rocket/aero_surface/fins/fins.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,13 +427,25 @@ def compute_forces_and_moments(
427427
return R1, R2, R3, M1, M2, M3
428428

429429
def to_dict(self, **kwargs):
430+
if self.airfoil:
431+
if kwargs.get("discretize", False):
432+
lower = -np.pi / 6 if self.airfoil[1] == "radians" else -30
433+
upper = np.pi / 6 if self.airfoil[1] == "radians" else 30
434+
airfoil = (
435+
self.airfoil_cl.set_discrete(lower, upper, 50, mutate_self=False),
436+
self.airfoil[1],
437+
)
438+
else:
439+
airfoil = (self.airfoil_cl, self.airfoil[1]) if self.airfoil else None
440+
else:
441+
airfoil = None
430442
data = {
431443
"n": self.n,
432444
"root_chord": self.root_chord,
433445
"span": self.span,
434446
"rocket_radius": self.rocket_radius,
435447
"cant_angle": self.cant_angle,
436-
"airfoil": self.airfoil,
448+
"airfoil": airfoil,
437449
"name": self.name,
438450
}
439451

rocketpy/rocket/rocket.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22

33
import numpy as np
4+
from rocketpy.tools import find_obj_from_hash
5+
from typing import Iterable
46

57
from rocketpy.control.controller import _Controller
68
from rocketpy.mathutils.function import Function
@@ -2070,17 +2072,32 @@ def from_dict(cls, data):
20702072
for parachute in data["parachutes"]:
20712073
rocket.parachutes.append(parachute)
20722074

2073-
for air_brakes in data["air_brakes"]:
2074-
rocket.add_air_brakes(
2075-
drag_coefficient_curve=air_brakes["drag_coefficient_curve"],
2076-
controller_function=air_brakes["controller_function"],
2077-
sampling_rate=air_brakes["sampling_rate"],
2078-
clamp=air_brakes["clamp"],
2079-
reference_area=air_brakes["reference_area"],
2080-
initial_observed_variables=air_brakes["initial_observed_variables"],
2081-
override_rocket_drag=air_brakes["override_rocket_drag"],
2082-
name=air_brakes["name"],
2083-
controller_name=air_brakes["controller_name"],
2084-
)
2075+
for sensor, position in data["sensors"]:
2076+
rocket.add_sensor(sensor, position)
2077+
2078+
for air_brake in data["air_brakes"]:
2079+
rocket.air_brakes.append(air_brake)
2080+
2081+
for controller in data["_controllers"]:
2082+
if (
2083+
interactive_objects_hash := getattr(
2084+
controller, "_interactive_objects_hash"
2085+
)
2086+
) is not None:
2087+
is_iterable = isinstance(interactive_objects_hash, Iterable)
2088+
if not is_iterable:
2089+
interactive_objects_hash = [interactive_objects_hash]
2090+
for hash_ in interactive_objects_hash:
2091+
if (hashed_obj := find_obj_from_hash(data, hash_)) is not None:
2092+
if not is_iterable:
2093+
controller.interactive_objects = hashed_obj
2094+
else:
2095+
controller.interactive_objects.append(hashed_obj)
2096+
else:
2097+
warnings.warn(
2098+
"Could not find controller interactive objects."
2099+
"Deserialization will proceed, results may not be accurate."
2100+
)
2101+
rocket._add_controllers(controller)
20852102

20862103
return rocket

rocketpy/sensors/accelerometer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,28 @@ def export_measured_data(self, filename, file_format="csv"):
275275
file_format=file_format,
276276
data_labels=("t", "ax", "ay", "az"),
277277
)
278+
279+
def to_dict(self, **kwargs):
280+
data = super().to_dict(**kwargs)
281+
data.update({"consider_gravity": self.consider_gravity})
282+
return data
283+
284+
@classmethod
285+
def from_dict(cls, data):
286+
return cls(
287+
sampling_rate=data["sampling_rate"],
288+
orientation=data["orientation"],
289+
measurement_range=data["measurement_range"],
290+
resolution=data["resolution"],
291+
noise_density=data["noise_density"],
292+
noise_variance=data["noise_variance"],
293+
random_walk_density=data["random_walk_density"],
294+
random_walk_variance=data["random_walk_variance"],
295+
constant_bias=data["constant_bias"],
296+
operating_temperature=data["operating_temperature"],
297+
temperature_bias=data["temperature_bias"],
298+
temperature_scale_factor=data["temperature_scale_factor"],
299+
cross_axis_sensitivity=data["cross_axis_sensitivity"],
300+
consider_gravity=data["consider_gravity"],
301+
name=data["name"],
302+
)

rocketpy/sensors/barometer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,20 @@ def export_measured_data(self, filename, file_format="csv"):
190190
file_format=file_format,
191191
data_labels=("t", "pressure"),
192192
)
193+
194+
@classmethod
195+
def from_dict(cls, data):
196+
return cls(
197+
sampling_rate=data["sampling_rate"],
198+
measurement_range=data["measurement_range"],
199+
resolution=data["resolution"],
200+
noise_density=data["noise_density"],
201+
noise_variance=data["noise_variance"],
202+
random_walk_density=data["random_walk_density"],
203+
random_walk_variance=data["random_walk_variance"],
204+
constant_bias=data["constant_bias"],
205+
operating_temperature=data["operating_temperature"],
206+
temperature_bias=data["temperature_bias"],
207+
temperature_scale_factor=data["temperature_scale_factor"],
208+
name=data["name"],
209+
)

rocketpy/sensors/gnss_receiver.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,20 @@ def export_measured_data(self, filename, file_format="csv"):
124124
file_format=file_format,
125125
data_labels=("t", "latitude", "longitude", "altitude"),
126126
)
127+
128+
def to_dict(self, **kwargs):
129+
return {
130+
"sampling_rate": self.sampling_rate,
131+
"position_accuracy": self.position_accuracy,
132+
"altitude_accuracy": self.altitude_accuracy,
133+
"name": self.name,
134+
}
135+
136+
@classmethod
137+
def from_dict(cls, data):
138+
return cls(
139+
sampling_rate=data["sampling_rate"],
140+
position_accuracy=data["position_accuracy"],
141+
altitude_accuracy=data["altitude_accuracy"],
142+
name=data["name"],
143+
)

rocketpy/sensors/gyroscope.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,28 @@ def export_measured_data(self, filename, file_format="csv"):
296296
file_format=file_format,
297297
data_labels=("t", "wx", "wy", "wz"),
298298
)
299+
300+
def to_dict(self, **kwargs):
301+
data = super().to_dict(**kwargs)
302+
data.update({"acceleration_sensitivity": self.acceleration_sensitivity})
303+
return data
304+
305+
@classmethod
306+
def from_dict(cls, data):
307+
return cls(
308+
sampling_rate=data["sampling_rate"],
309+
orientation=data["orientation"],
310+
measurement_range=data["measurement_range"],
311+
resolution=data["resolution"],
312+
noise_density=data["noise_density"],
313+
noise_variance=data["noise_variance"],
314+
random_walk_density=data["random_walk_density"],
315+
random_walk_variance=data["random_walk_variance"],
316+
constant_bias=data["constant_bias"],
317+
operating_temperature=data["operating_temperature"],
318+
temperature_bias=data["temperature_bias"],
319+
temperature_scale_factor=data["temperature_scale_factor"],
320+
cross_axis_sensitivity=data["cross_axis_sensitivity"],
321+
acceleration_sensitivity=data["acceleration_sensitivity"],
322+
name=data["name"],
323+
)

0 commit comments

Comments
 (0)