Skip to content

Commit 8cd39c0

Browse files
authored
Merge pull request #164 from ecboghiu/pedro_symmetries
Finish merging the Pedro_Symmetries branch
2 parents e4df380 + 54face8 commit 8cd39c0

File tree

3 files changed

+62
-41
lines changed

3 files changed

+62
-41
lines changed

inflation/InflationProblem.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def _lexrepr_to_names(self) -> np.ndarray:
543543
return np.asarray([self._interpretation_to_name(
544544
op_dict,
545545
include_copy_indices=self._any_inflation)
546-
for op_dict in self._lexrepr_to_dicts.flat])
546+
for op_dict in self._lexrepr_to_dicts.flat], dtype=object)
547547

548548
@cached_property
549549
def _original_event_names(self) -> np.ndarray:
@@ -558,7 +558,7 @@ def _original_event_names(self) -> np.ndarray:
558558
return np.asarray([self._interpretation_to_name(
559559
self._interpret_operator(event),
560560
include_copy_indices=False)
561-
for event in self.original_dag_events])
561+
for event in self.original_dag_events], dtype=object)
562562

563563
@cached_property
564564
def _lexrepr_to_copy_index_free_names(self) -> np.ndarray:
@@ -576,7 +576,7 @@ def _lexrepr_to_copy_index_free_names(self) -> np.ndarray:
576576
return np.asarray([self._interpretation_to_name(
577577
op_dict,
578578
include_copy_indices=False)
579-
for op_dict in self._lexrepr_to_dicts.flat])
579+
for op_dict in self._lexrepr_to_dicts.flat], dtype=object)
580580

581581
@cached_property
582582
def _lexrepr_to_all_names(self) -> np.ndarray:
@@ -599,7 +599,7 @@ def _lexrepr_to_all_names(self) -> np.ndarray:
599599
self._lexrepr_to_copy_index_free_names,
600600
old_names_v1,
601601
old_names_v2
602-
), 1)
602+
), 1).astype(object)
603603

604604
@cached_property
605605
def _lexrepr_to_symbols(self) -> np.ndarray:

inflation/lp/InflationLP.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ def __init__(self,
169169
self._boolvec_for_FR_eqs = self.blank_bool_vec
170170

171171
if self.verbose > 1:
172-
print("Number of single operator measurements per party:", end="")
172+
eprint("Number of single operator measurements per party:", end="")
173173
prefix = " "
174174
for i, measures in enumerate(inflationproblem.measurements):
175175
op_count = np.prod(measures.shape[:2])
176-
print(prefix + f"{self.names[i]}={op_count}", end="")
176+
eprint(prefix + f"{self.names[i]}={op_count}", end="")
177177
prefix = ", "
178-
print()
178+
eprint()
179179
self.use_lpi_constraints = False
180180

181181
self.identity_operator = np.empty((0, self._nr_properties),
@@ -972,7 +972,7 @@ def write_to_file(self, filename: str) -> None:
972972
# Write file according to the extension
973973
args = self._prepare_solver_arguments(separate_bounds=True)
974974
if self.verbose > 0:
975-
print("Writing the LP program to", filename)
975+
eprint("Writing the LP program to", filename)
976976
if extension == "lp":
977977
write_to_lp(args, filename)
978978
elif extension == "mps":
@@ -1142,7 +1142,7 @@ def _sanitise_monomial(self, mon: Any) -> CompoundMoment:
11421142
try:
11431143
return self.monomial_from_name[mon]
11441144
except KeyError:
1145-
print(f"As of now we only recognize \n{list(self.monomial_from_name.keys())}")
1145+
eprint(f"As of now we only recognize \n{list(self.monomial_from_name.keys())}")
11461146
return self._sanitise_monomial(self._interpret_name(mon))
11471147
elif isinstance(mon, Real):
11481148
if np.isclose(float(mon), 1):
@@ -1321,7 +1321,7 @@ def _generate_lp(self) -> None:
13211321
orbits_non_CG, return_index=True, return_inverse=True)
13221322
self.num_non_CG = len(old_reps_non_CG)
13231323
if self.verbose > 1:
1324-
print(f"Orbits discovered! {self.num_CG} unique monomials.")
1324+
eprint(f"Orbits discovered! {self.num_CG} unique monomials.")
13251325
# Obtain the real generating monomials after accounting for symmetry
13261326
else:
13271327
self.num_CG = self.raw_n_columns
@@ -1406,7 +1406,7 @@ def _generate_lp(self) -> None:
14061406

14071407
self._lp_has_been_generated = True
14081408
if self.verbose > 1:
1409-
print("LP initialization complete, ready to accept further specifics.")
1409+
eprint("LP initialization complete, ready to accept further specifics.")
14101410

14111411
def _template_to_event_boolarray(self, template: List[int], decompressor: List[np.ndarray]) -> np.ndarray:
14121412
if template:

inflation/sdp/InflationSDP.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
write_to_mat,
3636
write_to_sdpa)
3737
from ..lp.numbafied import nb_outer_bitwise_or
38-
from ..utils import clean_coefficients, partsextractor
38+
from ..utils import clean_coefficients, partsextractor, eprint
3939

