@@ -854,6 +854,14 @@ def run_models(opts, verbose=False):
854854
855855 base_time = comp_time = None
856856 base_value = comp_value = resid = relerr = None
857+ base_name = comp_name = None
858+ def name_pair (base_test , comp_test ):
859+ """
860+ iterate over name pairs in reverse order, keeping the first that are different
861+ """
862+ nonlocal base_name , comp_name
863+ if base_test != comp_test :
864+ base_name , comp_name = base_test , comp_test
857865
858866 # Base calculation
859867 try :
@@ -879,21 +887,35 @@ def run_models(opts, verbose=False):
879887 except ImportError :
880888 traceback .print_exc ()
881889
890+ # Find a string pair that describes the difference between base and comp.
891+ # Go from least interesting to most interesting, updating if different.
892+ name_pair ("base" , "comp" )
893+ name_pair (
894+ " " .join (f"{ k } ={ v } " for k , v in base_pars .items () if comp_pars .get (k , v ) != v ),
895+ " " .join (f"{ k } ={ v } " for k , v in comp_pars .items () if base_pars .get (k , v ) != v ),
896+ )
897+ name_pair (base .engine , comp .engine )
898+ name_pair (base .model .info .name , comp .model .info .name )
899+ else :
900+ name_pair (f"{ base .model .info .name } :{ base .engine } " , None )
901+
882902 # Compare, but only if computing both forms
883903 if comparison :
884904 resid = (base_value - comp_value )
885905 relerr = resid / np .where (comp_value != 0. , abs (comp_value ), 1.0 )
886906 if verbose :
887907 _print_stats ("|%s-%s|"
888- % (base . engine , comp . engine ) + (" " * (3 + len (comp . engine ))),
908+ % (base_name , comp_name ) + (" " * (3 + len (comp_name ))),
889909 resid )
890910 _print_stats ("|(%s-%s)/%s|"
891- % (base . engine , comp . engine , comp . engine ),
911+ % (base_name , comp_name , comp_name ),
892912 relerr )
893913
894- return dict (base_value = base_value , comp_value = comp_value ,
895- base_time = base_time , comp_time = comp_time ,
896- resid = resid , relerr = relerr )
914+ return dict (
915+ base_name = base_name , comp_name = comp_name ,
916+ base_value = base_value , comp_value = comp_value ,
917+ base_time = base_time , comp_time = comp_time ,
918+ resid = resid , relerr = relerr )
897919
898920
899921def _print_stats (label , err ):
@@ -923,6 +945,7 @@ def plot_models(opts, result, limits=None, setnum=0):
923945 """
924946 import matplotlib .pyplot as plt
925947
948+ base_name , comp_name = result ['base_name' ], result ['comp_name' ]
926949 base_value , comp_value = result ['base_value' ], result ['comp_value' ]
927950 base_time , comp_time = result ['base_time' ], result ['comp_time' ]
928951 resid , relerr = result ['resid' ], result ['relerr' ]
@@ -947,25 +970,30 @@ def plot_models(opts, result, limits=None, setnum=0):
947970
948971 if have_base :
949972 if have_comp :
950- plt .subplot (131 )
951- plot_theory (base_data , base_value , view = view , use_data = use_data , limits = limits )
973+ plt .subplot (221 )
974+ plot_theory (base_data , base_value , label = base_name , view = view , use_data = use_data , limits = limits )
952975 if setnum > 0 :
953976 plt .legend ([f"Set { k + 1 } " for k in range (setnum + 1 )], loc = 'best' )
954- plt .title ("%s t=%.2f ms" % (base .engine , base_time ))
977+ plt .title ("%s t=%.2f ms" % (base .model .info .name , base_time ))
978+ if have_comp :
979+ plt .gca ().tick_params (labelbottom = False )
980+ plt .gca ().set_xticks ([])
981+ plt .xlabel ('' )
982+
955983 #cbar_title = "log I"
956984 if have_comp :
957985 if have_base :
958- plt .subplot (132 )
986+ plt .subplot (223 )
959987 if not opts ['is2d' ] and have_base :
960- plot_theory (comp_data , base_value , view = view , use_data = use_data , limits = limits )
961- plot_theory (comp_data , comp_value , view = view , use_data = use_data , limits = limits )
962- plt .title ("%s t=%.2f ms" % (comp .engine , comp_time ))
988+ plot_theory (comp_data , base_value , label = base_name , view = view , use_data = use_data , limits = limits )
989+ plot_theory (comp_data , comp_value , label = comp_name , view = view , use_data = use_data , limits = limits )
990+ plt .title ("%s t=%.2f ms" % (comp .model . info . name , comp_time ))
963991 #plt.gca().tick_params(labelbottom=False, labelleft=False)
964- plt .gca ().set_yticks ([])
965- plt .ylabel ('' )
992+ # plt.gca().set_yticks([])
993+ # plt.ylabel('')
966994 #cbar_title = "log I"
967995 if have_base and have_comp :
968- plt .subplot (133 )
996+ plt .subplot (222 )
969997 if not opts ['rel_err' ]:
970998 err , errstr , errview = resid , "abs err" , "linear"
971999 else :
@@ -980,17 +1008,23 @@ def plot_models(opts, result, limits=None, setnum=0):
9801008 # Note: base_data only since base and comp have same q values (though
9811009 # perhaps different resolution), and we are plotting the difference
9821010 # at each q
983- plot_theory (base_data , err , view = errview , use_data = use_data )
1011+ plot_theory (base_data , err , view = errview , label = errstr , use_data = use_data )
9841012 plt .xscale ('log' if view == 'log' and not opts ['is2d' ] else 'linear' )
9851013 plt .title ("max %s = %.3g" % (errstr , abs (err ).max ()))
986- plt .ylabel (errstr )
1014+ if opts ['is2d' ]:
1015+ plt .gca ().tick_params (labelleft = False )
1016+ plt .gca ().set_yticks ([])
1017+ plt .ylabel ('' )
1018+ else :
1019+ plt .ylabel (errstr )
1020+
9871021 #cbar_title = errstr if errview=="linear" else "log "+errstr
988- #if is2D:
989- # h = plt.colorbar()
990- # h.ax.set_title(cbar_title)
991- fig = plt .gcf ()
992- extra_title = ' ' + opts ['title' ] if opts ['title' ] else ''
993- fig .suptitle (":" .join (opts ['name' ]) + extra_title )
1022+ # if is2D:
1023+ # h = plt.colorbar()
1024+ # h.ax.set_title(cbar_title)
1025+ # fig = plt.gcf()
1026+ # extra_title = ' '+opts['title'] if opts['title'] else ''
1027+ # fig.suptitle(":".join(opts['name']) + extra_title)
9941028
9951029 if have_base and have_comp and opts ['show_hist' ]:
9961030 plt .figure ()
0 commit comments