Skip to content

Commit ac4d3af

Browse files
authored
Merge pull request #702 from emtee14/enh/monte-carlo-callback
ENH: Callback function for collecting additional data from Monte Carlo sims
2 parents 5856353 + 6c477e3 commit ac4d3af

File tree

5 files changed

+226
-6
lines changed

5 files changed

+226
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Attention: The newest changes should be on top -->
3232

3333
### Added
3434

35+
- ENH: Callback function for collecting additional data from Monte Carlo sims [#702](https://github.com/RocketPy-Team/RocketPy/pull/702)
3536
- ENH: Implement optional plot saving [#597](https://github.com/RocketPy-Team/RocketPy/pull/597)
3637

3738
### Changed

docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,100 @@
11151115
" type=\"impact\",\n",
11161116
")"
11171117
]
1118+
},
1119+
{
1120+
"cell_type": "markdown",
1121+
"metadata": {},
1122+
"source": [
1123+
"## Custom exports using callback functions"
1124+
]
1125+
},
1126+
{
1127+
"cell_type": "markdown",
1128+
"metadata": {},
1129+
"source": [
1130+
"We have shown, so far, how to perform to use the `MonteCarlo` class and visualize its results. By default, some variables exported to the output files, such as *apogee* and *x_impact*. The `export_list` argument provides a simplified way for the user to export additional variables listed in the documentation, such as *inclination* and *heading*. \n",
1131+
"\n",
1132+
"There are applications in which you might need to extract more information in the results than the `export_list` argument can handle. To that end, the `MonteCarlo` class has a `data_collector` argument which allows you customize further the output of the simulation.\n",
1133+
"\n",
1134+
"To exemplify its use, we show how to export the *date* of the environment used in the simulation together with the *average reynolds number* along with the default variables."
1135+
]
1136+
},
1137+
{
1138+
"cell_type": "markdown",
1139+
"metadata": {},
1140+
"source": [
1141+
"We will use the `stochastic_env`, `stochastic_rocket` and `stochastic_flight` objects previously defined, and only change the `MonteCarlo` object. First, we need to define our customized data collector."
1142+
]
1143+
},
1144+
{
1145+
"cell_type": "code",
1146+
"execution_count": null,
1147+
"metadata": {},
1148+
"outputs": [],
1149+
"source": [
1150+
"import numpy as np\n",
1151+
"\n",
1152+
"\n",
1153+
"# Defining custom callback functions\n",
1154+
"def get_average_reynolds_number(flight):\n",
1155+
" reynold_number_list = flight.reynolds_number(flight.time)\n",
1156+
" average_reynolds_number = np.mean(reynold_number_list)\n",
1157+
" return average_reynolds_number\n",
1158+
"\n",
1159+
"\n",
1160+
"def get_date(flight):\n",
1161+
" return flight.env.date\n",
1162+
"\n",
1163+
"\n",
1164+
"custom_data_collector = {\n",
1165+
" \"average_reynolds_number\": get_average_reynolds_number,\n",
1166+
" \"date\": get_date,\n",
1167+
"}"
1168+
]
1169+
},
1170+
{
1171+
"cell_type": "markdown",
1172+
"metadata": {},
1173+
"source": [
1174+
"The `data_collector` must be a dictionary whose keys are the names of the variables we want to export and the values are callback functions (python callables) that compute these variable values. Notice how we can compute complex expressions in this function and just export the result. For instance, the *get_average_reynolds_number* calls the `flight.reynolds_number` method for each value in `flight.time` list and computes the average value using numpy's `mean`. The *date* variable is straightforward.\n",
1175+
"\n",
1176+
"After we define the data collector, we pass it as an argument to the `MonteCarlo` class."
1177+
]
1178+
},
1179+
{
1180+
"cell_type": "code",
1181+
"execution_count": null,
1182+
"metadata": {},
1183+
"outputs": [],
1184+
"source": [
1185+
"test_dispersion = MonteCarlo(\n",
1186+
" filename=\"monte_carlo_analysis_outputs/monte_carlo_class_example_customized\",\n",
1187+
" environment=stochastic_env,\n",
1188+
" rocket=stochastic_rocket,\n",
1189+
" flight=stochastic_flight,\n",
1190+
" export_list=[\"apogee\", \"apogee_time\", \"x_impact\"],\n",
1191+
" data_collector=custom_data_collector,\n",
1192+
")"
1193+
]
1194+
},
1195+
{
1196+
"cell_type": "code",
1197+
"execution_count": null,
1198+
"metadata": {},
1199+
"outputs": [],
1200+
"source": [
1201+
"test_dispersion.simulate(number_of_simulations=10, append=False)"
1202+
]
1203+
},
1204+
{
1205+
"cell_type": "code",
1206+
"execution_count": null,
1207+
"metadata": {},
1208+
"outputs": [],
1209+
"source": [
1210+
"test_dispersion.prints.all()"
1211+
]
11181212
}
11191213
],
11201214
"metadata": {
@@ -1134,7 +1228,7 @@
11341228
"name": "python",
11351229
"nbconvert_exporter": "python",
11361230
"pygments_lexer": "ipython3",
1137-
"version": "3.10.11"
1231+
"version": "3.11.2"
11381232
}
11391233
},
11401234
"nbformat": 4,

