@@ -410,7 +410,6 @@ def reconstruct(
410410 num_iter : int = 1 ,
411411 store_iterations : bool = False ,
412412 store_initial_object : bool = True ,
413- store_error_per_step : bool = True ,
414413 reset : bool = True ,
415414 step_size : float = 0.5 ,
416415 zero_edges_real : bool = True ,
@@ -450,8 +449,6 @@ def reconstruct(
450449 if True, stores number of iterations
451450 store_initial_object: bool
452451 if True, keeps a copy of an initial object to reset without preprocessing
453- store_error_per_step: bool
454- if True, stores error for each tilt for each iteration and order of tilts
455452 reset: bool
456453 if True, resets object
457454 step_size: float
@@ -520,9 +517,8 @@ def reconstruct(
520517 else :
521518 self ._object = self ._object_initial
522519
523- if store_error_per_step :
524- self ._tilt_order = []
525- self ._error_per_step = []
520+ self ._tilt_order = []
521+ self ._error_per_step = []
526522
527523 for a0 in tqdmnd (
528524 num_iter ,
@@ -595,9 +591,8 @@ def f(args):
595591 else :
596592 raise ValueError (("distributed not implemented for gpu" ))
597593
598- if store_error_per_step :
599- self ._tilt_order .append (a1_shuffle )
600- self ._error_per_step .append (error_iteration_datacube )
594+ self ._tilt_order .append (a1_shuffle )
595+ self ._error_per_step .append (error_iteration_datacube )
601596
602597 error_iteration += error_iteration_datacube
603598
@@ -630,6 +625,22 @@ def f(args):
630625 if store_iterations :
631626 self .object_iterations .append (self ._object .copy ())
632627
628+ num_iter = len (self ._tilt_order ) // self ._num_datacubes
629+ iterations = np .repeat (np .arange (num_iter ), self ._num_datacubes )
630+ order = np .asarray (self ._tilt_order )
631+ ind = np .argsort (
632+ np .ravel_multi_index ((iterations , order ), (num_iter , self ._num_datacubes ))
633+ )
634+ error_per_step_sorted = (np .asarray (self ._error_per_step )[ind ]).reshape (
635+ (num_iter , self ._num_datacubes )
636+ )
637+
638+ tilts_order = self ._tilt_deg
639+ tilts_order = np .argsort (tilts_order )
640+
641+ error_per_step_sorted = error_per_step_sorted [:, tilts_order ]
642+ self .error_per_step_sorted = error_per_step_sorted
643+
633644 return self
634645
635646 def _reconstruct (
@@ -1199,6 +1210,9 @@ def _reshape_diffraction_patterns(
11991210 if datacube_number == 0 :
12001211 self ._make_diffraction_masks (q_max_inv_A = q_max_inv_A )
12011212
1213+ if normalize_scans :
1214+ datacube .data /= datacube .data .mean ()
1215+
12021216 diffraction_patterns_reshaped = self ._reshape_4D_array_to_2D (
12031217 data = datacube .data ,
12041218 qx0_fit = qx0_fit ,
@@ -1213,9 +1227,6 @@ def _reshape_diffraction_patterns(
12131227 mask_real_space .ravel ()
12141228 ]
12151229
1216- if normalize_scans :
1217- diffraction_patterns_reshaped /= diffraction_patterns_reshaped .mean ()
1218-
12191230 self ._diffraction_patterns_projected .append (diffraction_patterns_reshaped )
12201231
12211232 def _make_diffraction_masks (self , q_max_inv_A ):
@@ -2204,32 +2215,24 @@ def visualize(self, plot_convergence=True, figsize=(10, 10)):
22042215 return self
22052216
22062217 def show_error_per_iteration (self , ** kwargs ):
2207- num_iter = len (self ._tilt_order ) // self ._num_datacubes
2208- iterations = np .repeat (np .arange (num_iter ), self ._num_datacubes )
2209- order = np .asarray (self ._tilt_order )
2210- ind = np .argsort (
2211- np .ravel_multi_index ((iterations , order ), (num_iter , self ._num_datacubes ))
2212- )
2213- error = (np .asarray (self ._error_per_step )[ind ]).reshape (
2214- (num_iter , self ._num_datacubes )
2215- )
2216-
2217- tilts_order = self ._tilt_deg
2218- tilts_order = np .argsort (tilts_order )
2219-
2220- error = error [:, tilts_order ]
2221-
22222218 vmin = kwargs .pop ("vmin" , None )
22232219 vmax = kwargs .pop ("vmax" , None )
22242220
22252221 fig , ax = show (
2226- error , cmap = "magma" , returnfig = True , vmax = vmax , vmin = vmin , ** kwargs
2222+ self .error_per_step_sorted ,
2223+ cmap = "magma" ,
2224+ returnfig = True ,
2225+ vmax = vmax ,
2226+ vmin = vmin ,
2227+ ** kwargs ,
22272228 )
22282229
2230+ ax .tick_params (top = False , labeltop = False , bottom = True , labelbottom = True )
2231+
22292232 ax .set_title ("error" )
22302233 ax .set_ylabel ("iteration" )
22312234 ax .set_xlabel ("tilts (negative -> positive)" )
2232- ax .set_xticks ([])
2235+ # ax.set_xticks([])
22332236
22342237 return self
22352238
0 commit comments