Skip to content

Commit 6f7104c

Browse files
committed
TST: implementing tests for CustomSampler and making small corrections
1 parent e75d25f commit 6f7104c

File tree

9 files changed

+144
-239
lines changed

9 files changed

+144
-239
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"Bigl",
4949
"Bigr",
5050
"bijective",
51+
"Bivariate",
5152
"bmatrix",
5253
"boldsymbol",
5354
"boxplot",

rocketpy/stochastic/custom_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class CustomSampler(ABC):
1010

1111
@abstractmethod
1212
def sample(self, n_samples=1):
13-
"""Generates n samples from the custom distribution
13+
"""Generates samples from the custom distribution
1414
1515
Parameters
1616
----------
@@ -19,7 +19,7 @@ def sample(self, n_samples=1):
1919
2020
Returns
2121
-------
22-
sample_list : list
22+
samples_list : list
2323
A list with n_samples elements, each of which is a valid sample
2424
"""
2525

rocketpy/stochastic/stochastic_model.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def _set_stochastic(self, seed=None):
8989
attr_value = None
9090
if input_value is not None:
9191
if "factor" in input_name:
92-
attr_value = self._validate_factors(input_name, input_value)
92+
attr_value = self._validate_factors(
93+
input_name, input_value, seed
94+
)
9395
elif input_name not in self.exception_list:
9496
if isinstance(input_value, tuple):
9597
attr_value = self._validate_tuple(input_name, input_value)
@@ -104,6 +106,7 @@ def _set_stochastic(self, seed=None):
104106
else:
105107
raise AssertionError(
106108
f"'{input_name}' must be a tuple, list, int, or float"
109+
"or a custom sampler"
107110
)
108111
else:
109112
attr_value = [getattr(self.obj, input_name)]
@@ -285,7 +288,7 @@ def _validate_scalar(self, input_name, input_value, getattr=getattr): # pylint:
285288
get_distribution("normal", self.__random_number_generator),
286289
)
287290

288-
def _validate_factors(self, input_name, input_value):
291+
def _validate_factors(self, input_name, input_value, seed):
289292
"""
290293
Validate factor arguments.
291294
@@ -313,8 +316,12 @@ def _validate_factors(self, input_name, input_value):
313316
return self._validate_tuple_factor(input_name, input_value)
314317
elif isinstance(input_value, list):
315318
return self._validate_list_factor(input_name, input_value)
319+
elif isinstance(input_value, CustomSampler):
320+
return self._validate_custom_sampler(input_name, input_value, seed)
316321
else:
317-
raise AssertionError(f"`{input_name}`: must be either a tuple or list")
322+
raise AssertionError(
323+
f"`{input_name}`: must be either a tuple or listor a custom sampler"
324+
)
318325

