Skip to content

Commit 4799294

Browse files
committed
updates
1 parent 1c598fa commit 4799294

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

py4DSTEM/tomography/tomography.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ def reconstruct(
410410
num_iter: int = 1,
411411
store_iterations: bool = False,
412412
store_initial_object: bool = True,
413-
store_error_per_step: bool = True,
414413
reset: bool = True,
415414
step_size: float = 0.5,
416415
zero_edges_real: bool = True,
@@ -450,8 +449,6 @@ def reconstruct(
450449
if True, stores number of iterations
451450
store_initial_object: bool
452451
if True, keeps a copy of an initial object to reset without preprocessing
453-
store_error_per_step: bool
454-
if True, stores error for each tilt for each iteration and order of tilts
455452
reset: bool
456453
if True, resets object
457454
step_size: float
@@ -520,9 +517,8 @@ def reconstruct(
520517
else:
521518
self._object = self._object_initial
522519

523-
if store_error_per_step:
524-
self._tilt_order = []
525-
self._error_per_step = []
520+
self._tilt_order = []
521+
self._error_per_step = []
526522

527523
for a0 in tqdmnd(
528524
num_iter,
@@ -595,9 +591,8 @@ def f(args):
595591
else:
596592
raise ValueError(("distributed not implemented for gpu"))
597593

598-
if store_error_per_step:
599-
self._tilt_order.append(a1_shuffle)
600-
self._error_per_step.append(error_iteration_datacube)
594+
self._tilt_order.append(a1_shuffle)
595+
self._error_per_step.append(error_iteration_datacube)
601596

602597
error_iteration += error_iteration_datacube
603598

@@ -630,6 +625,22 @@ def f(args):
630625
if store_iterations:
631626
self.object_iterations.append(self._object.copy())
632627

628+
num_iter = len(self._tilt_order) // self._num_datacubes
629+
iterations = np.repeat(np.arange(num_iter), self._num_datacubes)
630+
order = np.asarray(self._tilt_order)
631+
ind = np.argsort(
632+
np.ravel_multi_index((iterations, order), (num_iter, self._num_datacubes))
633+
)
634+
error_per_step_sorted = (np.asarray(self._error_per_step)[ind]).reshape(
635+
(num_iter, self._num_datacubes)
636+
)
637+
638+
tilts_order = self._tilt_deg
639+
tilts_order = np.argsort(tilts_order)
640+
641+
error_per_step_sorted = error_per_step_sorted[:, tilts_order]
642+
self.error_per_step_sorted = error_per_step_sorted
643+
633644
return self
634645

635646
def _reconstruct(
@@ -1199,6 +1210,9 @@ def _reshape_diffraction_patterns(
11991210
if datacube_number == 0:
12001211
self._make_diffraction_masks(q_max_inv_A=q_max_inv_A)
12011212

1213+
if normalize_scans:
1214+
datacube.data /= datacube.data.mean()
1215+
12021216
diffraction_patterns_reshaped = self._reshape_4D_array_to_2D(
12031217
data=datacube.data,
12041218
qx0_fit=qx0_fit,
@@ -1213,9 +1227,6 @@ def _reshape_diffraction_patterns(
12131227
mask_real_space.ravel()
12141228
]
12151229

1216-
if normalize_scans:
1217-
diffraction_patterns_reshaped /= diffraction_patterns_reshaped.mean()
1218-
12191230
self._diffraction_patterns_projected.append(diffraction_patterns_reshaped)
12201231

12211232
def _make_diffraction_masks(self, q_max_inv_A):
@@ -2204,32 +2215,24 @@ def visualize(self, plot_convergence=True, figsize=(10, 10)):
22042215
return self
22052216

22062217
def show_error_per_iteration(self, **kwargs):
2207-
num_iter = len(self._tilt_order) // self._num_datacubes
2208-
iterations = np.repeat(np.arange(num_iter), self._num_datacubes)
2209-
order = np.asarray(self._tilt_order)
2210-
ind = np.argsort(
2211-
np.ravel_multi_index((iterations, order), (num_iter, self._num_datacubes))
2212-
)
2213-
error = (np.asarray(self._error_per_step)[ind]).reshape(
2214-
(num_iter, self._num_datacubes)
2215-
)
2216-
2217-
tilts_order = self._tilt_deg
2218-
tilts_order = np.argsort(tilts_order)
2219-
2220-
error = error[:, tilts_order]
2221-
22222218
vmin = kwargs.pop("vmin", None)
22232219
vmax = kwargs.pop("vmax", None)
22242220

22252221
fig, ax = show(
2226-
error, cmap="magma", returnfig=True, vmax=vmax, vmin=vmin, **kwargs
2222+
self.error_per_step_sorted,
2223+
cmap="magma",
2224+
returnfig=True,
2225+
vmax=vmax,
2226+
vmin=vmin,
2227+
**kwargs,
22272228
)
22282229

2230+
ax.tick_params(top=False, labeltop=False, bottom=True, labelbottom=True)
2231+
22292232
ax.set_title("error")
22302233
ax.set_ylabel("iteration")
22312234
ax.set_xlabel("tilts (negative -> positive)")
2232-
ax.set_xticks([])
2235+
# ax.set_xticks([])
22332236

22342237
return self
22352238

0 commit comments

Comments
 (0)