Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 150 additions & 8 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import base64
import json
from datetime import datetime
from importlib import import_module

import dill
import numpy as np


class RocketPyEncoder(json.JSONEncoder):
"""Custom JSON encoder for RocketPy objects. It defines how to encode
different types of objects to a JSON supported format."""

def __init__(self, *args, **kwargs):
self.include_outputs = kwargs.pop("include_outputs", True)
super().__init__(*args, **kwargs)

def default(self, o):
if isinstance(
o,
Expand All @@ -33,17 +40,152 @@
elif isinstance(o, np.ndarray):
return o.tolist()
elif isinstance(o, datetime):
return o.isoformat()
return [o.year, o.month, o.day, o.hour]

Check warning on line 43 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L43

Added line #L43 was not covered by tests
elif hasattr(o, "__iter__") and not isinstance(o, str):
return list(o)
elif hasattr(o, "to_dict"):
return o.to_dict()
encoding = o.to_dict(self.include_outputs)
encoding = remove_circular_references(encoding)

encoding["signature"] = get_class_signature(o)

return encoding

elif hasattr(o, "__dict__"):
exception_set = {"prints", "plots"}
return {
key: value
for key, value in o.__dict__.items()
if key not in exception_set
}
encoding = remove_circular_references(o.__dict__)

Check warning on line 55 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L55

Added line #L55 was not covered by tests

if "rocketpy" in o.__class__.__module__:
encoding["signature"] = get_class_signature(o)

Check warning on line 58 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L57-L58

Added lines #L57 - L58 were not covered by tests

return encoding

Check warning on line 60 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L60

Added line #L60 was not covered by tests
else:
return super().default(o)


class RocketPyDecoder(json.JSONDecoder):
"""Custom JSON decoder for RocketPy objects. It defines how to decode
different types of objects from a JSON supported format."""

def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, obj):
if "signature" in obj:
signature = obj.pop("signature")

try:
class_ = get_class_from_signature(signature)

if hasattr(class_, "from_dict"):
return class_.from_dict(obj)
else:
# Filter keyword arguments
kwargs = {

Check warning on line 83 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L83

Added line #L83 was not covered by tests
key: value
for key, value in obj.items()
if key in class_.__init__.__code__.co_varnames
}

return class_(**kwargs)
except (ImportError, AttributeError):
return obj

Check warning on line 91 in rocketpy/_encoders.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/_encoders.py#L89-L91

Added lines #L89 - L91 were not covered by tests
else:
return obj


def get_class_signature(obj):
"""Returns the signature of a class in the form of a string.
The signature is an importable string that can be used to import
the class by its module.

Parameters
----------
obj : object
Object to get the signature from.

Returns
-------
str
Signature of the class.
"""
class_ = obj.__class__
name = getattr(class_, '__qualname__', class_.__name__)

return {"module": class_.__module__, "name": name}


def get_class_from_signature(signature):
"""Returns the class by importing its signature.

Parameters
----------
signature : str
Signature of the class.

Returns
-------
type
Class defined by the signature.
"""
module = import_module(signature["module"])
inner_class = None

for class_ in signature["name"].split("."):
inner_class = getattr(module, class_)

return inner_class


def remove_circular_references(obj_dict):
"""Removes circular references from a dictionary.

Parameters
----------
obj_dict : dict
Dictionary to remove circular references from.

Returns
-------
dict
Dictionary without circular references.
"""
obj_dict.pop("prints", None)
obj_dict.pop("plots", None)

return obj_dict


def to_hex_encode(obj, encoder=base64.b85encode):
"""Converts an object to hex representation using dill.

Parameters
----------
obj : object
Object to be converted to hex.
encoder : callable, optional
Function to encode the bytes. Default is base64.b85encode.

Returns
-------
bytes
Object converted to bytes.
"""
return encoder(dill.dumps(obj)).hex()


def from_hex_decode(obj_bytes, decoder=base64.b85decode):
"""Converts an object from hex representation using dill.

Parameters
----------
obj_bytes : str
Hex string to be converted to object.
decoder : callable, optional
Function to decode the bytes. Default is base64.b85decode.

Returns
-------
object
Object converted from bytes.
"""
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
90 changes: 90 additions & 0 deletions rocketpy/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,96 @@
arc_seconds = (remainder * 60 - arc_minutes) * 60
return degrees, arc_minutes, arc_seconds

def to_dict(self, include_outputs=True):
env_dict = {
"gravity": self.gravity,
"date": self.date,
"latitude": self.latitude,
"longitude": self.longitude,
"elevation": self.elevation,
"datum": self.datum,
"timezone": self.timezone,
"_max_expected_height": self.max_expected_height,
"atmospheric_model_type": self.atmospheric_model_type,
"pressure": self.pressure,
"barometric_height": self.barometric_height,
"temperature": self.temperature,
"wind_velocity_x": self.wind_velocity_x,
"wind_velocity_y": self.wind_velocity_y,
"wind_heading": self.wind_heading,
"wind_direction": self.wind_direction,
"wind_speed": self.wind_speed,
}

if include_outputs:
env_dict["density"] = self.density
env_dict["speed_of_sound"] = self.speed_of_sound
env_dict["dynamic_viscosity"] = self.dynamic_viscosity

return env_dict

@classmethod
def from_dict(cls, data): # pylint: disable=too-many-statements
env = cls(
gravity=data["gravity"],
date=data["date"],
latitude=data["latitude"],
longitude=data["longitude"],
elevation=data["elevation"],
datum=data["datum"],
timezone=data["timezone"],
max_expected_height=data["_max_expected_height"],
)
atmospheric_model = data["atmospheric_model_type"]

