16
16
import torch
17
17
import numpy as np
18
18
from typing import List , Tuple
19
- from itertools import zip_longest
19
+ from itertools import zip_longest , cycle
20
20
from functools import partial
21
21
22
22
from megatron import mpu , print_rank_0
@@ -62,6 +62,7 @@ def build_the_dataset(
62
62
dataset_impl ,
63
63
allow_chopped ,
64
64
num_samples ,
65
+ num_epochs ,
65
66
seq_length ,
66
67
seed ,
67
68
skip_warmup ,
@@ -141,6 +142,7 @@ def build_the_dataset(
141
142
documents ,
142
143
indexed_dataset ,
143
144
num_samples ,
145
+ num_epochs ,
144
146
seq_length ,
145
147
seed ,
146
148
pack_impl = pack_impl ,
@@ -179,6 +181,7 @@ def build_train_valid_test_datasets(
179
181
allow_chopped ,
180
182
splits_string ,
181
183
train_valid_test_num_samples ,
184
+ train_valid_test_epochs ,
182
185
seq_length ,
183
186
seed ,
184
187
skip_warmup ,
@@ -219,6 +222,7 @@ def build_dataset(index, name):
219
222
documents ,
220
223
indexed_dataset ,
221
224
train_valid_test_num_samples [index ],
225
+ train_valid_test_epochs [index ],
222
226
seq_length ,
223
227
seed ,
224
228
pack_impl = pack_impl ,
@@ -268,12 +272,15 @@ def get_normalized_weights_and_num_samples(
268
272
weight_sum = sum (weights )
269
273
assert weight_sum > 0.0
270
274
weights = [weight / weight_sum for weight in weights ]
271
- # Add 0.5% (the 1.005 factor) so in case the blending dataset does
272
- # not uniformly distribute the number of samples, we still have
273
- # samples left to feed to the network.
274
- weighted_num_samples = []
275
- for weight in weights :
276
- weighted_num_samples .append (int (math .ceil (num_samples * weight * 1.005 )))
275
+ if num_samples is not None :
276
+ # Add 0.5% (the 1.005 factor) so in case the blending dataset does
277
+ # not uniformly distribute the number of samples, we still have
278
+ # samples left to feed to the network.
279
+ weighted_num_samples = []
280
+ for weight in weights :
281
+ weighted_num_samples .append (int (math .ceil (num_samples * weight * 1.005 )))
282
+ else :
283
+ weighted_num_samples = [None for _ in weights ]
277
284
return weights , weighted_num_samples
278
285
279
286
@@ -282,9 +289,9 @@ def build_weighted_datasets(
282
289
train_num_samples ,
283
290
valid_num_samples ,
284
291
test_num_samples ,
285
- train_weights ,
286
- valid_weights ,
287
- test_weights ,
292
+ train_epochs ,
293
+ valid_epochs ,
294
+ test_epochs ,
288
295
build_index_mappings = True ,
289
296
):
290
297
# build individual datasets
@@ -367,6 +374,7 @@ def build_weighted_datasets(
367
374
pack_impl = neox_args .pack_impl ,
368
375
allow_chopped = neox_args .allow_chopped ,
369
376
num_samples = train_num_samples [i ],
377
+ num_epochs = train_epochs ,
370
378
seq_length = neox_args .seq_length ,
371
379
seed = neox_args .seed ,
372
380
skip_warmup = (not neox_args .mmap_warmup ),
@@ -391,6 +399,7 @@ def build_weighted_datasets(
391
399
pack_impl = neox_args .pack_impl ,
392
400
allow_chopped = neox_args .allow_chopped ,
393
401
num_samples = valid_num_samples [i ],
402
+ num_epochs = valid_epochs ,
394
403
seq_length = neox_args .seq_length ,
395
404
seed = neox_args .seed ,
396
405
skip_warmup = (not neox_args .mmap_warmup ),
@@ -415,6 +424,7 @@ def build_weighted_datasets(
415
424
pack_impl = neox_args .pack_impl ,
416
425
allow_chopped = neox_args .allow_chopped ,
417
426
num_samples = test_num_samples [i ],
427
+ num_epochs = test_epochs ,
418
428
seq_length = neox_args .seq_length ,
419
429
seed = neox_args .seed ,
420
430
skip_warmup = (not neox_args .mmap_warmup ),
@@ -469,9 +479,44 @@ def weights_by_num_docs(l: list, alpha=0.3):
469
479
return weights
470
480
471
481
472
- def build_train_valid_test_data_iterators (neox_args ):
482
+ def validate_train_epochs (neox_args ):
483
+ """Check for unsupported neox_args when using train_epochs instead of train_iters"""
484
+ if neox_args .train_epochs is None :
485
+ return
486
+
487
+ if neox_args .train_epochs and neox_args .train_iters :
488
+ raise ValueError (
489
+ "Cannot specify both train epochs and train iters simultaneously"
490
+ )
491
+
492
+ if neox_args .pack_impl != "packed" :
493
+ raise ValueError (
494
+ "Packing implementations other than 'packed' are currently unsupported with train_epochs"
495
+ )
496
+
497
+ if neox_args .weight_by_num_documents :
498
+ raise ValueError (
499
+ "Weighting by number of documents is currently unsupported with train_epochs"
500
+ )
501
+
502
+ if neox_args .train_data_weights and (
503
+ not all (weight == 1.0 for weight in neox_args .train_data_weights )
504
+ ):
505
+ raise ValueError (
506
+ "train_data_weights != None is currently unsupported with train_epochs"
507
+ )
508
+
509
+ if neox_args .dataset_impl != "gpt2" :
510
+ raise ValueError (
511
+ "non gpt2 datasets are not currently unsupported with train_epochs"
512
+ )
513
+
514
+
515
+ def build_train_valid_test_data_loaders (neox_args ):
473
516
"""XXX"""
474
517
518
+ validate_train_epochs (neox_args )
519
+
475
520
(train_dataloader , valid_dataloader , test_dataloader ) = (None , None , None )
476
521
477
522
print_rank_0 ("> building train, validation, and test datasets ..." )
@@ -489,14 +534,21 @@ def build_train_valid_test_data_iterators(neox_args):
489
534
# Data loader only on rank 0 of each model parallel group.
490
535
if mpu .get_model_parallel_rank () == 0 and pipe_load :
491
536
# Number of train/valid/test samples.
492
- train_iters = neox_args .train_iters
493
- eval_iters = (train_iters // neox_args .eval_interval + 1 ) * neox_args .eval_iters
494
- test_iters = neox_args .eval_iters
495
- train_val_test_num_samples = [
496
- train_iters * neox_args .train_batch_size ,
497
- eval_iters * neox_args .train_batch_size ,
498
- test_iters * neox_args .train_batch_size ,
499
- ]
537
+ if neox_args .train_iters is not None :
538
+ train_iters = neox_args .train_iters
539
+ eval_iters = (
540
+ train_iters // neox_args .eval_interval + 1
541
+ ) * neox_args .eval_iters
542
+ test_iters = neox_args .eval_iters
543
+ train_val_test_num_samples = [
544
+ train_iters * neox_args .train_batch_size ,
545
+ eval_iters * neox_args .train_batch_size ,
546
+ test_iters * neox_args .train_batch_size ,
547
+ ]
548
+ train_val_test_epochs = [None , None , None ]
549
+ elif neox_args .train_epochs is not None :
550
+ train_val_test_num_samples = [None , None , None ]
551
+ train_val_test_epochs = [1 , 1 , 1 ]
500
552
501
553
if (neox_args .train_data_paths ) or (neox_args .pos_train_data_paths ):
502
554
# when individual train / valid / test data paths are provided
@@ -517,9 +569,9 @@ def build_train_valid_test_data_iterators(neox_args):
517
569
train_num_samples ,
518
570
valid_num_samples ,
519
571
test_num_samples ,
520
- train_weights ,
521
- valid_weights ,
522
- test_weights ,
572
+ train_val_test_epochs [ 0 ] ,
573
+ train_val_test_epochs [ 1 ] ,
574
+ train_val_test_epochs [ 2 ] ,
523
575
build_index_mappings = not neox_args .weight_by_num_documents ,
524
576
)
525
577
@@ -565,9 +617,9 @@ def build_train_valid_test_data_iterators(neox_args):
565
617
train_num_samples ,
566
618
valid_num_samples ,
567
619
test_num_samples ,
568
- train_weights ,
569
- valid_weights ,
570
- test_weights ,
620
+ train_val_test_epochs [ 0 ] ,
621
+ train_val_test_epochs [ 1 ] ,
622
+ train_val_test_epochs [ 2 ] ,
571
623
)
572
624
573
625
if train_datasets :
@@ -585,6 +637,7 @@ def build_train_valid_test_data_iterators(neox_args):
585
637
data_impl = neox_args .data_impl ,
586
638
splits_string = neox_args .split ,
587
639
train_valid_test_num_samples = train_val_test_num_samples ,
640
+ train_valid_test_epochs = train_val_test_epochs ,
588
641
seq_length = neox_args .seq_length ,
589
642
seed = neox_args .seed ,
590
643
skip_warmup = (not neox_args .mmap_warmup ),
@@ -598,9 +651,15 @@ def build_train_valid_test_data_iterators(neox_args):
598
651
test_dataloader = make_data_loader (test_ds , neox_args = neox_args )
599
652
600
653
# Flags to know if we need to do training/validation/testing.
601
- do_train = train_dataloader is not None and neox_args .train_iters > 0
602
- do_valid = valid_dataloader is not None and neox_args .eval_iters > 0
603
- do_test = test_dataloader is not None and neox_args .eval_iters > 0
654
+ if neox_args .train_epochs :
655
+ do_train = train_dataloader is not None
656
+ do_valid = valid_dataloader is not None
657
+ do_test = test_dataloader is not None
658
+ else :
659
+ do_train = train_dataloader is not None and neox_args .train_iters > 0
660
+ do_valid = valid_dataloader is not None and neox_args .eval_iters > 0
661
+ do_test = test_dataloader is not None and neox_args .eval_iters > 0
662
+
604
663
# Need to broadcast num_tokens and num_type_tokens.
605
664
flags = torch .cuda .LongTensor ([int (do_train ), int (do_valid ), int (do_test )])
606
665
else :
@@ -620,6 +679,19 @@ def build_train_valid_test_data_iterators(neox_args):
620
679
neox_args .do_train = flags [0 ].item ()
621
680
neox_args .do_valid = flags [1 ].item ()
622
681
neox_args .do_test = flags [2 ].item ()
682
+ data_loaders = {
683
+ "train" : train_dataloader ,
684
+ "valid" : valid_dataloader ,
685
+ "test" : test_dataloader ,
686
+ }
687
+ return data_loaders
688
+
689
+
690
+ def shift_and_wrap_data_loaders (neox_args , data_loaders , loop = True ):
691
+ """Shift start iteration and wrap data_loaders in iterators"""
692
+ train_dataloader = data_loaders ["train" ]
693
+ valid_dataloader = data_loaders ["valid" ]
694
+ test_dataloader = data_loaders ["test" ]
623
695
624
696
# Shift the start iterations.
625
697
if train_dataloader is not None :
@@ -645,19 +717,34 @@ def build_train_valid_test_data_iterators(neox_args):
645
717
)
646
718
)
647
719
720
+ def loop_iterator (data_loader ):
721
+ while True :
722
+ for x in data_loader :
723
+ yield x
724
+ data_loader .start_iter = 0
725
+
648
726
# Build iterators.
649
727
if train_dataloader is not None :
650
- train_data_iterator = iter (train_dataloader )
728
+ if loop :
729
+ train_data_iterator = cycle (train_dataloader )
730
+ else :
731
+ train_data_iterator = iter (train_dataloader )
651
732
else :
652
733
train_data_iterator = None
653
734
654
735
if valid_dataloader is not None :
655
- valid_data_iterator = iter (valid_dataloader )
736
+ if loop :
737
+ valid_data_iterator = cycle (valid_dataloader )
738
+ else :
739
+ valid_data_iterator = iter (valid_dataloader )
656
740
else :
657
741
valid_data_iterator = None
658
742
659
743
if test_dataloader is not None :
660
- test_data_iterator = iter (test_dataloader )
744
+ if loop :
745
+ test_data_iterator = cycle (test_dataloader )
746
+ else :
747
+ test_data_iterator = iter (test_dataloader )
661
748
else :
662
749
test_data_iterator = None
663
750
0 commit comments