-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
51 lines (41 loc) · 1.63 KB
/
train.py
File metadata and controls
51 lines (41 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/env/bin python3
# -*- coding: utf-8 -*-
import argparse
import numpy as np
import sys
import subprocess
import os
import yaml
import chainer
from chainer import cuda, optimizers, serializers
from chainer import training
from chainercv.links import PixelwiseSoftmaxClassifier
from erfnet.config_utils import *
chainer.cuda.set_max_workspace_size(1024 * 1024 * 1024)
os.environ["CHAINER_TYPE_CHECK"] = "0"
from collections import OrderedDict
yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
lambda loader, node: OrderedDict(loader.construct_pairs(node)))
from erfnet.models import erfnet_paper
def train_erfnet():
"""Training ERFNet."""
chainer.config.debug = True
config = parse_args()
train_data, test_data = load_dataset(config["dataset"])
train_iter, test_iter = create_iterator(train_data, test_data, config['iterator'])
model = get_model(config["model"])
class_weight = get_class_weight(config)
model = PixelwiseSoftmaxClassifier(model, class_weight=class_weight)
optimizer = create_optimizer(config['optimizer'], model)
devices = parse_devices(config['gpus'])
updater = create_updater(train_iter, optimizer, config['updater'], devices)
trainer = training.Trainer(updater, config['end_trigger'], out=config['results'])
trainer = create_extension(trainer, test_iter, model.predictor,
config['extension'], devices=devices)
trainer.run()
chainer.serializers.save_npz(os.path.join(config['results'], 'model.npz'),
model.predictor)
def main():
train_erfnet()
if __name__ == '__main__':
main()