Skip to content

Commit 2b97517

Browse files
author
Corey Ostrove
committed
Address Feedback
Add a number of changes to address feedback on PR. Adds some try-excepts around stim imports, fixes a few bugs and renames a function.
1 parent 2db13a1 commit 2b97517

File tree

8 files changed

+47
-27
lines changed

8 files changed

+47
-27
lines changed

Diff for: pygsti/baseobjs/errorgenbasis.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def elemgen_supports_and_matrices(self):
724724
"""
725725
return tuple(zip(self.elemgen_supports, self.elemgen_matrices))
726726

727-
def label_index(self, label, ok_if_missing=False):
727+
def label_index(self, label, ok_if_missing=False, identity_label='I'):
728728
"""
729729
Return the index of the specified elementary error generator label
730730
in this basis' `labels` list.
@@ -736,12 +736,13 @@ def label_index(self, label, ok_if_missing=False):
736736
737737
ok_if_missing : bool
738738
If True, then returns `None` instead of an integer when the given label is not present.
739+
740+
identity_label : str, optional (default 'I')
741+
An optional string specifying the label used to denote the identity in basis element labels.
739742
"""
740-
#CIO: I don't entirely understand the intention behind this method, so rather than trying to make it work
741-
#using `LocalElementaryErrorgenLabel` I'll just assert it is a global one for now...
742743
if isinstance(label, _LocalElementaryErrorgenLabel):
743-
raise NotImplementedError('This method is not currently implemented for `LocalElementaryErrorgenLabel` inputs.')
744-
744+
label = _GlobalElementaryErrorgenLabel.cast(label, self.sslbls, identity_label=identity_label)
745+
745746
support = label.sslbls
746747
eetype = label.errorgen_type
747748
bels = label.basis_element_labels

Diff for: pygsti/errorgenpropagation/errorpropagator.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
# in compliance with the License. You may obtain a copy of the License at
77
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
88
#***************************************************************************************************
9-
10-
import stim
9+
import warnings
10+
try:
11+
import stim
12+
except ImportError:
13+
msg = "Stim is required for use of the error generator propagation module, " \
14+
"and it does not appear to be installed. If you intend to use this module please update" \
15+
" your environment."
16+
warnings.warn(msg)
1117
import numpy as _np
1218
import scipy.linalg as _spl
1319
from .localstimerrorgen import LocalStimErrorgenLabel as _LSE

Diff for: pygsti/errorgenpropagation/localstimerrorgen.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from pygsti.baseobjs.errorgenlabel import ElementaryErrorgenLabel as _ElementaryErrorgenLabel, GlobalElementaryErrorgenLabel as _GEEL,\
1111
LocalElementaryErrorgenLabel as _LEEL
12-
import stim
12+
try:
13+
import stim
14+
except ImportError:
15+
pass
1316
import numpy as _np
1417
from pygsti.tools import change_basis
1518
from pygsti.tools.lindbladtools import create_elementary_errorgen

Diff for: pygsti/tools/errgenproptools.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
1111
#***************************************************************************************************
1212

13-
import stim
13+
import warnings
14+
try:
15+
import stim
16+
except ImportError:
17+
msg = "Stim is required for use of the error generator propagation tools module, " \
18+
"and it does not appear to be installed. If you intend to use this module please update" \
19+
" your environment."
20+
warnings.warn(msg)
21+
1422
import numpy as _np
1523
from pygsti.baseobjs.errorgenlabel import GlobalElementaryErrorgenLabel as _GEEL, LocalElementaryErrorgenLabel as _LEEL
1624
from pygsti.baseobjs import QubitSpace as _QubitSpace

Diff for: pygsti/tools/jamiolkowski.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def sums_of_negative_choi_eigenvalues(model):
327327
"""
328328
ret = []
329329
for (_, gate) in model.operations.items():
330-
J = fast_jamiolkowski_iso_std(gate.to_dense(), model.basis) # Choi mx basis doesn't matter
330+
J = fast_jamiolkowski_iso_std(gate.to_dense(on_space='HilbertSchmidt'), model.basis) # Choi mx basis doesn't matter
331331
evals = _np.linalg.eigvals(J) # could use eigvalsh, but wary of this since eigh can be wrong...
332332
sumOfNeg = 0.0
333333
for ev in evals:

