Skip to content

Commit c963ede

Browse files
committed
Add certificate de-symmetrization
This has made me realize that lexorder is different in InflationLP and InflationSDP. We may want to correct this
1 parent 44b0abb commit c963ede

File tree

3 files changed

+108
-5
lines changed

3 files changed

+108
-5
lines changed

inflation/lp/InflationLP.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,49 @@ def evaluate_certificate(self, prob_array: np.ndarray) -> float:
873873
"valid for other distributions.")
874874
return self.evaluate_polynomial(self.certificate_as_dict(), prob_array)
875875

876+
def desymmetrize_certificate(self) -> dict:
877+
"""If the scenario contains symmetries other than the inflation
878+
symmetries, this function writes a certificate of infeasibility valid
879+
for non-symmetric distributions too.
880+
881+
Parameters
882+
----------
883+
replace : bool, optional
884+
Whether to replace the certificate in ``self.solution_object``.
885+
By default ``False``.
886+
887+
Returns
888+
-------
889+
dict
890+
The expression of the un-symmetrized certificate in terms of
891+
probabilities and marginals. The certificate of incompatibility is
892+
``cert < 0``.
893+
"""
894+
try:
895+
dual = self.solution_object["dual_certificate"]
896+
except AttributeError:
897+
raise Exception("For extracting a certificate you need to solve " +
898+
"a problem. Call \"InflationSDP.solve()\" first.")
899+
900+
desymmetrized = {}
901+
norm = len(self.InflationProblem.symmetries)
902+
lexmon_names = self.InflationProblem._lexrepr_to_copy_index_free_names
903+
for symm in self.InflationProblem.symmetries:
904+
for mon, coeff in dual.items():
905+
mon = self.monomial_from_name[mon]
906+
if not mon.is_zero:
907+
desymm_mon = lexmon_names[symm[mon.as_lexmon]]
908+
desymm_mon = sorted(desymm_mon, key=lambda x: x[0])
909+
if not mon.is_one:
910+
desymm_name = "P[" + " ".join(desymm_mon) + "]"
911+
else:
912+
desymm_name = "1"
913+
if desymm_name not in desymmetrized:
914+
desymmetrized[desymm_name] = coeff / norm
915+
else:
916+
desymmetrized[desymm_name] += coeff / norm
917+
return desymmetrized
918+
876919
###########################################################################
877920
# OTHER ROUTINES EXPOSED TO THE USER #
878921
##########################################################################

inflation/sdp/InflationSDP.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,43 @@ def evaluate_certificate(self, prob_array: np.ndarray) -> float:
11211121
"valid for other distributions.")
11221122
return self.evaluate_polynomial(self.certificate_as_dict(), prob_array)
11231123

1124+
def desymmetrize_certificate(self) -> dict:
1125+
"""If the scenario contains symmetries other than the inflation
1126+
symmetries, this function writes a certificate of infeasibility valid
1127+
for non-symmetric distributions too.
1128+
1129+
Returns
1130+
-------
1131+
dict
1132+
The expression of the un-symmetrized certificate in terms of
1133+
probabilities and marginals. The certificate of incompatibility is
1134+
``cert < 0``.
1135+
"""
1136+
try:
1137+
dual = self.solution_object["dual_certificate"]
1138+
except AttributeError:
1139+
raise Exception("For extracting a certificate you need to solve " +
1140+
"a problem. Call \"InflationSDP.solve()\" first.")
1141+
1142+
desymmetrized = {}
1143+
norm = len(self.InflationProblem.symmetries)
1144+
lexmon_names = self.InflationProblem._lexrepr_to_copy_index_free_names
1145+
for symm in self.InflationProblem.symmetries:
1146+
for mon, coeff in dual.items():
1147+
mon = self.monomial_from_name[mon]
1148+
if not mon.is_zero:
1149+
desymm_mon = lexmon_names[symm[mon.as_lexmon-1]]
1150+
desymm_mon = sorted(desymm_mon, key=lambda x: x[0])
1151+
if not mon.is_one:
1152+
desymm_name = "P[" + " ".join(desymm_mon) + "]"
1153+
else:
1154+
desymm_name = "1"
1155+
if desymm_name not in desymmetrized:
1156+
desymmetrized[desymm_name] = coeff / norm
1157+
else:
1158+
desymmetrized[desymm_name] += coeff / norm
1159+
return desymmetrized
1160+
11241161
###########################################################################
11251162
# OTHER ROUTINES EXPOSED TO THE USER #
11261163
###########################################################################

