Skip to content

Commit f3e45bb

Browse files
committed
MNT: change pickle encoding name to allow_pickle and test it.
1 parent dac833e commit f3e45bb

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

rocketpy/_encoders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, *args, **kwargs):
3636
- discretize: bool, whether to discretize Functions whose source
3737
are callables. If True, the accuracy of the decoding may be reduced.
3838
Default is False.
39-
- pickle_callables: bool, whether to pickle callable objects. If
39+
- allow_pickle: bool, whether to pickle callable objects. If
4040
False, callable sources (such as user-defined functions, parachute
4141
triggers or simulation callable outputs) will have their name
4242
stored instead of the function itself. This is useful for
@@ -47,7 +47,7 @@ def __init__(self, *args, **kwargs):
4747
self.include_outputs = kwargs.pop("include_outputs", False)
4848
self.include_function_data = kwargs.pop("include_function_data", True)
4949
self.discretize = kwargs.pop("discretize", False)
50-
self.pickle_callables = kwargs.pop("pickle_callables", True)
50+
self.allow_pickle = kwargs.pop("allow_pickle", True)
5151
super().__init__(*args, **kwargs)
5252

5353
def default(self, o):
@@ -66,15 +66,15 @@ def default(self, o):
6666
encoding = o.to_dict(
6767
include_outputs=self.include_outputs,
6868
discretize=self.discretize,
69-
pickle_callables=self.pickle_callables,
69+
allow_pickle=self.allow_pickle,
7070
)
7171
encoding["signature"] = get_class_signature(o)
7272
return encoding
7373
elif hasattr(o, "to_dict"):
7474
encoding = o.to_dict(
7575
include_outputs=self.include_outputs,
7676
discretize=self.discretize,
77-
pickle_callables=self.pickle_callables,
77+
allow_pickle=self.allow_pickle,
7878
)
7979
encoding = remove_circular_references(encoding)
8080

rocketpy/mathutils/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3578,7 +3578,7 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument
35783578
source = self.source
35793579

35803580
if callable(source):
3581-
if kwargs.get("pickle_callables", True):
3581+
if kwargs.get("allow_pickle", True):
35823582
source = to_hex_encode(source)
35833583
else:
35843584
source = source.__name__

rocketpy/rocket/parachute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,11 @@ def all_info(self):
252252
# self.plots.all() # Parachutes still doesn't have plots
253253

254254
def to_dict(self, **kwargs):
255-
pickle_callables = kwargs.get("pickle_callables", True)
255+
allow_pickle = kwargs.get("allow_pickle", True)
256256
trigger = self.trigger
257257

258258
if callable(self.trigger) and not isinstance(self.trigger, Function):
259-
if pickle_callables:
259+
if allow_pickle:
260260
trigger = to_hex_encode(trigger)
261261
else:
262262
trigger = trigger.__name__
@@ -274,7 +274,7 @@ def to_dict(self, **kwargs):
274274
data["noise_signal"] = self.noise_signal
275275
data["noise_function"] = (
276276
to_hex_encode(self.noise_function)
277-
if pickle_callables
277+
if allow_pickle
278278
else self.noise_function.__name__
279279
)
280280
data["noisy_pressure_signal"] = self.noisy_pressure_signal

tests/integration/test_encoding.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from rocketpy._encoders import RocketPyDecoder, RocketPyEncoder
88

9+
from rocketpy.tools import from_hex_decode
10+
911

1012
@pytest.mark.parametrize(
1113
["flight_name", "include_outputs"],
@@ -232,7 +234,7 @@ def test_rocket_encoder(rocket_name, request):
232234

233235
@pytest.mark.parametrize("rocket_name", ["calisto_robust"])
234236
def test_encoder_discretize(rocket_name, request):
235-
"""Test encoding the total mass of``rocketpy.Rocket`` with
237+
"""Test encoding the total mass of ``rocketpy.Rocket`` with
236238
discretized encoding.
237239
238240
Parameters
@@ -261,3 +263,22 @@ def test_encoder_discretize(rocket_name, request):
261263
atol=1e-1,
262264
)
263265
assert isinstance(mass_loaded.source, np.ndarray)
266+
267+
268+
@pytest.mark.parametrize("parachute_name", ["calisto_main_chute"])
269+
def test_encoder_no_pickle(parachute_name, request):
270+
"""Test encoding of a ``rocketpy.Parachute`` disallowing
271+
pickle usage.
272+
"""
273+
parachute_to_encode = request.getfixturevalue(parachute_name)
274+
275+
json_encoded = json.dumps(
276+
parachute_to_encode,
277+
cls=RocketPyEncoder,
278+
allow_pickle=False,
279+
)
280+
281+
trigger_loaded = json.loads(json_encoded)["trigger"]
282+
283+
with pytest.raises(ValueError, match=r"non-hexadecimal number found"):
284+
from_hex_decode(trigger_loaded)

0 commit comments

Comments
 (0)