Diff for: pygsti/tools/lindbladtools.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def create_lindbladian_term_errorgen(typ, Lm, Ln=None, sparse=False): # noqa N8
506506
return lind_errgen
507507

508508

509-
def random_error_generator_rates(num_qubits, errorgen_types=('H', 'S', 'C', 'A'), max_weights=None,
509+
def random_CPTP_error_generator_rates(num_qubits, errorgen_types=('H', 'S', 'C', 'A'), max_weights=None,
510510
H_params=(0.,.01), SCA_params=(0.,.01), error_metric=None, error_metric_value=None,
511511
relative_HS_contribution=None, fixed_errorgen_rates=None, sslbl_overlap=None,
512512
label_type='global', seed=None, qubit_labels=None):

Diff for: test/unit/objects/test_errorgenbasis.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,12 @@ def test_label_index(self):
107107
labels = self.complete_errorgen_basis_default_1Q.labels
108108

109109
test_eg = GlobalElementaryErrorgenLabel('C', ['X', 'Y'], (0,))
110+
test_eg_local = LocalElementaryErrorgenLabel('C', ['XI', 'YI'])
110111
test_eg_missing = GlobalElementaryErrorgenLabel('C', ['X', 'Y'], (1,))
111112

112113
lbl_idx = self.complete_errorgen_basis_default_1Q.label_index(test_eg)
113-
114+
lbl_idx_1 = self.complete_errorgen_basis_default_1Q.label_index(test_eg_local)
115+
assert lbl_idx == lbl_idx_1
114116
assert lbl_idx == labels.index(test_eg)
115117

116118
with self.assertRaises(KeyError):

Diff for: test/unit/tools/test_lindbladtools.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_elementary_errorgen_bases(self):
9494
class RandomErrorgenRatesTester(BaseCase):
9595

9696
def test_default_settings(self):
97-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, seed=1234, label_type='local')
97+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, seed=1234, label_type='local')
9898

9999
#make sure that we get the expected number of rates:
100100
self.assertEqual(len(random_errorgen_rates), 240)
@@ -105,31 +105,31 @@ def test_default_settings(self):
105105

106106
def test_sector_restrictions(self):
107107
#H-only:
108-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H',), seed=1234)
108+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H',), seed=1234)
109109
#make sure that we get the expected number of rates:
110110
self.assertEqual(len(random_errorgen_rates), 15)
111111
#also make sure this is CPTP, do so by constructing an error generator and confirming it doesn't fail
112112
#with CPTP parameterization. This should fail if the error generator dictionary is not CPTP.
113113
errorgen = LindbladErrorgen.from_elementary_errorgens(random_errorgen_rates, parameterization='CPTPLND', truncate=False, state_space=QubitSpace(2))
114114

115115
#S-only
116-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('S',), seed=1234)
116+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('S',), seed=1234)
117117
#make sure that we get the expected number of rates:
118118
self.assertEqual(len(random_errorgen_rates), 15)
119119
#also make sure this is CPTP, do so by constructing an error generator and confirming it doesn't fail
120120
#with CPTP parameterization. This should fail if the error generator dictionary is not CPTP.
121121
errorgen = LindbladErrorgen.from_elementary_errorgens(random_errorgen_rates, parameterization='CPTPLND', truncate=False, state_space=QubitSpace(2))
122122

123123
#H+S
124-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'), seed=1234)
124+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'), seed=1234)
125125
#make sure that we get the expected number of rates:
126126
self.assertEqual(len(random_errorgen_rates), 30)
127127
#also make sure this is CPTP, do so by constructing an error generator and confirming it doesn't fail
128128
#with CPTP parameterization. This should fail if the error generator dictionary is not CPTP.
129129
errorgen = LindbladErrorgen.from_elementary_errorgens(random_errorgen_rates, parameterization='CPTPLND', truncate=False, state_space=QubitSpace(2))
130130