rocketpy/prints/monte_carlo_prints.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ def all(self):
2424
print(f"{'Parameter':>25} {'Mean':>15} {'Std. Dev.':>15}")
2525
print("-" * 60)
2626
for key, value in self.monte_carlo.processed_results.items():
27-
print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}")
27+
try:
28+
print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}")
29+
except TypeError:
30+
print(f"{key:>25} {str(value[0]):>15} {str(value[1]):>15}")

rocketpy/simulation/monte_carlo.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class MonteCarlo:
4848
The stochastic flight object to be iterated over.
4949
export_list : list
5050
The list of variables to export at each simulation.
51+
data_collector : dict
52+
A dictionary whose keys are the names of the additional
53+
exported variables and the values are callback functions.
5154
inputs_log : list
5255
List of dictionaries with the inputs used in each simulation.
5356
outputs_log : list
@@ -80,7 +83,13 @@ class MonteCarlo:
8083
"""
8184

8285
def __init__(
83-
self, filename, environment, rocket, flight, export_list=None
86+
self,
87+
filename,
88+
environment,
89+
rocket,
90+
flight,
91+
export_list=None,
92+
data_collector=None,
8493
): # pylint: disable=too-many-statements
8594
"""
8695
Initialize a MonteCarlo object.
@@ -104,6 +113,17 @@ def __init__(
104113
`out_of_rail_stability_margin`, `out_of_rail_time`,
105114
`out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`,
106115
`lateral_surface_wind`. Default is None.
116+
data_collector : dict, optional
117+
A dictionary whose keys are the names of the exported variables
118+
and the values are callback functions. The keys (variable names) must not
119+
overwrite the default names on 'export_list'. The callback functions receive
120+
a Flight object and returns a value of that variable. For instance
121+
122+
.. code-block:: python
123+
custom_data_collector = {
124+
"max_acceleration": lambda flight: max(flight.acceleration(flight.time)),
125+
"date": lambda flight: flight.env.date,
126+
}
107127
108128
Returns
109129
-------
@@ -132,6 +152,8 @@ def __init__(
132152
self._last_print_len = 0 # used to print on the same line
133153

134154
self.export_list = self.__check_export_list(export_list)
155+
self._check_data_collector(data_collector)
156+
self.data_collector = data_collector
135157

136158
try:
137159
self.import_inputs()
@@ -359,6 +381,17 @@ def __export_flight_data(
359381
for export_item in self.export_list
360382
}
361383

384+
if self.data_collector is not None:
385+
additional_exports = {}
386+
for key, callback in self.data_collector.items():
387+
try:
388+
additional_exports[key] = callback(flight)
389+
except Exception as e:
390+
raise ValueError(
391+
f"An error was encountered running 'data_collector' callback {key}. "
392+
) from e
393+
results = results | additional_exports
394+
362395
input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
363396
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")
364397

@@ -466,6 +499,37 @@ def __check_export_list(self, export_list):
466499

467500
return export_list
468501

502+
def _check_data_collector(self, data_collector):
503+
"""Check if data collector provided is a valid
504+
505+
Parameters
506+
----------
507+
data_collector : dict
508+
A dictionary whose keys are the names of the exported variables
509+
and the values are callback functions that receive a Flight object
510+
and returns a value of that variable
511+
"""
512+
513+
if data_collector is not None:
514+
515+
if not isinstance(data_collector, dict):
516+
raise ValueError(
517+
"Invalid 'data_collector' argument! "
518+
"It must be a dict of callback functions."
519+
)
520+
521+
for key, callback in data_collector.items():
522+
if key in self.export_list:
523+
raise ValueError(
524+
"Invalid 'data_collector' key! "
525+
f"Variable names overwrites 'export_list' key '{key}'."
526+
)
527+
if not callable(callback):
528+
raise ValueError(
529+
f"Invalid value in 'data_collector' for key '{key}'! "
530+
"Values must be python callables (callback functions)."
531+
)
532+
469533
def __reprint(self, msg, end="\n", flush=False):
470534
"""
471535
Prints a message on the same line as the previous one and replaces the
@@ -654,9 +718,12 @@ def set_processed_results(self):
654718
"""
655719
self.processed_results = {}
656720
for result, values in self.results.items():
657-
mean = np.mean(values)
658-
stdev = np.std(values)
659-
self.processed_results[result] = (mean, stdev)
721+
try:
722+
mean = np.mean(values)
723+
stdev = np.std(values)
724+
self.processed_results[result] = (mean, stdev)
725+
except TypeError:
726+
self.processed_results[result] = (None, None)
660727

661728
# Import methods
662729

tests/integration/test_monte_carlo.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,58 @@ def test_monte_carlo_export_ellipses_to_kml(monte_carlo_calisto_pre_loaded):
111111
)
112112

113113
os.remove("monte_carlo_class_example.kml")
114+
115+
116+
@pytest.mark.slow
117+
def test_monte_carlo_callback(monte_carlo_calisto):
118+
"""Tests the data_collector argument of the MonteCarlo class.
119+
120+
Parameters
121+
----------
122+
monte_carlo_calisto : MonteCarlo
123+
The MonteCarlo object, this is a pytest fixture.
124+
"""
125+
126+
# define valid data collector
127+
valid_data_collector = {
128+
"name": lambda flight: flight.name,
129+
"density_t0": lambda flight: flight.env.density(0),
130+
}
131+
132+
monte_carlo_calisto.data_collector = valid_data_collector
133+
# NOTE: this is really slow, it runs 10 flight simulations
134+
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)
135+
136+
# tests if print works when we have None in summary
137+
monte_carlo_calisto.info()
138+
139+
## tests if an error is raised for invalid data_collector definitions
140+
# invalid type
141+
def invalid_data_collector(flight):
142+
return flight.name
143+
144+
with pytest.raises(ValueError):
145+
monte_carlo_calisto._check_data_collector(invalid_data_collector)
146+
147+
# invalid key overwrite
148+
invalid_data_collector = {"apogee": lambda flight: flight.apogee}
149+
with pytest.raises(ValueError):
150+
monte_carlo_calisto._check_data_collector(invalid_data_collector)
151+
152+
# invalid callback definition
153+
invalid_data_collector = {"name": "Calisto"} # callbacks must be callables!
154+
with pytest.raises(ValueError):
155+
monte_carlo_calisto._check_data_collector(invalid_data_collector)
156+
157+
# invalid logic (division by zero)
158+
invalid_data_collector = {
159+
"density_t0": lambda flight: flight.env.density(0) / "0",
160+
}
161+
monte_carlo_calisto.data_collector = invalid_data_collector
162+
# NOTE: this is really slow, it runs 10 flight simulations
163+
with pytest.raises(ValueError):
164+
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)
165+
166+
os.remove("monte_carlo_test.errors.txt")
167+
os.remove("monte_carlo_test.outputs.txt")
168+
os.remove("monte_carlo_test.inputs.txt")

0 commit comments

Comments
 (0)