if atmospheric_model == "standard_atmosphere":
env.set_atmospheric_model("standard_atmosphere")
elif atmospheric_model == "custom_atmosphere":
env.set_atmospheric_model(

Check warning on line 2791 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2791

Added line #L2791 was not covered by tests
type="custom_atmosphere",
pressure=data["pressure"],
temperature=data["temperature"],
wind_u=data["wind_velocity_x"],
wind_v=data["wind_velocity_y"],
)
else:
env.__set_pressure_function(data["pressure"])
env.__set_barometric_height_function(data["barometric_height"])
env.__set_temperature_function(data["temperature"])
env.__set_wind_velocity_x_function(data["wind_velocity_x"])
env.__set_wind_velocity_y_function(data["wind_velocity_y"])
env.__set_wind_heading_function(data["wind_heading"])
env.__set_wind_direction_function(data["wind_direction"])
env.__set_wind_speed_function(data["wind_speed"])
env.elevation = data["elevation"]
env.max_expected_height = data["_max_expected_height"]

if atmospheric_model in ["windy", "forecast", "reanalysis", "ensemble"]:
env.atmospheric_model_init_date = data["atmospheric_model_init_date"]
env.atmospheric_model_end_date = data["atmospheric_model_end_date"]
env.atmospheric_model_interval = data["atmospheric_model_interval"]
env.atmospheric_model_init_lat = data["atmospheric_model_init_lat"]
env.atmospheric_model_end_lat = data["atmospheric_model_end_lat"]
env.atmospheric_model_init_lon = data["atmospheric_model_init_lon"]
env.atmospheric_model_end_lon = data["atmospheric_model_end_lon"]

Check warning on line 2817 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2811-L2817

Added lines #L2811 - L2817 were not covered by tests

if atmospheric_model == "ensemble":
env.level_ensemble = data["level_ensemble"]
env.height_ensemble = data["height_ensemble"]
env.temperature_ensemble = data["temperature_ensemble"]
env.wind_u_ensemble = data["wind_u_ensemble"]
env.wind_v_ensemble = data["wind_v_ensemble"]
env.wind_heading_ensemble = data["wind_heading_ensemble"]
env.wind_direction_ensemble = data["wind_direction_ensemble"]
env.wind_speed_ensemble = data["wind_speed_ensemble"]
env.num_ensemble_members = data["num_ensemble_members"]

Check warning on line 2828 in rocketpy/environment/environment.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/environment/environment.py#L2820-L2828

Added lines #L2820 - L2828 were not covered by tests

env.calculate_density_profile()
env.calculate_speed_of_sound_profile()
env.calculate_dynamic_viscosity()

return env


if __name__ == "__main__":
import doctest
Expand Down
5 changes: 4 additions & 1 deletion rocketpy/environment/environment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ def __check_coordinates_inside_grid(
or lat_index > len(lat_array) - 1
):
raise ValueError(
f"Latitude and longitude pair {(self.latitude, self.longitude)} is outside the grid available in the given file, which is defined by {(lat_array[0], lon_array[0])} and {(lat_array[-1], lon_array[-1])}."
f"Latitude and longitude pair {(self.latitude, self.longitude)} "
"is outside the grid available in the given file, which "
f"is defined by {(lat_array[0], lon_array[0])} and "
f"{(lat_array[-1], lon_array[-1])}."
)

def __localize_input_dates(self):
Expand Down
19 changes: 8 additions & 11 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
carefully as it may impact all the rest of the project.
"""

import base64
import warnings
import zlib
from bisect import bisect_left
from collections.abc import Iterable
from copy import deepcopy
from functools import cached_property
from inspect import signature
from pathlib import Path

import dill
import matplotlib.pyplot as plt
import numpy as np
from scipy import integrate, linalg, optimize
Expand All @@ -25,6 +22,8 @@
RBFInterpolator,
)

from rocketpy._encoders import from_hex_decode, to_hex_encode

# Numpy 1.x compatibility,
# TODO: remove these lines when all dependencies support numpy>=2.0.0
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
Expand Down Expand Up @@ -712,9 +711,9 @@ def set_discrete(
if func.__dom_dim__ == 1:
xs = np.linspace(lower, upper, samples)
ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs)
func.set_source(np.concatenate(([xs], [ys])).transpose())
func.set_interpolation(interpolation)
func.set_extrapolation(extrapolation)
func.__interpolation__ = interpolation
func.__extrapolation__ = extrapolation
func.set_source(np.column_stack((xs, ys)))
elif func.__dom_dim__ == 2:
lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower
upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper
Expand Down Expand Up @@ -3390,7 +3389,7 @@ def __validate_extrapolation(self, extrapolation):
extrapolation = "natural"
return extrapolation

def to_dict(self):
def to_dict(self, _):
"""Serializes the Function instance to a dictionary.

Returns
Expand All @@ -3401,7 +3400,7 @@ def to_dict(self):
source = self.source

if callable(source):
source = zlib.compress(base64.b85encode(dill.dumps(source))).hex()
source = to_hex_encode(source)

return {
"source": source,
Expand All @@ -3423,9 +3422,7 @@ def from_dict(cls, func_dict):
"""
source = func_dict["source"]
if func_dict["interpolation"] is None and func_dict["extrapolation"] is None:
source = dill.loads(
base64.b85decode(zlib.decompress(bytes.fromhex(source)))
)
source = from_hex_decode(source)

return cls(
source=source,
Expand Down
Loading
Loading