Skip to content

Commit 4956cdf

Browse files
committed
improve readability with material reformatting. Check membership of a key in dicts by "k in d" rather than the redundant "k in d.keys()".
1 parent 0f77d16 commit 4956cdf

File tree

1 file changed

+71
-63
lines changed

1 file changed

+71
-63
lines changed

pygsti/protocols/gst.py

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -997,85 +997,93 @@ def _update_gaugeopt_dict_from_suitename(gaugeopt_suite_dict, root_lbl, suite_na
997997

998998
if suite_name in GSTGaugeOptSuite.STANDARD_SUITENAMES:
999999

1000-
stages = [] # multi-stage gauge opt
10011000
gg = model.default_gauge_group
1002-
convert_to = {'to_type': "full", 'flatten_structure': True, 'set_default_gauge_group': True} \
1001+
1002+
if gg is None:
1003+
return
1004+
1005+
stages = [] # multi-stage gauge opt
1006+
1007+
from pygsti.models.gaugegroup import TrivialGaugeGroup, UnitaryGaugeGroup, \
1008+
SpamGaugeGroup, TPSpamGaugeGroup
1009+
1010+
convert_to = {'to_type': "full", 'flatten_structure': True, 'set_default_gauge_group': False} \
10031011
if ('noconversion' not in suite_name and gg.name not in ("Full", "TP")) else None
10041012

1005-
if isinstance(gg, _models.gaugegroup.TrivialGaugeGroup) and convert_to is None:
1013+
if isinstance(gg, TrivialGaugeGroup) and convert_to is None:
10061014
if suite_name == "stdgaugeopt-unreliable2Q" and model.dim == 16:
1007-
if any([gl in model.operations.keys() for gl in unreliable_ops]):
1015+
if any([gl in model.operations for gl in unreliable_ops]):
10081016
gaugeopt_suite_dict[root_lbl] = {'verbosity': printer}
10091017
else:
10101018
#just do a single-stage "trivial" gauge opts using default group
10111019
gaugeopt_suite_dict[root_lbl] = {'verbosity': printer}
1012-
1013-
elif gg is not None:
1014-
metric = 'frobeniustt' if suite_name == 'stdgaugeopt-tt' else 'frobenius'
1015-
1016-
#Stage 1: plain vanilla gauge opt to get into "right ballpark"
1017-
if gg.name in ("Full", "TP"):
1018-
stages.append(
1019-
{
1020-
'gates_metric': metric, 'spam_metric': metric,
1021-
'item_weights': {'gates': 1.0, 'spam': 1.0},
1022-
'verbosity': printer
1023-
})
1024-
1025-
#Stage 2: unitary gauge opt that tries to nail down gates (at
1026-
# expense of spam if needed)
1027-
stages.append(
1028-
{
1029-
'convert_model_to': convert_to,
1030-
'gates_metric': metric, 'spam_metric': metric,
1031-
'item_weights': {'gates': 1.0, 'spam': 0.0},
1032-
'gauge_group': _models.gaugegroup.UnitaryGaugeGroup(model.state_space,
1033-
model.basis, model.evotype),
1034-
'oob_check_interval': 1 if ('-safe' in suite_name) else 0,
1035-
'verbosity': printer
1036-
})
1037-
1038-
#Stage 3: spam gauge opt that fixes spam scaling at expense of
1039-
# non-unital parts of gates (but shouldn't affect these
1040-
# elements much since they should be small from Stage 2).
1041-
s3gg = _models.gaugegroup.SpamGaugeGroup if (gg.name == "Full") else \
1042-
_models.gaugegroup.TPSpamGaugeGroup
1043-
stages.append(
1044-
{
1045-
'convert_model_to': convert_to,
1046-
'gates_metric': metric, 'spam_metric': metric,
1047-
'item_weights': {'gates': 0.0, 'spam': 1.0},
1048-
'spam_penalty_factor': 1.0,
1049-
'gauge_group': s3gg(model.state_space, model.evotype),
1050-
'oob_check_interval': 1,
1051-
'verbosity': printer
1052-
})
1053-
1054-
if suite_name == "stdgaugeopt-unreliable2Q" and model.dim == 16:
1055-
if any([gl in model.operations.keys() for gl in unreliable_ops]):
1056-
stage2_item_weights = {'gates': 1, 'spam': 0.0}
1057-
for gl in unreliable_ops:
1058-
if gl in model.operations.keys(): stage2_item_weights[gl] = 0.01
1059-
stages_2qubit_unreliable = [stage.copy() for stage in stages] # ~deep copy of stages
1060-
istage2 = 1 if gg.name in ("Full", "TP") else 0
1061-
stages_2qubit_unreliable[istage2]['item_weights'] = stage2_item_weights
1062-
gaugeopt_suite_dict[root_lbl] = stages_2qubit_unreliable # add additional gauge opt
1063-
else:
1064-
_warnings.warn(("`unreliable2Q` was given as a gauge opt suite, but none of the"
1065-
" gate names in 'unreliable_ops', i.e., %s,"
1066-
" are present in the target model. Omitting 'single-2QUR' gauge opt.")
1067-
% (", ".join(unreliable_ops)))
1020+
return
1021+
1022+
metric = 'frobeniustt' if suite_name == 'stdgaugeopt-tt' else 'frobenius'
1023+
ss = model.state_space
1024+
et = model.evotype
1025+
1026+
# Stage 1: plain vanilla gauge opt to get into "right ballpark"
1027+
if gg.name in ("Full", "TP"):
1028+
stages.append({
1029+
'gates_metric': metric, 'spam_metric': metric,
1030+
'item_weights': {'gates': 1.0, 'spam': 1.0},
1031+
'verbosity': printer
1032+
})
1033+
1034+
# Stage 2: unitary gauge opt that tries to nail down gates (at
1035+
# expense of spam if needed)
1036+
s2gg = UnitaryGaugeGroup(ss, model.basis, et)
1037+
stages.append({
1038+
'convert_model_to': convert_to,
1039+
'gates_metric': metric, 'spam_metric': metric,
1040+
'item_weights': {'gates': 1.0, 'spam': 0.0},
1041+
'gauge_group': s2gg,
1042+
'oob_check_interval': 1 if ('-safe' in suite_name) else 0,
1043+
'verbosity': printer
1044+
})
1045+
1046+
# Stage 3: spam gauge opt that fixes spam scaling at expense of
1047+
# non-unital parts of gates (but shouldn't affect these
1048+
# elements much since they should be small from Stage 2).
1049+
s3gg = SpamGaugeGroup(ss, et) if (gg.name == "Full") else TPSpamGaugeGroup(ss, et)
1050+
stages.append({
1051+
'convert_model_to': convert_to,
1052+
'gates_metric': metric, 'spam_metric': metric,
1053+
'item_weights': {'gates': 0.0, 'spam': 1.0},
1054+
'spam_penalty_factor': 1.0,
1055+
'gauge_group': s3gg,
1056+
'oob_check_interval': 1,
1057+
'verbosity': printer
1058+
})
1059+
1060+
if suite_name == "stdgaugeopt-unreliable2Q" and model.dim == 16:
1061+
if any([gl in model.operations for gl in unreliable_ops]):
1062+
stage2_item_weights = {'gates': 1.0, 'spam': 0.0}
1063+
for gl in unreliable_ops:
1064+
if gl in model.operations:
1065+
stage2_item_weights[gl] = 0.01
1066+
stages_2qubit_unreliable = [stage.copy() for stage in stages] # ~deep copy of stages
1067+
istage2 = 1 if gg.name in ("Full", "TP") else 0
1068+
stages_2qubit_unreliable[istage2]['item_weights'] = stage2_item_weights
1069+
gaugeopt_suite_dict[root_lbl] = stages_2qubit_unreliable # add additional gauge opt
10681070
else:
1069-
gaugeopt_suite_dict[root_lbl] = stages # can be a list of stage dictionaries
1071+
_warnings.warn(("`unreliable2Q` was given as a gauge opt suite, but none of the"
1072+
" gate names in 'unreliable_ops', i.e., %s,"
1073+
" are present in the target model. Omitting 'single-2QUR' gauge opt.")
1074+
% (", ".join(unreliable_ops)))
1075+
else:
1076+
gaugeopt_suite_dict[root_lbl] = stages # can be a list of stage dictionaries
10701077

10711078
elif suite_name in GSTGaugeOptSuite.SPECIAL_SUITENAMES:
10721079

10731080
base_wts = {'gates': 1.0}
10741081
if suite_name.endswith("unreliable2Q") and model.dim == 16:
1075-
if any([gl in model.operations.keys() for gl in unreliable_ops]):
1082+
if any([gl in model.operations for gl in unreliable_ops]):
10761083
base = {'gates': 1.0}
10771084
for gl in unreliable_ops:
1078-
if gl in model.operations.keys(): base[gl] = 0.01
1085+
if gl in model.operations:
1086+
base[gl] = 0.01
10791087
base_wts = base
10801088

10811089
if suite_name == "varySpam":

0 commit comments

Comments
 (0)