Skip to content

Commit 97f8146

Browse files
author
Morgan Thomas
committed
Inital implementation of callback function with some quick fixes of existing functions for compatability
1 parent 054f893 commit 97f8146

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

rocketpy/simulation/monte_carlo.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ class MonteCarlo:
8080
"""
8181

8282
def __init__(
83-
self, filename, environment, rocket, flight, export_list=None
83+
self, filename, environment, rocket, flight, export_list=None,
84+
export_function=None
8485
): # pylint: disable=too-many-statements
8586
"""
8687
Initialize a MonteCarlo object.
@@ -104,6 +105,11 @@ def __init__(
104105
`out_of_rail_stability_margin`, `out_of_rail_time`,
105106
`out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`,
106107
`lateral_surface_wind`. Default is None.
108+
export_function : callable, optional
109+
A function which gets called at the end of a simulation to collect
110+
additional data to be exported that isn't pre-defined. Takes the
111+
Flight object as an argument and returns a dictionary. Default is None.
112+
107113
108114
Returns
109115
-------
@@ -132,6 +138,7 @@ def __init__(
132138
self._last_print_len = 0 # used to print on the same line
133139

134140
self.export_list = self.__check_export_list(export_list)
141+
self.export_function = export_function
135142

136143
try:
137144
self.import_inputs()
@@ -359,6 +366,13 @@ def __export_flight_data(
359366
for export_item in self.export_list
360367
}
361368

369+
if self.export_function is not None:
370+
additional_exports = self.export_function(flight)
371+
for key in additional_exports.keys():
372+
if key in self.export_list:
373+
raise ValueError(f"Invalid export function, returns dict which overwrites key, '{key}'")
374+
results = results | additional_exports
375+
362376
input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
363377
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")
364378

@@ -654,9 +668,12 @@ def set_processed_results(self):
654668
"""
655669
self.processed_results = {}
656670
for result, values in self.results.items():
657-
mean = np.mean(values)
658-
stdev = np.std(values)
659-
self.processed_results[result] = (mean, stdev)
671+
try:
672+
mean = np.mean(values)
673+
stdev = np.std(values)
674+
self.processed_results[result] = (mean, stdev)
675+
except TypeError:
676+
self.processed_results[result] = (None, None)
660677

661678
# Import methods
662679

0 commit comments

Comments
 (0)