Skip to content

Commit cc47666

Browse files
authored
Merge pull request #622 from sandialabs/revise-gaugeopt-suites
Changes to definitions of named gauge optimization suites, resolve (unnecessary) warnings, robustify report generation
2 parents 60b6c19 + 4956cdf commit cc47666

File tree

5 files changed

+107
-101
lines changed

5 files changed

+107
-101
lines changed

pygsti/optimize/optimize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,7 @@ def _basin_callback(x, f, accept):
159159
elif method == "L-BFGS-B": opts['gtol'] = opts['ftol'] = tol # gradient norm and fractional y-tolerance
160160
elif method == "Nelder-Mead": opts['maxfev'] = maxfev # max fn evals (note: ftol and xtol can also be set)
161161

162-
if method in ("BFGS", "CG", "Newton-CG", "L-BFGS-B", "TNC", "SLSQP", "dogleg", "trust-ncg"): # use jacobian
163-
solution = _spo.minimize(fn, x0, options=opts, method=method, tol=tol, callback=callback, jac=jac)
164-
else:
165-
solution = _spo.minimize(fn, x0, options=opts, method=method, tol=tol, callback=callback)
162+
solution = _spo.minimize(fn, x0, options=opts, method=method, tol=tol, callback=callback, jac=jac)
166163

167164
return solution
168165

pygsti/protocols/gst.py

