Skip to content

Commit 4c02c87

Browse files
committed
Add secure np.cum{ulative_}sum().
1 parent 6ab5abb commit 4c02c87

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

mpyc/runtime.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3143,6 +3143,48 @@ async def np_sum(self, a, axis=None, keepdims=False, initial=0):
31433143
return np.sum(a, axis=axis, keepdims=keepdims, initial=initial.value)
31443144
# TODO: handle switch from initial (field elt) to initial.value inside finfields.py
31453145

3146+
@asyncoro.mpc_coro_no_pc
3147+
async def np_cumsum(self, a, axis=None):
3148+
"""Secure cumulative sum of array a along given axis.
3149+
3150+
If axis is None, array a is flattened first.
3151+
"""
3152+
shape = (a.size,) if a.shape == () or axis is None else a.shape
3153+
if isinstance(a, self.SecureFixedPointArray):
3154+
rettype = (type(a), a.integral, shape)
3155+
else:
3156+
rettype = (type(a), shape)
3157+
await self.returnType(rettype)
3158+
a = await self.gather(a)
3159+
return np.cumsum(a, axis=axis)
3160+
3161+
@asyncoro.mpc_coro_no_pc
3162+
async def np_cumulative_sum(self, a, axis=None, include_initial=False):
3163+
"""Secure cumulative sum of array a along given axis.
3164+
3165+
Only for 0D and 1D arrays, axis is allowed to be None.
3166+
If include_initial holds, the initial zero(s) are included in the output.
3167+
"""
3168+
if axis is None:
3169+
if a.ndim < 2:
3170+
axis = 0
3171+
else:
3172+
raise ValueError('For arrays which have more than one dimension '
3173+
'``axis`` argument is required.')
3174+
3175+
shape = (a.size,) if a.shape == () else a.shape
3176+
if include_initial:
3177+
shape = list(shape)
3178+
shape[axis] += 1
3179+
shape = tuple(shape)
3180+
if isinstance(a, self.SecureFixedPointArray):
3181+
rettype = (type(a), a.integral, shape)
3182+
else:
3183+
rettype = (type(a), shape)
3184+
await self.returnType(rettype)
3185+
a = await self.gather(a)
3186+
return np.cumulative_sum(a, axis=axis, include_initial=include_initial)
3187+
31463188
@asyncoro.mpc_coro_no_pc
31473189
async def np_negative(self, a):
31483190
"""Secure elementwise negation -a (additive inverse) of a."""

tests/test_runtime.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def __lt__(self, other):
246246
np.assertEqual(mpc.run(mpc.output(c.sum(keepdims=True))), a.sum(keepdims=True))
247247
np.assertEqual(mpc.run(mpc.output(c.sum(axis=(0, 2), keepdims=True))),
248248
a.sum(axis=(0, 2), keepdims=True))
249+
np.assertEqual(mpc.run(mpc.output(np.cumsum(c))), np.cumsum(a))
250+
np.assertEqual(mpc.run(mpc.output(np.cumulative_sum(c, axis=1, include_initial=True))),
251+
np.cumulative_sum(a, axis=1, include_initial=True))
249252
np.assertEqual(mpc.run(mpc.output(c**254 * c**0 * c**-253)), a)
250253

251254
# TODO: c //= 2 secure int __floordiv__() etc.
@@ -317,6 +320,9 @@ def test_secfxp_array(self):
317320
np.assertEqual(mpc.run(mpc.output(np.equal(c, 0))), False)
318321
np.assertEqual(mpc.run(mpc.output(np.sum(c, axis=(-2, 1)))), np.sum(a, axis=(-2, 1)))
319322
np.assertEqual(mpc.run(mpc.output(c.sum(axis=1, initial=1.5))), a.sum(axis=1, initial=1.5))
323+
np.assertEqual(mpc.run(mpc.output(np.cumsum(c, axis=0))), np.cumsum(a, axis=0))
324+
np.assertEqual(mpc.run(mpc.output(np.cumulative_sum(c, axis=1, include_initial=True))),
325+
np.cumulative_sum(a, axis=1, include_initial=True))
320326
self.assertEqual(np.prod(c, axis=(-2, 1)).integral, False)
321327
np.assertEqual(mpc.run(mpc.output(np.prod(c, axis=(-2, 1)))), np.prod(a, axis=(-2, 1)))
322328
a = a.flatten()[:3]
@@ -393,6 +399,16 @@ def test_secfxp_array(self):
393399
self.assertEqual(np.sum(c1).integral, False)
394400
self.assertEqual(np.sum(c1, axis=0).integral, False)
395401

402+
self.assertEqual(np.cumsum(c2).integral, True)
403+
self.assertEqual(np.cumsum(c2, axis=0).integral, True)
404+
self.assertEqual(np.cumsum(c1).integral, False)
405+
self.assertEqual(np.cumsum(c1, axis=0).integral, False)
406+
407+
self.assertEqual(np.cumulative_sum(c2, axis=1).integral, True)
408+
self.assertEqual(np.cumulative_sum(c2, axis=0).integral, True)
409+
self.assertEqual(np.cumulative_sum(c1, axis=1).integral, False)
410+
self.assertEqual(np.cumulative_sum(c1, axis=0).integral, False)
411+
396412
self.assertEqual(mpc.np_sgn(c1).integral, True)
397413
self.assertEqual(mpc.np_sgn(c2).integral, True)
398414
self.assertEqual(np.absolute(c1).integral, False)
@@ -505,6 +521,10 @@ def test_secfld_array(self):
505521
np.assertEqual(mpc.run(mpc.output(np.outer(a, c))), np.outer(a, a))
506522
np.assertEqual(mpc.run(mpc.output(np.convolve(a[0][0], c[0][0]))),
507523
np.convolve(a[0][0], a[0][0]))
524+
np.assertEqual(mpc.run(mpc.output(np.sum(c))), np.sum(a))
525+
np.assertEqual(mpc.run(mpc.output(np.cumsum(c))), np.cumsum(a))
526+
np.assertEqual(mpc.run(mpc.output(np.cumulative_sum(c, axis=0, include_initial=True))),
527+
np.cumulative_sum(a, axis=0, include_initial=True))
508528
np.assertEqual(mpc.run(mpc.output(np.roll(c, 1))), np.roll(a, 1))
509529
self.assertEqual(len(c), 1)
510530
self.assertEqual(len(c.T), 2)
@@ -525,6 +545,7 @@ def test_np_errors(self):
525545
self.assertRaises(ValueError, np.convolve, c, c)
526546
self.assertRaises(ValueError, np.convolve, c[0], c[0][:0])
527547
self.assertRaises(ValueError, np.convolve, c[0][:0], c[0])
548+
self.assertRaises(ValueError, np.cumulative_sum, c)
528549

529550
def test_async(self):
530551
mpc.options.no_async = False

0 commit comments

Comments
 (0)