test/test_symmetry.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
import numpy as np
44

5-
from inflation import InflationProblem
5+
from inflation import InflationProblem, InflationLP
66
from inflation.symmetry_utils import (discover_distribution_symmetries,
77
group_elements_from_generators)
88

@@ -24,17 +24,19 @@ def setUpClass(cls):
2424
bellScenario = InflationProblem({"Lambda": ["A", "B"]},
2525
outcomes_per_party=[2, 2],
2626
settings_per_party=[2, 2],
27-
inflation_level_per_source=[1])
27+
inflation_level_per_source=[1],
28+
classical_sources='all')
2829

2930
triangle = InflationProblem({"Lambda": ["A", "B"],
3031
"Mu": ["B", "C"],
3132
"Sigma": ["C", "A"]},
3233
outcomes_per_party=[2, 2, 2],
3334
inflation_level_per_source=[2, 1, 1])
3435

36+
PRbox_symmetries = discover_distribution_symmetries(PR_box,
37+
bellScenario)
38+
3539
def test_discover(self):
36-
PRbox_symmetries = discover_distribution_symmetries(self.PR_box,
37-
self.bellScenario)
3840
# Order: (a=0,x=0), (a=1,x=0), (a=0,x=1), (a=1,x=1),
3941
# (b=0,y=0), (b=1,y=0), (b=0,y=1), (b=1,y=1)
4042
# There's five symmetries: flip parties, flip X, flip Y, and flip the
@@ -60,7 +62,7 @@ def test_discover(self):
6062
swapped = [symm[4:] + symm[:4] for symm in symmetries]
6163
symmetries = symmetries + swapped
6264
self.assertSetEqual(set(map(tuple, symmetries)),
63-
set(map(tuple, PRbox_symmetries)),
65+
set(map(tuple, self.PRbox_symmetries)),
6466
"Failed to discover the symmetries of the PR box.")
6567

6668
def test_discover_inflation(self):
@@ -89,3 +91,24 @@ def test_group_elements_from_generators(self):
8991
self.assertSetEqual(set(map(tuple, elements)),
9092
set(map(tuple, truth)),
9193
"Failed to generate S3 from generators.")
94+
95+
def test_desymmetrized_certificate(self):
96+
self.bellScenario.add_symmetries(self.PRbox_symmetries)
97+
lp = InflationLP(self.bellScenario, verbose=0)
98+
lp.set_distribution(self.PR_box)
99+
lp.solve()
100+
certificate = lp.desymmetrize_certificate()
101+
truth = {
102+
'P[A_0=0]': 0.125, 'P[A_0=1]': 0.125, 'P[A_1=0]': 0.125, 'P[A_1=1]': 0.125,
103+
'P[B_0=0]': 0.125, 'P[B_0=1]': 0.125, 'P[B_1=0]': 0.125, 'P[B_1=1]': 0.125,
104+
'P[A_0=0 B_0=0]': -0.1875, 'P[A_0=0 B_0=1]': 0.0625,
105+
'P[A_0=0 B_1=0]': -0.1875, 'P[A_0=0 B_1=1]': 0.0625,
106+
'P[A_0=1 B_1=1]': -0.1875, 'P[A_0=1 B_1=0]': 0.0625,
107+
'P[A_0=1 B_0=1]': -0.1875, 'P[A_0=1 B_0=0]': 0.0625,
108+
'P[A_1=0 B_0=0]': -0.1875, 'P[A_1=0 B_0=1]': 0.0625,
109+
'P[A_1=0 B_1=1]': -0.1875, 'P[A_1=0 B_1=0]': 0.0625,
110+
'P[A_1=1 B_1=0]': -0.1875, 'P[A_1=1 B_1=1]': 0.0625,
111+
'P[A_1=1 B_0=1]': -0.1875, 'P[A_1=1 B_0=0]': 0.0625}
112+
self.assertDictEqual(certificate, truth,
113+
"Failed to desymmetrize the CHSH inequality.")
114+

0 commit comments

Comments
 (0)