6
6
from mrt import conf , utils
7
7
8
8
import numpy as np
9
+ import argparse
10
+
11
+ parser = argparse .ArgumentParser (description = 'Mnist Traning' )
12
+ parser .add_argument ('--cpu' , default = False , action = 'store_true' ,
13
+ help = 'whether enable cpu (default use gpu)' )
14
+ parser .add_argument ('--gpu-id' , type = int , default = 0 ,
15
+ help = 'gpu device id' )
16
+ parser .add_argument ('--net' , type = str , default = '' ,
17
+ help = 'choose available networks, optional: lenet, mlp' )
18
+
19
+ args = parser .parse_args ()
9
20
10
21
def load_fname (version , suffix = None , with_ext = False ):
11
22
suffix = "." + suffix if suffix is not None else ""
12
- prefix = "{}/mnist_{}{}" .format (conf .MRT_MODEL_ROOT , version , suffix )
23
+ version = "_" + version if version is not None else ""
24
+ prefix = "{}/mnist{}{}" .format (conf .MRT_MODEL_ROOT , version , suffix )
13
25
return utils .extend_fname (prefix , with_ext )
14
26
15
27
def data_xform (data ):
@@ -25,10 +37,12 @@ def data_xform(data):
25
37
train_loader = mx .gluon .data .DataLoader (train_data , shuffle = True , batch_size = batch_size )
26
38
val_loader = mx .gluon .data .DataLoader (val_data , shuffle = False , batch_size = batch_size )
27
39
28
- version = ''
40
+ version = args .net
41
+ print ("Training {} Mnist" .format (version ))
29
42
30
43
# Set the gpu device id
31
- ctx = mx .gpu (0 )
44
+ ctx = mx .cpu () if args .cpu else mx .gpu (args .gpu_id )
45
+ print ("Using device: {}" .format (ctx ))
32
46
33
47
def train_mnist ():
34
48
# Select a fixed random seed for reproducibility
@@ -70,6 +84,8 @@ def train_mnist():
70
84
nn .Dense (64 , activation = 'relu' ),
71
85
nn .Dense (10 , activation = None ) # loss function includes softmax already, see below
72
86
)
87
+ else :
88
+ assert False
73
89
74
90
# Random initialize all the mnist model parameters
75
91
net .initialize (mx .init .Xavier (), ctx = ctx )
@@ -118,5 +134,4 @@ def train_mnist():
118
134
fout .write (sym .tojson ())
119
135
net .collect_params ().save (param_file )
120
136
121
- print ("Test mnist" , version )
122
137
train_mnist ()
0 commit comments