diff --git a/TrainingCodes/dncnn_pytorch/data_generator.py b/TrainingCodes/dncnn_pytorch/data_generator.py index f5402e6f..c54f2127 100644 --- a/TrainingCodes/dncnn_pytorch/data_generator.py +++ b/TrainingCodes/dncnn_pytorch/data_generator.py @@ -91,7 +91,7 @@ def gen_patches(file_name): patches = [] for s in scales: h_scaled, w_scaled = int(h*s), int(w*s) - img_scaled = cv2.resize(img, (h_scaled, w_scaled), interpolation=cv2.INTER_CUBIC) + img_scaled = cv2.resize(img, (w_scaled, h_scaled), interpolation=cv2.INTER_CUBIC) # extract patches for i in range(0, h_scaled-patch_size+1, stride): for j in range(0, w_scaled-patch_size+1, stride): @@ -132,4 +132,4 @@ def datagenerator(data_dir='data/Train400', verbose=False): # if not os.path.exists(save_dir): # os.mkdir(save_dir) # np.save(save_dir+'clean_patches.npy', res) -# print('Done.') \ No newline at end of file +# print('Done.')