Skip to content

Commit e608e75

Browse files
committed
Improved warnings to the user when not all symmetries of the InflationProblem can be exploited by InflationSDP, resolving issue #163
1 parent f0b8c4e commit e608e75

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

inflation/sdp/InflationSDP.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
write_to_mat,
3636
write_to_sdpa)
3737
from ..lp.numbafied import nb_outer_bitwise_or
38-
from ..utils import clean_coefficients, partsextractor
38+
from ..utils import clean_coefficients, partsextractor, eprint
3939

4040

4141
class InflationSDP:
@@ -149,8 +149,23 @@ def __init__(self,
149149
self._lexorder = self._default_lexorder.copy()
150150
self.op_to_lexrepr_dict = {tuple(op): i for i, op in enumerate(self._lexorder)}
151151
self._lexorder_len = len(self._lexorder)
152-
self.lexorder_symmetries = \
152+
self.raw_lexorder_symmetries = \
153153
np.pad(inflationproblem.symmetries + 1, ((0, 0), (1, 0)))
154+
# self.lexorder_symmetries = self.raw_lexorder_symmetries.copy()
155+
CG_ops = []
156+
for boolarray in self._CG_limited_ortho_groups_as_boolarrays:
157+
CG_ops.extend(np.flatnonzero(boolarray)+1)
158+
CG_ops = np.sort(CG_ops).astype(int)
159+
self.lexorder_symmetries=np.array([
160+
perm for perm in self.raw_lexorder_symmetries
161+
if np.array_equal(np.sort(perm[CG_ops]), CG_ops)
162+
], dtype=int)
163+
if self.verbose > 0:
164+
old_group_size = len(self.raw_lexorder_symmetries)
165+
new_group_size = len(self.lexorder_symmetries)
166+
if new_group_size < old_group_size:
167+
eprint("Warning: The use of Collins-Gisin notation internally via the argument `include_all_outcomes=False`")
168+
eprint(f" means that not all symmetries of the problem can be exploited. Group size drop from {old_group_size} to {new_group_size}.")
154169

155170
self._lexrepr_to_names = \
156171
np.hstack((["0"], inflationproblem._lexrepr_to_names))
@@ -1959,19 +1974,23 @@ def _discover_columns_symmetries(self) -> np.ndarray:
19591974
permutation_failed = False
19601975
for inf_sym in self.lexorder_symmetries[1:]:
19611976
skip_this_one = False
1962-
try:
1963-
total_perm = np.empty(self.n_columns, dtype=int)
1964-
for i, lexmon in enumerate(self.generating_monomials_1d):
1965-
new_lexmon = inf_sym[lexmon]
1966-
new_lexmon_canon = self._to_canonical_memoized_1d(
1967-
new_lexmon,
1968-
apply_only_commutations=True)
1977+
total_perm = np.empty(self.n_columns, dtype=int)
1978+
for i, lexmon in enumerate(self.generating_monomials_1d):
1979+
new_lexmon = np.argsort(inf_sym)[lexmon]
1980+
new_lexmon_canon = self._to_canonical_memoized_1d(
1981+
new_lexmon,
1982+
apply_only_commutations=True)
1983+
try:
19691984
total_perm[i] \
1970-
= self.genmon_1d_to_index[tuple(new_lexmon_canon)]
1971-
except KeyError:
1972-
permutation_failed = True
1973-
permutations_failed += 1
1974-
skip_this_one = True
1985+
= self.genmon_1d_to_index[tuple(new_lexmon_canon)]
1986+
except KeyError:
1987+
eprint(f"Warning: generating monomial before symmetry becomes unrecognizable after symmetry!")
1988+
eprint(f" Generating monomial before symmetry: {self._lexrepr_to_names[lexmon]}")
1989+
eprint(f" Generating monomial after symmetry: {self._lexrepr_to_names[new_lexmon_canon]}")
1990+
permutation_failed = True
1991+
permutations_failed += 1
1992+
skip_this_one = True
1993+
break
19751994
if not skip_this_one:
19761995
discovered_symmetries.append(total_perm)
19771996
if permutation_failed and (self.verbose > 0):

0 commit comments

Comments
 (0)