Lines changed: 96 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,14 @@ class GSTGaugeOptSuite(_NicelySerializable):
857857
given by the target model, which are used as the default when
858858
`gaugeopt_target` is None.
859859
"""
860+
861+
STANDARD_SUITENAMES = ("stdgaugeopt", "stdgaugeopt-unreliable2Q", "stdgaugeopt-tt", "stdgaugeopt-safe",
862+
"stdgaugeopt-noconversion", "stdgaugeopt-noconversion-safe")
863+
864+
SPECIAL_SUITENAMES = ("varySpam", "varySpamWt", "varyValidSpamWt", "toggleValidSpam",
865+
"varySpam-unreliable2Q", "varySpamWt-unreliable2Q",
866+
"varyValidSpamWt-unreliable2Q", "toggleValidSpam-unreliable2Q")
867+
860868
@classmethod
861869
def cast(cls, obj):
862870
if obj is None:
@@ -872,14 +880,13 @@ def cast(cls, obj):
872880

873881
def __init__(self, gaugeopt_suite_names=None, gaugeopt_argument_dicts=None, gaugeopt_target=None):
874882
super().__init__()
875-
if gaugeopt_suite_names is not None:
876-
if gaugeopt_suite_names == 'none':
877-
self.gaugeopt_suite_names = None
878-
else:
879-
self.gaugeopt_suite_names = (gaugeopt_suite_names,) \
880-
if isinstance(gaugeopt_suite_names, str) else tuple(gaugeopt_suite_names)
881-
else:
883+
if gaugeopt_suite_names is None or gaugeopt_suite_names == 'none':
882884
self.gaugeopt_suite_names = None
885+
elif isinstance(gaugeopt_suite_names, str):
886+
self.gaugeopt_suite_names = (gaugeopt_suite_names,)
887+
else:
888+
self.gaugeopt_suite_names = tuple(gaugeopt_suite_names)
889+
883890

884891
if gaugeopt_argument_dicts is not None:
885892
self.gaugeopt_argument_dicts = gaugeopt_argument_dicts.copy()
@@ -985,92 +992,98 @@ def to_dictionary(self, model, unreliable_ops=(), verbosity=0):
985992

986993
return gaugeopt_suite_dict
987994

988-
def _update_gaugeopt_dict_from_suitename(self, gaugeopt_suite_dict, root_lbl, suite_name, model,
989-
unreliable_ops, printer):
990-
if suite_name in ("stdgaugeopt", "stdgaugeopt-unreliable2Q", "stdgaugeopt-tt", "stdgaugeopt-safe",
991-
"stdgaugeopt-noconversion", "stdgaugeopt-noconversion-safe"):
995+
@staticmethod
996+
def _update_gaugeopt_dict_from_suitename(gaugeopt_suite_dict, root_lbl, suite_name, model, unreliable_ops, printer):
997+
998+
if suite_name in GSTGaugeOptSuite.STANDARD_SUITENAMES:
992999

993-
stages = [] # multi-stage gauge opt
9941000
gg = model.default_gauge_group
995-
convert_to = {'to_type': "full TP", 'flatten_structure': True, 'set_default_gauge_group': True} \
1001+
1002+
if gg is None:
1003+
return
1004+
1005+
stages = [] # multi-stage gauge opt
1006+
1007+
from pygsti.models.gaugegroup import TrivialGaugeGroup, UnitaryGaugeGroup, \
1008+
SpamGaugeGroup, TPSpamGaugeGroup
1009+
1010+
convert_to = {'to_type': "full", 'flatten_structure': True, 'set_default_gauge_group': False} \
9961011
if ('noconversion' not in suite_name and gg.name not in ("Full", "TP")) else None
9971012

998-
if isinstance(gg, _models.gaugegroup.TrivialGaugeGroup) and convert_to is None:
1013+
if isinstance(gg, TrivialGaugeGroup) and convert_to is None:
9991014
if suite_name == "stdgaugeopt-unreliable2Q" and model.dim == 16:
1000-
if any([gl in model.operations.keys() for gl in unreliable_ops]):
1015+
if any([gl in model.operations for gl in unreliable_ops]):
10011016
gaugeopt_suite_dict[root_lbl] = {'verbosity': printer}
10021017
else:
10031018
#just do a single-stage "trivial" gauge opts using default group
10041019
gaugeopt_suite_dict[root_lbl] = {'verbosity': printer}
1005-
1006-
elif gg is not None:
1007-
metric = 'frobeniustt' if suite_name == 'stdgaugeopt-tt' else 'frobenius'
1008-
1009-
#Stage 1: plain vanilla gauge opt to get into "right ballpark"
1010-
if gg.name in ("Full", "TP"):
1011-
stages.append(
1012-
{
1013-
'gates_metric': metric, 'spam_metric': metric,
1014-
'item_weights': {'gates': 1.0, 'spam': 1.0},
1015-
'verbosity': printer
1016-
})
1017-
1018-
#Stage 2: unitary gauge opt that tries to nail down gates (at
1019-
# expense of spam if needed)
1020-
stages.append(
1021-
{
1022-
'convert_model_to': convert_to,
1023-
'gates_metric': metric, 'spam_metric': metric,
1024-
'item_weights': {'gates': 1.0, 'spam': 0.0},
1025-
'gauge_group': _models.gaugegroup.UnitaryGaugeGroup(model.state_space,
1026-
model.basis, model.evotype),
1027-
'oob_check_interval': 1 if ('-safe' in suite_name) else 0,
1028-
'verbosity': printer
1029-
})
1030-
1031-
#Stage 3: spam gauge opt that fixes spam scaling at expense of
1032-
# non-unital parts of gates (but shouldn't affect these
1033-
# elements much since they should be small from Stage 2).
1034-
s3gg = _models.gaugegroup.SpamGaugeGroup if (gg.name == "Full") else \
1035-
_models.gaugegroup.TPSpamGaugeGroup
1036-
stages.append(
1037-
{
1038-
'convert_model_to': convert_to,
1039-
'gates_metric': metric, 'spam_metric': metric,
1040-
'item_weights': {'gates': 0.0, 'spam': 1.0},
1041-
'spam_penalty_factor': 1.0,
1042-
'gauge_group': s3gg(model.state_space, model.evotype),
1043-
'oob_check_interval': 1,
1044-
'verbosity': printer
1045-
})
1046-
1047-
if suite_name == "stdgaugeopt-unreliable2Q" and model.dim == 16:
1048-
if any([gl in model.operations.keys() for gl in unreliable_ops]):
1049-
stage2_item_weights = {'gates': 1, 'spam': 0.0}
1050-
for gl in unreliable_ops:
1051-
if gl in model.operations.keys(): stage2_item_weights[gl] = 0.01
1052-
stages_2qubit_unreliable = [stage.copy() for stage in stages] # ~deep copy of stages
1053-
istage2 = 1 if gg.name in ("Full", "TP") else 0
1054-
stages_2qubit_unreliable[istage2]['item_weights'] = stage2_item_weights
1055-
gaugeopt_suite_dict[root_lbl] = stages_2qubit_unreliable # add additional gauge opt
1056-
else:
1057-
_warnings.warn(("`unreliable2Q` was given as a gauge opt suite, but none of the"
1058-
" gate names in 'unreliable_ops', i.e., %s,"
1059-
" are present in the target model. Omitting 'single-2QUR' gauge opt.")
1060-
% (", ".join(unreliable_ops)))
1020+
return
1021+
1022+
metric = 'frobeniustt' if suite_name == 'stdgaugeopt-tt' else 'frobenius'
1023+
ss = model.state_space
1024+
et = model.evotype
1025+
1026+
# Stage 1: plain vanilla gauge opt to get into "right ballpark"
1027+
if gg.name in ("Full", "TP"):
1028+
stages.append({
1029+
'gates_metric': metric, 'spam_metric': metric,
1030+
'item_weights': {'gates': 1.0, 'spam': 1.0},
1031+
'verbosity': printer
1032+
})
1033+
1034+
# Stage 2: unitary gauge opt that tries to nail down gates (at
1035+
# expense of spam if needed)
1036+
s2gg = UnitaryGaugeGroup(ss, model.basis, et)
1037+
stages.append({
1038+
'convert_model_to': convert_to,
1039+
'gates_metric': metric, 'spam_metric': metric,
1040+
'item_weights': {'gates': 1.0, 'spam': 0.0},
1041+
'gauge_group': s2gg,
1042+
'oob_check_interval': 1 if ('-safe' in suite_name) else 0,
1043+
'verbosity': printer
1044+
})
1045+
1046+
# Stage 3: spam gauge opt that fixes spam scaling at expense of
1047+
# non-unital parts of gates (but shouldn't affect these
1048+
# elements much since they should be small from Stage 2).
1049+
s3gg = SpamGaugeGroup(ss, et) if (gg.name == "Full") else TPSpamGaugeGroup(ss, et)
1050+
stages.append({
1051+
'convert_model_to': convert_to,
1052+
'gates_metric': metric, 'spam_metric': metric,
1053+
'item_weights': {'gates': 0.0, 'spam': 1.0},
1054+
'spam_penalty_factor': 1.0,
1055+
'gauge_group': s3gg,
1056+
'oob_check_interval': 1,
1057+
'verbosity': printer
1058+
})
1059+
1060+
if suite_name == "stdgaugeopt-unreliable2Q" and model.dim == 16:
1061+
if any([gl in model.operations for gl in unreliable_ops]):
1062+
stage2_item_weights = {'gates': 1.0, 'spam': 0.0}
1063+
for gl in unreliable_ops:
1064+
if gl in model.operations:
1065+
stage2_item_weights[gl] = 0.01
1066+
stages_2qubit_unreliable = [stage.copy() for stage in stages] # ~deep copy of stages
1067+
istage2 = 1 if gg.name in ("Full", "TP") else 0
1068+
stages_2qubit_unreliable[istage2]['item_weights'] = stage2_item_weights
1069+
gaugeopt_suite_dict[root_lbl] = stages_2qubit_unreliable # add additional gauge opt
10611070
else:
1062-
gaugeopt_suite_dict[root_lbl] = stages # can be a list of stage dictionaries
1071+
_warnings.warn(("`unreliable2Q` was given as a gauge opt suite, but none of the"
1072+
" gate names in 'unreliable_ops', i.e., %s,"
1073+
" are present in the target model. Omitting 'single-2QUR' gauge opt.")
1074+
% (", ".join(unreliable_ops)))
1075+
else:
1076+
gaugeopt_suite_dict[root_lbl] = stages # can be a list of stage dictionaries
10631077

1064-
elif suite_name in ("varySpam", "varySpamWt", "varyValidSpamWt", "toggleValidSpam") or \
1065-
suite_name in ("varySpam-unreliable2Q", "varySpamWt-unreliable2Q",
1066-
"varyValidSpamWt-unreliable2Q", "toggleValidSpam-unreliable2Q"):
1078+
elif suite_name in GSTGaugeOptSuite.SPECIAL_SUITENAMES:
10671079

1068-
base_wts = {'gates': 1}
1080+
base_wts = {'gates': 1.0}
10691081
if suite_name.endswith("unreliable2Q") and model.dim == 16:
1070-
if any([gl in model.operations.keys() for gl in unreliable_ops]):
1071-
base = {'gates': 1}
1082+
if any([gl in model.operations for gl in unreliable_ops]):
1083+
base = {'gates': 1.0}
10721084
for gl in unreliable_ops:
1073-
if gl in model.operations.keys(): base[gl] = 0.01
1085+
if gl in model.operations:
1086+
base[gl] = 0.01
10741087
base_wts = base
10751088

10761089
if suite_name == "varySpam":
@@ -1097,9 +1110,6 @@ def _update_gaugeopt_dict_from_suitename(self, gaugeopt_suite_dict, root_lbl, su
10971110
'item_weights': item_weights,
10981111
'spam_penalty_factor': valid_spam, 'verbosity': printer}
10991112

1100-
elif suite_name == "unreliable2Q":
1101-
raise ValueError(("unreliable2Q is no longer a separate 'suite'. You should precede it with the suite"
1102-
" name, e.g. 'stdgaugeopt-unreliable2Q' or 'varySpam-unreliable2Q'"))
11031113
elif suite_name == 'none':
11041114
gaugeopt_suite_dict[root_lbl] = None
11051115
else:
@@ -2091,9 +2101,9 @@ def _add_gauge_opt(results, base_est_label, gaugeopt_suite, starting_model,
20912101
"""
20922102
printer = _baseobjs.VerbosityPrinter.create_printer(verbosity, comm)
20932103