319326
def _validate_tuple_factor(self, input_name, factor_tuple):
320327
"""
@@ -463,7 +470,7 @@ def _validate_custom_sampler(self, input_name, sampler, seed=None):
463470
sampler.reset_seed(seed)
464471
except RuntimeError as e:
465472
raise RuntimeError(
466-
f"An error occurred in the 'reset_seed' of {input_name} CustomSampler"
473+
f"An error occurred in the 'reset_seed' method of {input_name} CustomSampler"
467474
) from e
468475

469476
return sampler
@@ -531,7 +538,7 @@ def dict_generator(self):
531538
generated_dict[arg] = value.sample(n_samples=1)[0]
532539
except RuntimeError as e:
533540
raise RuntimeError(
534-
f"An error occurred in the 'sample' of {arg} CustomSampler"
541+
f"An error occurred in the 'sample' method of {arg} CustomSampler"
535542
) from e
536543
self.last_rnd_dict = generated_dict
537544
yield generated_dict

test_custom_sampler.ipynb

Lines changed: 0 additions & 232 deletions
This file was deleted.

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"tests.fixtures.units.numerical_fixtures",
1919
"tests.fixtures.monte_carlo.monte_carlo_fixtures",
2020
"tests.fixtures.monte_carlo.stochastic_fixtures",
21+
"tests.fixtures.monte_carlo.custom_sampler_fixtures",
2122
"tests.fixtures.monte_carlo.stochastic_motors_fixtures",
2223
"tests.fixtures.sensors.sensors_fixtures",
2324
"tests.fixtures.generic_surfaces.generic_surfaces_fixtures",
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""This file contains fixtures of CustomSampler used in stochastic classes."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from rocketpy import CustomSampler
7+
8+
9+
@pytest.fixture
10+
def elevation_sampler():
11+
"""Fixture to create mixture of two gaussian sampler"""
12+
means_tuple = [1400, 1500]
13+
sd_tuple = [40, 50]
14+
prob_tuple = [0.4, 0.6]
15+
return TwoGaussianMixture(means_tuple, sd_tuple, prob_tuple)
16+
17+
18+
class TwoGaussianMixture(CustomSampler):
19+
"""Class to sample from a mixture of two Gaussian distributions"""
20+
21+
def __init__(self, means_tuple, sd_tuple, prob_tuple, seed=None):
22+
"""Creates a sampler for a mixture of two Gaussian distributions
23+
24+
Parameters
25+
----------
26+
means_tuple : 2-tuple
27+
2-Tuple that contains the mean of each normal distribution of the
28+
mixture
29+
sd_tuple : 2-tuple
30+
2-Tuple that contains the sd of each normal distribution of the
31+
mixture
32+
prob_tuple : 2-tuple
33+
2-Tuple that contains the probability of each normal distribution of the
34+
mixture. Its entries should be non-negative and sum up to 1.
35+
"""
36+
np.random.default_rng(seed)
37+
self.means_tuple = means_tuple
38+
self.sd_tuple = sd_tuple
39+
self.prob_tuple = prob_tuple
40+
41+
def sample(self, n_samples=1):
42+
"""Samples from a mixture of two Gaussian
43+
44+
Parameters
45+
----------
46+
n_samples : int, optional
47+
Number of samples, by default 1
48+
49+
Returns
50+
-------
51+
samples_list
52+
List containing n_samples samples
53+
"""
54+
samples_list = [0] * n_samples
55+
mixture_id_list = np.random.binomial(1, self.prob_tuple[0], n_samples)
56+
for i, mixture_id in enumerate(mixture_id_list):
57+
if mixture_id:
58+
samples_list[i] = np.random.normal(
59+
self.means_tuple[0], self.sd_tuple[0]
60+
)
61+
else:
62+
samples_list[i] = np.random.normal(
63+
self.means_tuple[1], self.sd_tuple[1]
64+
)
65+
66+
return samples_list
67+
68+
def reset_seed(self, seed=None):
69+
"""Resets all associated random number generators
70+
71+
Parameters
72+
----------
73+
seed : int, optional
74+
Seed for the random number generator.
75+
"""
76+
np.random.default_rng(seed)

tests/fixtures/monte_carlo/stochastic_fixtures.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,36 @@ def stochastic_environment(example_spaceport_env):
4343
)
4444

4545

46+
@pytest.fixture
47+
def stochastic_environment_custom_sampler(example_spaceport_env, elevation_sampler):
48+
"""This fixture is used to create a stochastic environment object for the
49+
Calisto flight using a custom sampler for the elevation.
50+
51+
Parameters
52+
----------
53+
example_spaceport_env : Environment
54+
This is another fixture.
55+
56+
elevation_sampler: CustomSampler
57+
This is another fixture.
58+
59+
Returns
60+
-------
61+
StochasticEnvironment
62+
The stochastic environment object
63+
"""
64+
return StochasticEnvironment(
65+
environment=example_spaceport_env,
66+
elevation=elevation_sampler,
67+
gravity=None,
68+
latitude=None,
69+
longitude=None,
70+
ensemble_member=None,
71+
wind_velocity_x_factor=(1.0, 0.033, "normal"),
72+
wind_velocity_y_factor=(1.0, 0.033, "normal"),
73+
)
74+
75+
4676
@pytest.fixture
4777
def stochastic_nose_cone(calisto_nose_cone):
4878
"""This fixture is used to create a StochasticNoseCone object for the
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from rocketpy.environment.environment import Environment
2+
3+
4+
def test_create_object(stochastic_environment_custom_sampler):
5+
"""Test create object method of StochasticEnvironment class.
6+
7+
This test checks if the create_object method of the StochasticEnvironment
8+
class creates a StochasticEnvironment object from the randomly generated
9+
input arguments.
10+
11+
Parameters
12+
----------
13+
stochastic_environment : StochasticEnvironment
14+
StochasticEnvironment object to be tested.
15+
16+
Returns
17+
-------
18+
None
19+
"""
20+
obj = stochastic_environment_custom_sampler.create_object()
21+
assert isinstance(obj, Environment)

tests/unit/test_stochastic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"stochastic_rail_buttons",
88
"stochastic_main_parachute",
99
"stochastic_environment",
10+
"stochastic_environment_custom_sampler",
1011
"stochastic_tail",
1112
"stochastic_calisto",
1213
],

0 commit comments

Comments
 (0)