diff --git a/CHANGELOG.md b/CHANGELOG.md index d4d083a77..90d73549b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Attention: The newest changes should be on top --> ### Added +- ENH: Callback function for collecting additional data from Monte Carlo sims [#702](https://github.com/RocketPy-Team/RocketPy/pull/702) - ENH: Implement optional plot saving [#597](https://github.com/RocketPy-Team/RocketPy/pull/597) ### Changed diff --git a/docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb b/docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb index 3886f72c4..2f94a0b16 100644 --- a/docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb +++ b/docs/notebooks/monte_carlo_analysis/monte_carlo_class_usage.ipynb @@ -1115,6 +1115,100 @@ " type=\"impact\",\n", ")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom exports using callback functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have shown, so far, how to perform to use the `MonteCarlo` class and visualize its results. By default, some variables exported to the output files, such as *apogee* and *x_impact*. The `export_list` argument provides a simplified way for the user to export additional variables listed in the documentation, such as *inclination* and *heading*. \n", + "\n", + "There are applications in which you might need to extract more information in the results than the `export_list` argument can handle. To that end, the `MonteCarlo` class has a `data_collector` argument which allows you customize further the output of the simulation.\n", + "\n", + "To exemplify its use, we show how to export the *date* of the environment used in the simulation together with the *average reynolds number* along with the default variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use the `stochastic_env`, `stochastic_rocket` and `stochastic_flight` objects previously defined, and only change the `MonteCarlo` object. First, we need to define our customized data collector." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "\n", + "# Defining custom callback functions\n", + "def get_average_reynolds_number(flight):\n", + " reynold_number_list = flight.reynolds_number(flight.time)\n", + " average_reynolds_number = np.mean(reynold_number_list)\n", + " return average_reynolds_number\n", + "\n", + "\n", + "def get_date(flight):\n", + " return flight.env.date\n", + "\n", + "\n", + "custom_data_collector = {\n", + " \"average_reynolds_number\": get_average_reynolds_number,\n", + " \"date\": get_date,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `data_collector` must be a dictionary whose keys are the names of the variables we want to export and the values are callback functions (python callables) that compute these variable values. Notice how we can compute complex expressions in this function and just export the result. For instance, the *get_average_reynolds_number* calls the `flight.reynolds_number` method for each value in `flight.time` list and computes the average value using numpy's `mean`. The *date* variable is straightforward.\n", + "\n", + "After we define the data collector, we pass it as an argument to the `MonteCarlo` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_dispersion = MonteCarlo(\n", + " filename=\"monte_carlo_analysis_outputs/monte_carlo_class_example_customized\",\n", + " environment=stochastic_env,\n", + " rocket=stochastic_rocket,\n", + " flight=stochastic_flight,\n", + " export_list=[\"apogee\", \"apogee_time\", \"x_impact\"],\n", + " data_collector=custom_data_collector,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_dispersion.simulate(number_of_simulations=10, append=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_dispersion.prints.all()" + ] } ], "metadata": { @@ -1134,7 +1228,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.11.2" } }, "nbformat": 4, diff --git a/rocketpy/prints/monte_carlo_prints.py b/rocketpy/prints/monte_carlo_prints.py index 6249626ce..dc7cc1265 100644 --- a/rocketpy/prints/monte_carlo_prints.py +++ b/rocketpy/prints/monte_carlo_prints.py @@ -24,4 +24,7 @@ def all(self): print(f"{'Parameter':>25} {'Mean':>15} {'Std. Dev.':>15}") print("-" * 60) for key, value in self.monte_carlo.processed_results.items(): - print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}") + try: + print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}") + except TypeError: + print(f"{key:>25} {str(value[0]):>15} {str(value[1]):>15}") diff --git a/rocketpy/simulation/monte_carlo.py b/rocketpy/simulation/monte_carlo.py index 4ed051a4f..70584838a 100644 --- a/rocketpy/simulation/monte_carlo.py +++ b/rocketpy/simulation/monte_carlo.py @@ -48,6 +48,9 @@ class MonteCarlo: The stochastic flight object to be iterated over. export_list : list The list of variables to export at each simulation. + data_collector : dict + A dictionary whose keys are the names of the additional + exported variables and the values are callback functions. inputs_log : list List of dictionaries with the inputs used in each simulation. outputs_log : list @@ -80,7 +83,13 @@ class MonteCarlo: """ def __init__( - self, filename, environment, rocket, flight, export_list=None + self, + filename, + environment, + rocket, + flight, + export_list=None, + data_collector=None, ): # pylint: disable=too-many-statements """ Initialize a MonteCarlo object. @@ -104,6 +113,17 @@ def __init__( `out_of_rail_stability_margin`, `out_of_rail_time`, `out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`, `lateral_surface_wind`. Default is None. + data_collector : dict, optional + A dictionary whose keys are the names of the exported variables + and the values are callback functions. The keys (variable names) must not + overwrite the default names on 'export_list'. The callback functions receive + a Flight object and returns a value of that variable. For instance + + .. code-block:: python + custom_data_collector = { + "max_acceleration": lambda flight: max(flight.acceleration(flight.time)), + "date": lambda flight: flight.env.date, + } Returns ------- @@ -132,6 +152,8 @@ def __init__( self._last_print_len = 0 # used to print on the same line self.export_list = self.__check_export_list(export_list) + self._check_data_collector(data_collector) + self.data_collector = data_collector try: self.import_inputs() @@ -359,6 +381,17 @@ def __export_flight_data( for export_item in self.export_list } + if self.data_collector is not None: + additional_exports = {} + for key, callback in self.data_collector.items(): + try: + additional_exports[key] = callback(flight) + except Exception as e: + raise ValueError( + f"An error was encountered running 'data_collector' callback {key}. " + ) from e + results = results | additional_exports + input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n") output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n") @@ -466,6 +499,37 @@ def __check_export_list(self, export_list): return export_list + def _check_data_collector(self, data_collector): + """Check if data collector provided is a valid + + Parameters + ---------- + data_collector : dict + A dictionary whose keys are the names of the exported variables + and the values are callback functions that receive a Flight object + and returns a value of that variable + """ + + if data_collector is not None: + + if not isinstance(data_collector, dict): + raise ValueError( + "Invalid 'data_collector' argument! " + "It must be a dict of callback functions." + ) + + for key, callback in data_collector.items(): + if key in self.export_list: + raise ValueError( + "Invalid 'data_collector' key! " + f"Variable names overwrites 'export_list' key '{key}'." + ) + if not callable(callback): + raise ValueError( + f"Invalid value in 'data_collector' for key '{key}'! " + "Values must be python callables (callback functions)." + ) + def __reprint(self, msg, end="\n", flush=False): """ Prints a message on the same line as the previous one and replaces the @@ -654,9 +718,12 @@ def set_processed_results(self): """ self.processed_results = {} for result, values in self.results.items(): - mean = np.mean(values) - stdev = np.std(values) - self.processed_results[result] = (mean, stdev) + try: + mean = np.mean(values) + stdev = np.std(values) + self.processed_results[result] = (mean, stdev) + except TypeError: + self.processed_results[result] = (None, None) # Import methods diff --git a/tests/integration/test_monte_carlo.py b/tests/integration/test_monte_carlo.py index b5caddbc8..51d8bfae9 100644 --- a/tests/integration/test_monte_carlo.py +++ b/tests/integration/test_monte_carlo.py @@ -111,3 +111,58 @@ def test_monte_carlo_export_ellipses_to_kml(monte_carlo_calisto_pre_loaded): ) os.remove("monte_carlo_class_example.kml") + + +@pytest.mark.slow +def test_monte_carlo_callback(monte_carlo_calisto): + """Tests the data_collector argument of the MonteCarlo class. + + Parameters + ---------- + monte_carlo_calisto : MonteCarlo + The MonteCarlo object, this is a pytest fixture. + """ + + # define valid data collector + valid_data_collector = { + "name": lambda flight: flight.name, + "density_t0": lambda flight: flight.env.density(0), + } + + monte_carlo_calisto.data_collector = valid_data_collector + # NOTE: this is really slow, it runs 10 flight simulations + monte_carlo_calisto.simulate(number_of_simulations=10, append=False) + + # tests if print works when we have None in summary + monte_carlo_calisto.info() + + ## tests if an error is raised for invalid data_collector definitions + # invalid type + def invalid_data_collector(flight): + return flight.name + + with pytest.raises(ValueError): + monte_carlo_calisto._check_data_collector(invalid_data_collector) + + # invalid key overwrite + invalid_data_collector = {"apogee": lambda flight: flight.apogee} + with pytest.raises(ValueError): + monte_carlo_calisto._check_data_collector(invalid_data_collector) + + # invalid callback definition + invalid_data_collector = {"name": "Calisto"} # callbacks must be callables! + with pytest.raises(ValueError): + monte_carlo_calisto._check_data_collector(invalid_data_collector) + + # invalid logic (division by zero) + invalid_data_collector = { + "density_t0": lambda flight: flight.env.density(0) / "0", + } + monte_carlo_calisto.data_collector = invalid_data_collector + # NOTE: this is really slow, it runs 10 flight simulations + with pytest.raises(ValueError): + monte_carlo_calisto.simulate(number_of_simulations=10, append=False) + + os.remove("monte_carlo_test.errors.txt") + os.remove("monte_carlo_test.outputs.txt") + os.remove("monte_carlo_test.inputs.txt")