17
17
18
18
import numpy as np
19
19
20
- try :
21
- from apex .parallel import DistributedDataParallel as DDP
22
- from apex .fp16_utils import *
23
- from apex import amp , optimizers
24
- from apex .multi_tensor_apply import multi_tensor_applier
25
- except ImportError :
26
- raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to run this example." )
20
+ from torch .nn .parallel import DistributedDataParallel as DDP
21
+
22
+ def to_python_float (scalar_tensor : torch .Tensor ):
23
+ return scalar_tensor .float ().item ()
27
24
28
25
def fast_collate (batch , memory_format ):
29
26
@@ -152,24 +149,9 @@ def main():
152
149
momentum = args .momentum ,
153
150
weight_decay = args .weight_decay )
154
151
155
- # Initialize Amp. Amp accepts either values or strings for the optional override arguments,
156
- # for convenient interoperation with argparse.
157
- model , optimizer = amp .initialize (model , optimizer ,
158
- opt_level = args .opt_level ,
159
- keep_batchnorm_fp32 = args .keep_batchnorm_fp32 ,
160
- loss_scale = args .loss_scale
161
- )
162
-
163
- # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
164
- # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
165
- # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
166
- # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
167
152
if args .distributed :
168
- # By default, apex.parallel.DistributedDataParallel overlaps communication with
169
- # computation in the backward pass.
170
- # model = DDP(model)
171
- # delay_allreduce delays all communication to the end of the backward pass.
172
- model = DDP (model , delay_allreduce = True )
153
+ model = DDP (model )
154
+ scaler = torch .amp .GradScaler ("cuda" )
173
155
174
156
# define loss function (criterion) and optimizer
175
157
criterion = nn .CrossEntropyLoss ().cuda ()
@@ -245,7 +227,7 @@ def resume():
245
227
train_sampler .set_epoch (epoch )
246
228
247
229
# train for one epoch
248
- train (train_loader , model , criterion , optimizer , epoch )
230
+ train (train_loader , model , criterion , optimizer , scaler , epoch )
249
231
250
232
# evaluate on validation set
251
233
prec1 = validate (val_loader , model , criterion )
@@ -317,7 +299,7 @@ def next(self):
317
299
return input , target
318
300
319
301
320
- def train (train_loader , model , criterion , optimizer , epoch ):
302
+ def train (train_loader , model , criterion , optimizer , scaler , epoch ):
321
303
batch_time = AverageMeter ()
322
304
losses = AverageMeter ()
323
305
top1 = AverageMeter ()
@@ -341,24 +323,25 @@ def train(train_loader, model, criterion, optimizer, epoch):
341
323
adjust_learning_rate (optimizer , epoch , i , len (train_loader ))
342
324
343
325
# compute output
344
- if args .prof >= 0 : torch .cuda .nvtx .range_push ("forward" )
345
- output = model (input )
346
- if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
347
- loss = criterion (output , target )
326
+ with torch .autocast (device_type = "cuda" ):
327
+ if args .prof >= 0 : torch .cuda .nvtx .range_push ("forward" )
328
+ output = model (input )
329
+ if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
330
+ loss = criterion (output , target )
348
331
349
332
# compute gradient and do SGD step
350
333
optimizer .zero_grad ()
351
334
352
335
if args .prof >= 0 : torch .cuda .nvtx .range_push ("backward" )
353
- with amp .scale_loss (loss , optimizer ) as scaled_loss :
354
- scaled_loss .backward ()
336
+ scaler .scale (loss ).backward ()
355
337
if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
356
338
357
339
# for param in model.parameters():
358
340
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
359
341
360
342
if args .prof >= 0 : torch .cuda .nvtx .range_push ("optimizer.step()" )
361
- optimizer .step ()
343
+ scaler .step (optimizer )
344
+ scaler .update ()
362
345
if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
363
346
364
347
if i % args .print_freq == 0 :
0 commit comments