Skip to content

Commit 7256fd5

Browse files
authored
Python wrappers (#110)
* wrap C++ objects in python
1 parent 9cb5580 commit 7256fd5

File tree

4 files changed

+160
-23
lines changed

4 files changed

+160
-23
lines changed

private_set_intersection/python/BUILD

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,14 @@ pybind_extension(
1919

2020
py_library(
2121
name = "openmined_psi",
22-
srcs = ["__init__.py"],
22+
srcs = [
23+
"__init__.py",
24+
],
2325
data = ["//private_set_intersection/python:_openmined_psi.so"],
2426
srcs_version = "PY3",
2527
visibility = ["//visibility:public"],
2628
)
2729

28-
py_binary(
29-
name = "openmined_psi_bin",
30-
srcs = ["__init__.py"],
31-
data = ["//private_set_intersection/python:_openmined_psi.so"],
32-
main = "__init__.py",
33-
srcs_version = "PY3",
34-
)
35-
3630
py_test(
3731
name = "tests",
3832
srcs = ["tests.py"],

private_set_intersection/python/__init__.py

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,156 @@
11
"""Private Set Intersection protocol based on ECDH and Bloom
22
Filters.
33
"""
4+
from typing import List
5+
46
try:
57
# Used in Bazel envs
68
from private_set_intersection.python import _openmined_psi as psi
79
except ImportError:
810
# Default package
9-
import openmined_psi as psi
11+
import _openmined_psi as psi
1012

11-
client = psi.client
12-
server = psi.server
13+
__version__ = psi.__version__
1314

14-
proto_server_setup = psi.proto_server_setup
15-
proto_request = psi.proto_request
16-
proto_response = psi.proto_response
1715

18-
__version__ = psi.__version__
16+
proto_server_setup = psi.cpp_proto_server_setup
17+
proto_request = psi.cpp_proto_request
18+
proto_response = psi.cpp_proto_response
19+
20+
21+
class client:
22+
def __init__(self, data: psi.cpp_client):
23+
"""Constructor method for the client object.
24+
Args:
25+
data: cpp_client object.
26+
Returns:
27+
client object.
28+
"""
29+
self.data = data
30+
31+
@classmethod
32+
def CreateWithNewKey(cls, reveal_intersection: bool):
33+
"""Constructor method for the client object.
34+
Args:
35+
reveal_intersection: indicates whether the client wants to learn the elements in the intersection or only its size.
36+
Returns:
37+
client object.
38+
"""
39+
return cls(psi.cpp_client.CreateWithNewKey(reveal_intersection))
40+
41+
@classmethod
42+
def CreateFromKey(cls, key_bytes: bytes, reveal_intersection: bool):
43+
"""Constructor method for the client object.
44+
Args:
45+
reveal_intersection: indicates whether the client wants to learn the elements in the intersection or only its size.
46+
key_bytes: existing encryption key.
47+
Returns:
48+
client object.
49+
"""
50+
return cls(psi.cpp_client.CreateFromKey(key_bytes, reveal_intersection))
51+
52+
def CreateRequest(self, data: List[str]) -> proto_request:
53+
"""Create a request protobuf to be serialized and sent to the server.
54+
Args:
55+
data: client items.
56+
Returns:
57+
A Protobuffer with the request.
58+
"""
59+
return self.data.CreateRequest(data)
60+
61+
def GetIntersection(
62+
self, server_setup: proto_server_setup, server_response: proto_response
63+
) -> List[int]:
64+
"""Process the server's response and return the intersection of the client and server inputs.
65+
Args:
66+
server_setup: A protobuffer with the setup message.
67+
server_response: A protobuffer with server's response.
68+
Returns:
69+
A list of indices in clients set.
70+
"""
71+
return self.data.GetIntersection(server_setup, server_response)
72+
73+
def GetIntersectionSize(
74+
self, server_setup: proto_server_setup, server_response: proto_response
75+
) -> int:
76+
"""Process the server's response and return the size of the intersection.
77+
Args:
78+
server_setup: A protobuffer with the setup message.
79+
server_response: A protobuffer with server's response.
80+
Returns:
81+
The intersection size.
82+
"""
83+
return self.data.GetIntersectionSize(server_setup, server_response)
84+
85+
def GetPrivateKeyBytes(self) -> bytes:
86+
"""Returns this instance's private key. This key should only be used to create other client instances. DO NOT SEND THIS KEY TO ANY OTHER PARTY!
87+
Returns:
88+
Bytes containing the key.
89+
"""
90+
return self.data.GetPrivateKeyBytes()
91+
92+
93+
class server:
94+
def __init__(self, data: psi.cpp_server):
95+
"""Constructor method for the server object.
96+
Args:
97+
data: cpp_server object.
98+
Returns:
99+
server object.
100+
"""
101+
self.data = data
102+
103+
@classmethod
104+
def CreateWithNewKey(cls, reveal_intersection: bool):
105+
"""Constructor method for the server object.
106+
Args:
107+
reveal_intersection: indicates whether the server supports to return the elements in the intersection or only its size.
108+
Returns:
109+
server object.
110+
"""
111+
return cls(psi.cpp_server.CreateWithNewKey(reveal_intersection))
112+
113+
@classmethod
114+
def CreateFromKey(cls, key_bytes: bytes, reveal_intersection: bool):
115+
"""Constructor method for the server object.
116+
Args:
117+
reveal_intersection: indicates whether the server supports to return the elements in the intersection or only its size.
118+
key_bytes: existing encryption key.
119+
Returns:
120+
Returns:
121+
server object.
122+
"""
123+
return cls(psi.cpp_server.CreateFromKey(key_bytes, reveal_intersection))
124+
125+
def CreateSetupMessage(
126+
self, fpr: float, num_client_inputs: int, inputs: List[str]
127+
) -> proto_server_setup:
128+
"""Create a setup message from the server's dataset to be sent to the client.
129+
Args:
130+
fpr: the probability that any query of size `num_client_inputs` will result in a false positive.
131+
num_client_inputs: Client set size.
132+
inputs: Server items.
133+
Returns:
134+
A Protobuf with the setup message.
135+
"""
136+
return self.data.CreateSetupMessage(fpr, num_client_inputs, inputs)
137+
138+
def ProcessRequest(self, client_request: proto_request) -> proto_response:
139+
"""Process a client query and returns the corresponding server response to be sent to the client.
140+
Args:
141+
client_request: A Protobuf containing the client request
142+
Returns:
143+
A Protobuf with the server response.
144+
"""
145+
return self.data.ProcessRequest(client_request)
146+
147+
def GetPrivateKeyBytes(self) -> bytes:
148+
"""Return this instance's private key. This key should only be used to create other server instances. DO NOT SEND THIS KEY TO ANY OTHER PARTY!
149+
Returns:
150+
Bytes containing the key.
151+
"""
152+
return self.data.GetPrivateKeyBytes()
153+
19154

20155
__all__ = [
21156
"client",

private_set_intersection/python/psi_bindings.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void bind(pybind11::module& m) {
4040
"Filters";
4141

4242
m.attr("__version__") = ::private_set_intersection::Package::kVersion;
43-
py::class_<psi_proto::ServerSetup>(m, "proto_server_setup")
43+
py::class_<psi_proto::ServerSetup>(m, "cpp_proto_server_setup")
4444
.def(py::init<>())
4545
.def(
4646
"bits",
@@ -64,7 +64,7 @@ void bind(pybind11::module& m) {
6464
loadProto(obj, data);
6565
return obj;
6666
});
67-
py::class_<psi_proto::Request>(m, "proto_request")
67+
py::class_<psi_proto::Request>(m, "cpp_proto_request")
6868
.def(py::init<>())
6969
.def("encrypted_elements_size",
7070
&psi_proto::Request::encrypted_elements_size)
@@ -99,7 +99,7 @@ void bind(pybind11::module& m) {
9999
loadProto(obj, data);
100100
return obj;
101101
});
102-
py::class_<psi_proto::Response>(m, "proto_response")
102+
py::class_<psi_proto::Response>(m, "cpp_proto_response")
103103
.def(py::init<>())
104104
.def("encrypted_elements_size",
105105
&psi_proto::Response::encrypted_elements_size)
@@ -131,7 +131,7 @@ void bind(pybind11::module& m) {
131131
return obj;
132132
});
133133

134-
py::class_<psi::PsiClient>(m, "client")
134+
py::class_<psi::PsiClient>(m, "cpp_client")
135135
.def_static(
136136
"CreateWithNewKey",
137137
[](bool reveal_intersection) {
@@ -183,7 +183,7 @@ void bind(pybind11::module& m) {
183183
},
184184
py::call_guard<py::gil_scoped_release>());
185185

186-
py::class_<psi::PsiServer>(m, "server")
186+
py::class_<psi::PsiServer>(m, "cpp_server")
187187
.def_static(
188188
"CreateWithNewKey",
189189
[](bool reveal_intersection) {

setup.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,25 @@ def bazel_build(self, ext):
8484
]
8585
self.spawn(bazel_argv)
8686

87+
ext.name = "_" + ext.name
8788
shared_lib_ext = ".so"
88-
shared_lib = "_" + ext.name + shared_lib_ext
89+
shared_lib = ext.name + shared_lib_ext
8990
ext_bazel_bin_path = os.path.join(self.build_temp, "bazel-bin", ext.relpath, shared_lib)
9091

9192
ext_dest_path = self.get_ext_fullpath(ext.name)
9293
ext_dest_dir = os.path.dirname(ext_dest_path)
94+
9395
if not os.path.exists(ext_dest_dir):
9496
os.makedirs(ext_dest_dir)
9597
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
9698

99+
package_dir = os.path.join(ext_dest_dir, "openmined_psi")
100+
if not os.path.exists(package_dir):
101+
os.makedirs(package_dir)
102+
shutil.copyfile(
103+
"private_set_intersection/python/__init__.py", os.path.join(package_dir, "__init__.py")
104+
)
105+
97106

98107
setuptools.setup(
99108
name="openmined.psi",
@@ -103,7 +112,6 @@ def bazel_build(self, ext):
103112
url="https://github.com/OpenMined/PSI",
104113
python_requires=">=3.6",
105114
package_dir={"": "private_set_intersection/python"},
106-
packages=setuptools.find_packages("private_set_intersection/python"),
107115
cmdclass=dict(build_ext=BuildBazelExtension),
108116
ext_modules=[
109117
BazelExtension("openmined_psi", "//private_set_intersection/python:openmined_psi",)

0 commit comments

Comments
 (0)