Skip to content

Commit 3470799

Browse files
authored
Merge pull request #41 from CortexFoundation/wlt
update train mnist
2 parents bfb239d + 08dde7a commit 3470799

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

.github/workflows/ccpp.yml

Lines changed: 0 additions & 1 deletion
This file was deleted.

tests/mrt/train_mnist.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,22 @@
66
from mrt import conf, utils
77

88
import numpy as np
9+
import argparse
10+
11+
parser = argparse.ArgumentParser(description='Mnist Traning')
12+
parser.add_argument('--cpu', default=False, action='store_true',
13+
help='whether enable cpu (default use gpu)')
14+
parser.add_argument('--gpu-id', type=int, default=0,
15+
help='gpu device id')
16+
parser.add_argument('--net', type=str, default='',
17+
help='choose available networks, optional: lenet, mlp')
18+
19+
args = parser.parse_args()
920

1021
def load_fname(version, suffix=None, with_ext=False):
1122
suffix = "."+suffix if suffix is not None else ""
12-
prefix = "{}/mnist_{}{}".format(conf.MRT_MODEL_ROOT, version, suffix)
23+
version = "_"+version if version is not None else ""
24+
prefix = "{}/mnist{}{}".format(conf.MRT_MODEL_ROOT, version, suffix)
1325
return utils.extend_fname(prefix, with_ext)
1426

1527
def data_xform(data):
@@ -25,10 +37,12 @@ def data_xform(data):
2537
train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
2638
val_loader = mx.gluon.data.DataLoader(val_data, shuffle=False, batch_size=batch_size)
2739

28-
version = ''
40+
version = args.net
41+
print ("Training {} Mnist".format(version))
2942

3043
# Set the gpu device id
31-
ctx = mx.gpu(0)
44+
ctx = mx.cpu() if args.cpu else mx.gpu(args.gpu_id)
45+
print ("Using device: {}".format(ctx))
3246

3347
def train_mnist():
3448
# Select a fixed random seed for reproducibility
@@ -70,6 +84,8 @@ def train_mnist():
7084
nn.Dense(64, activation='relu'),
7185
nn.Dense(10, activation=None) # loss function includes softmax already, see below
7286
)
87+
else:
88+
assert False
7389

7490
# Random initialize all the mnist model parameters
7591
net.initialize(mx.init.Xavier(), ctx=ctx)
@@ -118,5 +134,4 @@ def train_mnist():
118134
fout.write(sym.tojson())
119135
net.collect_params().save(param_file)
120136

121-
print ("Test mnist", version)
122137
train_mnist()

0 commit comments

Comments
 (0)