Skip to content
114 changes: 81 additions & 33 deletions inflation/sdp/InflationSDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
write_to_mat,
write_to_sdpa)
from ..lp.numbafied import nb_outer_bitwise_or
from ..utils import clean_coefficients, partsextractor
from ..utils import clean_coefficients, partsextractor, eprint


class InflationSDP:
Expand All @@ -49,6 +49,7 @@ def __init__(self,
supports_problem: bool = False,
include_all_outcomes: bool = False,
commuting: bool = False,
real_qt: bool = False,
verbose: int = None) -> None:
"""
Class for generating and solving an SDP relaxation for quantum inflation.
Expand All @@ -60,6 +61,9 @@ def __init__(self,
supports_problem : bool, optional
Whether to consider feasibility problems with distributions, or just
with the distribution's support. By default ``False``.
real_qt : bool, optional
Whether to assume real quantum theory instead of traditional (complex)
quantum theory. By default ``False``.
verbose : int, optional
Optional parameter for level of verbose:

Expand Down Expand Up @@ -102,17 +106,17 @@ def __init__(self,

self.measurements = self._generate_parties()
if self.verbose > 1:
print("Number of single operator measurements per party:", end="")
eprint("Number of single operator measurements per party:", end="")
prefix = " "
for i, measures in enumerate(self.measurements):
counter = count()
deque(zip(chain.from_iterable(
chain.from_iterable(measures)),
counter),
maxlen=0)
print(prefix + f"{self.names[i]}={next(counter)}", end="")
eprint(prefix + f"{self.names[i]}={next(counter)}", end="")
prefix = ", "
print()
eprint()
self.use_lpi_constraints = False
self.network_scenario = inflationproblem.is_network
self._is_knowable_q_non_networks = \
Expand Down Expand Up @@ -195,6 +199,13 @@ def __init__(self,
"You appear to be requesting commuting (classical)" \
+ " inflation, \nbut have not specified classical_sources=`all`." \
+ "\nNote that the `commuting` keyword argument has been deprecated as of release 2.0.0"
if real_qt:
assert not self.all_operators_commute, \
"You appear to be requesting inflation assuming real quantum theory," \
+ " but this is meaningless without noncommuting operators."
self.real_qt = True
else:
self.real_qt = False
if self.all_operators_commute:
self.all_commuting_q_2d = lambda mon: True
self.all_commuting_q_1d = lambda lexmon: True
Expand Down Expand Up @@ -224,8 +235,7 @@ def generate_relaxation(self,
column_specification:
Union[str,
List[List[int]],
List[sp.core.symbol.Symbol]] = "npa1",
**kwargs
List[sp.core.symbol.Symbol]] = "npa1"
) -> None:
r"""Creates the SDP relaxation of the quantum inflation problem using
the `NPA hierarchy <https://www.arxiv.org/abs/quant-ph/0607119>`_ and
Expand Down Expand Up @@ -308,10 +318,10 @@ def generate_relaxation(self,
self.Constant_Term.name = self.constant_term_name
self.monomial_from_name[self.constant_term_name] = self.Constant_Term

self.build_columns(column_specification, **kwargs)
self.build_columns(column_specification)
collect()
if self.verbose > 0:
print("Number of columns in the moment matrix:", self.n_columns)
eprint("Number of columns in the moment matrix:", self.n_columns)

# Calculate the moment matrix without the inflation symmetries
unsymmetrized_mm, unsymmetrized_corresp = \
Expand All @@ -323,7 +333,7 @@ def generate_relaxation(self,
else "")
if 0 in unsymmetrized_mm.flat:
additional_var = 1
print("Number of variables" + extra_msg + ":",
eprint("Number of variables" + extra_msg + ":",
len(unsymmetrized_corresp) + additional_var)

# Calculate the inflation symmetries
Expand All @@ -340,7 +350,7 @@ def generate_relaxation(self,
if self.verbose > 0:
extra_msg = (" after symmetrization" if symmetrization_required
else "")
print(f"Number of variables{extra_msg}: "
eprint(f"Number of variables{extra_msg}: "
+ f"{len(self.symmetrized_corresp)+additional_var}")
del unsymmetrized_mm, unsymmetrized_corresp, \
symmetrization_required, additional_var
Expand All @@ -351,26 +361,57 @@ def generate_relaxation(self,

# Associate Monomials to the remaining entries. The zero monomial is
# not stored during calculate_momentmatrix
first_free_index = 0
self.compmoment_from_idx = dict()
if self.momentmatrix_has_a_zero:
self.compmoment_from_idx[0] = self.Zero
_compmonomial_to_idx = dict()
self.n_vars = len(self.symmetrized_corresp)
_compmonomial_to_idx[self.Zero] = 0
self.compmoment_from_idx[0] = self.Zero
first_free_index += 1
self.n_vars += 1
self.extra_inverse = np.arange(self.n_vars, dtype=int)
self.old_indices_associated_with_new_index = defaultdict(list)
self.old_indices_associated_with_monomial = defaultdict(list)
self.old_indices_associated_with_new_index[0]=[0]
self.old_indices_associated_with_monomial[self.Zero] = [0]
for (idx, lexmon) in tqdm(self.symmetrized_corresp.items(),
disable=not self.verbose,
desc="Initializing monomials "):
self.compmoment_from_idx[idx] = self.Moment_1d(lexmon, idx)
self.first_free_idx = max(self.compmoment_from_idx.keys()) + 1
self.moments = list(self.compmoment_from_idx.values())
self.monomials = list(self.compmoment_from_idx.values())

assert all(v == 1 for v in Counter(self.monomials).values()), \
"Multiple indices are being associated to the same monomial"
disable=not self.verbose,
desc="Initializing monomials ",
total=self.n_vars):
mon = self.Moment_1d(lexmon, first_free_index)
self.compmoment_from_idx[idx] = mon # Critical for normalization equations and other functions that use old indices
try:
current_index = _compmonomial_to_idx[mon]
mon.idx = current_index
except KeyError:
current_index = first_free_index
_compmonomial_to_idx[mon] = current_index
first_free_index += 1
self.old_indices_associated_with_new_index[current_index].append(idx)
self.old_indices_associated_with_monomial[mon].append(idx)
self.extra_inverse[idx] = current_index

self.monomials = list(_compmonomial_to_idx.keys())
if not self.momentmatrix_has_a_zero:
self.monomials = self.monomials[1:]
self.n_vars -= 1
self.moments = self.monomials
old_num_vars = self.n_vars
self.n_vars = len(self.monomials)
self.first_free_idx = first_free_index
if self.n_vars < old_num_vars:
if self.verbose > 0:
eprint("Further variable reduction has been made possible. Number of variables in the SDP:",
self.n_vars)
del _compmonomial_to_idx
collect(generation=2)

_counter = Counter([mon.knowability_status for mon in self.moments])
self.n_knowable = _counter["Knowable"]
self.n_something_knowable = _counter["Semi"]
self.n_unknowable = _counter["Unknowable"]
if self.verbose > 1:
print(f"The problem has {self.n_knowable} knowable moments, " +
eprint(f"The problem has {self.n_knowable} knowable moments, " +
f"{self.n_something_knowable} semi-knowable moments, " +
f"and {self.n_unknowable} unknowable moments.")

Expand All @@ -380,7 +421,7 @@ def generate_relaxation(self,
self.hermitian_moments = [mon for mon in self.moments
if mon.is_hermitian]
if self.verbose > 1:
print(f"The problem has {len(self.hermitian_moments)} " +
eprint(f"The problem has {len(self.hermitian_moments)} " +
"non-negative moments.")

# This dictionary useful for certificates_as_probs
Expand All @@ -398,7 +439,7 @@ def generate_relaxation(self,
self.momentmatrix,
self.verbose)
if self.verbose > 1 and len(self.idx_level_equalities):
print("Number of normalization equalities:",
eprint("Number of normalization equalities:",
len(self.idx_level_equalities))
for (norm_idx, summation_idxs) in self.idx_level_equalities:
eq_dict = {self.compmoment_from_idx[norm_idx]: 1}
Expand Down Expand Up @@ -1315,7 +1356,7 @@ def write_to_file(self, filename: str) -> None:

# Write file according to the extension
if self.verbose > 0:
print("Writing the SDP program to", filename)
eprint("Writing the SDP program to", filename)
if extension == "dat-s":
write_to_sdpa(self, filename)
elif extension == "csv":
Expand Down Expand Up @@ -1361,6 +1402,9 @@ def _AtomicMonomial(self,
return mon
except KeyError:
mon = InternalAtomicMonomialSDP(self, repr_lexmon)
if self.real_qt:
conj = mon.dagger
mon = min(mon, conj)
self.atomic_monomial_from_hash[key] = mon
self.atomic_monomial_from_hash[new_key] = mon
return mon
Expand Down Expand Up @@ -1467,8 +1511,8 @@ def _construct_mask_matrices(self) -> None:
if self._relaxation_has_been_generated:
if self.n_columns > 0:
self.maskmatrices = {
mon: coo_array(self.momentmatrix == mon.idx)
for mon in tqdm(self.moments,
mon: sum(coo_array(self.momentmatrix == oldidx) for oldidx in oldidxs)
for mon, oldidxs in tqdm(self.old_indices_associated_with_monomial.items(),
disable=not self.verbose,
desc="Assigning mask matrices ")
}
Expand Down Expand Up @@ -1509,9 +1553,13 @@ def _monomial_from_atoms(self,
else:
pass
atoms = tuple(sorted(list_of_atoms))
conjugate = tuple(sorted(factor.dagger for factor in atoms))
atoms = min(atoms, conjugate)
del conjugate
if not self.all_operators_commute:
conjugate = [factor.dagger for factor in atoms]
if not self.real_qt:
atoms = min(atoms, tuple(sorted(conjugate)))
else:
atoms = min(tuple(sorted(candidate)) for candidate in product(*zip(atoms, conjugate)))
del conjugate
try:
mon = self.monomial_from_atoms[atoms]
return mon
Expand Down Expand Up @@ -1768,7 +1816,7 @@ def _build_cols_from_specs(self, col_specs: List[List[int]]) -> List:
for specs in col_specs:
to_print.append("1" if specs == []
else "".join([self.names[p] for p in specs]))
print("Column structure:", "+".join(to_print))
eprint("Column structure:", "+".join(to_print))

_zero_lexorder = np.array([0], dtype=np.intc)
columns = []
Expand Down Expand Up @@ -1981,14 +2029,14 @@ def _cleanup_after_set_values(self) -> None:
if self.momentmatrix_has_a_one:
num_nontrivial_known -= 1
if self.verbose > 1 and num_nontrivial_known > 0:
print("Number of variables with fixed numeric value:",
eprint("Number of variables with fixed numeric value:",
len(self.known_moments))
if len(self.semiknown_moments):
for k in self.known_moments.keys():
self.semiknown_moments.pop(k, None)
num_semiknown = len(self.semiknown_moments)
if self.verbose > 1 and num_semiknown > 0:
print(f"Number of semiknown variables: {num_semiknown}")
eprint(f"Number of semiknown variables: {num_semiknown}")

def _reset_lowerbounds(self) -> None:
"""Reset the list of lower bounds."""
Expand Down
10 changes: 6 additions & 4 deletions inflation/sdp/sdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,20 +302,22 @@ def constraint_dicts_to_sparse(constraints: List[dict]) -> (coo_array, coo_array
I = M.variable("I",
len(var_inequalities),
Domain.greaterThan(0))
I_reshaped = I.reshape(I.getShape()[0], 1)
# It seems MOSEK Fusion API does not allow to pick index i
# of an expression (A^T I)_i, so we do it manually row by row.
AtI = [] # \sum_j I_j A_ji as i-th entry of AtI
for var in variables:
slice_ = coo_getcol(A, var2index[var])
sparse_slice = scipy_to_mosek(slice_)
AtI.append(Expr.dot(sparse_slice, I))
AtI.append(Expr.dot(sparse_slice, I_reshaped))
if var_equalities:
E = M.variable("E", len(var_equalities), Domain.unbounded())
E_reshaped = E.reshape(E.getShape()[0], 1)
CtI = [] # \sum_j E_j C_ji as i-th entry of CtI
for var in variables:
slice_ = coo_getcol(C, var2index[var])
sparse_slice = scipy_to_mosek(slice_)
CtI.append(Expr.dot(sparse_slice, E))
CtI.append(Expr.dot(sparse_slice, E_reshaped))

# Define and set objective function
# c0 + Tr Z F0 + I·b + E·d
Expand All @@ -328,11 +330,11 @@ def constraint_dicts_to_sparse(constraints: List[dict]) -> (coo_array, coo_array
del F0_mosek
if var_inequalities:
b_mosek = scipy_to_mosek(b)
obj_mosek = Expr.add(obj_mosek, Expr.dot(I, b_mosek))
obj_mosek = Expr.add(obj_mosek, Expr.dot(b_mosek, I_reshaped))
del b_mosek
if var_equalities:
d_mosek = scipy_to_mosek(d)
obj_mosek = Expr.add(obj_mosek, Expr.dot(E, d_mosek))
obj_mosek = Expr.add(obj_mosek, Expr.dot(d_mosek, E_reshaped))
del d_mosek

M.objective(ObjectiveSense.Minimize, obj_mosek)
Expand Down