diff --git a/trainer/dataset.py b/trainer/dataset.py index a6485eea3a9..34e90f14345 100644 --- a/trainer/dataset.py +++ b/trainer/dataset.py @@ -10,7 +10,6 @@ from PIL import Image import numpy as np from torch.utils.data import Dataset, ConcatDataset, Subset -from torch._utils import _accumulate import torchvision.transforms as transforms def contrast_grey(img): @@ -27,6 +26,19 @@ def adjust_contrast_grey(img, target = 0.4): img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8) return img +def _accumulate(iterable, fn=lambda x, y: x + y): + "Return running totals" + # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + yield total + for element in it: + total = fn(total, element) + yield total class Batch_Balanced_Dataset(object): @@ -98,12 +110,12 @@ def get_batch(self): for i, data_loader_iter in enumerate(self.dataloader_iter_list): try: - image, text = data_loader_iter.next() + image, text = next(data_loader_iter) balanced_batch_images.append(image) balanced_batch_texts += text except StopIteration: self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) - image, text = self.dataloader_iter_list[i].next() + image, text = next(self.dataloader_iter_list[i]) balanced_batch_images.append(image) balanced_batch_texts += text except ValueError: diff --git a/trainer/trainer.py b/trainer/trainer.py new file mode 100644 index 00000000000..b3c89a6441c --- /dev/null +++ b/trainer/trainer.py @@ -0,0 +1,38 @@ +import os + +import pandas as pd +import torch.backends.cudnn as cudnn +import yaml +from train import train +from utils import AttrDict + +cudnn.benchmark = True +cudnn.deterministic = False + + +def get_config(file_path): + with open(file_path, 'r', encoding="utf8") as stream: + opt = yaml.safe_load(stream) + opt = AttrDict(opt) + if opt.lang_char == 'None': + characters = '' + for data in opt['select_data'].split('-'): + csv_path = os.path.join(opt['train_data'], data, 'labels.csv') + df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', + usecols=['filename', 'words'], keep_default_na=False) + all_char = ''.join(df['words']) + characters += ''.join(set(all_char)) + characters = sorted(set(characters)) + opt.character = ''.join(characters) + else: + opt.character = opt.number + opt.symbol + opt.lang_char + os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) + return opt + + +if __name__ == "__main__": + opt = get_config("config_files/en_filtered_config.yaml") + for item in opt.items(): + print(item) + print("Training started...") + train(opt, amp=False)