Skip to content

Commit e75d25f

Browse files
committed
ENH: first implementation of CustomSampler
1 parent c674725 commit e75d25f

File tree

8 files changed

+335
-9
lines changed

8 files changed

+335
-9
lines changed

rocketpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .sensors import Accelerometer, Barometer, GnssReceiver, Gyroscope
4545
from .simulation import Flight, MonteCarlo, MultivariateRejectionSampler
4646
from .stochastic import (
47+
CustomSampler,
4748
StochasticAirBrakes,
4849
StochasticEllipticalFins,
4950
StochasticEnvironment,

rocketpy/_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import numpy as np
88

99
from rocketpy.mathutils.function import Function
10-
from rocketpy.prints.flight_prints import _FlightPrints
1110
from rocketpy.plots.flight_plots import _FlightPlots
11+
from rocketpy.prints.flight_prints import _FlightPrints
1212

1313

1414
class RocketPyEncoder(json.JSONEncoder):

rocketpy/stochastic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
associated with each input parameter.
66
"""
77

8+
from .custom_sampler import CustomSampler
89
from .stochastic_aero_surfaces import (
910
StochasticAirBrakes,
1011
StochasticEllipticalFins,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Provides an abstract class so that users can build custom samplers upon
3+
"""
4+
5+
from abc import ABC, abstractmethod
6+
7+
8+
class CustomSampler(ABC):
9+
"""Abstract subclass for user defined samplers"""
10+
11+
@abstractmethod
12+
def sample(self, n_samples=1):
13+
"""Generates n samples from the custom distribution
14+
15+
Parameters
16+
----------
17+
n_samples : int, optional
18+
Numbers of samples to be generated
19+
20+
Returns
21+
-------
22+
sample_list : list
23+
A list with n_samples elements, each of which is a valid sample
24+
"""
25+
26+
@abstractmethod
27+
def reset_seed(self, seed=None):
28+
"""Resets the seeds of all associated stochastic generators
29+
30+
Parameters
31+
----------
32+
seed : int, optional
33+
Seed for the random number generator. The default is None
34+
35+
Returns
36+
-------
37+
None
38+
"""

rocketpy/stochastic/stochastic_model.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99

1010
from rocketpy.mathutils.function import Function
11+
from rocketpy.stochastic.custom_sampler import CustomSampler
1112

1213
from ..tools import get_distribution
1314

@@ -96,6 +97,10 @@ def _set_stochastic(self, seed=None):
9697
attr_value = self._validate_list(input_name, input_value)
9798
elif isinstance(input_value, (int, float)):
9899
attr_value = self._validate_scalar(input_name, input_value)
100+
elif isinstance(input_value, CustomSampler):
101+
attr_value = self._validate_custom_sampler(
102+
input_name, input_value, seed
103+
)
99104
else:
100105
raise AssertionError(
101106
f"'{input_name}' must be a tuple, list, int, or float"
@@ -436,6 +441,33 @@ def _validate_positive_int_list(self, input_name, input_value):
436441
isinstance(member, int) and member >= 0 for member in input_value
437442
), f"`{input_name}` must be a list of positive integers"
438443

444+
def _validate_custom_sampler(self, input_name, sampler, seed=None):
445+
"""
446+
Validate a custom sampler.
447+
448+
Parameters
449+
----------
450+
input_name : str
451+
Name of the input argument.
452+
sampler : CustomSampler object
453+
Custom sampler provided by the user
454+
seed : int, optional
455+
Seed for the random number generator. The default is None
456+
457+
Raises
458+
------
459+
AssertionError
460+
If the input is not in a valid format.
461+
"""
462+
try:
463+
sampler.reset_seed(seed)
464+
except RuntimeError as e:
465+
raise RuntimeError(
466+
f"An error occurred in the 'reset_seed' of {input_name} CustomSampler"
467+
) from e
468+
469+
return sampler
470+
439471
def _validate_airfoil(self, airfoil):
440472
"""
441473
Validate airfoil input.
@@ -490,9 +522,17 @@ def dict_generator(self):
490522
generated_dict = {}
491523
for arg, value in self.__dict__.items():
492524
if isinstance(value, tuple):
493-
generated_dict[arg] = value[-1](value[0], value[1])
525+
dist_sampler = value[-1]
526+
generated_dict[arg] = dist_sampler(value[0], value[1])
494527
elif isinstance(value, list):
495528
generated_dict[arg] = choice(value) if value else value
529+
elif isinstance(value, CustomSampler):
530+
try:
531+
generated_dict[arg] = value.sample(n_samples=1)[0]
532+
except RuntimeError as e:
533+
raise RuntimeError(
534+
f"An error occurred in the 'sample' of {arg} CustomSampler"
535+
) from e
496536
self.last_rnd_dict = generated_dict
497537
yield generated_dict
498538

@@ -527,6 +567,12 @@ def format_attribute(attr, value):
527567
f"{nominal_value:.5f} ± "
528568
f"{std_dev:.5f} ({dist_func.__name__})"
529569
)
570+
elif isinstance(value, CustomSampler):
571+
sampler_name = type(value).__name__
572+
return (
573+
f"\t{attr.ljust(max_str_length)} "
574+
f"\t{sampler_name.ljust(max_str_length)} "
575+
)
530576
return None
531577

532578
attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
@@ -550,6 +596,9 @@ def format_attribute(attr, value):
550596
list_attributes = [
551597
attr for attr, val in items if isinstance(val, list) and len(val) > 1
552598
]
599+
custom_attributes = [
600+
attr for attr, val in items if isinstance(val, CustomSampler)
601+
]
553602

554603
if constant_attributes:
555604
report.append("\nConstant Attributes:")
@@ -568,5 +617,10 @@ def format_attribute(attr, value):
568617
report.extend(
569618
format_attribute(attr, attributes[attr]) for attr in list_attributes
570619
)
620+
if custom_attributes:
621+
report.append("\nStochastic Attributes with Custom user samplers:")
622+
report.extend(
623+
format_attribute(attr, attributes[attr]) for attr in custom_attributes
624+
)
571625

572626
print("\n".join(filter(None, report)))

rocketpy/utilities.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
import ast
22
import inspect
3-
import traceback
4-
import warnings
53
import json
64
import os
7-
8-
from pathlib import Path
9-
from importlib.metadata import version
5+
import traceback
6+
import warnings
107
from datetime import date
8+
from importlib.metadata import version
9+
from pathlib import Path
10+
1111
import matplotlib.pyplot as plt
1212
import numpy as np
1313
from scipy.integrate import solve_ivp
1414

15+
from ._encoders import RocketPyDecoder, RocketPyEncoder
1516
from .environment.environment import Environment
1617
from .mathutils.function import Function
1718
from .plots.plot_helpers import show_or_save_plot
1819
from .rocket.aero_surface import TrapezoidalFins
1920
from .simulation.flight import Flight
20-
from ._encoders import RocketPyEncoder, RocketPyDecoder
2121

2222

2323
def compute_cd_s_from_drop_test(

0 commit comments

Comments
 (0)