Skip to content

Commit 6a4c0f2

Browse files
committed
ENH: allow for disallowing pickle on encoding.
1 parent d4e90fc commit 6a4c0f2

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

rocketpy/_encoders.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,41 @@
1313

1414
class RocketPyEncoder(json.JSONEncoder):
1515
"""Custom JSON encoder for RocketPy objects. It defines how to encode
16-
different types of objects to a JSON supported format."""
16+
different types of objects to a JSON supported format.
17+
"""
1718

1819
def __init__(self, *args, **kwargs):
20+
"""Initializes the encoder with parameter options.
21+
22+
Parameters
23+
----------
24+
*args : tuple
25+
Positional arguments to pass to the parent class.
26+
**kwargs : dict
27+
Keyword arguments to configure the encoder. The following
28+
options are available:
29+
- include_outputs: bool, whether to include simulation outputs.
30+
Default is False.
31+
- include_function_data: bool, whether to include Function
32+
data in the encoding. If False, Functions will be encoded by their
33+
``__repr__``. This is useful for reducing the size of the outputs,
34+
but it prevents full restoration of the object upon decoding.
35+
Default is True.
36+
- discretize: bool, whether to discretize Functions whose source
37+
are callables. If True, the accuracy of the decoding may be reduced.
38+
Default is False.
39+
- pickle_callables: bool, whether to pickle callable objects. If
40+
False, callable sources (such as user-defined functions, parachute
41+
triggers or simulation callable outputs) will have their name
42+
stored instead of the function itself. This is useful for
43+
reducing the size of the outputs, but it prevents full restoration
44+
of the object upon decoding.
45+
Default is True.
46+
"""
1947
self.include_outputs = kwargs.pop("include_outputs", False)
20-
self.discretize = kwargs.pop("discretize", False)
2148
self.include_function_data = kwargs.pop("include_function_data", True)
49+
self.discretize = kwargs.pop("discretize", False)
50+
self.pickle_callables = kwargs.pop("pickle_callables", True)
2251
super().__init__(*args, **kwargs)
2352

2453
def default(self, o):
@@ -35,13 +64,17 @@ def default(self, o):
3564
return str(o)
3665
else:
3766
encoding = o.to_dict(
38-
include_outputs=self.include_outputs, discretize=self.discretize
67+
include_outputs=self.include_outputs,
68+
discretize=self.discretize,
69+
pickle_callables=self.pickle_callables,
3970
)
4071
encoding["signature"] = get_class_signature(o)
4172
return encoding
4273
elif hasattr(o, "to_dict"):
4374
encoding = o.to_dict(
44-
include_outputs=self.include_outputs, discretize=self.discretize
75+
include_outputs=self.include_outputs,
76+
discretize=self.discretize,
77+
pickle_callables=self.pickle_callables,
4578
)
4679
encoding = remove_circular_references(encoding)
4780

rocketpy/mathutils/function.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3578,7 +3578,10 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument
35783578
source = self.source
35793579

35803580
if callable(source):
3581-
source = to_hex_encode(source)
3581+
if kwargs.get("pickle_callables", True):
3582+
source = to_hex_encode(source)
3583+
else:
3584+
source = source.__name__
35823585

35833586
return {
35843587
"source": source,

rocketpy/rocket/parachute.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,14 @@ def all_info(self):
252252
# self.plots.all() # Parachutes still doesn't have plots
253253

254254
def to_dict(self, **kwargs):
255+
pickle_callables = kwargs.get("pickle_callables", True)
255256
trigger = self.trigger
256257

257258
if callable(self.trigger) and not isinstance(self.trigger, Function):
258-
trigger = to_hex_encode(trigger)
259+
if pickle_callables:
260+
trigger = to_hex_encode(trigger)
261+
else:
262+
trigger = trigger.__name__
259263

260264
data = {
261265
"name": self.name,
@@ -268,7 +272,11 @@ def to_dict(self, **kwargs):
268272

269273
if kwargs.get("include_outputs", False):
270274
data["noise_signal"] = self.noise_signal
271-
data["noise_function"] = to_hex_encode(self.noise_function)
275+
data["noise_function"] = (
276+
to_hex_encode(self.noise_function)
277+
if pickle_callables
278+
else self.noise_function.__name__
279+
)
272280
data["noisy_pressure_signal"] = self.noisy_pressure_signal
273281
data["clean_pressure_signal"] = self.clean_pressure_signal
274282

0 commit comments

Comments
 (0)