4040

4141
class InflationSDP:
@@ -101,17 +101,17 @@ def __init__(self,
101101

102102
self.measurements = self._generate_parties()
103103
if self.verbose > 1:
104-
print("Number of single operator measurements per party:", end="")
104+
eprint("Number of single operator measurements per party:", end="")
105105
prefix = " "
106106
for i, measures in enumerate(self.measurements):
107107
counter = count()
108108
deque(zip(chain.from_iterable(
109109
chain.from_iterable(measures)),
110110
counter),
111111
maxlen=0)
112-
print(prefix + f"{self.names[i]}={next(counter)}", end="")
112+
eprint(prefix + f"{self.names[i]}={next(counter)}", end="")
113113
prefix = ", "
114-
print()
114+
eprint()
115115
self.use_lpi_constraints = False
116116
self.network_scenario = inflationproblem.is_network
117117
self._is_knowable_q_non_networks = \
@@ -149,13 +149,30 @@ def __init__(self,
149149
self._lexorder = self._default_lexorder.copy()
150150
self.op_to_lexrepr_dict = {tuple(op): i for i, op in enumerate(self._lexorder)}
151151
self._lexorder_len = len(self._lexorder)
152-
self.lexorder_symmetries = \
152+
self.raw_lexorder_symmetries = \
153153
np.pad(inflationproblem.symmetries + 1, ((0, 0), (1, 0)))
154+
# self.lexorder_symmetries = self.raw_lexorder_symmetries.copy()
155+
CG_ops = []
156+
for boolarray in self._CG_limited_ortho_groups_as_boolarrays:
157+
CG_ops.extend(np.flatnonzero(boolarray)+1)
158+
CG_ops = np.sort(CG_ops).astype(int)
159+
self.lexorder_symmetries=np.array([
160+
perm for perm in self.raw_lexorder_symmetries
161+
if np.array_equal(np.sort(perm[CG_ops]), CG_ops)
162+
], dtype=int)
163+
if self.verbose > 0:
164+
old_group_size = len(self.raw_lexorder_symmetries)
165+
new_group_size = len(self.lexorder_symmetries)
166+
if new_group_size < old_group_size:
167+
eprint("Warning: The use of Collins-Gisin notation internally via the argument `include_all_outcomes=False`")
168+
eprint(f" means that not all symmetries of the problem can be exploited. Group size drop from {old_group_size} to {new_group_size}.")
154169

155170
self._lexrepr_to_names = \
156-
np.hstack((["0"], inflationproblem._lexrepr_to_names))
171+
np.hstack((["0"], inflationproblem._lexrepr_to_names)).astype(object)
172+
# eprint("CG stuff:", self._lexrepr_to_names[CG_ops])
173+
# eprint("else: ", np.setdiff1d(self._lexrepr_to_names, self._lexrepr_to_names[CG_ops]))
157174
self._lexrepr_to_copy_index_free_names = \
158-
np.hstack((["0"], inflationproblem._lexrepr_to_copy_index_free_names))
175+
np.hstack((["0"], inflationproblem._lexrepr_to_copy_index_free_names)).astype(object)
159176
self.op_from_name = {"0": 0}
160177
for i, op_names in enumerate(inflationproblem._lexrepr_to_all_names.tolist()):
161178
for op_name in op_names:
@@ -311,7 +328,7 @@ def generate_relaxation(self,
311328
self.build_columns(column_specification, **kwargs)
312329
collect()
313330
if self.verbose > 0:
314-
print("Number of columns in the moment matrix:", self.n_columns)
331+
eprint("Number of columns in the moment matrix:", self.n_columns)
315332

316333
# Calculate the moment matrix without the inflation symmetries
317334
unsymmetrized_mm, unsymmetrized_corresp = \
@@ -323,7 +340,7 @@ def generate_relaxation(self,
323340
else "")
324341
if 0 in unsymmetrized_mm.flat:
325342
additional_var = 1
326-
print("Number of variables" + extra_msg + ":",
343+
eprint("Number of variables" + extra_msg + ":",
327344
len(unsymmetrized_corresp) + additional_var)
328345

329346
# Calculate the inflation symmetries
@@ -340,7 +357,7 @@ def generate_relaxation(self,
340357
if self.verbose > 0:
341358
extra_msg = (" after symmetrization" if symmetrization_required
342359
else "")
343-
print(f"Number of variables{extra_msg}: "
360+
eprint(f"Number of variables{extra_msg}: "
344361
+ f"{len(self.symmetrized_corresp)+additional_var}")
345362
del unsymmetrized_mm, unsymmetrized_corresp, \
346363
symmetrization_required, additional_var
@@ -392,7 +409,7 @@ def generate_relaxation(self,
392409
self.first_free_idx = first_free_index
393410
if self.n_vars < old_num_vars:
394411
if self.verbose > 0:
395-
print("Further variable reduction has been made possible. Number of variables in the SDP:",
412+
eprint("Further variable reduction has been made possible. Number of variables in the SDP:",
396413
self.n_vars)
397414
# self.compmoment_from_idx = dict(zip(range(self.n_vars), monomials_as_list))
398415
# self.compmoment_to_idx = dict(zip(monomials_as_list, range(self.n_vars)))
@@ -418,7 +435,7 @@ def generate_relaxation(self,
418435
self.n_something_knowable = _counter["Semi"]
419436
self.n_unknowable = _counter["Unknowable"]
420437
if self.verbose > 1:
421-
print(f"The problem has {self.n_knowable} knowable moments, " +
438+
eprint(f"The problem has {self.n_knowable} knowable moments, " +
422439
f"{self.n_something_knowable} semi-knowable moments, " +
423440
f"and {self.n_unknowable} unknowable moments.")
424441

@@ -428,7 +445,7 @@ def generate_relaxation(self,
428445
self.hermitian_moments = [mon for mon in self.moments
429446
if mon.is_hermitian]
430447
if self.verbose > 1:
431-
print(f"The problem has {len(self.hermitian_moments)} " +
448+
eprint(f"The problem has {len(self.hermitian_moments)} " +
432449
"non-negative moments.")
433450

434451
# This dictionary useful for certificates_as_probs
@@ -446,7 +463,7 @@ def generate_relaxation(self,
446463
self.momentmatrix,
447464
self.verbose)
448465
if self.verbose > 1 and len(self.idx_level_equalities):
449-
print("Number of normalization equalities:",
466+
eprint("Number of normalization equalities:",
450467
len(self.idx_level_equalities))
451468
for (norm_idx, summation_idxs) in self.idx_level_equalities:
452469
eq_dict = {self.compmoment_from_idx[norm_idx]: 1}
@@ -1404,7 +1421,7 @@ def write_to_file(self, filename: str) -> None:
14041421

14051422
# Write file according to the extension
14061423
if self.verbose > 0:
1407-
print("Writing the SDP program to", filename)
1424+
eprint("Writing the SDP program to", filename)
14081425
if extension == "dat-s":
14091426
write_to_sdpa(self, filename)
14101427
elif extension == "csv":
@@ -1857,7 +1874,7 @@ def _build_cols_from_specs(self, col_specs: List[List[int]]) -> List:
18571874
for specs in col_specs:
18581875
to_print.append("1" if specs == []
18591876
else "".join([self.names[p] for p in specs]))
1860-
print("Column structure:", "+".join(to_print))
1877+
eprint("Column structure:", "+".join(to_print))
18611878

18621879
_zero_lexorder = np.array([0], dtype=np.intc)
18631880
columns = []
@@ -1969,19 +1986,23 @@ def _discover_columns_symmetries(self) -> np.ndarray:
19691986
permutation_failed = False
19701987
for inf_sym in self.lexorder_symmetries[1:]:
19711988
skip_this_one = False
1972-
try:
1973-
total_perm = np.empty(self.n_columns, dtype=int)
1974-
for i, lexmon in enumerate(self.generating_monomials_1d):
1975-
new_lexmon = inf_sym[lexmon]
1976-
new_lexmon_canon = self._to_canonical_memoized_1d(
1977-
new_lexmon,
1978-
apply_only_commutations=True)
1989+
total_perm = np.empty(self.n_columns, dtype=int)
1990+
for i, lexmon in enumerate(self.generating_monomials_1d):
1991+
new_lexmon = np.argsort(inf_sym)[lexmon]
1992+
new_lexmon_canon = self._to_canonical_memoized_1d(
1993+
new_lexmon,
1994+
apply_only_commutations=True)
1995+
try:
19791996
total_perm[i] \
1980-
= self.genmon_1d_to_index[tuple(new_lexmon_canon)]
1981-
except KeyError:
1982-
permutation_failed = True
1983-
permutations_failed += 1
1984-
skip_this_one = True
1997+
= self.genmon_1d_to_index[tuple(new_lexmon_canon)]
1998+
except KeyError:
1999+
eprint(f"Warning: generating monomial before symmetry becomes unrecognizable after symmetry!")
2000+
eprint(f" Generating monomial before symmetry: {self._lexrepr_to_names[lexmon]}")
2001+
eprint(f" Generating monomial after symmetry: {self._lexrepr_to_names[new_lexmon_canon]}")
2002+
permutation_failed = True
2003+
permutations_failed += 1
2004+
skip_this_one = True
2005+
break
19852006
if not skip_this_one:
19862007
discovered_symmetries.append(total_perm)
19872008
if permutation_failed and (self.verbose > 0):
@@ -2071,14 +2092,14 @@ def _cleanup_after_set_values(self) -> None:
20712092
if self.momentmatrix_has_a_one:
20722093
num_nontrivial_known -= 1
20732094
if self.verbose > 1 and num_nontrivial_known > 0:
2074-
print("Number of variables with fixed numeric value:",
2095+
eprint("Number of variables with fixed numeric value:",
20752096
len(self.known_moments))
20762097
if len(self.semiknown_moments):
20772098
for k in self.known_moments.keys():
20782099
self.semiknown_moments.pop(k, None)
20792100
num_semiknown = len(self.semiknown_moments)
20802101
if self.verbose > 1 and num_semiknown > 0:
2081-
print(f"Number of semiknown variables: {num_semiknown}")
2102+
eprint(f"Number of semiknown variables: {num_semiknown}")
20822103

20832104
def _reset_lowerbounds(self) -> None:
20842105
"""Reset the list of lower bounds."""

0 commit comments

Comments
 (0)