Skip to content

Commit fde4cf5

Browse files
committed
Add test for discover_distribution_symmetries
1 parent 88057ab commit fde4cf5

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

test/test_symmetry.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
import warnings
3+
import numpy as np
4+
5+
from inflation import InflationProblem
6+
from inflation.symmetry_utils import discover_distribution_symmetries
7+
8+
9+
class TestSymmetry(unittest.TestCase):
10+
@classmethod
11+
def setUpClass(cls):
12+
warnings.simplefilter("ignore", category=UserWarning)
13+
14+
PR_box = np.zeros((2, 2, 2, 2))
15+
for a,b,x,y in np.ndindex(*PR_box.shape):
16+
if np.bitwise_xor(a,b) == np.bitwise_and(x,y):
17+
PR_box[a,b,x,y] = 0.5
18+
19+
bellScenario = InflationProblem({"Lambda": ["A", "B"]},
20+
outcomes_per_party=[2, 2],
21+
settings_per_party=[2, 2],
22+
inflation_level_per_source=[1])
23+
24+
def test_discover(self):
25+
PRbox_symmetries = discover_distribution_symmetries(self.PR_box,
26+
self.bellScenario)
27+
# Order: (a=0,x=0), (a=1,x=0), (a=0,x=1), (a=1,x=1),
28+
# (b=0,y=0), (b=1,y=0), (b=0,y=1), (b=1,y=1)
29+
# There's five symmetries: flip parties, flip X, flip Y, and flip the
30+
# outcomes of each. Out of all these, the only valid ones are those that
31+
# do not change a+b+xy mod 2
32+
# Identity
33+
symmetries = [[0, 1, 2, 3, 4, 5, 6, 7]]
34+
# Flip outcomes in x=1, and flip y
35+
symmetries += [[0, 1, 3, 2, 6, 7, 4, 5]]
36+
# Flip outcomes in x=0, flip b and y
37+
symmetries += [[1, 0, 2, 3, 7, 6, 5, 4]]
38+
# Flip a and b
39+
symmetries += [[1, 0, 3, 2, 5, 4, 7, 6]]
40+
# Flip x, flip b outcomes in y=1
41+
symmetries += [[2, 3, 0, 1, 4, 5, 7, 6]]
42+
# Flip x, flip y, flip a in x=1, flip b in y=0
43+
symmetries += [[2, 3, 1, 0, 7, 6, 4, 5]]
44+
# Flip x, flip y, flip a in x=0, flip b in y=1
45+
symmetries += [[3, 2, 0, 1, 6, 7, 5, 4]]
46+
# Flip x, flip a, flip b in y=0
47+
symmetries += [[3, 2, 1, 0, 5, 4, 6, 7]]
48+
# All the above, but swapping A and B
49+
swapped = [symm[4:] + symm[:4] for symm in symmetries]
50+
symmetries = symmetries + swapped
51+
self.assertSetEqual(set(map(tuple, symmetries)),
52+
set(map(tuple, PRbox_symmetries)),
53+
"Failed to discover the symmetries of the PR box.")

0 commit comments

Comments
 (0)