2094-
#Get gauge optimization dictionary
2095-
gaugeopt_suite_dict = gaugeopt_suite.to_dictionary(starting_model,
2096-
unreliable_ops, printer - 1)
2104+
gaugeopt_suite_dict = gaugeopt_suite.to_dictionary(
2105+
starting_model, unreliable_ops, printer - 1
2106+
)
20972107

20982108
#Gauge optimize to list of gauge optimization parameters
20992109
for go_label, goparams in gaugeopt_suite_dict.items():
@@ -2505,6 +2515,7 @@ def _compute_1d_reference_values_and_name(estimate, badfit_options, gaugeopt_sui
25052515
spamdd[key] = 0.5 * _tools.optools.povm_diamonddist(gaugeopt_model, target_model, key)
25062516

25072517
dd[lbl]['SPAM'] = sum(spamdd.values())
2518+
25082519
return dd, 'diamond distance'
25092520
else:
25102521
raise ValueError("Invalid wildcard1d_reference value (%s) in bad-fit options!"

pygsti/tools/basistools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import numpy as _np
1616

1717
from pygsti.baseobjs.basisconstructors import _basis_constructor_dict
18-
# from ..baseobjs.basis import Basis, BuiltinBasis, DirectSumBasis
1918
from pygsti.baseobjs import basis as _basis
2019

20+
2121
@lru_cache(maxsize=1)
2222
def basis_matrices(name_or_basis, dim, sparse=False):
2323
"""

pygsti/tools/optools.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,6 @@ def psd_square_root(mat):
163163
"""
164164
_warnings.warn(message)
165165
evals[evals < 0] = 0.0
166-
tr = _np.sum(evals)
167-
if abs(tr - 1) > __VECTOR_TOL__:
168-
message = f"""
169-
The PSD part of the input matrix is not trace-1 up to tolerance {__VECTOR_TOL__}.
170-
Beware result!
171-
"""
172-
_warnings.warn(message)
173166
sqrt_mat = U @ (_np.sqrt(evals).reshape((-1, 1)) * U.T.conj())
174167
return sqrt_mat
175168

@@ -1031,9 +1024,14 @@ def povm_diamonddist(model, target_model, povmlbl):
10311024
-------
10321025
float
10331026
"""
1034-
povm_mx = compute_povm_map(model, povmlbl)
1035-
target_povm_mx = compute_povm_map(target_model, povmlbl)
1036-
return diamonddist(povm_mx, target_povm_mx, target_model.basis)
1027+
try:
1028+
povm_mx = compute_povm_map(model, povmlbl)
1029+
target_povm_mx = compute_povm_map(target_model, povmlbl)
1030+
return diamonddist(povm_mx, target_povm_mx, target_model.basis)
1031+
except AssertionError as e:
1032+
assert '`dim` must be a perfect square' in str(e)
1033+
return _np.NaN
1034+
10371035

10381036
def instrument_infidelity(a, b, mx_basis):
10391037
"""

test/unit/tools/test_optools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_unitary_to_pauligate(self):
4848
# U_2Q is 4x4 unitary matrix operating on isolated two-qubit space (CX(pi) rotation)
4949

5050
op_2Q = ot.unitary_to_pauligate(U_2Q)
51-
op_2Q_inv = ot.process_mx_to_unitary(bt.change_basis(op_2Q, 'pp', 'std'))
51+
op_2Q_inv = ot.std_process_mx_to_unitary(bt.change_basis(op_2Q, 'pp', 'std'))
5252
self.assertArraysAlmostEqual(U_2Q, op_2Q_inv)
5353

5454
def test_decompose_gate_matrix(self):

0 commit comments

Comments
 (0)