Skip to content

Commit 954613a

Browse files
committed
Never passed torchcompile-mode arg through to compile call
1 parent 055a600 commit 954613a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,11 @@ def main():
648648
if args.resume:
649649
load_checkpoint(model_ema.module, args.resume, use_ema=True)
650650
if args.torchcompile:
651-
model_ema = torch.compile(model_ema, backend=args.torchcompile)
651+
model_ema = torch.compile(
652+
model_ema,
653+
backend=args.torchcompile,
654+
mode=args.torchcompile_mode,
655+
)
652656

653657
# setup distributed training
654658
if args.distributed:

0 commit comments

Comments
 (0)