22
33import click
44import torch
5+ from torch .utils .data import DataLoader
6+
57from pathlib import Path
6- from wirecell .util .cli import context , log , jsonnet_loader
8+ from wirecell .util .cli import context , log , jsonnet_loader , anyconfig_file
79from wirecell .util .paths import unglob , listify
810
911
@@ -17,16 +19,23 @@ def cli(ctx):
1719 '''
1820 pass
1921
22+ @cli .command ('dump-config' )
23+ @anyconfig_file ("wirecelldnn" )
24+ @click .pass_context
25+ def dump_config (ctx , config ):
26+ print (config )
27+
28+ return
29+
30+
31+ train_defaults = dict (epochs = 1 , batch = 1 , device = 'cpu' , name = 'dnnroi' , train_ratio = 0.8 )
2032@cli .command ('train' )
21- @click .option ("-c" , "--config" ,
22- type = click .Path (),
23- help = "Set configuration file" )
24- @click .option ("-e" , "--epochs" , default = 1 ,
33+ @click .option ("-e" , "--epochs" , default = None , type = int ,
2534 help = "Number of epochs over which to train. "
2635 "This is a relative count if the training starts with a -l/--load'ed state." )
27- @click .option ("-b" , "--batch" , default = 1 ,
36+ @click .option ("-b" , "--batch" , default = None , type = int ,
2837 help = "Batch size" )
29- @click .option ("-d" , "--device" , default = 'cpu' ,
38+ @click .option ("-d" , "--device" , default = None , type = str ,
3039 help = "The compute device" )
3140@click .option ("--cache/--no-cache" , is_flag = True , default = False ,
3241 help = "Cache data in memory" )
@@ -38,33 +47,40 @@ def cli(ctx):
3847@click .option ("--checkpoint-modulus" , default = 1 ,
3948 help = "Checkpoint modulus. "
4049 "If checkpoint path is given, the training is checkpointed ever this many epochs.." )
41- @click .option ("-n " , "--name " , default = 'dnnroi' ,
42- help = "The application name (def=dnnroi) " )
50+ @click .option ("-a " , "--app " , default = None , type = str ,
51+ help = "The application name" )
4352@click .option ("-l" , "--load" , default = None ,
4453 help = "File name providing the initial model state dict (def=None - construct fresh)" )
4554@click .option ("-s" , "--save" , default = None ,
4655 help = "File name to save model state dict after training (def=None - results not saved)" )
47- @click .option ("--eval-files" , multiple = True , type = str , # fixme: remove this in favor of a single file set and a train/eval partitioning
48- help = "File path or globs as comma separated list to use for evaluation dataset" )
49- @click .argument ("train_files" , nargs = - 1 )
56+ @click .option ("--train-ratio" , default = None , type = float ,
57+ help = "Fraction of samples to use for training (default=1.0, no evaluation loss calculated)" )
58+ @anyconfig_file ("wirecelldnn" , section = 'train' , defaults = train_defaults )
59+ @click .argument ("files" , nargs = - 1 )
5060@click .pass_context
5161def train (ctx , config , epochs , batch , device , cache , debug_torch ,
5262 checkpoint_save , checkpoint_modulus ,
53- name , load , save , eval_files , train_files ):
63+ app , load , save , train_ratio , files ):
5464 '''
5565 Train a model.
5666 '''
57- if not train_files :
67+
68+ if not files : # args not processed by anyconfig_files
69+ try :
70+ files = config ['train' ]['files' ]
71+ except KeyError :
72+ files = None
73+ if not files :
5874 raise click .BadArgumentUsage ("no training files given" )
59- train_files = unglob (listify (train_files ))
60- log .info (f'training files: { train_files } ' )
75+ files = unglob (listify (files ))
76+ log .info (f'training files: { files } ' )
6177
6278 if device == 'gpu' : device = 'cuda'
63- log .info (f'using device { device } ' )
6479
6580 if debug_torch :
6681 torch .autograd .set_detect_anomaly (True )
6782
83+ name = app
6884 app = getattr (dnn .apps , name )
6985
7086 net = app .Network ()
@@ -78,24 +94,17 @@ def train(ctx, config, epochs, batch, device, cache, debug_torch,
7894 raise click .FileError (load , 'warning: DNN module load file does not exist' )
7995 history = dnn .io .load_checkpoint (load , net , opt )
8096
81- train_ds = app .Dataset (train_files , cache = cache )
82- ntrain = len (train_ds )
83- if ntrain == 0 :
84- raise click .BadArgumentUsage (f'no samples from { len (train_files )} files' )
85-
86- from torch .utils .data import DataLoader
87- train_dl = DataLoader (train_ds , batch_size = batch , shuffle = True , pin_memory = True )
88-
89- neval = 0
90- eval_dl = None
91- if eval_files :
92- eval_files = unglob (listify (eval_files , delim = "," ))
93- log .info (f'eval files: { eval_files } ' )
94- eval_ds = app .Dataset (eval_files , cache = cache )
95- neval = len (eval_ds )
96- eval_dl = DataLoader (train_ds , batch_size = batch , shuffle = False , pin_memory = True )
97- else :
98- log .info ("no eval files" )
97+ ds = app .Dataset (files , cache = cache , config = config .get ("dataset" , None ))
98+ if len (ds ) == 0 :
99+ raise click .BadArgumentUsage (f'no samples from { len (files )} files' )
100+
101+ tbatch ,ebatch = batch ,1
102+
103+ dses = dnn .data .train_eval_split (ds , train_ratio )
104+ dles = [DataLoader (one , batch_size = bb , shuffle = True , pin_memory = True ) for one ,bb in zip (dses , [tbatch ,ebatch ])]
105+
106+ ntrain = len (dses [0 ])
107+ neval = len (dses [1 ])
99108
100109 # History
101110 run_history = history .get ("runs" , dict ())
@@ -104,9 +113,8 @@ def train(ctx, config, epochs, batch, device, cache, debug_torch,
104113 this_run_number = max (run_history .keys ()) + 1
105114 this_run = dict (
106115 run = this_run_number ,
107- train_files = train_files ,
116+ data_files = files ,
108117 ntrain = ntrain ,
109- eval_files = eval_files or [],
110118 neval = neval ,
111119 nepochs = epochs ,
112120 batch = batch ,
@@ -128,13 +136,17 @@ def saveit(path):
128136 dnn .io .save_checkpoint (path , net , opt , runs = run_history , epochs = epoch_history )
129137
130138 for this_epoch_number in range (first_epoch_number , first_epoch_number + epochs ):
131- train_losses = trainer .epoch (train_dl )
132- train_loss = sum (train_losses )/ ntrain
133139
134- eval_losses = []
140+ train_loss = 0
141+ train_losses = []
142+ if ntrain :
143+ train_losses = trainer .epoch (dles [0 ])
144+ train_loss = sum (train_losses )/ ntrain
145+
135146 eval_loss = 0
136- if eval_dl :
137- eval_losses = trainer .evaluate (eval_dl )
147+ eval_losses = []
148+ if neval :
149+ eval_losses = trainer .evaluate (dles [1 ])
138150 eval_loss = sum (eval_losses ) / neval
139151
140152 this_epoch = dict (
@@ -146,7 +158,7 @@ def saveit(path):
146158 eval_loss = eval_loss )
147159 epoch_history [this_epoch_number ] = this_epoch
148160
149- log .info (f'run: { this_run_number } epoch: { this_epoch_number } loss: { train_loss } eval: { eval_loss } ' )
161+ log .info (f'run: { this_run_number } epoch: { this_epoch_number } loss: { train_loss :.4e } [b= { tbatch } ,n= { ntrain } ] eval: { eval_loss :.4e } [b= { ebatch } ,n= { neval } ] ' )
150162
151163 if checkpoint_save :
152164 if this_epoch_number % checkpoint_modulus == 0 :
0 commit comments