Skip to content

Commit 40309d4

Browse files
authored
Merge pull request #598 from WeBankFinTech/dev-2.4.7
add whitelist validator to fix issue caused by using pickle serdes
2 parents 2fb85c2 + e2ff7fd commit 40309d4

7 files changed

Lines changed: 215 additions & 11 deletions

File tree

BUILD_INFO

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
eggroll.version=2.4.6
1+
eggroll.version=2.4.7

conf/whitelist.json

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
{
2+
"builtins": [
3+
"int",
4+
"list",
5+
"set"
6+
],
7+
"collections": [
8+
"OrderedDict",
9+
"defaultdict"
10+
],
11+
"eggroll.core.transfer_model": [
12+
"ErRollSiteHeader"
13+
],
14+
"eggroll.roll_pair.task.storage": [
15+
"BSS"
16+
],
17+
"federatedml.cipher_compressor.compressor": [
18+
"PackingCipherTensor",
19+
"NormalCipherPackage",
20+
"PackingCipherTensorPackage"
21+
],
22+
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.feature_histogram": [
23+
"FeatureHistogramWeights",
24+
"HistogramBag"
25+
],
26+
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.feature_importance": [
27+
"FeatureImportance"
28+
],
29+
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.g_h_optim": [
30+
"SplitInfoPackage"
31+
],
32+
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.node": [
33+
"Node"
34+
],
35+
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.splitter": [
36+
"SplitInfo"
37+
],
38+
"federatedml.evaluation.performance_recorder": [
39+
"PerformanceRecorder"
40+
],
41+
"federatedml.feature.binning.bin_result": [
42+
"BinColResults"
43+
],
44+
"federatedml.feature.binning.optimal_binning.bucket_info": [
45+
"Bucket"
46+
],
47+
"federatedml.feature.binning.quantile_summaries": [
48+
"QuantileSummaries",
49+
"Stats",
50+
"SparseQuantileSummaries"
51+
],
52+
"federatedml.feature.fate_element_type": [
53+
"NoneType"
54+
],
55+
"federatedml.feature.homo_feature_binning.homo_binning_base": [
56+
"SplitPointNode"
57+
],
58+
"federatedml.feature.instance": [
59+
"Instance"
60+
],
61+
"federatedml.feature.one_hot_encoder": [
62+
"TransferPair"
63+
],
64+
"federatedml.feature.sparse_vector": [
65+
"SparseVector"
66+
],
67+
"federatedml.framework.weights": [
68+
"TransferableWeights",
69+
"DictWeights",
70+
"NumpyWeights",
71+
"ListWeights",
72+
"OrderDictWeights",
73+
"NumericWeights"
74+
],
75+
"federatedml.linear_model.linear_model_weight": [
76+
"LinearModelWeights"
77+
],
78+
"federatedml.secureprotol.fate_paillier": [
79+
"PaillierPublicKey",
80+
"PaillierEncryptedNumber"
81+
],
82+
"federatedml.secureprotol.fixedpoint": [
83+
"FixedPointNumber"
84+
],
85+
"federatedml.secureprotol.number_theory.field.integers_modulo_prime_field": [
86+
"IntegersModuloPrimeElement"
87+
],
88+
"federatedml.secureprotol.number_theory.group.twisted_edwards_curve_group": [
89+
"TwistedEdwardsCurveElement"
90+
],
91+
"federatedml.secureprotol.symmetric_encryption.cryptor_executor": [
92+
"CryptoExecutor"
93+
],
94+
"federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption": [
95+
"PohligHellmanCipherKey",
96+
"PohligHellmanCiphertext"
97+
],
98+
"federatedml.statistic.intersect.intersect_preprocess": [
99+
"BitArray"
100+
],
101+
"federatedml.statistic.statics": [
102+
"SummaryStatistics"
103+
],
104+
"gmpy2": [
105+
"from_binary"
106+
],
107+
"numpy": [
108+
"dtype",
109+
"ndarray"
110+
],
111+
"numpy.core.multiarray": [
112+
"_reconstruct",
113+
"scalar"
114+
],
115+
"numpy.core.numeric": [
116+
"_frombuffer"
117+
],
118+
"tensorflow.python.framework.ops": [
119+
"convert_to_tensor"
120+
],
121+
"torch._utils": [
122+
"_rebuild_tensor_v2"
123+
],
124+
"torch.storage": [
125+
"_load_from_bytes"
126+
]
127+
}

jvm/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201
<modelVersion>4.0.0</modelVersion>
202202

203203
<properties>
204-
<eggroll.version>2.4.6</eggroll.version>
204+
<eggroll.version>2.4.7</eggroll.version>
205205

206206
<!-- Languages -->
207207
<code.cache.size>512m</code.cache.size>
@@ -396,4 +396,4 @@
396396
<module>roll_site</module>
397397
</modules>
398398

399-
</project>
399+
</project>

python/eggroll/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# limitations under the License.
1515
#
1616

17-
__version__ = "2.4.6"
17+
__version__ = "2.4.7"

python/eggroll/core/serdes/eggroll_serdes.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,81 @@ def deserialize(_bytes):
9393
return _bytes
9494

9595