131131
#H+S+A
132-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S','A'), seed=1234)
132+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S','A'), seed=1234)
133133
#make sure that we get the expected number of rates:
134134
self.assertEqual(len(random_errorgen_rates), 135)
135135
#also make sure this is CPTP, do so by constructing an error generator and confirming it doesn't fail
@@ -138,7 +138,7 @@ def test_sector_restrictions(self):
138138

139139
def test_error_metric_restrictions(self):
140140
#test generator_infidelity
141-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
141+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
142142
error_metric= 'generator_infidelity',
143143
error_metric_value=.99, seed=1234)
144144
#confirm this has the correct generator infidelity.
@@ -152,7 +152,7 @@ def test_error_metric_restrictions(self):
152152
assert abs(gen_infdl-.99)<1e-5
153153

154154
#test generator_error
155-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
155+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
156156
error_metric= 'total_generator_error',
157157
error_metric_value=.99, seed=1234)
158158
#confirm this has the correct generator infidelity.
@@ -166,7 +166,7 @@ def test_error_metric_restrictions(self):
166166
assert abs(gen_error-.99)<1e-5
167167

168168
#test relative_HS_contribution:
169-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
169+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
170170
error_metric= 'generator_infidelity',
171171
error_metric_value=.99,
172172
relative_HS_contribution=(.5, .5), seed=1234)
@@ -181,7 +181,7 @@ def test_error_metric_restrictions(self):
181181

182182
assert abs(gen_infdl_S - gen_infdl_H)<1e-5
183183

184-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
184+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
185185
error_metric= 'total_generator_error',
186186
error_metric_value=.99,
187187
relative_HS_contribution=(.5, .5), seed=1234)
@@ -198,41 +198,41 @@ def test_error_metric_restrictions(self):
198198

199199
def test_fixed_errorgen_rates(self):
200200
fixed_rates_dict = {GlobalElementaryErrorgenLabel('H', ('X',), (0,)): 1}
201-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
201+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
202202
fixed_errorgen_rates=fixed_rates_dict,
203203
seed=1234)
204204

205205
self.assertEqual(random_errorgen_rates[GlobalElementaryErrorgenLabel('H', ('X',), (0,))], 1)
206206

207207
def test_label_type(self):
208208

209-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
209+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
210210
label_type='local', seed=1234)
211211
assert isinstance(next(iter(random_errorgen_rates)), LocalElementaryErrorgenLabel)
212212

213213
def test_sslbl_overlap(self):
214-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
214+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S'),
215215
sslbl_overlap=(0,),
216216
seed=1234)
217217
for coeff in random_errorgen_rates:
218218
assert 0 in coeff.sslbls
219219

220220
def test_weight_restrictions(self):
221-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S','C','A'),
221+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S','C','A'),
222222
label_type='local', seed=1234,
223223
max_weights={'H':1, 'S':1, 'C':1, 'A':1})
224224
assert len(random_errorgen_rates) == 24
225225
#confirm still CPTP
226226
errorgen = LindbladErrorgen.from_elementary_errorgens(random_errorgen_rates, parameterization='CPTPLND', truncate=False, state_space=QubitSpace(2))
227227

228-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, errorgen_types=('H','S','C','A'),
228+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, errorgen_types=('H','S','C','A'),
229229
label_type='local', seed=1234,
230230
max_weights={'H':2, 'S':2, 'C':1, 'A':1})
231231
assert len(random_errorgen_rates) == 42
232232
errorgen = LindbladErrorgen.from_elementary_errorgens(random_errorgen_rates, parameterization='CPTPLND', truncate=False, state_space=QubitSpace(2))
233233

234234
def test_global_labels(self):
235-
random_errorgen_rates = lt.random_error_generator_rates(num_qubits=2, seed=1234, label_type='global')
235+
random_errorgen_rates = lt.random_CPTP_error_generator_rates(num_qubits=2, seed=1234, label_type='global')
236236

237237
#make sure that we get the expected number of rates:
238238
self.assertEqual(len(random_errorgen_rates), 240)

0 commit comments

Comments
 (0)