Skip to content

Commit 80b023f

Browse files
author
Paul Kienzle
committed
nicer layout for sasmodels.compare plots
1 parent 8daeb6e commit 80b023f

File tree

1 file changed

+57
-23
lines changed

1 file changed

+57
-23
lines changed

sasmodels/compare.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,14 @@ def run_models(opts, verbose=False):
854854

855855
base_time = comp_time = None
856856
base_value = comp_value = resid = relerr = None
857+
base_name = comp_name = None
858+
def name_pair(base_test, comp_test):
859+
"""
860+
iterate over name pairs in reverse order, keeping the first that are different
861+
"""
862+
nonlocal base_name, comp_name
863+
if base_test != comp_test:
864+
base_name, comp_name = base_test, comp_test
857865

858866
# Base calculation
859867
try:
@@ -879,21 +887,35 @@ def run_models(opts, verbose=False):
879887
except ImportError:
880888
traceback.print_exc()
881889

890+
# Find a string pair that describes the difference between base and comp.
891+
# Go from least interesting to most interesting, updating if different.
892+
name_pair("base", "comp")
893+
name_pair(
894+
" ".join(f"{k}={v}" for k, v in base_pars.items() if comp_pars.get(k, v) != v),
895+
" ".join(f"{k}={v}" for k, v in comp_pars.items() if base_pars.get(k, v) != v),
896+
)
897+
name_pair(base.engine, comp.engine)
898+
name_pair(base.model.info.name, comp.model.info.name)
899+
else:
900+
name_pair(f"{base.model.info.name}:{base.engine}", None)
901+
882902
# Compare, but only if computing both forms
883903
if comparison:
884904
resid = (base_value - comp_value)
885905
relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
886906
if verbose:
887907
_print_stats("|%s-%s|"
888-
% (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
908+
% (base_name, comp_name) + (" "*(3+len(comp_name))),
889909
resid)
890910
_print_stats("|(%s-%s)/%s|"
891-
% (base.engine, comp.engine, comp.engine),
911+
% (base_name, comp_name, comp_name),
892912
relerr)
893913

894-
return dict(base_value=base_value, comp_value=comp_value,
895-
base_time=base_time, comp_time=comp_time,
896-
resid=resid, relerr=relerr)
914+
return dict(
915+
base_name=base_name, comp_name=comp_name,
916+
base_value=base_value, comp_value=comp_value,
917+
base_time=base_time, comp_time=comp_time,
918+
resid=resid, relerr=relerr)
897919

898920

899921
def _print_stats(label, err):
@@ -923,6 +945,7 @@ def plot_models(opts, result, limits=None, setnum=0):
923945
"""
924946
import matplotlib.pyplot as plt
925947

948+
base_name, comp_name = result['base_name'], result['comp_name']
926949
base_value, comp_value = result['base_value'], result['comp_value']
927950
base_time, comp_time = result['base_time'], result['comp_time']
928951
resid, relerr = result['resid'], result['relerr']
@@ -947,25 +970,30 @@ def plot_models(opts, result, limits=None, setnum=0):
947970

948971
if have_base:
949972
if have_comp:
950-
plt.subplot(131)
951-
plot_theory(base_data, base_value, view=view, use_data=use_data, limits=limits)
973+
plt.subplot(221)
974+
plot_theory(base_data, base_value, label=base_name, view=view, use_data=use_data, limits=limits)
952975
if setnum > 0:
953976
plt.legend([f"Set {k+1}" for k in range(setnum+1)], loc='best')
954-
plt.title("%s t=%.2f ms"%(base.engine, base_time))
977+
plt.title("%s t=%.2f ms"%(base.model.info.name, base_time))
978+
if have_comp:
979+
plt.gca().tick_params(labelbottom=False)
980+
plt.gca().set_xticks([])
981+
plt.xlabel('')
982+
955983
#cbar_title = "log I"
956984
if have_comp:
957985
if have_base:
958-
plt.subplot(132)
986+
plt.subplot(223)
959987
if not opts['is2d'] and have_base:
960-
plot_theory(comp_data, base_value, view=view, use_data=use_data, limits=limits)
961-
plot_theory(comp_data, comp_value, view=view, use_data=use_data, limits=limits)
962-
plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
988+
plot_theory(comp_data, base_value, label=base_name, view=view, use_data=use_data, limits=limits)
989+
plot_theory(comp_data, comp_value, label=comp_name, view=view, use_data=use_data, limits=limits)
990+
plt.title("%s t=%.2f ms"%(comp.model.info.name, comp_time))
963991
#plt.gca().tick_params(labelbottom=False, labelleft=False)
964-
plt.gca().set_yticks([])
965-
plt.ylabel('')
992+
#plt.gca().set_yticks([])
993+
#plt.ylabel('')
966994
#cbar_title = "log I"
967995
if have_base and have_comp:
968-
plt.subplot(133)
996+
plt.subplot(222)
969997
if not opts['rel_err']:
970998
err, errstr, errview = resid, "abs err", "linear"
971999
else:
@@ -980,17 +1008,23 @@ def plot_models(opts, result, limits=None, setnum=0):
9801008
# Note: base_data only since base and comp have same q values (though
9811009
# perhaps different resolution), and we are plotting the difference
9821010
# at each q
983-
plot_theory(base_data, err, view=errview, use_data=use_data)
1011+
plot_theory(base_data, err, view=errview, label=errstr, use_data=use_data)
9841012
plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
9851013
plt.title("max %s = %.3g"%(errstr, abs(err).max()))
986-
plt.ylabel(errstr)
1014+
if opts['is2d']:
1015+
plt.gca().tick_params(labelleft=False)
1016+
plt.gca().set_yticks([])
1017+
plt.ylabel('')
1018+
else:
1019+
plt.ylabel(errstr)
1020+
9871021
#cbar_title = errstr if errview=="linear" else "log "+errstr
988-
#if is2D:
989-
# h = plt.colorbar()
990-
# h.ax.set_title(cbar_title)
991-
fig = plt.gcf()
992-
extra_title = ' '+opts['title'] if opts['title'] else ''
993-
fig.suptitle(":".join(opts['name']) + extra_title)
1022+
# if is2D:
1023+
# h = plt.colorbar()
1024+
# h.ax.set_title(cbar_title)
1025+
# fig = plt.gcf()
1026+
# extra_title = ' '+opts['title'] if opts['title'] else ''
1027+
# fig.suptitle(":".join(opts['name']) + extra_title)
9941028

9951029
if have_base and have_comp and opts['show_hist']:
9961030
plt.figure()

0 commit comments

Comments
 (0)