Skip to content

Commit a6d6af0

Browse files
Automatically compute train_iters when train_epochs is specified. (#1283)
* preliminary epoch setting * first working iteration * train_epochs_special_case * handle flags * fix bugs * working single path case * working multi-path without eval * remove unused files * additional checks * remove print statement * apply precommit * add lr_decay_fraction * spelling --------- Co-authored-by: Quentin Anthony <[email protected]>
1 parent b0b490d commit a6d6af0

File tree

8 files changed

+247
-75
lines changed

8 files changed

+247
-75
lines changed

configs/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ These can be set to any integer between `0` and `num_gpus`, and `num_gpus` must
124124
# this should provide some speedup but takes a while to build, set to true if desired
125125
"scaled_upper_triang_masked_softmax_fusion": false,
126126
"train_iters": 320000,
127+
# alternatively, use train_epochs to automatically determine the number of training iterations
128+
#"train_epochs": 1,
127129
```
128130
An example of some basic settings used to configure your model's architecture and number of training steps.
129131

configs/neox_arguments.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@ LR Scheduler Arguments
1414
Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'.
1515

1616

17-
1817
- **lr_decay_iters**: int
1918

2019
Default = None
2120

22-
Number of iterations to decay learning rate over, If None defaults to --train-iters
21+
Number of iterations to decay learning rate over. If None, defaults to
22+
--train-iters or the equivalent inferred value from train_epochs.
23+
24+
- **lr_decay_fraction**: float
2325

26+
Default = None
2427

28+
Effective fraction of training over which to decay lr. Overrides lr_decay_iters.
29+
Useful when specifying train_epochs.
2530

2631
- **min_lr**: float
2732

@@ -1928,6 +1933,15 @@ Training Arguments
19281933
19291934
19301935
1936+
- **train_epochs**: int
1937+
1938+
Default = None
1939+
1940+
Number of epochs to run for training. Do not specify both train_epochs and train_iters.
1941+
Not currently compatible with data reweighing, pairwise datasets, and packing other than 'packed'
1942+
1943+
1944+
19311945
- **eval_iters**: int
19321946
19331947
Default = 100

megatron/data/data_utils.py

Lines changed: 118 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
import numpy as np
1818
from typing import List, Tuple
19-
from itertools import zip_longest
19+
from itertools import zip_longest, cycle
2020
from functools import partial
2121

2222
from megatron import mpu, print_rank_0
@@ -62,6 +62,7 @@ def build_the_dataset(
6262
dataset_impl,
6363
allow_chopped,
6464
num_samples,
65+
num_epochs,
6566
seq_length,
6667
seed,
6768
skip_warmup,
@@ -141,6 +142,7 @@ def build_the_dataset(
141142
documents,
142143
indexed_dataset,
143144
num_samples,
145+
num_epochs,
144146
seq_length,
145147
seed,
146148
pack_impl=pack_impl,
@@ -179,6 +181,7 @@ def build_train_valid_test_datasets(
179181
allow_chopped,
180182
splits_string,
181183
train_valid_test_num_samples,
184+
train_valid_test_epochs,
182185
seq_length,
183186
seed,
184187
skip_warmup,
@@ -219,6 +222,7 @@ def build_dataset(index, name):
219222
documents,
220223
indexed_dataset,
221224
train_valid_test_num_samples[index],
225+
train_valid_test_epochs[index],
222226
seq_length,
223227
seed,
224228
pack_impl=pack_impl,
@@ -268,12 +272,15 @@ def get_normalized_weights_and_num_samples(
268272
weight_sum = sum(weights)
269273
assert weight_sum > 0.0
270274
weights = [weight / weight_sum for weight in weights]
271-
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
272-
# not uniformly distribute the number of samples, we still have
273-
# samples left to feed to the network.
274-
weighted_num_samples = []
275-
for weight in weights:
276-
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
275+
if num_samples is not None:
276+
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
277+
# not uniformly distribute the number of samples, we still have
278+
# samples left to feed to the network.
279+
weighted_num_samples = []
280+
for weight in weights:
281+
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
282+
else:
283+
weighted_num_samples = [None for _ in weights]
277284
return weights, weighted_num_samples
278285

279286

@@ -282,9 +289,9 @@ def build_weighted_datasets(
282289
train_num_samples,
283290
valid_num_samples,
284291
test_num_samples,
285-
train_weights,
286-
valid_weights,
287-
test_weights,
292+
train_epochs,
293+
valid_epochs,
294+
test_epochs,
288295
build_index_mappings=True,
289296
):
290297
# build individual datasets
@@ -367,6 +374,7 @@ def build_weighted_datasets(
367374
pack_impl=neox_args.pack_impl,
368375
allow_chopped=neox_args.allow_chopped,
369376
num_samples=train_num_samples[i],
377+
num_epochs=train_epochs,
370378
seq_length=neox_args.seq_length,
371379
seed=neox_args.seed,
372380
skip_warmup=(not neox_args.mmap_warmup),
@@ -391,6 +399,7 @@ def build_weighted_datasets(
391399
pack_impl=neox_args.pack_impl,
392400
allow_chopped=neox_args.allow_chopped,
393401
num_samples=valid_num_samples[i],
402+
num_epochs=valid_epochs,
394403
seq_length=neox_args.seq_length,
395404
seed=neox_args.seed,
396405
skip_warmup=(not neox_args.mmap_warmup),
@@ -415,6 +424,7 @@ def build_weighted_datasets(
415424
pack_impl=neox_args.pack_impl,
416425
allow_chopped=neox_args.allow_chopped,
417426
num_samples=test_num_samples[i],
427+
num_epochs=test_epochs,
418428
seq_length=neox_args.seq_length,
419429
seed=neox_args.seed,
420430
skip_warmup=(not neox_args.mmap_warmup),
@@ -469,9 +479,44 @@ def weights_by_num_docs(l: list, alpha=0.3):
469479
return weights
470480

471481

472-
def build_train_valid_test_data_iterators(neox_args):
482+
def validate_train_epochs(neox_args):
483+
"""Check for unsupported neox_args when using train_epochs instead of train_iters"""
484+
if neox_args.train_epochs is None:
485+
return
486+
487+
if neox_args.train_epochs and neox_args.train_iters:
488+
raise ValueError(
489+
"Cannot specify both train epochs and train iters simultaneously"
490+
)
491+
492+
if neox_args.pack_impl != "packed":
493+
raise ValueError(
494+
"Packing implementations other than 'packed' are currently unsupported with train_epochs"
495+
)
496+
497+
if neox_args.weight_by_num_documents:
498+
raise ValueError(
499+
"Weighting by number of documents is currently unsupported with train_epochs"
500+
)
501+
502+
if neox_args.train_data_weights and (
503+
not all(weight == 1.0 for weight in neox_args.train_data_weights)
504+
):
505+
raise ValueError(
506+
"train_data_weights != None is currently unsupported with train_epochs"
507+
)
508+
509+
if neox_args.dataset_impl != "gpt2":
510+
raise ValueError(
511+
"non gpt2 datasets are not currently unsupported with train_epochs"
512+
)
513+
514+
515+
def build_train_valid_test_data_loaders(neox_args):
473516
"""XXX"""
474517

518+
validate_train_epochs(neox_args)
519+
475520
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
476521

477522
print_rank_0("> building train, validation, and test datasets ...")
@@ -489,14 +534,21 @@ def build_train_valid_test_data_iterators(neox_args):
489534
# Data loader only on rank 0 of each model parallel group.
490535
if mpu.get_model_parallel_rank() == 0 and pipe_load:
491536
# Number of train/valid/test samples.
492-
train_iters = neox_args.train_iters
493-
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
494-
test_iters = neox_args.eval_iters
495-
train_val_test_num_samples = [
496-
train_iters * neox_args.train_batch_size,
497-
eval_iters * neox_args.train_batch_size,
498-
test_iters * neox_args.train_batch_size,
499-
]
537+
if neox_args.train_iters is not None:
538+
train_iters = neox_args.train_iters
539+
eval_iters = (
540+
train_iters // neox_args.eval_interval + 1
541+
) * neox_args.eval_iters
542+
test_iters = neox_args.eval_iters
543+
train_val_test_num_samples = [
544+
train_iters * neox_args.train_batch_size,
545+
eval_iters * neox_args.train_batch_size,
546+
test_iters * neox_args.train_batch_size,
547+
]
548+
train_val_test_epochs = [None, None, None]
549+
elif neox_args.train_epochs is not None:
550+
train_val_test_num_samples = [None, None, None]
551+
train_val_test_epochs = [1, 1, 1]
500552

501553
if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths):
502554
# when individual train / valid / test data paths are provided
@@ -517,9 +569,9 @@ def build_train_valid_test_data_iterators(neox_args):
517569
train_num_samples,
518570
valid_num_samples,
519571
test_num_samples,
520-
train_weights,
521-
valid_weights,
522-
test_weights,
572+
train_val_test_epochs[0],
573+
train_val_test_epochs[1],
574+
train_val_test_epochs[2],
523575
build_index_mappings=not neox_args.weight_by_num_documents,
524576
)
525577

@@ -565,9 +617,9 @@ def build_train_valid_test_data_iterators(neox_args):
565617
train_num_samples,
566618
valid_num_samples,
567619
test_num_samples,
568-
train_weights,
569-
valid_weights,
570-
test_weights,
620+
train_val_test_epochs[0],
621+
train_val_test_epochs[1],
622+
train_val_test_epochs[2],
571623
)
572624

573625
if train_datasets:
@@ -585,6 +637,7 @@ def build_train_valid_test_data_iterators(neox_args):
585637
data_impl=neox_args.data_impl,
586638
splits_string=neox_args.split,
587639
train_valid_test_num_samples=train_val_test_num_samples,
640+
train_valid_test_epochs=train_val_test_epochs,
588641
seq_length=neox_args.seq_length,
589642
seed=neox_args.seed,
590643
skip_warmup=(not neox_args.mmap_warmup),
@@ -598,9 +651,15 @@ def build_train_valid_test_data_iterators(neox_args):
598651
test_dataloader = make_data_loader(test_ds, neox_args=neox_args)
599652

600653
# Flags to know if we need to do training/validation/testing.
601-
do_train = train_dataloader is not None and neox_args.train_iters > 0
602-
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
603-
do_test = test_dataloader is not None and neox_args.eval_iters > 0
654+
if neox_args.train_epochs:
655+
do_train = train_dataloader is not None
656+
do_valid = valid_dataloader is not None
657+
do_test = test_dataloader is not None
658+
else:
659+
do_train = train_dataloader is not None and neox_args.train_iters > 0
660+
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
661+
do_test = test_dataloader is not None and neox_args.eval_iters > 0
662+
604663
# Need to broadcast num_tokens and num_type_tokens.
605664
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
606665
else:
@@ -620,6 +679,19 @@ def build_train_valid_test_data_iterators(neox_args):
620679
neox_args.do_train = flags[0].item()
621680
neox_args.do_valid = flags[1].item()
622681
neox_args.do_test = flags[2].item()
682+
data_loaders = {
683+
"train": train_dataloader,
684+
"valid": valid_dataloader,
685+
"test": test_dataloader,
686+
}
687+
return data_loaders
688+
689+
690+
def shift_and_wrap_data_loaders(neox_args, data_loaders, loop=True):
691+
"""Shift start iteration and wrap data_loaders in iterators"""
692+
train_dataloader = data_loaders["train"]
693+
valid_dataloader = data_loaders["valid"]
694+
test_dataloader = data_loaders["test"]
623695

624696
# Shift the start iterations.
625697
if train_dataloader is not None:
@@ -645,19 +717,34 @@ def build_train_valid_test_data_iterators(neox_args):
645717
)
646718
)
647719

720+
def loop_iterator(data_loader):
721+
while True:
722+
for x in data_loader:
723+
yield x
724+
data_loader.start_iter = 0
725+
648726
# Build iterators.
649727
if train_dataloader is not None:
650-
train_data_iterator = iter(train_dataloader)
728+
if loop:
729+
train_data_iterator = cycle(train_dataloader)
730+
else:
731+
train_data_iterator = iter(train_dataloader)
651732
else:
652733
train_data_iterator = None
653734

654735
if valid_dataloader is not None:
655-
valid_data_iterator = iter(valid_dataloader)
736+
if loop:
737+
valid_data_iterator = cycle(valid_dataloader)
738+
else:
739+
valid_data_iterator = iter(valid_dataloader)
656740
else:
657741
valid_data_iterator = None
658742

659743
if test_dataloader is not None:
660-
test_data_iterator = iter(test_dataloader)
744+
if loop:
745+
test_data_iterator = cycle(test_dataloader)
746+
else:
747+
test_data_iterator = iter(test_dataloader)
661748
else:
662749
test_data_iterator = None
663750

megatron/data/gpt2_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
documents,
3535
indexed_dataset,
3636
num_samples,
37+
num_epochs,
3738
seq_length,
3839
seed,
3940
pack_impl="packed",
@@ -70,6 +71,7 @@ def __init__(
7071
self.indexed_dataset.sizes,
7172
self.label_dataset,
7273
num_samples,
74+
num_epochs,
7375
seq_length,
7476
seed,
7577
self.pack_impl,
@@ -203,6 +205,7 @@ def _build_index_mappings(
203205
sizes,
204206
label_dataset,
205207
num_samples,
208+
num_epochs,
206209
seq_length,
207210
seed,
208211
packing_impl,
@@ -217,7 +220,8 @@ def _build_index_mappings(
217220
"""
218221
# Number of tokens in each epoch and number of required epochs.
219222
tokens_per_epoch = _num_tokens(documents, sizes)
220-
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
223+
if not num_epochs:
224+
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
221225
# rng state
222226
np_rng = np.random.RandomState(seed=seed)
223227

megatron/data/samplers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
100100
specifying True will result in the following samples for each gpu:
101101
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
102102
specifying False will result in the following samples:
103-
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
103+
GPU0: [0,1,2,3] GPU1: [4,5,6,7]
104+
105+
The `infinite_loop` parameter allows the sampler to yield batches indefinitely,
106+
restarting from the beginning of the dataset when all samples have been iterated over.
107+
"""
104108

105109
def __init__(
106110
self,

0 commit comments

Comments
 (0)