96-
deserialize_blacklist = [b'eval', b'execfile', b'compile', b'system', b'popen',
96+
class WhitelistPickleSerdes(ABCSerdes):
97+
@staticmethod
98+
def serialize(_obj):
99+
return p_dumps(_obj)
100+
101+
@staticmethod
102+
def deserialize(_bytes):
103+
bytes_security_check(_bytes)
104+
return RestrictedUnpickler(io.BytesIO(_bytes)).load()
105+
106+
class _DeserializeWhitelist:
107+
loaded = False
108+
deserialize_whitelist = {}
109+
deserialize_glob_whitelist = set()
110+
111+
112+
@classmethod
113+
def get_whitelist_glob(cls):
114+
if not cls.loaded:
115+
cls.load_deserialize_whitelist()
116+
return cls.deserialize_glob_whitelist
117+
118+
@classmethod
119+
def get_whitelist(cls):
120+
if not cls.loaded:
121+
cls.load_deserialize_whitelist()
122+
return cls.deserialize_whitelist
123+
124+
@classmethod
125+
def get_whitelist_path(cls):
126+
import os.path
127+
128+
return os.path.abspath(
129+
os.path.join(
130+
__file__,
131+
os.path.pardir,
132+
os.path.pardir,
133+
os.path.pardir,
134+
os.path.pardir,
135+
os.path.pardir,
136+
"conf",
137+
"whitelist.json",
138+
)
139+
)
140+
141+
@classmethod
142+
def load_deserialize_whitelist(cls):
143+
import json
144+
with open(cls.get_whitelist_path()) as f:
145+
for k, v in json.load(f).items():
146+
if k.endswith("*"):
147+
cls.deserialize_glob_whitelist.add(k[:-1])
148+
else:
149+
cls.deserialize_whitelist[k] = set(v)
150+
cls.loaded = True
151+
152+
class RestrictedUnpickler(pickle.Unpickler):
153+
154+
def _load(self, module, name):
155+
try:
156+
return super().find_class(module, name)
157+
except:
158+
return getattr(importlib.import_module(module), name)
159+
160+
161+
def find_class(self, module, name):
162+
if name in _DeserializeWhitelist.get_whitelist().get(module, set()):
163+
return self._load(module, name)
164+
else:
165+
for m in _DeserializeWhitelist.get_whitelist_glob():
166+
if module.startswith(m):
167+
return self._load(module, name)
168+
raise pickle.UnpicklingError(f"forbidden unpickle class {module} {name}")
169+
170+
deserialize_blacklist = {b'eval', b'execfile', b'compile', b'system', b'popen',
97171
b'popen2', b'popen3',
98172
b'popen4', b'fdopen', b'tmpfile', b'fchmod', b'fchown',
99173
b'openpty',
@@ -116,7 +190,8 @@ def deserialize(_bytes):
116190
b'listdir', b'opendir', b'timeit', b'repeat',
117191
b'call_tracing', b'interact', b'compile_command',
118192
b'spawn',
119-
b'fileopen']
193+
b'fileopen',
194+
b'getattr'}
120195

121196
future_blacklist = [b'read', b'dup', b'fork', b'walk', b'file', b'move',
122197
b'link', b'kill', b'open', b'pipe']

python/eggroll/roll_pair/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from eggroll.core.pair_store import create_pair_adapter
44
import cloudpickle
55
from eggroll.core.serdes.eggroll_serdes import PickleSerdes, \
6-
CloudPickleSerdes, EmptySerdes, eggroll_pickle_loads
6+
CloudPickleSerdes, EmptySerdes, eggroll_pickle_loads, WhitelistPickleSerdes
77
from eggroll.roll_pair.utils.pair_utils import get_db_path
88

99

@@ -15,16 +15,18 @@ def create_adapter(er_partition: ErPartition, options: dict = None):
1515
options['er_partition'] = er_partition
1616
return create_pair_adapter(options=options)
1717

18-
1918
def create_serdes(serdes_type: SerdesTypes = SerdesTypes.CLOUD_PICKLE):
19+
if serdes_type == SerdesTypes.CLOUD_PICKLE or serdes_type == SerdesTypes.PROTOBUF or (not serdes_type or serdes_type == SerdesTypes.PICKLE):
20+
return WhitelistPickleSerdes
21+
else:
22+
return EmptySerdes
2023
if serdes_type == SerdesTypes.CLOUD_PICKLE or serdes_type == SerdesTypes.PROTOBUF:
2124
return CloudPickleSerdes
2225
elif not serdes_type or serdes_type == SerdesTypes.PICKLE:
2326
return PickleSerdes
2427
else:
2528
return EmptySerdes
2629

27-
2830
def create_functor(func_bin):
2931
try:
3032
return cloudpickle.loads(func_bin)

python/eggroll/roll_site/roll_site.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
L = log_utils.get_logger()
4040
P = log_utils.get_logger('profile')
41-
_serdes = eggroll_serdes.PickleSerdes
41+
_serdes = eggroll_serdes.WhitelistPickleSerdes
4242
RS_KEY_DELIM = "#"
4343
STATUS_TABLE_NAME = "__rs_status"
4444

@@ -625,7 +625,7 @@ def clear_status(task):
625625

626626
clear_future = self._receive_executor_pool.submit(rp.with_stores, clear_status, options={"__op": "clear_status"})
627627
if data_type == "object":
628-
result = pickle.loads(b''.join(map(lambda t: t[1], sorted(rp.get_all(), key=lambda x: int.from_bytes(x[0], "big")))))
628+
result = _serdes.deserialize(b''.join(map(lambda t: t[1], sorted(rp.get_all(), key=lambda x: int.from_bytes(x[0], "big")))))
629629
rp.destroy()
630630
L.debug(f"pulled object: rs_key={rs_key}, rs_header={rs_header}, is_none={result is None}, "
631631
f"elapsed={time.time() - start_time}")

0 commit comments

Comments
 (0)