Skip to content

Commit

Permalink
Extend mpc.random.shuffle() to lists of lists.
Browse files Browse the repository at this point in the history
Consistent with mpc.sorted(), mpc.if_else(), mpc.if_swap(), mpc.min(), mpc.argmax() etc., which also work for lists of (all same length) lists.
  • Loading branch information
lschoe committed Nov 18, 2022
1 parent 1dc3fe1 commit 9876610
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
30 changes: 24 additions & 6 deletions mpyc/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,35 @@ def shuffle(sectype, x):
"""Shuffle list x secretly in-place, and return None.
Given list x may contain public or secret elements.
Elements of x are all numbers or all lists (of the same length) of numbers.
"""
n = len(x)
if not isinstance(x[0], sectype): # assume same type for all elts of x
for i in range(len(x)):
x[i] = sectype(x[i])
# assume same type for all elts of x
x_i_is_list = isinstance(x[0], list)
if not x_i_is_list:
# elements of x are numbers
if not isinstance(x[0], sectype):
for i in range(n):
x[i] = sectype(x[i])
for i in range(n-1):
u = random_unit_vector(sectype, n - i)
x_u = runtime.in_prod(x[i:], u)
d = runtime.scalar_mul(x[i] - x_u, u)
x[i] = x_u
x[i:] = runtime.vector_add(x[i:], d)
return

# elements of x are lists of numbers
for j in range(len(x[0])):
if not isinstance(x[0][j], sectype):
for i in range(n):
x[i][j] = sectype(x[i][j])
for i in range(n-1):
u = random_unit_vector(sectype, n - i)
x_u = runtime.in_prod(x[i:], u)
d = runtime.scalar_mul(x[i] - x_u, u)
x_u = runtime.matrix_prod([u], x[i:])[0]
d = runtime.matrix_prod([[a] for a in u], [runtime.vector_sub(x[i], x_u)])
x[i] = x_u
x[i:] = runtime.vector_add(x[i:], d)
x[i:] = runtime.matrix_add(x[i:], d)


def random_permutation(sectype, x):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ def test_secint(self):

x = list(range(8))
shuffle(secint, x)
shuffle(secint, x)
x = mpc.run(mpc.output(x))
self.assertSetEqual(set(x), set(range(8)))
x = list(map(list, zip(range(8), range(0, -8, -1))))
shuffle(secint, x)
a = mpc.run(mpc.output(x[0]))
self.assertEqual(a[1], -a[0])
x = mpc.run(mpc.output(random_permutation(secint, 8)))
self.assertSetEqual(set(x), set(range(8)))
x = mpc.run(mpc.output(random_derangement(secint, 2)))
Expand Down

0 comments on commit 9876610

Please sign in to comment.