Skip to content

Commit 7834947

Browse files
committed
MNT: Allow for encoding customization of MonteCarlo.
1 parent 715338e commit 7834947

File tree

11 files changed

+2206
-1797
lines changed

11 files changed

+2206
-1797
lines changed

docs/notebooks/monte_carlo_analysis/monte_carlo_analysis_outputs/monte_carlo_class_example.inputs.txt

Lines changed: 1000 additions & 819 deletions
Large diffs are not rendered by default.

docs/notebooks/monte_carlo_analysis/monte_carlo_analysis_outputs/monte_carlo_class_example.kml

Lines changed: 30 additions & 30 deletions
Large diffs are not rendered by default.

docs/notebooks/monte_carlo_analysis/monte_carlo_analysis_outputs/monte_carlo_class_example.outputs.txt

Lines changed: 1000 additions & 819 deletions
Large diffs are not rendered by default.

docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb

Lines changed: 90 additions & 77 deletions
Large diffs are not rendered by default.

rocketpy/_encoders.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
"""Defines a custom JSON encoder for RocketPy objects."""
22

3-
import base64
43
import json
54
from datetime import datetime
65
from importlib import import_module
76

8-
import dill
97
import numpy as np
108

9+
from rocketpy.mathutils.function import Function
10+
1111

1212
class RocketPyEncoder(json.JSONEncoder):
1313
"""Custom JSON encoder for RocketPy objects. It defines how to encode
1414
different types of objects to a JSON supported format."""
1515

1616
def __init__(self, *args, **kwargs):
1717
self.include_outputs = kwargs.pop("include_outputs", False)
18+
self.include_function_data = kwargs.pop("include_function_data", True)
1819
super().__init__(*args, **kwargs)
1920

2021
def default(self, o):
@@ -43,6 +44,13 @@ def default(self, o):
4344
return [o.year, o.month, o.day, o.hour]
4445
elif hasattr(o, "__iter__") and not isinstance(o, str):
4546
return list(o)
47+
elif isinstance(o, Function):
48+
if not self.include_function_data:
49+
return str(o)
50+
else:
51+
encoding = o.to_dict(self.include_outputs)
52+
encoding["signature"] = get_class_signature(o)
53+
return encoding
4654
elif hasattr(o, "to_dict"):
4755
encoding = o.to_dict(self.include_outputs)
4856
encoding = remove_circular_references(encoding)
@@ -155,39 +163,3 @@ def remove_circular_references(obj_dict):
155163
obj_dict.pop("plots", None)
156164

157165
return obj_dict
158-
159-
160-
def to_hex_encode(obj, encoder=base64.b85encode):
161-
"""Converts an object to hex representation using dill.
162-
163-
Parameters
164-
----------
165-
obj : object
166-
Object to be converted to hex.
167-
encoder : callable, optional
168-
Function to encode the bytes. Default is base64.b85encode.
169-
170-
Returns
171-
-------
172-
bytes
173-
Object converted to bytes.
174-
"""
175-
return encoder(dill.dumps(obj)).hex()
176-
177-
178-
def from_hex_decode(obj_bytes, decoder=base64.b85decode):
179-
"""Converts an object from hex representation using dill.
180-
181-
Parameters
182-
----------
183-
obj_bytes : str
184-
Hex string to be converted to object.
185-
decoder : callable, optional
186-
Function to decode the bytes. Default is base64.b85decode.
187-
188-
Returns
189-
-------
190-
object
191-
Object converted from bytes.
192-
"""
193-
return dill.loads(decoder(bytes.fromhex(obj_bytes)))

rocketpy/mathutils/function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
RBFInterpolator,
2323
)
2424

25-
from ..plots.plot_helpers import show_or_save_plot
25+
from rocketpy.tools import from_hex_decode, to_hex_encode
2626

27-
from rocketpy._encoders import from_hex_decode, to_hex_encode
27+
from ..plots.plot_helpers import show_or_save_plot
2828

2929
# Numpy 1.x compatibility,
3030
# TODO: remove these lines when all dependencies support numpy>=2.0.0

rocketpy/rocket/parachute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from rocketpy._encoders import from_hex_decode, to_hex_encode
5+
from rocketpy.tools import from_hex_decode, to_hex_encode
66

77
from ..mathutils.function import Function
88
from ..prints.parachute_prints import _ParachutePrints

