Skip to content

Commit 3208718

Browse files
Enable serialization in daal4py algorithm classes (uxlfoundation#2067)
* enable serialization for algorithm classes * prevent accidental reorder of dict * add test for serialization of kmeans
1 parent e5fe55e commit 3208718

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

generator/wrapper_gen.py

+13
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,8 @@ def __cinit__(self):
10031003
10041004
# this is our actual algorithm class for Python
10051005
cdef class {{algo}}{{'('+iface[0]|lower+'__iface__)' if iface[0] else ''}}:
1006+
cdef tuple _params
1007+
10061008
'''
10071009
{{algo}}
10081010
{{params_all|fmt('{}', 'sphinx', sep='\n')|indent(4)}}
@@ -1017,6 +1019,17 @@ def __cinit__(self,
10171019
self.c_ptr = mk_{{algo}}(
10181020
{{params_all|fmt('{}', 'arg_cyext', sep=',\n')|indent(25+(algo|length))}}
10191021
)
1022+
current_locals = locals()
1023+
ordered_input_args = '''
1024+
{{params_all|fmt('{}', 'name', sep=' ')|indent(0)}}
1025+
'''.strip().split()
1026+
self._params = tuple(
1027+
current_locals[arg]
1028+
for arg in ordered_input_args
1029+
)
1030+
1031+
def __reduce__(self):
1032+
return (self.__class__, self._params)
10201033
10211034
{% if not iface[0] %}
10221035
# the C++ manager__iface__ (de-templatized)

tests/test_daal4py_serialization.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
import pickle
18+
import unittest
19+
20+
import numpy as np
21+
22+
import daal4py
23+
24+
25+
class Test(unittest.TestCase):
26+
def test_serialization_of_qr(self):
27+
obj_original = daal4py.qr(fptype="float")
28+
obj_deserialized = pickle.loads(pickle.dumps(obj_original))
29+
30+
rng = np.random.default_rng(seed=123)
31+
X = rng.standard_normal(size=(10, 5))
32+
33+
Q_orig = obj_original.compute(X).matrixQ
34+
Q_deserialized = obj_deserialized.compute(X).matrixQ
35+
np.testing.assert_almost_equal(Q_orig, Q_deserialized)
36+
assert Q_orig.dtype == Q_deserialized.dtype
37+
38+
def test_serialization_of_kmeans(self):
39+
obj_original = daal4py.kmeans_init(nClusters=4)
40+
obj_deserialized = pickle.loads(pickle.dumps(obj_original))
41+
42+
rng = np.random.default_rng(seed=123)
43+
X = rng.standard_normal(size=(100, 20))
44+
45+
np.testing.assert_almost_equal(
46+
obj_original.compute(X).centroids,
47+
obj_deserialized.compute(X).centroids,
48+
)

0 commit comments

Comments
 (0)