Skip to content

Commit c54dbd3

Browse files
authored
Use Pytorch's AMP for amp test (#1929)
1 parent 26bba57 commit c54dbd3

File tree

1 file changed

+16
-33
lines changed

1 file changed

+16
-33
lines changed

examples/imagenet/main_amp.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717

1818
import numpy as np
1919

20-
try:
21-
from apex.parallel import DistributedDataParallel as DDP
22-
from apex.fp16_utils import *
23-
from apex import amp, optimizers
24-
from apex.multi_tensor_apply import multi_tensor_applier
25-
except ImportError:
26-
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
20+
from torch.nn.parallel import DistributedDataParallel as DDP
21+
22+
def to_python_float(scalar_tensor: torch.Tensor):
23+
return scalar_tensor.float().item()
2724

2825
def fast_collate(batch, memory_format):
2926

@@ -152,24 +149,9 @@ def main():
152149
momentum=args.momentum,
153150
weight_decay=args.weight_decay)
154151

155-
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
156-
# for convenient interoperation with argparse.
157-
model, optimizer = amp.initialize(model, optimizer,
158-
opt_level=args.opt_level,
159-
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
160-
loss_scale=args.loss_scale
161-
)
162-
163-
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
164-
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
165-
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
166-
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
167152
if args.distributed:
168-
# By default, apex.parallel.DistributedDataParallel overlaps communication with
169-
# computation in the backward pass.
170-
# model = DDP(model)
171-
# delay_allreduce delays all communication to the end of the backward pass.
172-
model = DDP(model, delay_allreduce=True)
153+
model = DDP(model)
154+
scaler = torch.amp.GradScaler("cuda")
173155

174156
# define loss function (criterion) and optimizer
175157
criterion = nn.CrossEntropyLoss().cuda()
@@ -245,7 +227,7 @@ def resume():
245227
train_sampler.set_epoch(epoch)
246228

247229
# train for one epoch
248-
train(train_loader, model, criterion, optimizer, epoch)
230+
train(train_loader, model, criterion, optimizer, scaler, epoch)
249231

250232
# evaluate on validation set
251233
prec1 = validate(val_loader, model, criterion)
@@ -317,7 +299,7 @@ def next(self):
317299
return input, target
318300

319301

320-
def train(train_loader, model, criterion, optimizer, epoch):
302+
def train(train_loader, model, criterion, optimizer, scaler, epoch):
321303
batch_time = AverageMeter()
322304
losses = AverageMeter()
323305
top1 = AverageMeter()
@@ -341,24 +323,25 @@ def train(train_loader, model, criterion, optimizer, epoch):
341323
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
342324

343325
# compute output
344-
if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
345-
output = model(input)
346-
if args.prof >= 0: torch.cuda.nvtx.range_pop()
347-
loss = criterion(output, target)
326+
with torch.autocast(device_type="cuda"):
327+
if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
328+
output = model(input)
329+
if args.prof >= 0: torch.cuda.nvtx.range_pop()
330+
loss = criterion(output, target)
348331

349332
# compute gradient and do SGD step
350333
optimizer.zero_grad()
351334

352335
if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
353-
with amp.scale_loss(loss, optimizer) as scaled_loss:
354-
scaled_loss.backward()
336+
scaler.scale(loss).backward()
355337
if args.prof >= 0: torch.cuda.nvtx.range_pop()
356338

357339
# for param in model.parameters():
358340
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
359341

360342
if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
361-
optimizer.step()
343+
scaler.step(optimizer)
344+
scaler.update()
362345
if args.prof >= 0: torch.cuda.nvtx.range_pop()
363346

364347
if i%args.print_freq == 0:

0 commit comments

Comments
 (0)