rocketpy/simulation/monte_carlo.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def __init__(
173173

174174
# pylint: disable=consider-using-with
175175
def simulate(
176-
self, number_of_simulations, append=False
176+
self, number_of_simulations, append=False, **kwargs
177177
): # pylint: disable=too-many-statements
178178
"""
179179
Runs the Monte Carlo simulation and saves all data.
@@ -185,6 +185,17 @@ def simulate(
185185
append : bool, optional
186186
If True, the results will be appended to the existing files. If
187187
False, the files will be overwritten. Default is False.
188+
kwargs : dict
189+
Custom arguments for simulation export of the ``inputs`` file. Options
190+
are:
191+
192+
* ``include_outputs``: whether to also include outputs data of the
193+
simulation. Default is ``False``.
194+
195+
* ``include_function_data``: whether to include ``rocketpy.Function``
196+
results into the export. Default is ``True``.
197+
198+
See ``rocketpy._encoders.RocketPyEncoder`` for more information.
188199
189200
Returns
190201
-------
@@ -204,6 +215,7 @@ def simulate(
204215
overwritten. Make sure to save the files with the results before
205216
running the simulation again with `append=False`.
206217
"""
218+
self._export_config = kwargs
207219
# Create data files for inputs, outputs and error logging
208220
open_mode = "a" if append else "w"
209221
input_file = open(self._input_file, open_mode, encoding="utf-8")
@@ -224,11 +236,21 @@ def simulate(
224236
self.__run_single_simulation(input_file, output_file)
225237
except KeyboardInterrupt:
226238
print("Keyboard Interrupt, files saved.")
227-
error_file.write(json.dumps(self._inputs_dict, cls=RocketPyEncoder) + "\n")
239+
error_file.write(
240+
json.dumps(
241+
self._inputs_dict, cls=RocketPyEncoder, **self._export_config
242+
)
243+
+ "\n"
244+
)
228245
self.__close_files(input_file, output_file, error_file)
229246
except Exception as error:
230247
print(f"Error on iteration {self.__iteration_count}: {error}")
231-
error_file.write(json.dumps(self._inputs_dict, cls=RocketPyEncoder) + "\n")
248+
error_file.write(
249+
json.dumps(
250+
self._inputs_dict, cls=RocketPyEncoder, **self._export_config
251+
)
252+
+ "\n"
253+
)
232254
self.__close_files(input_file, output_file, error_file)
233255
raise error
234256
finally:
@@ -393,8 +415,12 @@ def __export_flight_data(
393415
) from e
394416
results = results | additional_exports
395417

396-
input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
397-
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")
418+
input_file.write(
419+
json.dumps(inputs_dict, cls=RocketPyEncoder, **self._export_config) + "\n"
420+
)
421+
output_file.write(
422+
json.dumps(results, cls=RocketPyEncoder, **self._export_config) + "\n"
423+
)
398424

399425
def __check_export_list(self, export_list):
400426
"""

rocketpy/tools.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
between minor versions if necessary, although this will be always avoided.
77
"""
88

9+
import base64
910
import functools
1011
import importlib
1112
import importlib.metadata
@@ -15,6 +16,7 @@
1516
import time
1617
from bisect import bisect_left
1718

19+
import dill
1820
import matplotlib.pyplot as plt
1921
import numpy as np
2022
import pytz
@@ -1167,6 +1169,42 @@ def get_matplotlib_supported_file_endings():
11671169
return filetypes
11681170

11691171

1172+
def to_hex_encode(obj, encoder=base64.b85encode):
1173+
"""Converts an object to hex representation using dill.
1174+
1175+
Parameters
1176+
----------
1177+
obj : object
1178+
Object to be converted to hex.
1179+
encoder : callable, optional
1180+
Function to encode the bytes. Default is base64.b85encode.
1181+
1182+
Returns
1183+
-------
1184+
bytes
1185+
Object converted to bytes.
1186+
"""
1187+
return encoder(dill.dumps(obj)).hex()
1188+
1189+
1190+
def from_hex_decode(obj_bytes, decoder=base64.b85decode):
1191+
"""Converts an object from hex representation using dill.
1192+
1193+
Parameters
1194+
----------
1195+
obj_bytes : str
1196+
Hex string to be converted to object.
1197+
decoder : callable, optional
1198+
Function to decode the bytes. Default is base64.b85decode.
1199+
1200+
Returns
1201+
-------
1202+
object
1203+
Object converted from bytes.
1204+
"""
1205+
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
1206+
1207+
11701208
if __name__ == "__main__":
11711209
import doctest
11721210

tests/fixtures/motor/tanks_fixtures.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
LevelBasedTank,
99
MassBasedTank,
1010
MassFlowRateBasedTank,
11+
SphericalTank,
1112
TankGeometry,
1213
UllageBasedTank,
1314
)
@@ -430,9 +431,7 @@ def oxidizer_tank(oxidizer_fluid, oxidizer_pressurant, propellant_tank_geometry)
430431

431432

432433
@pytest.fixture
433-
def spherical_oxidizer_tank(
434-
oxidizer_fluid, oxidizer_pressurant, spherical_oxidizer_geometry
435-
):
434+
def spherical_oxidizer_tank(oxidizer_fluid, oxidizer_pressurant):
436435
"""An example of a oxidizer spherical tank.
437436
438437
Parameters
@@ -447,12 +446,11 @@ def spherical_oxidizer_tank(
447446
-------
448447
rocketpy.LevelBasedTank
449448
"""
450-
geometry = SphericalTank(0.051)
451449
liquid_level = Function(lambda t: 0.1 * np.exp(-t / 2) - 0.05)
452450
oxidizer_tank = LevelBasedTank(
453451
name="Lox Tank",
454452
flux_time=10,
455-
geometry=spherical_oxidizer_geometry,
453+
geometry=SphericalTank(0.0501),
456454
liquid=oxidizer_fluid,
457455
gas=oxidizer_pressurant,
458456
liquid_height=liquid_level,

0 commit comments

Comments
 (0)