Skip to content

Commit ee66fa2

Browse files
committed
Allow custom NPA levels to be specified wherein each column is a string denoting a sequence of operators (instead of symbols).
1 parent 9a8277b commit ee66fa2

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

inflation/sdp/InflationSDP.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def generate_relaxation(self,
250250
column_specification:
251251
Union[str,
252252
List[List[int]],
253-
List[sp.core.symbol.Symbol]] = "npa1",
253+
List[Union[sp.core.symbol.Symbol, str]]] = "npa1",
254254
**kwargs
255255
) -> None:
256256
r"""Creates the SDP relaxation of the quantum inflation problem using
@@ -314,11 +314,11 @@ def generate_relaxation(self,
314314
is the same as ``npa2`` for three parties. ``[[]]`` encodes the
315315
identity element.
316316
317-
* `List[sympy.core.symbol.Symbol]`: one can also fully specify the
318-
generating set by giving a list of symbolic operators built from
319-
the measurement operators in ``InflationSDP.measurements``. This
320-
list needs to have the identity ``sympy.S.One`` as the first
321-
element.
317+
* `List[Union[sp.core.symbol.Symbol, str]]`: one can also fully specify the
318+
generating set by giving a list of operator sequence names or symbolic
319+
operators built from the measurement operators in ``InflationSDP.measurements``.
320+
This list needs to have the identity (string or integer 1, or ``sympy.S.One``)
321+
as the first element.
322322
323323
kwargs :
324324
Additional arguments that will be passed to
@@ -1200,7 +1200,7 @@ def build_columns(self,
12001200
12011201
Parameters
12021202
----------
1203-
column_specification : Union[str, List[List[int]], List[sympy.core.symbol.Symbol]]
1203+
column_specification : Union[str, List[List[int]], List[Union[sympy.core.symbol.Symbol, str]]]
12041204
See description in the ``self.generate_relaxation()`` method.
12051205
max_monomial_length : int, optional
12061206
Maximum number of letters in a monomial in the generating set,
@@ -1236,10 +1236,12 @@ def build_columns(self,
12361236
else:
12371237
raise Exception("The generating columns are not specified "
12381238
+ "in a valid format.")
1239-
elif type(column_specification[0]) in [int, sp.core.symbol.Symbol,
1239+
elif type(column_specification[0]) in [int,
1240+
sp.core.symbol.Symbol,
12401241
sp.core.power.Pow,
12411242
sp.core.mul.Mul,
1242-
sp.core.numbers.One]:
1243+
sp.core.numbers.One,
1244+
str]:
12431245
columns = []
12441246
for col in column_specification:
12451247
if type(col) in [int, sp.core.numbers.One]:
@@ -1250,7 +1252,8 @@ def build_columns(self,
12501252
columns.append(np.array([], dtype=np.intc))
12511253
elif type(col) in [sp.core.symbol.Symbol,
12521254
sp.core.power.Pow,
1253-
sp.core.mul.Mul]:
1255+
sp.core.mul.Mul,
1256+
str]:
12541257
columns.append(self.mon_to_lexrepr(
12551258
self._interpret_name(col)))
12561259
else:

0 commit comments

Comments
 (0)