diff --git a/examples/run_evalharness_deepspeed.md b/examples/run_evalharness_deepspeed.md index 695d9d0aa..eee4d70e6 100644 --- a/examples/run_evalharness_deepspeed.md +++ b/examples/run_evalharness_deepspeed.md @@ -29,7 +29,7 @@ Also make sure `data` is not on one of the limited paritions like WORKSF. Then install datasets for the tasks: ``` python ./tasks/eval_harness/download.py --task_list -arc_challenge,arc_easy,boolq,copa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc +arc_challenge,arc_easy,boolq,copa,hellaswag,lambada_openai,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc ``` and make sure that `export HF_DATASETS_OFFLINE=1` diff --git a/examples/run_evalharness_deepspeed.slurm b/examples/run_evalharness_deepspeed.slurm index e58ed9608..82d2ff024 100644 --- a/examples/run_evalharness_deepspeed.slurm +++ b/examples/run_evalharness_deepspeed.slurm @@ -85,7 +85,7 @@ CMD="./tasks/eval_harness/evaluate.py \ --seq-length $SEQ_LEN \ --adaptive_seq_len \ --eval_fp32 \ - --task_list arc_challenge,arc_easy,boolq,copa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sst,webqs,wic,winogrande,wnli,wsc,triviaqa,sciq \ + --task_list arc_challenge,arc_easy,boolq,copa,hellaswag,lambada_openai,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sst,webqs,wic,winogrande,wnli,wsc,triviaqa,sciq \ $MEGATRON_REQUIRED_ARGS \ " diff --git a/examples/run_evalharness_tr11-176b-ml.slurm b/examples/run_evalharness_tr11-176b-ml.slurm index 6d4849461..273215c08 100644 --- a/examples/run_evalharness_tr11-176b-ml.slurm +++ b/examples/run_evalharness_tr11-176b-ml.slurm @@ -89,7 +89,7 @@ CMD="./tasks/eval_harness/evaluate.py \ --bf16 \ --inference \ --seq-length $SEQ_LEN \ - --task_list arc_challenge,arc_easy,boolq,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc \ + --task_list arc_challenge,arc_easy,boolq,copa,headqa,hellaswag,lambada_openai,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,rte,sciq,sst,triviaqa,webqs,wic,winogrande,wnli,wsc \ --deepspeed \ --deepspeed_config ds_config.json \ --bootstrap_iters 2 \ diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index 14650a6e5..7a15bb735 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -94,7 +94,11 @@ def get_batch_pipe(data): segment_ids=segment_ids.long(), ) - if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: + if args.position_embedding_type not in [ + PositionEmbeddingType.alibi, + PositionEmbeddingType.rotary, + PositionEmbeddingType.xpos, + ]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") return (tokens, position_ids, attention_mask), (labels, loss_mask) diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..160a5afb7 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -24,7 +24,7 @@ import torch import deepspeed -from megatron.enums import PositionEmbeddingType +from megatron.enums import PositionEmbeddingType, UL2ModelType import megatron from megatron.logging import log_levels @@ -49,6 +49,7 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_autoresume_args(parser) parser = _add_biencoder_args(parser) parser = _add_vit_args(parser) + parser = _add_ul2_args(parser) parser = _add_logging_args(parser) parser = _add_zero_args(parser) parser = _add_memoryopt_args(parser) @@ -310,6 +311,17 @@ def parse_args(extra_args_provider=None, defaults={}, ) args.skip_train_iteration_range = skip_train_iteration_range + args.ul2_model_type = UL2ModelType(args.ul2_model_type) + if ( + args.ul2_model_type is not UL2ModelType.ENCODER_DECODER + and args.decoder_seq_length is not None + ): + print( + f'WARNING: `--decoder_seq_length` is ignored when ' + f'`--ul2-model-type` is not ' + f'"{UL2ModelType.ENCODER_DECODER.value}"!' + ) + if args.use_bnb_optimizer: try: import bitsandbytes as bnb @@ -398,7 +410,7 @@ def _add_network_size_args(parser): group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), default=PositionEmbeddingType.absolute, - help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.' + help='Define position embedding type ("absolute" | "rotary" | "alibi" | "xpos"). "absolute" by default.' ) group.add_argument('--glu-activation', type=str, choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(), @@ -901,6 +913,13 @@ def __call__(self, parser, args, values, option_string=None): help='Probability of replacing a token with mask.') group.add_argument('--short-seq-prob', type=float, default=0.1, help='Probability of producing a short sequence.') + group.add_argument('--no-add-mask-tokens', action='store_false', + help='Whether not to add sentinel tokens for masked ' + 'spans in span corruption tasks.', + dest='add_mask_tokens') + group.add_argument('--pack-samples', action='store_true', + help='Whether to pack samples in span corruption ' + 'datasets (T5 or UL2). GPT dataset is always packed.') group.add_argument('--mmap-warmup', action='store_true', help='Warm up mmap files.') group.add_argument('--num-workers', type=int, default=2, @@ -1024,6 +1043,64 @@ def _add_vit_args(parser): return parser +def _add_ul2_args(parser): + group = parser.add_argument_group(title="UL2") + + group.add_argument('--ul2-model-type', type=str, default='ED', + choices=['ED', 'ND', 'CD'], + help='What type of model to use for UL2 pretraining. ' + 'ED = encoder-decoder; ND = non-causal decoder-only; ' + 'CD = causal decoder-only') + group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float, + default=None, + help='Probability of each denoising objective to be ' + 'selected. Uniform distribution by default.') + group.add_argument('--ul2-denoisers', nargs='+', type=str, + default=['R', 'R', 'S', 'X', 'X', 'X', 'X'], + choices=['R', 'S', 'X'], + help='What type of UL2 denoising objective the other ' + 'UL2 configurations refer to.') + group.add_argument('--ul2-mean-span-lengths', nargs='+', type=float, + default=[3, 8, 0.25, 3, 8, 64, 64], + help='Mean length for sampling span lengths. ' + 'Numbers < 1 indicate a mean length of the sequence ' + 'length times that number.') + group.add_argument('--ul2-mask-ratios', nargs='+', type=float, + default=[0.15, 0.15, 0.25, 0.5, 0.5, 0.15, 0.5], + help='Ratio of masked token in the full sequence.') + group.add_argument('--ul2-r-denoiser-token', type=str, default='[R]', + help='What token to prepend for the UL2 R-denoising ' + 'objective. If empty, do not prepend a token for this ' + 'objective.') + group.add_argument('--ul2-s-denoiser-token', type=str, default='[S]', + help='What token to prepend for the UL2 S-denoising ' + 'objective. If empty, do not prepend a token for this ' + 'objective.') + group.add_argument('--ul2-x-denoiser-token', type=str, default='[X]', + help='What token to prepend for the UL2 X-denoising ' + 'objective. If empty, do not prepend a token for this ' + 'objective.') + group.add_argument('--ul2-scale-normal-std', action='store_true', + help='Whether to scale the standard deviation when ' + 'using a normal distribution for span length sampling.') + group.add_argument('--ul2-like-ul2r', action='store_true', + help='Whether to use the updated implementation as ' + 'described in the UL2R paper. This only changes the ' + 'implementation, not the objective configurations!') + group.add_argument('--ul2-pack-any', action='store_true', + help='When `--pack-samples` is also given, whether to ' + 'pack different denoisers into one sample. If not ' + 'given, the same denoiser is used for all packed ' + 'samples.') + group.add_argument('--ul2-pack-no-repeat-prompt', action='store_false', + help='When `--pack-samples` is also given and ' + '`--ul2-pack-any` is *not* given, whether to ' + 'repeat the prompt token for each packed sample.', + dest='ul2_pack_repeat_prompt') + + return parser + + def _add_zero_args(parser): """Text generate arguments.""" diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py index 3841e263e..e3e528e74 100644 --- a/megatron/data/dataset_utils.py +++ b/megatron/data/dataset_utils.py @@ -18,6 +18,8 @@ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. +import bisect +from enum import Enum import math import os import time @@ -37,8 +39,17 @@ DSET_TYPE_BERT = 'standard_bert' DSET_TYPE_ICT = 'ict' DSET_TYPE_T5 = 't5' +DSET_TYPE_UL2 = 'ul2' -DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_UL2] + + +class SamplingStyle(Enum): + POISSON = 'poisson' + GEOMETRIC = 'geometric' + UNIFORM = 'uniform' + NORMAL = 'normal' + UNSCALED_NORMAL = 'unscaled normal' def analyze_data_prefix(data_prefix): @@ -183,6 +194,35 @@ def is_start_piece(piece): return not piece.startswith("##") +def get_ngram_indices( + idx, + ngrams, + cand_indexes, + num_to_predict, + num_filtered_tokens, + prefix_lm, +): + if prefix_lm: + # Find first index which is greater than the number of + # predictions. + first_gt_index = bisect.bisect_right( + cand_indexes, + [num_filtered_tokens - num_to_predict], + ) + # Then move one index before to get less than or equal to the + # number of predictions, handling not going below 0. + first_le_index = max(1, first_gt_index) - 1 + + tail_cand_indexes = cand_indexes[first_le_index:] + ngram_index = [ + tail_cand_indexes[i:] + for i in range(len(tail_cand_indexes)) + ] + else: + ngram_index = [cand_indexes[idx:idx + n] for n in ngrams] + return ngram_index + + def create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, @@ -194,15 +234,23 @@ def create_masked_lm_predictions(tokens, favor_longer_ngram=False, do_permutation=False, geometric_dist=False, - masking_style="bert"): + masking_style="bert", + sampling_style=SamplingStyle.POISSON, + prefix_lm=False): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" + if not isinstance(sampling_style, SamplingStyle): + sampling_style = SamplingStyle(sampling_style) + # Backward-compatibility + if geometric_dist: + sampling_style = SamplingStyle.GEOMETRIC cand_indexes = [] # Note(mingdachen): We create a list for recording if the piece is # the starting piece of current token, where 1 means true, so that # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) + num_filtered_tokens = 0 for (i, token) in enumerate(tokens): if token == cls_id or token == sep_id: @@ -221,6 +269,7 @@ def create_masked_lm_predictions(tokens, cand_indexes.append([i]) if is_start_piece(vocab_id_to_token_dict[token]): token_boundary[i] = 1 + num_filtered_tokens += 1 output_tokens = list(tokens) @@ -231,11 +280,26 @@ def create_masked_lm_predictions(tokens, return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) + if ( + sampling_style is SamplingStyle.NORMAL + or sampling_style is SamplingStyle.UNSCALED_NORMAL + ): + # First, we get the center of our normal distribution from + # `max_ngrams`. Keeping the meaning of `max_ngrams` this way + # plays nicely with the other probability distributions in terms + # of math. + normal_mean = (max_ngrams + 1) / 2 + normal_std = ( + math.sqrt(normal_mean) + if sampling_style is not SamplingStyle.UNSCALED_NORMAL + else 1.0 + ) + # However, we do not want to bound the maximum length of + # n-grams. + max_ngrams = num_filtered_tokens - 1 ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) - if not geometric_dist: + if sampling_style is SamplingStyle.POISSON: # Note(mingdachen): # By default, we set the probilities to favor shorter ngram sequences. pvals = 1. / np.arange(1, max_ngrams + 1) @@ -243,14 +307,30 @@ def create_masked_lm_predictions(tokens, if favor_longer_ngram: pvals = pvals[::-1] - ngram_indexes = [] - for idx in range(len(cand_indexes)): - ngram_index = [] - for n in ngrams: - ngram_index.append(cand_indexes[idx:idx + n]) - ngram_indexes.append(ngram_index) + if prefix_lm: + # We only do one span searching loop anyway, so this does not + # matter in terms of random search. However, we do want to allow + # sequences greater than the mean ratio. + num_to_predict = max_predictions_per_seq - np_rng.shuffle(ngram_indexes) + ngram_index_indexes = np.array([0]) + else: + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + ngram_index_indexes = np.arange(len(cand_indexes)) + np_rng.shuffle(ngram_index_indexes) + + def get_ngram_indices_(idx): + return get_ngram_indices( + idx, + ngrams, + cand_indexes, + num_to_predict, + num_filtered_tokens, + prefix_lm, + ) + ngram_indexes = map(get_ngram_indices_, ngram_index_indexes) (masked_lms, masked_spans) = ([], []) covered_indexes = set() @@ -261,20 +341,39 @@ def create_masked_lm_predictions(tokens, continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. + is_covered = False for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes: - continue + is_covered = True + break + if is_covered: + break + if is_covered: + continue - if not geometric_dist: + if sampling_style is SamplingStyle.POISSON: n = np_rng.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) - else: + elif sampling_style is SamplingStyle.GEOMETRIC: # Sampling "n" from the geometric distribution and clipping it to # the max_ngrams. Using p=0.2 default from the SpanBERT paper # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) n = min(np_rng.geometric(0.2), max_ngrams) + elif sampling_style is SamplingStyle.UNIFORM: + n = np_rng.choice(ngrams[:len(cand_index_set)]) + elif ( + sampling_style is SamplingStyle.NORMAL + or sampling_style is SamplingStyle.UNSCALED_NORMAL + ): + n = round(np.clip( + np_rng.normal(loc=normal_mean, scale=normal_std), + 1, + len(cand_index_set), + )) + else: + raise ValueError('unknown sampling style') index_set = sum(cand_index_set[n - 1], []) n -= 1 @@ -324,7 +423,8 @@ def create_masked_lm_predictions(tokens, label=[tokens[index] for index in index_set])) assert len(masked_lms) <= num_to_predict - np_rng.shuffle(ngram_indexes) + np_rng.shuffle(ngram_index_indexes) + ngram_indexes = map(get_ngram_indices_, ngram_index_indexes) select_indexes = set() if do_permutation: @@ -450,7 +550,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + seed, skip_warmup, binary_head, max_seq_length_dec, + dataset_type=dataset_type) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -522,6 +623,7 @@ def build_dataset(index, name): from megatron.data.bert_dataset import BertDataset from megatron.data.ict_dataset import ICTDataset from megatron.data.t5_dataset import T5Dataset + from megatron.data.ul2_dataset import UL2Dataset dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. @@ -553,13 +655,40 @@ def build_dataset(index, name): **kwargs ) elif dataset_type == DSET_TYPE_T5: + args = get_args() dataset = T5Dataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, max_seq_length_dec=max_seq_length_dec, short_seq_prob=short_seq_prob, + add_mask_tokens=args.add_mask_tokens, + pack_samples=args.pack_samples, **kwargs ) + elif dataset_type == DSET_TYPE_UL2: + args = get_args() + dataset = UL2Dataset( + indexed_dataset=indexed_dataset, + model_type=args.ul2_model_type, + denoiser_ratios=args.ul2_denoiser_ratios, + denoisers=args.ul2_denoisers, + mean_span_lengths=args.ul2_mean_span_lengths, + mask_ratios=args.ul2_mask_ratios, + add_mask_tokens=args.add_mask_tokens, + pack_samples=args.pack_samples, + denoiser_tokens={ + 'R': args.ul2_r_denoiser_token, + 'S': args.ul2_s_denoiser_token, + 'X': args.ul2_x_denoiser_token, + }, + scale_normal_std=args.ul2_scale_normal_std, + like_ul2r=args.ul2_like_ul2r, + pack_any=args.ul2_pack_any, + pack_repeat_prompt=args.ul2_pack_repeat_prompt, + max_seq_length_dec=max_seq_length_dec, + short_seq_prob=short_seq_prob, + **kwargs, + ) elif dataset_type == DSET_TYPE_BERT: dataset = BertDataset( indexed_dataset=indexed_dataset, @@ -714,15 +843,7 @@ def get_samples_mapping(indexed_dataset, print_rank_0(' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format( time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + dp_pp_barrier() # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( @@ -735,3 +856,18 @@ def get_samples_mapping(indexed_dataset, samples_mapping.shape[0])) return samples_mapping + + +def dp_pp_barrier(): + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce( + counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size( + group=mpu.get_tensor_model_parallel_group()) + ) diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py index 4edf14207..944affb70 100644 --- a/megatron/data/decoder_packed_mtf_dataset.py +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -7,7 +7,7 @@ from megatron import print_rank_0, mpu, logging from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \ - get_train_valid_test_split_ + get_train_valid_test_split_, dp_pp_barrier from megatron.data.mtf_dataset import MTFDataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset @@ -437,15 +437,7 @@ def _build_index_mappings( print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping' ' (seconds): {:4f}'.format(time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + dp_pp_barrier() # Load mappings. start_time = time.time() diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 0db1aa2fe..05944503d 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -15,16 +15,21 @@ """GPT style dataset.""" +import math import os import time import numpy as np import torch -from megatron import mpu, print_rank_0 +from megatron import print_rank_0 from megatron.data.blendable_dataset import BlendableDataset -from megatron.data.dataset_utils import get_datasets_weights_and_num_samples -from megatron.data.dataset_utils import get_train_valid_test_split_, get_split_by_range_ +from megatron.data.dataset_utils import ( + dp_pp_barrier, + get_datasets_weights_and_num_samples, + get_split_by_range_, + get_train_valid_test_split_, +) from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset @@ -237,7 +242,7 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.doc_idx, self.sample_idx, self.shuffle_idx = build_index_mappings( self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed) @@ -247,35 +252,50 @@ def __len__(self): return self.sample_idx.shape[0] - 1 def __getitem__(self, idx): - # Get the shuffled index. - idx = self.shuffle_idx[idx] - # Start and end documents and offsets. - doc_index_f = self.sample_idx[idx][0] - doc_index_l = self.sample_idx[idx + 1][0] - offset_f = self.sample_idx[idx][1] - offset_l = self.sample_idx[idx + 1][1] - # If we are within the same document, just extract the chunk. - if doc_index_f == doc_index_l: - sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1) - else: - # Otherwise, get the rest of the initial document. - sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f)] - # Loop over all in between documents and add the entire document. - for i in range(doc_index_f + 1, doc_index_l): - sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) - # And finally add the relevant portion of last document. - sample_list.append(self.indexed_dataset.get( - self.doc_idx[doc_index_l], - length=offset_l + 1)) - sample = np.concatenate(sample_list) - + sample_list = get_samples(self.indexed_dataset, self.doc_idx, + self.sample_idx, self.shuffle_idx, idx) + sample = np.concatenate(sample_list) return {'text': np.array(sample, dtype=np.int64)} -def _build_index_mappings(name, data_prefix, documents, sizes, +def get_samples(indexed_dataset, doc_idx, sample_idx, shuffle_idx, idx): + # Get the shuffled index. + idx = shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = sample_idx[idx][0] + doc_index_l = sample_idx[idx + 1][0] + offset_f = sample_idx[idx][1] + offset_l = sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = indexed_dataset.get(doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1) + sample_list = [sample] + else: + # Otherwise, get the rest of the initial document. + sample_list = [indexed_dataset.get(doc_idx[doc_index_f], + offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(indexed_dataset.get(doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append(indexed_dataset.get( + doc_idx[doc_index_l], + length=offset_l + 1)) + return sample_list + + +def _get_filename_prefix(data_prefix, name, num_samples, seq_length, seed): + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}sl'.format(seq_length) + _filename += '_{}s'.format(seed) + return _filename + + +def build_index_mappings(name, data_prefix, documents, sizes, num_samples, seq_length, seed, cutoff_last_epoch=0.95): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. @@ -290,11 +310,8 @@ def _build_index_mappings(name, data_prefix, documents, sizes, np_rng = np.random.RandomState(seed=seed) # Filename of the index mappings. - _filename = data_prefix - _filename += '_{}_indexmap'.format(name) - _filename += '_{}ns'.format(num_samples) - _filename += '_{}sl'.format(seq_length) - _filename += '_{}s'.format(seed) + _filename = _get_filename_prefix( + data_prefix, name, num_samples, seq_length, seed) doc_idx_filename = _filename + '_doc_idx.npy' sample_idx_filename = _filename + '_sample_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy' @@ -379,15 +396,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, print_rank_0(' > elasped time to build and save shuffle-idx mapping' ' (seconds): {:4f}'.format(time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + dp_pp_barrier() # Load mappings. start_time = time.time() @@ -409,6 +418,85 @@ def _build_index_mappings(name, data_prefix, documents, sizes, return doc_idx, sample_idx, shuffle_idx +def build_index_mappings_full_docs( + name, data_prefix, documents, sizes, + num_samples, seq_length, seed): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = _get_filename_prefix( + data_prefix, name, num_samples, seq_length, seed) + _filename += '_fd' # Full docs + doc_idx_filename = _filename + '_doc_idx.npy' + sample_idx_filename = _filename + '_sample_idx.npy' + shuffle_idx_filename = _filename + '_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0: + if (not os.path.isfile(doc_idx_filename)) or \ + (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): + + print_rank_0( + ' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert sizes.dtype == np.int32 + # sample_idx = helpers.build_sample_idx_full_docs( + # sizes, doc_idx, seq_length, num_samples) + doc_idx, sample_idx = _build_sample_idx_full_docs( + sizes, documents, seq_length, num_samples, np_rng) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx and ' + 'sample-idx mapping (seconds): {:4f}'.format( + time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + num_samples_ = sample_idx.shape[0] + shuffle_idx = _build_shuffle_idx( + num_samples_, sample_idx.shape[0], np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + ' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + + dp_pp_barrier() + + # Load mappings. + start_time = time.time() + print_rank_0(' > loading doc-idx mapping from {}'.format( + doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading sample-idx mapping from {}'.format( + sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading shuffle-idx mapping from {}'.format( + shuffle_idx_filename)) + shuffle_idx = np.load( + shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + sample_idx.shape[0])) + num_epochs = math.ceil(len(doc_idx) / len(documents)) + print_rank_0(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + def _num_tokens(documents, sizes): """Total number of tokens in the dataset.""" return np.sum(sizes[documents]) @@ -494,6 +582,60 @@ def _build_sample_idx(sizes, doc_idx, seq_length, return sample_idx +def _build_sample_idx_full_docs( + sizes, documents, seq_length, num_samples, np_rng): + """Sample index mapping is a 1D array with sizes + [number-of-samples] where [..., 0] contains + the last index into `doc_idx`. + """ + sample_idx = np.zeros([num_samples], dtype=np.int32) + # If we only manage to pack one sample each time, we need this many + # epochs. + min_epochs = math.ceil(num_samples / len(documents)) + + doc_idx = _build_doc_idx(documents, min_epochs, np_rng, False) + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + while sample_index < num_samples: + + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + is_multiple = False + while remaining_seq_length != 0: + if doc_idx_index >= len(doc_idx): + # Extend doc-idx. + doc_idx = np.concatenate([doc_idx, _build_doc_idx( + documents, 1, np_rng, False)]) + + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] + # And add it to the current sequence. + remaining_seq_length -= doc_length + if is_multiple: + remaining_seq_length -= 1 + # If we have more than a full sequence, set remaining length + # to zero so we return from the while loop. + if remaining_seq_length <= 0: + remaining_seq_length = 0 + else: + # Otherwise, start from the begining of the next document. + doc_idx_index += 1 + is_multiple = True + # Record the sequence. + sample_idx[sample_index] = doc_idx_index + + # Reset to next document. + doc_idx_index += 1 + sample_index += 1 + + # `doc_idx_index` is already incremented by one. We want to include + # it because the last document index is inclusive. + return doc_idx[:doc_idx_index], sample_idx + + def _build_shuffle_idx(num_samples, total_size, np_rng): """Build the range [0, size) and shuffle.""" print(' > building shuffle index with split [0, {}) and [{}, {}) ' diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index d0d312544..0676e697f 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -99,8 +99,8 @@ def write_longs(f, a): 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float, - 7: np.double, + 7: np.single, + 6: np.double, 8: np.uint16 } @@ -273,7 +273,7 @@ class IndexedDatasetBuilder(object): np.int16: 2, np.int32: 4, np.int64: 8, - np.float: 4, + np.single: 4, np.double: 8 } diff --git a/megatron/data/t5_dataset.py b/megatron/data/t5_dataset.py index 42110b923..84d43fb7c 100644 --- a/megatron/data/t5_dataset.py +++ b/megatron/data/t5_dataset.py @@ -25,13 +25,35 @@ create_masked_lm_predictions, get_samples_mapping ) +from megatron.data.gpt_dataset import build_index_mappings_full_docs + + +class LengthExceededError(ValueError): + def __init__(self, msg=None): + if msg is None: + msg = ( + 'The sequence input became too long. ' + 'Try to increase `--seq-length` or `--encoder-seq-length`.' + ) + super().__init__(msg) + + +class DecoderLengthExceededError(ValueError): + def __init__(self, msg=None): + if msg is None: + msg = ( + 'The sequence input for the decoder became too long. ' + 'Try to increase `--decoder-seq-length`.' + ) + super().__init__(msg) + class T5Dataset(torch.utils.data.Dataset): def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, max_seq_length_dec, - short_seq_prob, seed): + short_seq_prob, add_mask_tokens, pack_samples, seed): # Params to store. self.name = name @@ -42,17 +64,33 @@ def __init__(self, name, indexed_dataset, data_prefix, # Dataset. self.indexed_dataset = indexed_dataset - - # Build the samples mapping. - self.samples_mapping = get_samples_mapping(self.indexed_dataset, - data_prefix, - num_epochs, - max_num_samples, - self.max_seq_length - 2, # account for added tokens - short_seq_prob, - self.seed, - self.name, - False) + self.pack_samples = pack_samples + + # Minimum number of tokens added: BOS and EOS. + min_added_tokens = 2 + if self.pack_samples: + ( + self.doc_idx, + self.sample_idx, + self.shuffle_idx, + ) = build_index_mappings_full_docs( + self.name, data_prefix, self.indexed_dataset.get_doc_idx()[:-1], + self.indexed_dataset.sizes, max_num_samples, + self.max_seq_length - min_added_tokens, self.seed) + else: + # Build the samples mapping. + self.samples_mapping = get_samples_mapping( + self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + # account for added tokens + self.max_seq_length - min_added_tokens, + short_seq_prob, + self.seed, + self.name, + False, + ) # Vocab stuff. tokenizer = get_tokenizer() @@ -64,31 +102,76 @@ def __init__(self, name, indexed_dataset, data_prefix, self.pad_id = tokenizer.pad self.bos_id = tokenizer.bos_token_id self.eos_id = tokenizer.eos_token_id - self.sentinel_tokens = tokenizer.additional_special_tokens_ids - assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script" + if add_mask_tokens: + self.sentinel_tokens = tokenizer.additional_special_tokens_ids + assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script" + else: + self.sentinel_tokens = None def __len__(self): - return self.samples_mapping.shape[0] + if self.pack_samples: + return self.sample_idx.shape[0] + else: + return self.samples_mapping.shape[0] def __getitem__(self, idx): - - start_index, end_index, seq_length = self.samples_mapping[idx] - sample = [] - for index in range(start_index, end_index): - sample.append(self.indexed_dataset[index]) # Note that this rng state should be numpy and not python since # python randint is inclusive whereas the numpy one is exclusive. np_rng = np.random.RandomState(seed=(self.seed + idx)) - return build_training_sample(sample, seq_length, - self.max_seq_length, # needed for padding - self.max_seq_length_dec, - self.vocab_id_list, - self.vocab_id_to_token_dict, - self.cls_id, self.sep_id, - self.mask_id, self.pad_id, - self.masked_lm_prob, np_rng, - self.bos_id, self.eos_id, - self.sentinel_tokens) + if self.pack_samples: + samples_dict = self._pack_samples(np_rng, idx) + else: + start_index, end_index, seq_length = self.samples_mapping[idx] + sample = [] + for index in range(start_index, end_index): + sample.append(self.indexed_dataset[index]) + samples_dict = build_training_sample( + sample, seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, self.vocab_id_list, + self.vocab_id_to_token_dict, self.cls_id, self.sep_id, + self.mask_id, self.pad_id, self.masked_lm_prob, np_rng, + self.bos_id, self.eos_id, self.sentinel_tokens) + return samples_dict + + def _pack_samples(self, np_rng, idx): + samples = get_samples(self.indexed_dataset, self.doc_idx, + self.sample_idx, self.shuffle_idx, idx) + samples_dict = create_samples_dict( + self.max_seq_length, self.max_seq_length_dec) + prev_len = 0 + prev_len_dec = 0 + + for sample in samples: + remaining_seq_len = self.max_seq_length - prev_len + seq_length = min(remaining_seq_len, len(sample)) + + result_sample = build_training_sample( + [sample], seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, self.vocab_id_list, + self.vocab_id_to_token_dict, self.cls_id, self.sep_id, + self.mask_id, self.pad_id, self.masked_lm_prob, np_rng, + self.bos_id, self.eos_id, self.sentinel_tokens) + maybe_lens = update_samples_dict( + samples_dict, + result_sample, + self.max_seq_length, + self.max_seq_length_dec, + prev_len, + prev_len_dec, + self.pad_id, + ) + if maybe_lens is None: + # We are exceeding our sequence length already. + break + + len_enc, len_dec = maybe_lens + prev_len += len_enc + prev_len_dec += len_dec + + add_final_padding(samples_dict, prev_len, prev_len_dec, self.pad_id) + return samples_dict def build_training_sample(sample, target_seq_length, @@ -104,6 +187,8 @@ def build_training_sample(sample, target_seq_length, target_seq_length: Desired sequence length. max_seq_length: Maximum length of the sequence. All values are padded to this length. + max_seq_length_dec: Maximum length of the decoder input sequence. All + values are padded to this length. vocab_id_list: List of vocabulary ids. Used to pick a random id. vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. cls_id: Start of example id. @@ -157,29 +242,42 @@ def build_training_sample(sample, target_seq_length, return train_sample -def pad_and_convert_to_numpy(tokens, masked_positions, - masked_labels, pad_id, - max_seq_length, max_seq_length_dec, - masked_spans=None, bos_id=None, - eos_id=None, sentinel_tokens=None): - """Pad sequences and convert them to numpy.""" +def merge_subsequent_masks(tokens, masked_spans=None, bos_id=None, + eos_id=None, sentinel_tokens=None, prefix_lm=False): + if prefix_lm: + assert len(masked_spans) <= 1, \ + 'Received more than one masked span for PrefixLM masking' + elif sentinel_tokens is not None: + sentinel_tokens = collections.deque(sentinel_tokens) + + insert_mask_tokens = not prefix_lm and sentinel_tokens is not None - sentinel_tokens = collections.deque(sentinel_tokens) t5_input = [] (t5_decoder_in, t5_decoder_out) = ([bos_id], []) (start_index, end_index) = (0, None) for span in masked_spans: - flag = sentinel_tokens.popleft() - - # Append the same tokens in decoder input and output - t5_decoder_in.append(flag) + end_index = span.index[0] + # The part of the sequence that is visible before the masked + # span starts. Starting from beginning or end of last masked + # span. + before_mask = tokens[start_index:end_index] + + if insert_mask_tokens: + flag = sentinel_tokens.popleft() + + # Append the same tokens in decoder input and output + t5_decoder_in.append(flag) + t5_decoder_out.append(flag) + elif not prefix_lm: + # Append visible part of input sequence. + t5_decoder_in.extend(before_mask) + t5_decoder_out.extend(before_mask) t5_decoder_in.extend(span.label) - t5_decoder_out.append(flag) t5_decoder_out.extend(span.label) - end_index = span.index[0] - t5_input.extend(tokens[start_index: end_index]) - t5_input.append(flag) + t5_input.extend(before_mask) + if insert_mask_tokens: + t5_input.append(flag) # the next start index is the token after the last span token start_index = span.index[-1] + 1 @@ -189,6 +287,19 @@ def pad_and_convert_to_numpy(tokens, masked_positions, # Add the remaining tokens to the t5 input t5_input.extend(tokens[start_index:]) + return t5_input, t5_decoder_in, t5_decoder_out + + +def pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, + max_seq_length, max_seq_length_dec, + masked_spans=None, bos_id=None, + eos_id=None, sentinel_tokens=None, + prefix_lm=False): + """Pad sequences and convert them to numpy.""" + + t5_input, t5_decoder_in, t5_decoder_out = merge_subsequent_masks( + tokens, masked_spans, bos_id, eos_id, sentinel_tokens, prefix_lm) # assert (len(t5_input) - len(masked_spans)) + \ # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens) @@ -198,7 +309,8 @@ def pad_and_convert_to_numpy(tokens, masked_positions, # Encoder-side padding mask. num_tokens = len(t5_input) padding_length = max_seq_length - num_tokens - assert padding_length >= 0 + if padding_length < 0: + raise LengthExceededError() assert len(masked_positions) == len(masked_labels) # Tokens.. @@ -208,7 +320,8 @@ def pad_and_convert_to_numpy(tokens, masked_positions, # Decoder-side padding mask. num_tokens_dec = len(t5_decoder_in) padding_length_dec = max_seq_length_dec - num_tokens_dec - assert padding_length_dec >= 0 + if padding_length_dec < 0: + raise DecoderLengthExceededError() filler_dec = [pad_id] * padding_length_dec tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64) @@ -268,3 +381,139 @@ def make_history_mask_3d(block): history_mask = (arange[None, ] <= arange[:, None])[None, ] history_mask = history_mask.expand(batch, length, length) return history_mask + + +def get_samples(indexed_dataset, doc_idx, sample_idx, shuffle_idx, idx): + # Get the shuffled index. + idx = shuffle_idx[idx] + # Start and end documents. + if idx == 0: + doc_index_f = 0 + else: + doc_index_f = sample_idx[idx - 1] + 1 + doc_index_l = sample_idx[idx] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = indexed_dataset.get(doc_idx[doc_index_f]) + sample_list = [sample] + else: + # Otherwise, get the rest of the initial document. + sample_list = [indexed_dataset.get(doc_idx[doc_index_f])] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(indexed_dataset.get(doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append(indexed_dataset.get( + doc_idx[doc_index_l])) + return sample_list + + +def create_samples_dict(max_seq_length, max_seq_length_dec): + samples_dict = { + 'text_enc': np.empty((max_seq_length,), dtype=np.int64), + 'text_dec': np.empty( + (max_seq_length_dec,), dtype=np.int64), + 'labels': np.empty( + (max_seq_length_dec,), dtype=np.int64), + 'loss_mask': np.zeros( + (max_seq_length_dec,), dtype=np.int64), + 'truncated': 0, + 'enc_mask': np.zeros( + (max_seq_length, max_seq_length), + dtype=np.int64, + ), + 'dec_mask': np.zeros( + (max_seq_length_dec, max_seq_length_dec), + dtype=np.int64, + ), + 'enc_dec_mask': np.zeros( + (max_seq_length_dec, max_seq_length), + dtype=np.int64, + ), + } + return samples_dict + + +def _remove_padding(result_sample, pad_id): + # Remove padding + padding_start = np.argmax(result_sample['text_enc'] == pad_id) + padding_start_dec = np.argmax(result_sample['text_dec'] == pad_id) + if padding_start == 0 and padding_start_dec == 0: + return + elif padding_start == 0: + padding_start = None + elif padding_start_dec == 0: + padding_start_dec = None + + result_sample['text_enc'] = result_sample['text_enc'][:padding_start] + for key in ['text_dec', 'labels', 'loss_mask']: + result_sample[key] = result_sample[key][:padding_start_dec] + result_sample['enc_mask'] = \ + result_sample['enc_mask'][:padding_start, :padding_start] + result_sample['enc_dec_mask'] = \ + result_sample['enc_dec_mask'][:padding_start_dec, :padding_start] + result_sample['dec_mask'] = \ + result_sample['dec_mask'][:padding_start_dec, :padding_start_dec] + + +def get_lens(key, prev_len, prev_len_dec, len_enc, len_dec): + assert key != 'enc_dec_mask' + if key in ['text_enc', 'enc_mask']: + offset = prev_len + length = len_enc + else: + offset = prev_len_dec + length = len_dec + return offset, length + + +def update_samples_dict( + samples_dict, + result_sample, + max_seq_len, + max_seq_len_dec, + prev_len, + prev_len_dec, + pad_id, +): + _remove_padding(result_sample, pad_id) + + len_enc = len(result_sample['text_enc']) + len_dec = len(result_sample['text_dec']) + + if ( + prev_len + len_enc > max_seq_len + or prev_len_dec + len_dec > max_seq_len_dec + ): + return None + + for key in ['text_enc', 'text_dec', 'labels']: + curr_sample = result_sample[key] + offset, length = get_lens( + key, prev_len, prev_len_dec, len_enc, len_dec) + samples_dict[key][offset:offset + length] = curr_sample + + samples_dict['loss_mask'][ + prev_len_dec:prev_len_dec + len_dec, + ] += result_sample['loss_mask'] + samples_dict['enc_mask'][ + prev_len:prev_len + len_enc, + prev_len:prev_len + len_enc, + ] += result_sample['enc_mask'] + samples_dict['dec_mask'][ + prev_len_dec:prev_len_dec + len_dec, + prev_len_dec:prev_len_dec + len_dec, + ] += result_sample['dec_mask'] + samples_dict['enc_dec_mask'][ + prev_len_dec:prev_len_dec + len_dec, + prev_len:prev_len + len_enc, + ] += result_sample['enc_dec_mask'] + + samples_dict['truncated'] += result_sample['truncated'] + return len_enc, len_dec + + +def add_final_padding(samples_dict, prev_len, prev_len_dec, pad_id): + samples_dict['text_enc'][prev_len:] = pad_id + samples_dict['text_dec'][prev_len_dec:] = pad_id + samples_dict['labels'][prev_len_dec:] = -1 diff --git a/megatron/data/ul2_dataset.py b/megatron/data/ul2_dataset.py new file mode 100644 index 000000000..e8d3862ed --- /dev/null +++ b/megatron/data/ul2_dataset.py @@ -0,0 +1,558 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UL2-style dataset.""" + +from collections import ChainMap +import math + +import numpy as np +import torch + +from megatron import get_tokenizer +from megatron.data.dataset_utils import ( + create_masked_lm_predictions, + get_samples_mapping, + SamplingStyle +) +from megatron.data.gpt_dataset import build_index_mappings_full_docs +from megatron.data.t5_dataset import ( + add_final_padding, + create_samples_dict as t5_create_samples_dict, + get_samples, + LengthExceededError, + make_history_mask, + merge_subsequent_masks, + pad_and_convert_to_numpy, + update_samples_dict, +) +from megatron.enums import UL2ModelType + + +def is_decoder_only(ul2_model_type): + """Return whether we use a decoder-only model.""" + assert isinstance(ul2_model_type, UL2ModelType) + return ul2_model_type is not UL2ModelType.ENCODER_DECODER + + +def is_prefix_lm(ul2_model_type): + """Return whether we use a non-causal decoder-only model.""" + assert isinstance(ul2_model_type, UL2ModelType) + return ul2_model_type is UL2ModelType.NON_CAUSAL_DECODER + + +class UL2Dataset(torch.utils.data.Dataset): + def __init__(self, name, indexed_dataset, data_prefix, + num_epochs, max_num_samples, model_type, + denoiser_ratios, denoisers, mean_span_lengths, + mask_ratios, add_mask_tokens, pack_samples, + denoiser_tokens, scale_normal_std, like_ul2r, + pack_any, pack_repeat_prompt, + max_seq_length, max_seq_length_dec, + short_seq_prob, seed): + super().__init__() + + if denoiser_ratios is None: + # Uniform distribution by default. + denoiser_ratios = [1 / len(denoisers)] * len(denoisers) + + assert ( + len(denoiser_ratios) == len(denoisers) + == len(mean_span_lengths) == len(mask_ratios) + ), ( + 'some UL2 configurations do not correspond to the amount of ' + 'denoising objectives' + ) + + # Params to store. + self.name = name + self.seed = seed + self.max_seq_length = max_seq_length + self.max_seq_length_dec = max_seq_length_dec + + self.model_type = model_type + self.denoiser_ratios = [ + denoiser_ratio / sum(denoiser_ratios) + for denoiser_ratio in denoiser_ratios + ] + self.denoisers = [denoiser.upper() for denoiser in denoisers] + self.mean_span_lengths = mean_span_lengths + self.mask_ratios = mask_ratios + self.scale_normal_std = scale_normal_std + self.like_ul2r = like_ul2r + + # Dataset. + self.indexed_dataset = indexed_dataset + self.pack_samples = pack_samples + self.pack_any = pack_any + self.repeat_prompt = pack_repeat_prompt + + # Minimum number of tokens added: BOS and EOS. + min_added_tokens = 2 + if is_decoder_only(model_type): + # Here we also add a SEP token. + min_added_tokens += 1 + + # Build the samples mapping. + if self.pack_samples: + ( + self.doc_idx, + self.sample_idx, + self.shuffle_idx, + ) = build_index_mappings_full_docs( + self.name, data_prefix, + self.indexed_dataset.get_doc_idx()[:-1], + self.indexed_dataset.sizes, max_num_samples, + self.max_seq_length - min_added_tokens, self.seed) + else: + self.samples_mapping = get_samples_mapping( + self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + # account for added tokens + self.max_seq_length - min_added_tokens, + short_seq_prob, + self.seed, + self.name, + False, + ) + + # Vocab stuff. + tokenizer = get_tokenizer() + # Some tokenizers split their vocabularies. Here we handle both + # cases. + if ( + hasattr(tokenizer, 'tokenizer') + and hasattr(tokenizer.tokenizer, 'special_tokens_decoder') + ): + inv_vocab = ChainMap( + tokenizer.inv_vocab, + tokenizer.tokenizer.special_tokens_decoder, + ) + vocab = ChainMap( + tokenizer.vocab, tokenizer.tokenizer.special_tokens) + else: + inv_vocab = tokenizer.inv_vocab + vocab = tokenizer.vocab + self.vocab_id_list = list(inv_vocab.keys()) + self.vocab_id_to_token_dict = inv_vocab + # Replace empty string tokens with `None` – we want to ignore + # those. + self.cls_ids = { + denoiser: vocab[token] if token else None + for (denoiser, token) in denoiser_tokens.items() + } + # cls_token = self.vocab_id_to_token_dict[tokenizer.cls] + # if cls_token not in self.cls_ids: + # self.cls_ids[cls_token] = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + self.bos_id = tokenizer.bos_token_id + self.eos_id = tokenizer.eos_token_id + + if add_mask_tokens: + # Filter out denoiser tokens. + self.sentinel_tokens = [ + token + for token in tokenizer.additional_special_tokens_ids + if token not in self.cls_ids.values() + ] + assert len(self.sentinel_tokens) > 0, \ + "Provide the argument --vocab-extra-ids 100 to the script" + else: + self.sentinel_tokens = None + + def __len__(self): + if self.pack_samples: + return self.sample_idx.shape[0] + else: + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + np_rng = np.random.RandomState(seed=(self.seed + idx)) + # Denoiser selection + denoiser_index = np_rng.choice( + np.arange(len(self.denoisers)), + p=self.denoiser_ratios, + ) + + if self.pack_samples: + samples_dict = self._pack_samples(np_rng, idx, denoiser_index) + else: + start_index, end_index, seq_length = self.samples_mapping[idx] + sample = [] + for index in range(start_index, end_index): + sample.append(self.indexed_dataset[index]) + samples_dict = build_training_sample( + sample, seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, self.vocab_id_list, + self.vocab_id_to_token_dict, self.cls_ids, self.sep_id, + self.mask_id, self.pad_id, self.model_type, denoiser_index, + self.denoisers, self.mean_span_lengths, + self.mask_ratios, self.scale_normal_std, self.like_ul2r, + np_rng, self.bos_id, self.eos_id, self.sentinel_tokens) + return samples_dict + + def _pack_samples(self, np_rng, idx, denoiser_index): + samples = get_samples(self.indexed_dataset, self.doc_idx, + self.sample_idx, self.shuffle_idx, idx) + samples_dict = create_samples_dict( + self.max_seq_length, self.max_seq_length_dec, self.model_type) + prev_len = 0 + prev_len_dec = 0 + cls_ids = self.cls_ids + + for sample in samples: + remaining_seq_len = self.max_seq_length - prev_len + seq_length = min(remaining_seq_len, len(sample)) + + result_sample = build_training_sample( + [sample], seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, self.vocab_id_list, + self.vocab_id_to_token_dict, cls_ids, self.sep_id, + self.mask_id, self.pad_id, self.model_type, denoiser_index, + self.denoisers, self.mean_span_lengths, + self.mask_ratios, self.scale_normal_std, self.like_ul2r, + np_rng, self.bos_id, self.eos_id, self.sentinel_tokens) + if is_decoder_only(self.model_type): + maybe_lens = update_samples_dict_decoder_only( + samples_dict, + result_sample, + self.max_seq_length, + prev_len, + self.pad_id, + ) + else: + maybe_lens = update_samples_dict( + samples_dict, + result_sample, + self.max_seq_length, + self.max_seq_length_dec, + prev_len, + prev_len_dec, + self.pad_id, + ) + if maybe_lens is None: + # We are exceeding our sequence length already. + break + + if is_decoder_only(self.model_type): + len_enc = maybe_lens + else: + len_enc, len_dec = maybe_lens + prev_len_dec += len_dec + prev_len += len_enc + + if not self.repeat_prompt and not self.pack_any: + cls_ids = {self.denoisers[denoiser_index]: None} + + if self.pack_any: + denoiser_index = np_rng.choice( + np.arange(len(self.denoisers)), + p=self.denoiser_ratios, + ) + + if is_decoder_only(self.model_type): + samples_dict['text'][prev_len:] = self.pad_id + samples_dict['labels'][prev_len:] = -1 + else: + add_final_padding( + samples_dict, prev_len, prev_len_dec, self.pad_id) + + return samples_dict + + +def build_training_sample(sample, target_seq_length, + max_seq_length, max_seq_length_dec, + vocab_id_list, vocab_id_to_token_dict, + cls_ids, sep_id, mask_id, pad_id, + model_type, denoiser_index, + denoisers, mean_span_lengths, + mask_ratios, scale_normal_std, like_ul2r, + np_rng, bos_id=None, eos_id=None, + sentinel_tokens=None): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + max_seq_length_dec: Maximum length of the decoder input sequence. All + values are padded to this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_ids: Start of example ids. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + model_type: What type of model is used. + denoiser_index: Index of selected denoising objective. + denoisers: What type of UL2 denoising objective the other UL2 + configurations refer to. + mean_span_lengths: Mean length for sampling span lengths. Numbers < 1 + indicate a mean length of the sequence length times that number. + mask_ratios: Ratio of masked token in the full sequence. + scale_normal_std: Whether to scale the standard deviation when using a + normal distribution for span length sampling. + like_ul2r: Whether to use the updated implementation as specified in + the UL2R paper. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + bos_id: start of decoder example id + eos_id: end of generation id + sentinel_tokens: unique value to be substituted for every replaced span + """ + add_mask_tokens = sentinel_tokens is not None + + # Denoiser selection + denoiser = denoisers[denoiser_index] + masked_lm_prob = mask_ratios[denoiser_index] + + assert target_seq_length <= max_seq_length + + # flatten sentences into one list + tokens = [token for sentence in sample for token in sentence] + + # Prepend objective token. + cls_id = cls_ids.get(denoiser, False) + if cls_id is False: + raise ValueError('unknown denoiser') + + # If objective token is `None`, ignore it. + if cls_id is not None: + tokens = [cls_id] + tokens + + max_num_tokens = target_seq_length + if ( + is_decoder_only(model_type) + and denoiser != 'S' + and add_mask_tokens + ): + # Keep space for repeated `extra_id` tokens; not the most data + # efficient since we calculate this based on the maximum number + # of possible `extra_id` tokens. + safe_max_seq_len = math.floor(max_num_tokens / (1 + masked_lm_prob)) + truncated = len(tokens) > safe_max_seq_len + tokens = tokens[:safe_max_seq_len] + else: + # If we are S-denoising, we know three tokens are going to be + # added: `bos`, `sep`, and `eos`. Same when not adding mask + # tokens. + if ( + is_decoder_only(model_type) and denoiser == 'S' + or not add_mask_tokens + ): + max_num_tokens -= 3 + + # If we have a decoder-only model and do not add mask tokens, we + # basically duplicate the sequence. So cut the maximum length in + # half. + if ( + is_decoder_only(model_type) + and denoiser != 'S' + and not add_mask_tokens + ): + max_num_tokens = max_num_tokens // 2 + + # Truncate to `target_sequence_length`. + truncated = len(tokens) > max_num_tokens + tokens = tokens[:max_num_tokens] + + # Masking. + mean_ngrams = mean_span_lengths[denoiser_index] + if mean_ngrams < 1: + # Ensure we always obtain at least one `max_ngrams`. + mean_ngrams = max(1, round(len(tokens) * mean_ngrams)) + max_ngrams = mean_ngrams * 2 - 1 + + if denoiser == 'R' or denoiser == 'X': + if like_ul2r: + sampling_style = SamplingStyle.UNIFORM + elif scale_normal_std: + sampling_style = SamplingStyle.NORMAL + else: + sampling_style = SamplingStyle.UNSCALED_NORMAL + prefix_lm = False + max_predictions_per_seq = len(tokens) - 1 + elif denoiser == 'S': + sampling_style = SamplingStyle.UNIFORM + prefix_lm = True + max_predictions_per_seq = min( + round(masked_lm_prob * len(tokens)) * 2 - 1, + len(tokens) - 1, + ) + else: + raise ValueError('unknown denoiser') + + # Ensure we always have at least one prediction. + max_predictions_per_seq = max(1, max_predictions_per_seq) + ( + tokens, masked_positions, masked_labels, _, masked_spans, + ) = create_masked_lm_predictions( + tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, + max_ngrams=max_ngrams, masking_style="t5", + sampling_style=sampling_style, prefix_lm=prefix_lm, + ) + + if is_decoder_only(model_type): + # Concatenate to one sequence. + tokens_enc, tokens_dec_in, labels = merge_subsequent_masks( + tokens, masked_spans, bos_id, eos_id, sentinel_tokens, prefix_lm) + + # Move EOS tokens to end of sequence. + while tokens_enc and tokens_enc[-1] == eos_id: + del tokens_enc[-1] + tokens_dec_in.append(eos_id) + labels.append(eos_id) + + # Move BOS token to start of sequence. + tokens_dec_in = tokens_dec_in[1:] + if not add_mask_tokens: + # Do not reproduce objective token when not using masking + # tokens. + tokens_dec_in = tokens_dec_in[1:] + labels = labels[1:] + + num_labels = len(labels) + + # Do not add separator token if S-denoising. + separator = [sep_id] if denoiser != 'S' else [] + tokens = ( + [bos_id] + + tokens_enc + + separator + + tokens_dec_in + ) + + # Pad and convert to NumPy. + padding_length = max_seq_length - len(tokens) + if padding_length < 0: + raise LengthExceededError() + filler = [pad_id] * padding_length + + tokens = np.array(tokens + filler, dtype=np.int64) + labels = np.array(( + tokens_enc + + separator + + labels + + filler + ), dtype=np.int64) + + loss_mask = np.zeros(len(tokens), dtype=np.int64) + labels_start_neg_index = -(num_labels + padding_length) + labels_end_neg_index = -padding_length if padding_length > 0 else None + loss_mask[labels_start_neg_index:labels_end_neg_index] = 1 + + dec_mask = make_history_mask(tokens) + if is_prefix_lm(model_type): + dec_mask[:labels_start_neg_index, :labels_start_neg_index] = 1 + + train_sample = { + 'text': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'dec_mask': dec_mask, + } + else: + # Padding. + ( + tokens_enc, tokens_dec_in, labels, enc_mask, + dec_mask, enc_dec_mask, loss_mask, + ) = pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, max_seq_length, + max_seq_length_dec, masked_spans, + bos_id, eos_id, sentinel_tokens, + prefix_lm) + + train_sample = { + 'text_enc': tokens_enc, + 'text_dec': tokens_dec_in, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'enc_mask': enc_mask, + 'dec_mask': dec_mask, + 'enc_dec_mask': enc_dec_mask, + } + return train_sample + + +def create_samples_dict(max_seq_length, max_seq_length_dec, model_type): + if is_decoder_only(model_type): + samples_dict = { + 'text': np.empty((max_seq_length,), dtype=np.int64), + 'labels': np.empty((max_seq_length,), dtype=np.int64), + 'loss_mask': np.zeros((max_seq_length,), dtype=np.int64), + 'truncated': 0, + 'dec_mask': np.zeros( + (max_seq_length, max_seq_length), + dtype=np.int64, + ), + } + else: + samples_dict = t5_create_samples_dict( + max_seq_length, max_seq_length_dec) + return samples_dict + + +def _remove_padding(result_sample, pad_id): + # Remove padding + padding_start = np.argmax(result_sample['text'] == pad_id) + if padding_start == 0: + return + result_sample['text'] = result_sample['text'][:padding_start] + for key in ['labels', 'loss_mask']: + result_sample[key] = result_sample[key][:padding_start] + result_sample['dec_mask'] = \ + result_sample['dec_mask'][:padding_start, :padding_start] + + +def update_samples_dict_decoder_only( + samples_dict, + result_sample, + max_seq_len, + prev_len, + pad_id, +): + _remove_padding(result_sample, pad_id) + len_enc = len(result_sample['text']) + + if prev_len + len_enc > max_seq_len: + return None + + for key in ['text', 'labels']: + curr_sample = result_sample[key] + samples_dict[key][prev_len:prev_len + len_enc] = curr_sample + + samples_dict['loss_mask'][ + prev_len:prev_len + len_enc, + ] += result_sample['loss_mask'] + samples_dict['dec_mask'][ + prev_len:prev_len + len_enc, + prev_len:prev_len + len_enc, + ] += result_sample['dec_mask'] + + samples_dict['truncated'] += result_sample['truncated'] + return len_enc diff --git a/megatron/enums.py b/megatron/enums.py index 90d00a071..c749ab915 100644 --- a/megatron/enums.py +++ b/megatron/enums.py @@ -33,3 +33,9 @@ class PositionEmbeddingType(enum.Enum): rotary = 1 absolute = 2 alibi = 3 + xpos = 4 + +class UL2ModelType(enum.Enum): + ENCODER_DECODER = 'ED' + NON_CAUSAL_DECODER = 'ND' + CAUSAL_DECODER = 'CD' diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 07192e2bf..5fd2ceab4 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -188,7 +188,6 @@ def forward_fused_softmax(self, input, mask): if self.attn_mask_type == AttnMaskType.causal: assert sq == sk, "causal mask is only for self attention" - assert mask is None, "Mask is silently ignored due to the use of a custom kernel" # input is 3D tensor (attn_batches, sq, sk) input = input.view(-1, sq, sk) @@ -214,8 +213,7 @@ def forward_torch_softmax(self, input, mask): if self.scale is not None: input = input * self.scale - if self.attn_mask_type == AttnMaskType.causal: - assert mask is None + if self.attn_mask_type == AttnMaskType.causal and mask is None: assert input.shape[2] == input.shape[3] mask = self.get_causal_mask(input.shape[2]) diff --git a/megatron/model/glu_activations.py b/megatron/model/glu_activations.py index c479d9683..5f30ef846 100644 --- a/megatron/model/glu_activations.py +++ b/megatron/model/glu_activations.py @@ -4,6 +4,7 @@ from megatron import logging from megatron.model.utils import log_debug_usage +from megatron import mpu logger = logging.get_logger(__name__) @@ -38,10 +39,133 @@ def __init__(self): super().__init__(F.silu) +class _T5GLUBase(nn.Module): + def __init__( + self, + in_features, + out_features, + activation_fn=torch.sigmoid, + bias=False, + gather_output=True, + init_method=torch.nn.init.xavier_normal_, + ): + super().__init__() + self.linear = mpu.ColumnParallelLinear( + in_features, + out_features, + bias=bias, + gather_output=gather_output, + init_method=init_method, + ) + self.nonlinear = mpu.ColumnParallelLinear( + in_features, + out_features, + bias=bias, + gather_output=gather_output, + init_method=init_method, + ) + self.activation_fn = activation_fn + + def forward(self, x): + output = self.linear(x)[0] * self.activation_fn(self.nonlinear(x)[0]) + return output, None + + +class T5LiGLU(_T5GLUBase): + def __init__( + self, + in_features, + out_features, + bias=False, + gather_output=True, + init_method=torch.nn.init.xavier_normal_, + ): + super().__init__( + in_features, + out_features, + activation_fn=nn.Identity(), + bias=bias, + gather_output=gather_output, + init_method=init_method, + ) + + +class T5GEGLU(_T5GLUBase): + def __init__( + self, + in_features, + out_features, + bias=False, + gather_output=True, + init_method=torch.nn.init.xavier_normal_, + ): + super().__init__( + in_features, + out_features, + activation_fn=F.gelu, + bias=bias, + gather_output=gather_output, + init_method=init_method, + ) + + +class T5ReGLU(_T5GLUBase): + def __init__( + self, + in_features, + out_features, + bias=False, + gather_output=True, + init_method=torch.nn.init.xavier_normal_, + ): + super().__init__( + in_features, + out_features, + activation_fn=F.relu, + bias=bias, + gather_output=gather_output, + init_method=init_method, + ) + + +class T5SwiGLU(_T5GLUBase): + def __init__( + self, + in_features, + out_features, + bias=False, + gather_output=True, + init_method=torch.nn.init.xavier_normal_, + ): + super().__init__( + in_features, + out_features, + activation_fn=F.silu, + bias=bias, + gather_output=gather_output, + init_method=init_method, + ) + + +def replaces_linear(wrapped_glu_act): + """Return whether the GLU activation wrapped by `log_debug_usage` + contains a type. + """ + return ( + hasattr(wrapped_glu_act, '__closure__') + and wrapped_glu_act.__closure__ + and isinstance(wrapped_glu_act.__closure__[0].cell_contents, type) + ) + + liglu = log_debug_usage(logger, "Using GLU activation: LiGLU.")(torch.jit.script(LiGLU())) geglu = log_debug_usage(logger, "Using GLU activation: GELU.")(torch.jit.script(GEGLU())) reglu = log_debug_usage(logger, "Using GLU activation: ReGLU.")(torch.jit.script(ReGLU())) swiglu = log_debug_usage(logger, "Using GLU activation: SwiGLU.")(torch.jit.script(SwiGLU())) +t5_liglu = log_debug_usage(logger, "Using GLU activation: T5LiGLU.")(T5LiGLU) +t5_geglu = log_debug_usage(logger, "Using GLU activation: T5GELU.")(T5GEGLU) +t5_reglu = log_debug_usage(logger, "Using GLU activation: T5ReGLU.")(T5ReGLU) +t5_swiglu = log_debug_usage(logger, "Using GLU activation: T5SwiGLU.")(T5SwiGLU) GLU_ACTIVATIONS = { @@ -49,4 +173,8 @@ def __init__(self): "liglu": liglu, "reglu": reglu, "swiglu": swiglu, + "t5_geglu": t5_geglu, + "t5_liglu": t5_liglu, + "t5_reglu": t5_reglu, + "t5_swiglu": t5_swiglu, } diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 3494f9e4e..3ca68ef9e 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) \ No newline at end of file + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +# Original implementation adjusted from https://github.com/sunyt32/torchscale + +def fixed_pos_embedding(x, base): + seq_len, dim = x.shape + inv_freq = 1.0 / (base ** (torch.arange(0, dim) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + ) + return torch.cos(sinusoid_inp), torch.sin(sinusoid_inp) + + +class XPosEmbedding(torch.nn.Module): + """ + xPos positional embeddings from https://arxiv.org/abs/2212.10554. + """ + + def __init__(self, head_dim, freq_base=10000, scale_base=512, gamma=0.4, precision=torch.half): + super().__init__() + self.scale_base = scale_base + self.register_buffer( + "scale", + ( + (torch.arange(0, head_dim, 2) + gamma * head_dim) + / ((1.0 + gamma) * head_dim) + ), + ) + self.max_seq_len_cached = None + self.precision = precision + self.freq_base = freq_base + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + scale = ( + self.scale + ** ( + torch.arange(0, seq_len, 1) - seq_len // 2 + ).to(self.scale).div(self.scale_base)[:, None] + ) + + if ( + self.max_seq_len_cached is None + or (seq_len > self.max_seq_len_cached) + ): + self.max_seq_len_cached = seq_len + cos, sin = fixed_pos_embedding(scale, self.freq_base) + self.cos_cached = cos + self.sin_cached = sin + if self.precision == torch.bfloat16: + self.cos_cached = self.cos_cached.bfloat16() + self.sin_cached = self.sin_cached.bfloat16() + return ( + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], + scale, + ) + + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m.unsqueeze(1) + + +def _apply_xpos_emb(x, cos, sin, scale): + # x is assumed to be (seq_len, batch_size, dim) here. + cos = duplicate_interleave(cos * scale) + sin = duplicate_interleave(sin * scale) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +@torch.jit.script +def apply_xpos_emb(q, k, cos, sin, scale, offset: int = 0): + # q/k are assumed to be (seq_len, batch_size, dim) here. + cos = cos[offset:q.shape[0] + offset] + sin = sin[offset:q.shape[0] + offset] + scale = scale[offset:q.shape[0] + offset] + return ( + _apply_xpos_emb(q, cos, sin, scale), + _apply_xpos_emb(k, cos, sin, 1.0 / scale), + ) + + +def apply_xpos_emb_torch(q, k, cos, sin, scale, offset: int = 0): + # q/k are assumed to be (seq_len, batch_size, dim) here. + cos = cos[offset:q.shape[0] + offset] + sin = sin[offset:q.shape[0] + offset] + scale = scale[offset:q.shape[0] + offset] + return ( + _apply_xpos_emb(q, cos, sin, scale), + _apply_xpos_emb(k, cos, sin, 1.0 / scale), + ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 03e6faaec..077b24763 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,8 +30,15 @@ import deepspeed -from .glu_activations import GLU_ACTIVATIONS -from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb +from .glu_activations import GLU_ACTIVATIONS, replaces_linear +from .positional_embeddings import ( + apply_rotary_pos_emb, + apply_rotary_pos_emb_torch, + apply_xpos_emb, + apply_xpos_emb_torch, + RotaryEmbedding, + XPosEmbedding, +) # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -69,19 +76,34 @@ def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() - # Project to ffn_hidden_size - self.dense_h_to_4h = mpu.ColumnParallelLinear( - args.hidden_size, - # GLU is a special activation that divides the dimension by a factor 2. - 2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size, - gather_output=False, - init_method=init_method, - skip_bias_add=True) + if args.glu_activation: + glu_activation = GLU_ACTIVATIONS[args.glu_activation] + else: + glu_activation = None + # Project to ffn_hidden_size + if replaces_linear(glu_activation): + self.dense_h_to_4h = glu_activation( + args.hidden_size, + args.ffn_hidden_size, + gather_output=False, + init_method=init_method) + else: + self.dense_h_to_4h = mpu.ColumnParallelLinear( + args.hidden_size, + # GLU is a special activation that divides the dimension by a factor 2. + # Only the case for non-T5 GLU. + 2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu - if args.glu_activation: - self.activation_func = GLU_ACTIVATIONS[args.glu_activation] + + if replaces_linear(glu_activation): + self.activation_func = nn.Identity() + elif glu_activation: + self.activation_func = glu_activation elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: @@ -91,6 +113,7 @@ def __init__(self, init_method, output_layer_init_method): self.dense_4h_to_h = mpu.RowParallelLinear( args.ffn_hidden_size, args.hidden_size, + bias=not replaces_linear(glu_activation), input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) @@ -101,7 +124,9 @@ def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - if self.bias_gelu_fusion: + if bias_parallel is None: + intermediate_parallel = self.activation_func(intermediate_parallel) + elif self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(intermediate_parallel, bias_parallel) else: @@ -204,6 +229,11 @@ def __init__(self, init_method, if self.position_embedding_type == PositionEmbeddingType.rotary: self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype) + elif self.position_embedding_type == PositionEmbeddingType.xpos: + self.xpos_emb = XPosEmbedding( + self.hidden_size_per_attention_head, + precision=args.params_dtype, + ) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None, alibi=None): @@ -291,16 +321,23 @@ def forward(self, hidden_states, attention_mask, layer_past=None, matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]] # Rotary embeddings - if self.position_embedding_type == PositionEmbeddingType.rotary: - apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb - + if self.position_embedding_type in [ + PositionEmbeddingType.rotary, PositionEmbeddingType.xpos]: seq_len = key_layer.shape[0] offset = 0 if layer_past is not None and layer_past.numel() > 0: offset = layer_past[0].shape[0] seq_len += offset + + if self.position_embedding_type == PositionEmbeddingType.rotary: + apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) + elif self.position_embedding_type == PositionEmbeddingType.xpos: + apply_xpos_fn = apply_xpos_emb_torch if self.bf16 else apply_xpos_emb + cos, sin, scale = self.xpos_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_xpos_fn( + query_layer, key_layer, cos, sin, scale, offset=offset) # Raw attention scores. [b * np, sq, sk] if alibi is None: @@ -406,8 +443,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None, def bias_dropout_add(x, bias, residual, prob, training): - # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor - out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) out = residual + out return out @@ -420,13 +459,13 @@ def _bias_dropout_add(x, bias, residual, prob): @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): - # type: (Tensor, Tensor, Tensor, float) -> Tensor + # type: (Tensor, Optional[Tensor], Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): - # type: (Tensor, Tensor, Tensor, float) -> Tensor + # type: (Tensor, Optional[Tensor], Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, False) @@ -579,7 +618,7 @@ def forward(self, hidden_states, attention_mask, with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, - mlp_bias.expand_as(residual), + mlp_bias.expand_as(residual) if mlp_bias is not None else None, residual, self.hidden_dropout) diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index c419d9bf6..e6d3c95bb 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -423,7 +423,7 @@ def __init__(self, input_size, output_size, bias=True, else: self.register_parameter('bias', None) - self.bias_tp_auto_sync = args.sync_tp_duplicated_parameters + self.bias_tp_auto_sync = bias and args.sync_tp_duplicated_parameters def forward(self, input_): diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 09304b1dd..8b4de875a 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -30,17 +30,38 @@ def build_tokenizer(args): # Select and instantiate the tokenizer. assert args.vocab_file is not None or args.tokenizer_type == "PretrainedFromHF" + + if hasattr(args, '_is_ul2') and args._is_ul2: + ul2_denoiser_tokens = [ + args.ul2_r_denoiser_token, + args.ul2_s_denoiser_token, + args.ul2_x_denoiser_token, + ] + else: + ul2_denoiser_tokens = [] + if args.tokenizer_type == 'BertWordPieceLowerCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, - lower_case=True, - vocab_extra_ids=args.vocab_extra_ids) + tokenizer = _BertWordPieceTokenizer( + vocab_file=args.vocab_file, + lower_case=True, + vocab_extra_ids=args.vocab_extra_ids, + ul2_denoiser_tokens=ul2_denoiser_tokens, + ) elif args.tokenizer_type == 'BertWordPieceCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, - lower_case=False, - vocab_extra_ids=args.vocab_extra_ids) + tokenizer = _BertWordPieceTokenizer( + vocab_file=args.vocab_file, + lower_case=False, + vocab_extra_ids=args.vocab_extra_ids, + ul2_denoiser_tokens=ul2_denoiser_tokens, + ) elif args.tokenizer_type == 'GPT2BPETokenizer': assert args.merge_file is not None - tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) + tokenizer = _GPT2BPETokenizer( + args.vocab_file, + args.merge_file, + vocab_extra_ids=args.vocab_extra_ids, + ul2_denoiser_tokens=ul2_denoiser_tokens, + ) elif args.tokenizer_type == "PretrainedFromHF": assert args.tokenizer_name_or_path is not None @@ -55,7 +76,11 @@ def build_tokenizer(args): if args.rank == 0: print(" vocab file is un-used. loading tokenizer from pre-trained model") - tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids) + tokenizer = _AutoTokenizer( + args.tokenizer_name_or_path, + vocab_extra_ids=args.vocab_extra_ids, + ul2_denoiser_tokens=ul2_denoiser_tokens, + ) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) @@ -155,7 +180,13 @@ def mask(self): class _BertWordPieceTokenizer(AbstractTokenizer): """Original BERT wordpiece tokenizer.""" - def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): + def __init__( + self, + vocab_file, + lower_case=True, + vocab_extra_ids=0, + ul2_denoiser_tokens=None, + ): if lower_case: name = 'BERT Lower Case' else: @@ -184,6 +215,13 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): additional_special_tokens = [] additional_special_tokens.extend( ["".format(i) for i in range(vocab_extra_ids)]) + + if ul2_denoiser_tokens is None: + ul2_denoiser_tokens = [] + self._ul2_tokens = ul2_denoiser_tokens + for value in self._ul2_tokens: + self.add_token(value) + self.add_additional_special_tokens(additional_special_tokens) def add_token(self, token): @@ -282,21 +320,63 @@ def additional_special_tokens_ids(self): def additional_special_tokens(self, value): self._additional_special_tokens = value + @property + def ul2_token_ids(self): + return [self.vocab[k] for k in self._ul2_tokens] + class _GPT2BPETokenizer(AbstractTokenizer): """Original GPT2 BPE tokenizer.""" - def __init__(self, vocab_file, merge_file): + def __init__( + self, + vocab_file, + merge_file, + vocab_extra_ids=0, + ul2_denoiser_tokens=None, + ): name = 'GPT2 BPE' super().__init__(name) + self._extra_id_tokens = [ + f"" for i in range(vocab_extra_ids)] + + if ul2_denoiser_tokens is None: + ul2_denoiser_tokens = [] + self._ul2_tokens = ul2_denoiser_tokens + + special_tokens = self._extra_id_tokens.copy() + if self._ul2_tokens: + special_tokens.extend(self._ul2_tokens) + extra_ul2_tokens = [ + '', + '', + '', + '', + '', + ] + special_tokens.extend(extra_ul2_tokens) + self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', - special_tokens=[], max_len=None) + special_tokens=special_tokens, + max_len=None) + if self._ul2_tokens: + self.sep_id = self.tokenizer.special_tokens[''] + self.mask_id = self.tokenizer.special_tokens[''] + self.pad_id = self.tokenizer.special_tokens[''] + self._bos_token_id = self.tokenizer.special_tokens[''] + self._eos_token_id = self.tokenizer.special_tokens[''] + else: + self.sep_id = None + self.mask_id = None + self.pad_id = None + self._bos_token_id = None + self._eos_token_id = None self.eod_id = self.tokenizer.encoder['<|endoftext|>'] @property def vocab_size(self): - return len(self.tokenizer.encoder) + return len(self.tokenizer) @property def vocab(self): @@ -312,22 +392,109 @@ def tokenize(self, text): def detokenize(self, token_ids): return self.tokenizer.decode(token_ids) + @property + def sep(self): + if self.sep_id is None: + raise AttributeError( + 'GPT tokenizer does not have a SEP token by default; ' + 'please add it to the `special_tokens`') + return self.sep_id + + @property + def mask(self): + if self.mask_id is None: + raise AttributeError( + 'GPT tokenizer does not have a MASK token by default; ' + 'please add it to the `special_tokens`') + return self.mask_id + + @property + def pad(self): + if self.pad_id is None: + raise AttributeError( + 'GPT tokenizer does not have a PAD token by default; ' + 'please add it to the `special_tokens`') + return self.pad_id + + @property + def bos_token_id(self): + if self._bos_token_id is None: + raise AttributeError( + 'GPT tokenizer does not have a BOS token by default; ' + 'please add it to the `special_tokens`') + return self._bos_token_id + + @property + def eos_token_id(self): + if self._eos_token_id is None: + raise AttributeError( + 'GPT tokenizer does not have a EOS token by default; ' + 'please add it to the `special_tokens`') + return self._eos_token_id + @property def eod(self): return self.eod_id + @property + def additional_special_tokens_ids(self): + return [ + self.tokenizer.special_tokens[k] for k in self._extra_id_tokens] + + @property + def ul2_tokens_ids(self): + return [self.tokenizer.special_tokens[k] for k in self._ul2_tokens] + class _AutoTokenizer(AbstractTokenizer): """AutoTokenizer for Hf Pretrained model loading.""" - def __init__(self, tokenizer_name_or_path, vocab_extra_ids): + def __init__( + self, + tokenizer_name_or_path, + vocab_extra_ids, + ul2_denoiser_tokens=None, + ): name = tokenizer_name_or_path super().__init__(name) hf_tokenizer_kwargs = {} + if vocab_extra_ids > 0: # TODO @thomasw21 we might need to concatenate to a pre-existing list? - hf_tokenizer_kwargs["additional_special_tokens"] = [f"" for _id in range(vocab_extra_ids)] - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs) + self._extra_id_tokens = [ + f"" for _id in range(vocab_extra_ids)] + hf_tokenizer_kwargs["additional_special_tokens"] = \ + self._extra_id_tokens + else: + self._extra_id_tokens = [] + + if ul2_denoiser_tokens is None: + ul2_denoiser_tokens = [] + self._ul2_tokens = ul2_denoiser_tokens + + if self._ul2_tokens: + additional_tokens = hf_tokenizer_kwargs.setdefault( + 'additional_special_tokens', []) + additional_tokens.extend(self._ul2_tokens) + + try: + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, **hf_tokenizer_kwargs) + except ValueError as e: + # Try to catch the exception raised when we have to pass + # `extra_ids` explicitly because its default does not match. + if not ( + str(e).startswith('Both extra_ids ') + and str(e).endswith( + 'the additional_special_tokens must include the ' + 'extra_ids tokens' + ) + ): + raise e + + hf_tokenizer_kwargs['extra_ids'] = len(self._extra_id_tokens) + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, **hf_tokenizer_kwargs) self.encoder = self.tokenizer.get_vocab() self.decoder = {v: k for k, v in self.encoder.items()} @@ -388,10 +555,26 @@ def eos(self): candidate = self.tokenizer.eos_token_id return self._check_token_candidate(candidate) + @property + def bos_token_id(self): + """Id of the beginning of sentence token in the vocabulary.""" + candidate = self.tokenizer.bos_token_id + return self._check_token_candidate(candidate) + + @property + def eos_token_id(self): + """Id of the end of sentence token in the vocabulary.""" + candidate = self.tokenizer.eos_token_id + return self._check_token_candidate(candidate) + @property def additional_special_tokens_ids(self): """ All the additional special tokens you may want to use (list of strings).""" - return self.tokenizer.additional_special_tokens_ids + return [self.vocab[k] for k in self._extra_id_tokens] + + @property + def ul2_token_ids(self): + return [self.vocab[k] for k in self._ul2_tokens] @staticmethod def _check_token_candidate(candidate): diff --git a/megatron/utils.py b/megatron/utils.py index 893f58dd2..f658da932 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -376,15 +376,18 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_ :param eod_token: int, token_id used to signal end of document :param partial_prefix_indices: this agument can have multiple types: - None, it signals that all prefix indices are randomly sampled. - - List[Optional[int]], its length has to be equal to mini batch size. It stores all the indices for per row prefix. - Optional means that if set to None, we allows ourselves to sample one randomly. - - List[List[Optional[int]]], it follows the following rules: + - False, it signals that prefix indices always go to the end of document. + - List[Union[int, bool, None]], its length has to be equal to mini batch size. It stores all the indices for per row prefix. + If set to None, we allows ourselves to sample one randomly. + If set to False, the current row will be attended to completely. + - List[List[Union[int, bool, None]]], it follows the following rules: - The first dimension refers to that sample, ie len(partial_prefix_indices) == len(data) - The second dimension refers to the number of document of that sample, ie len(partial_prefix_indices[b]) == (data[b] == eod_token).sum() (+1 for the last partial document). - partial_prefix_indices have to be interleaved with eod_indices, ie eod_indices[b][d-1] < partial_prefix_indices[b][d] < eod_indices[b][d] + 1 or is None. - - Optional means that if set to None, we allows ourselves to sample one randomly. + - If set to None, we allows ourselves to sample one randomly. + - If set to False, the current document will be attended to completely. :param reset_attention_mask: bool, determines if prefixes are to be per document or per row. :return Depending if prefix is per document or per row, the method returns: - List[List[int]]: prefix indices for each document in case of per document prefix @@ -393,7 +396,7 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_ micro_batch_size, seq_length = data.size() prefix_indices = [] - assert partial_prefix_indices is None or len(partial_prefix_indices) == micro_batch_size, f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}" + assert partial_prefix_indices is None or partial_prefix_indices is False or len(partial_prefix_indices) == micro_batch_size, f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}" for batch_id in range(micro_batch_size): # Prefix lm per document. if reset_attention_mask: @@ -411,14 +414,28 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_ ) prev_index = 0 - assert partial_prefix_indices is None or len(partial_prefix_indices[batch_id]) == len(eod_indices), f"The number of prefixes has to match the number of documents, complete or partial. Got {len(partial_prefix_indices[batch_id])} prefixes and {len(eod_indices)} documents" + assert partial_prefix_indices is None or partial_prefix_indices is False or len(partial_prefix_indices[batch_id]) == len(eod_indices), f"The number of prefixes has to match the number of documents, complete or partial. Got {len(partial_prefix_indices[batch_id])} prefixes and {len(eod_indices)} documents" for doc_id, eod_index in enumerate(eod_indices): - assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], list), f"Per document prefix has to store a list on indices for each row, got {partial_prefix_indices[batch_id]}" + assert partial_prefix_indices is None or partial_prefix_indices is False or isinstance(partial_prefix_indices[batch_id], list), f"Per document prefix has to store a list on indices for each row, got {partial_prefix_indices[batch_id]}" # Prefix index is defined as the first index that isn't attended by all tokens in a document - if partial_prefix_indices is None or partial_prefix_indices[batch_id][doc_id] is None: + if ( + partial_prefix_indices is None + or ( + partial_prefix_indices is not False + and partial_prefix_indices[batch_id][doc_id] is None + ) + ): # We need to randomly generate a prefix index that satisfies the interleave condition in the docstring prefix_index = randint(prev_index + 1, eod_index) + elif ( + partial_prefix_indices is False + or ( + partial_prefix_indices is not None + and partial_prefix_indices[batch_id][doc_id] is False + ) + ): + prefix_index = eod_index else: # We get value from partial_prefix_indices, and run validation on that value prefix_index = partial_prefix_indices[batch_id][doc_id] @@ -429,14 +446,28 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_ # Prefix lm per row. else: - assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], int), \ + assert partial_prefix_indices is None or partial_prefix_indices is False or isinstance(partial_prefix_indices[batch_id], int), \ f"Per document prefix has to store an int for each row, got {partial_prefix_indices[batch_id]}" # Prefix index is defined as the first index that isn't attended by all previous tokens in a document prefix_index: int - if partial_prefix_indices is None or partial_prefix_indices[batch_id] is None: + if ( + partial_prefix_indices is None + or ( + partial_prefix_indices is not False + and partial_prefix_indices[batch_id] is None + ) + ): # 0 being the first prefix index makes no sense since 0 always attends to itself, and there are no other tokens before. prefix_index = randint(1, seq_length) + elif ( + partial_prefix_indices is False + or ( + partial_prefix_indices is not None + and partial_prefix_indices[batch_id] is False + ) + ): + prefix_index = seq_length else: # We get value from partial_prefix_indices, and run validation on that value prefix_index = partial_prefix_indices[batch_id] @@ -536,4 +567,4 @@ def dump_weights(preamble, iteration, model, optimizer, tensor=None): # hostname = socket.gethostname() # pid = os.getpid() # global_rank = torch.distributed.get_rank() - #fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt" \ No newline at end of file + #fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt" diff --git a/pretrain_ul2.py b/pretrain_ul2.py new file mode 100644 index 000000000..cb7aa61cd --- /dev/null +++ b/pretrain_ul2.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain UL2""" + +import argparse +from functools import partial + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +import torch + +from megatron import ( + get_args, + get_timers, + mpu, + print_rank_0 +) +from megatron.data.dataset_utils import build_train_valid_test_datasets +from megatron.data.ul2_dataset import ( + is_decoder_only as _is_decoder_only, + is_prefix_lm as _is_prefix_lm, +) +from megatron.enums import AttnMaskType +from megatron.model.gpt_model import GPTModel, GPTModelPipe +from megatron.model.t5_model import T5Model, t5_position_ids +from megatron.training import pretrain +from megatron.utils import average_losses_across_data_parallel_group + + +def is_decoder_only(): + """Return whether we use a decoder-only model.""" + args = get_args() + return _is_decoder_only(args.ul2_model_type) + + +def is_prefix_lm(): + """Return whether we use a non-causal decoder-only model.""" + args = get_args() + return _is_prefix_lm(args.ul2_model_type) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + + see_memory_usage("Before Building Model", force=True) + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=( + None + if args.remote_device == 'none' + else args.remote_device + ), + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + + print_rank_0('building UL2 model ...') + if is_decoder_only(): + print_rank_0('Using decoder-only UL2 model.') + if args.deepspeed: + args.pretrain_causal_attention = not is_prefix_lm() + model = GPTModelPipe( + num_tokentypes=0, + parallel_output=True, + attn_mask_type=( + AttnMaskType.prefix + if is_prefix_lm() + else AttnMaskType.causal + ), + ) + # This is a hack to give us a reference to + # `get_batch_pipe` from within `training.py`. + # We need to call `model.set_batch_fn` after + # `deepspeed.initialize`. + model._megatron_batch_fn = get_batch_pipe + else: + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + prefix_lm=is_prefix_lm(), + ) + else: + assert pre_process and post_process and not args.deepspeed, \ + "Encoder-decoder model doesn't yet support pipelining" + print_rank_0('Using encoder-decoder UL2 model.') + model = T5Model(num_tokentypes=0, parallel_output=True) + see_memory_usage("After Building Model", force=True) + return model + + +def get_batch(data_iterator): + """Build the batch.""" + + if is_decoder_only(): + keys = ['text', 'labels', 'loss_mask', 'dec_mask'] + else: + keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', + 'enc_mask', 'dec_mask', 'enc_dec_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + if is_decoder_only(): + tokens = data_b['text'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() + + dec_mask = (data_b['dec_mask'] < 0.5) + dec_mask = dec_mask.unsqueeze(1) + return tokens, loss_mask, labels, dec_mask + else: + tokens_enc = data_b['text_enc'].long() + tokens_dec = data_b['text_dec'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() + + enc_mask = (data_b['enc_mask'] < 0.5) + dec_mask = (data_b['dec_mask'] < 0.5) + enc_dec_mask = (data_b['enc_dec_mask'] < 0.5) + + return tokens_enc, tokens_dec, loss_mask, labels, \ + enc_mask, dec_mask, enc_dec_mask + + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` + instead of `data_iterator`. + """ + + if is_decoder_only(): + keys = ['text', 'labels', 'loss_mask', 'dec_mask'] + else: + keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', + 'enc_mask', 'dec_mask', 'enc_dec_mask'] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + if is_decoder_only(): + tokens = data_b['text'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() + + dec_mask = (data_b['dec_mask'] < 0.5) + dec_mask = dec_mask.unsqueeze(1) + + position_ids = t5_position_ids(tokens) + return (tokens, position_ids, dec_mask), (labels, loss_mask) + else: + tokens_enc = data_b['text_enc'].long() + tokens_dec = data_b['text_dec'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() + + enc_mask = (data_b['enc_mask'] < 0.5) + dec_mask = (data_b['dec_mask'] < 0.5) + enc_dec_mask = (data_b['enc_dec_mask'] < 0.5) + + # This will probably be incorrect. Need to adapt this if + # pipelining for encoder-decoder models is ever implemented (and + # implemented similarly to the GPT model). + return (tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask), \ + (labels, loss_mask) + + +def loss_func(loss_mask, output_tensor): + if is_decoder_only(): + lm_loss_ = output_tensor + else: + lm_loss_, _ = output_tensor + + lm_loss_ = lm_loss_.float() + lm_loss = torch.sum( + lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + loss = lm_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss]) + + return loss, {'lm loss': averaged_losses[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch generator').start() + if is_decoder_only(): + (tokens, loss_mask, lm_labels, dec_mask) = get_batch(data_iterator) + else: + ( + tokens_enc, tokens_dec, loss_mask, lm_labels, + enc_mask, dec_mask, enc_dec_mask, + ) = get_batch(data_iterator) + timers('batch generator').stop() + + # Forward model lm_labels + if is_decoder_only(): + position_ids = t5_position_ids(tokens) + output_tensor = model(tokens, position_ids, dec_mask, + labels=lm_labels) + else: + output_tensor = model(tokens_enc, + tokens_dec, + enc_mask, + dec_mask, + enc_dec_mask, + tokentype_ids=None, + lm_labels=lm_labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for UL2 ...') + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + max_seq_length=args.encoder_seq_length, + max_seq_length_dec=args.decoder_seq_length, + masked_lm_prob=args.mask_prob, + short_seq_prob=args.short_seq_prob, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + dataset_type='ul2') + print_rank_0("> finished creating UL2 datasets ...") + + return train_ds, valid_ds, test_ds + + +def extra_args_provider(parser): + parser.add_argument('--_is_ul2', default=True, help=argparse.SUPPRESS) + return parser + + +if __name__ == "__main__": + + pretrain(train_valid_test_datasets_provider, model_provider, forward_step, + extra_args_provider=extra_args_provider, + args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) diff --git a/tasks/eval_harness/evaluate.py b/tasks/eval_harness/evaluate.py index 68dd649fd..2549c1fac 100644 --- a/tasks/eval_harness/evaluate.py +++ b/tasks/eval_harness/evaluate.py @@ -1,4 +1,6 @@ +import argparse from functools import reduce +import importlib from logging import logMultiprocessing import os import sys @@ -13,7 +15,6 @@ import torch.nn.functional as F from lm_eval.tasks import ALL_TASKS -from pretrain_gpt import model_provider import numpy as np import torch @@ -21,18 +22,28 @@ from megatron import print_rank_0 from megatron import get_tokenizer from megatron import mpu +from megatron.data.t5_dataset import ( + make_attention_mask_3d, + make_history_mask_3d, +) from megatron.training import setup_model_and_optimizer, get_model from megatron.mpu.mappings import gather_from_tensor_model_parallel_region -from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model +from megatron.utils import ( + get_ltor_masks_and_position_ids, + get_prefix_indices, + unwrap_model, +) from megatron.p2p_communication import recv_forward, send_forward import pickle import json from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron.model.distributed import DistributedDataParallel as LocalDDP +from megatron.model.gpt_model import GPTModelPipe from megatron.model.module import Float16Module from deepspeed.runtime.pipe import schedule +from deepspeed.runtime.pipe.engine import PipelineEngine class EvalHarnessAdaptor(GPT2LM): def __init__(self, model, tokenizer): @@ -44,6 +55,20 @@ def __init__(self, model, tokenizer): self.EOT_TOKEN_ID = tokenizer.eod self._max_length = args.seq_length + self._prefix_tokens = args.prefix_tokens + self._prefix_token_ids = [ + self.tokenizer.tokenizer.convert_tokens_to_ids(token) + for token in self._prefix_tokens + ] + + # TODO More general check for pipelined models would be desirable. + self._is_encoder_decoder = not ( + isinstance(self.model, GPTModelPipe) + or isinstance(self.model, PipelineEngine) + or hasattr(self.model, 'language_model') + and hasattr(self.model.language_model, 'add_decoder') + and not self.model.language_model.add_decoder + ) # For ds we split into mini batches and then micro batches to keep pipelining api happy. # With Megatron we just go to micro_batches directly @@ -74,6 +99,14 @@ def batch_size(self): def device(self): return self._device + def _prepend_prefix_token_ids(self, tokens): + if not self._prefix_token_ids: + pass + elif tokens and tokens[0] == self.EOT_TOKEN_ID: + tokens = tokens[:1] + self._prefix_token_ids + tokens[1:] + else: + tokens = self._prefix_token_ids + tokens + return tokens def loglikelihood(self, requests): new_reqs = [] @@ -129,14 +162,37 @@ def _collate(x): reord = utils.Reorderer(requests, _collate) for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): - inps, contlens, inplens, padding_length = [], [], [], None + inps, ctxlens, contlens, inplens, padding_length = [], [], [], [], None for _, context_enc, continuation_enc in chunk: # when too long to fit in context, truncate from the left + context_len = len(context_enc) + len(self._prefix_tokens) + total_len = context_len + len(continuation_enc) + + context_num_truncated = max( + total_len - self.max_length + 1, 0) + # Need actual truncated length of context here + # (without prefix tokens). + continuation_num_truncated = max( + context_num_truncated - len(context_enc), 0) + + context_enc = context_enc[context_num_truncated:] + continuation_enc = \ + continuation_enc[continuation_num_truncated:] + + # Add prefix token after truncation. + context_enc = self._prepend_prefix_token_ids(context_enc) + inp = torch.tensor( - (context_enc + continuation_enc)[-(self.max_length + 1):][:-1] - , dtype=torch.long).to(self.device) + context_enc + continuation_enc, + dtype=torch.long, + ).to(self.device) inplen, = inp.shape + if len(continuation_enc) == 0: + ctxlen = 1 + else: + ctxlen = max(context_len - context_num_truncated, 1) + cont = continuation_enc # since in _collate we make sure length is descending, the longest is always the first one. @@ -153,8 +209,9 @@ def _collate(x): contlens.append(cont) inplens.append(inplen) + ctxlens.append(ctxlen) - logits = self._model_call(torch.cat(inps, dim=0)) + logits = self._model_call((torch.cat(inps, dim=0), ctxlens)) res_len += len(chunk) if logits is not None: multi_logits = F.log_softmax(logits, dim=-1).cpu() # [batch, seq, vocab] @@ -185,27 +242,94 @@ def _collate(x): def create_model_inputs(self, tokens): args = get_args() - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, - self.EOT_TOKEN_ID, - args.reset_position_ids, - args.reset_attention_mask, - args.eod_mask_loss, - prefix_indices=None, - loss_on_targets_only=False) - - return (tokens, position_ids, attention_mask), (tokens, loss_mask) + if isinstance(tokens, tuple) and len(tokens) == 2: + tokens, ctxlens = tokens + else: + ctxlens = None + + # TODO Handle encoder-only + if not self._is_encoder_decoder: + if args.prefix_lm and ctxlens is not None: + prefix_indices = get_prefix_indices( + tokens, + self.EOT_TOKEN_ID, + partial_prefix_indices=ctxlens, + reset_attention_mask=args.reset_attention_mask + ) + else: + if args.prefix_lm: + print( + 'Warning: requested PrefixLM inputs, but cannot determine ' + 'prefix length – prefix is empty.' + ) + prefix_indices = None + + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + self.EOT_TOKEN_ID, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=prefix_indices, + loss_on_targets_only=False) + return (tokens, position_ids, attention_mask), (tokens, loss_mask) + else: + assert ctxlens is not None + + # Split tokens to separate encoder and decoder input. + # No BOS token used with eval harness, so we do not need to + # worry about the decoder receiving it in mind of the split. + enc_tokens = torch.stack([ + F.pad(tok[:ctxlen], (0, len(tok) - ctxlen), value=0) + for (tok, ctxlen) in zip(tokens, ctxlens) + ]) + dec_tokens = torch.stack([ + F.pad(tok[ctxlen:], (0, ctxlen), value=0) + for (tok, ctxlen) in zip(tokens, ctxlens) + ]) + + enc_attn_mask = make_attention_mask_3d(enc_tokens, enc_tokens) + dec_attn_mask = make_attention_mask_3d(dec_tokens, dec_tokens) + dec_attn_mask *= make_history_mask_3d(dec_tokens) + enc_dec_attn_mask = make_attention_mask_3d(dec_tokens, enc_tokens) + + loss_mask = torch.ones( + dec_tokens.shape[:2], + device=dec_tokens.device, + dtype=dec_tokens.dtype, + ) + for (i, ctxlen) in enumerate(ctxlens): + if ctxlen != 0: + loss_mask[i, -ctxlen:] = 0 + + return ( + (enc_tokens, dec_tokens, enc_attn_mask, + dec_attn_mask, enc_dec_attn_mask), + (dec_tokens, loss_mask) + ) def _model_call(self, inps): args = get_args() + if isinstance(inps, tuple) and len(inps) == 2: + inps, ctxlens = inps + else: + ctxlens = None + if args.deepspeed: self.model.set_batch_fn(self.create_model_inputs) # round up to multiple of micro_batch_size new_size = ((len(inps) + args.micro_batch_size-1) // args.micro_batch_size) * args.micro_batch_size padded = F.pad(inps, (0, 0, 0, new_size-len(inps)), value = 0) + ctxlens = ctxlens + [1] * (new_size - len(ctxlens)) # dummy data iterator for pipelining. - data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size))) + data_iterator = list(( + (torch.stack(inp), ctxlen) + for (inp, ctxlen) in zip( + utils.chunks(padded, args.micro_batch_size), + utils.chunks(ctxlens, args.micro_batch_size), + ) + )) self.model.micro_batches = len(data_iterator) if self.adaptive_seq_len: @@ -239,7 +363,9 @@ def _model_call(self, inps): # Forward pass through the model. unwrapped_model = unwrap_model(self.model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) - output = self.model(*self.create_model_inputs(inps)[0]) + output = self.model(*self.create_model_inputs((inps, ctxlens))[0]) + if isinstance(output, tuple): + output = output[0] send_forward(output) if mpu.is_pipeline_last_stage(): @@ -260,7 +386,7 @@ def tokenizer_encode(self, text): from megatron.initialize import initialize_megatron import megatron -from tools.convert_checkpoint.deepspeed_checkpoint import DeepSpeedCheckpoint +from deepspeed.checkpoint.deepspeed_checkpoint import DeepSpeedCheckpoint from tools.convert_checkpoint.deepspeed_to_megatron import _create_rank_checkpoint def override_args(args, override_args, skip_keys, skip_if_specified_keys): @@ -293,12 +419,18 @@ def load_ds_checkpoint_and_setup_megatron(args): if not os.path.exists(args.load): raise ValueError(f"checkpoint path {args.load} doesn't exit") - ds_checkpoint = DeepSpeedCheckpoint(args.load, - tp_degree=args.tensor_model_parallel_size, - pp_degree=args.pipeline_model_parallel_size) - - - cp_args = ds_checkpoint.get_args() + try: + is_ds_cp = True + ds_checkpoint = DeepSpeedCheckpoint(args.load, + tp_degree=args.tensor_model_parallel_size, + pp_degree=args.pipeline_model_parallel_size) + + cp_args = ds_checkpoint.get_args() + except (AssertionError, ZeroDivisionError): + is_ds_cp = False + cp_path = os.path.join(args.load, 'mp_rank_00', 'model_optim_rng.pt') + state_dict = torch.load(cp_path, map_location='cpu') + cp_args = state_dict['args'] # Merge the current args with the checkpoint args. skip_keys = [ 'abort_on_unmet_fused_kernel_constraints', @@ -340,9 +472,16 @@ def load_ds_checkpoint_and_setup_megatron(args): # Initializing megatron will update eg. tokenizer size. Override again. override_args(args, cp_args, skip_keys, skip_if_specified) + model_provider = importlib.import_module( + f'pretrain_{args.model_name}', + ).model_provider + # print final arguments. _print_args(args) - if args.deepspeed: + if not is_ds_cp: + model = get_model(model_provider)[0] + model.load_state_dict(state_dict['model'], strict=True) + elif args.deepspeed: # Hack #3: # Loading pipelined models in deepspeed with different TP than it was originally trained on fails @@ -382,14 +521,28 @@ def tasks_args(parser): """Provide extra arguments required for tasks.""" group = parser.add_argument_group(title='Evaluation options') + group.add_argument('--model_name', type=str, default="gpt", + help=( + 'Which model architecture to use (must exist as ' + '`pretrain_{model_name}.py` script).' + )) group.add_argument('--task_list', type=str, default = "all", help='Either "all" or comma separated list of tasks.') group.add_argument('--results_path', type=str, default = results_path_default, help='Path to where the results will be stored.') group.add_argument('--adaptive_seq_len', default = False, action='store_true', help='Should the sequence length be adapted to the batch during evaluation, if in fp16 the results will be slightly different due to numerical errors but greatly speed up evaluation.') group.add_argument('--eval_fp32', default = False, action='store_true', help='Should the evaluation run in fp32') group.add_argument('--intermed_results', default = False, action='store_true', help='Whether to print & write intermediate results for each task') + group.add_argument('--num_fewshot', type=int, default=0, + help='How many examples to show.') group.add_argument('--bootstrap_iters', type=int, default=100000, help='How many iterations to use for stderr estimation') group.add_argument('--micro_bs_multiplier', type=int, default=1, help='Increase the global batch size to remove bubble when pipeline parallel') + group.add_argument('--prefix_lm', action='store_true', + help='Whether to adjust attention masks for a PrefixLM ' + 'decoder-only model.') + group.add_argument('--prefix_tokens', type=str, nargs='*', default=[], + help='Tokens to add at the front of the input sequence.') + # Automatically add UL2 tokens. + group.add_argument('--_is_ul2', default=True, help=argparse.SUPPRESS) return parser from megatron.global_vars import _parse_args @@ -411,13 +564,15 @@ def main(): task_list = ALL_TASKS if args.task_list == 'all' else args.task_list.split(',') task_dict = tasks.get_task_dict(task_list) - model.module.activation_checkpoint_interval = 0 + if hasattr(model, 'module'): + model.module.activation_checkpoint_interval = 0 model._compute_loss = False model.fwd_outputs = [] tokenizer = get_tokenizer() adaptor = EvalHarnessAdaptor(model, tokenizer) - + num_fewshot = args.num_fewshot + if args.intermed_results: global_results = {"results": {}, "versions": {}} timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') @@ -426,7 +581,7 @@ def main(): # Backup file in case of interruption during writing results_path_backup = args.results_path.replace(".json", f"_lm-eval_{iteration_id}_{timestamp}_backup.json") for task_name, task in task_dict.items(): - results = evaluator.evaluate(adaptor, {task_name: task}, False, 0, None, bootstrap_iters=args.bootstrap_iters) + results = evaluator.evaluate(adaptor, {task_name: task}, False, num_fewshot, None, bootstrap_iters=args.bootstrap_iters) global_results["results"] = {**global_results["results"], **results["results"]} global_results["versions"] = {**global_results["versions"], **results["versions"]} if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0: @@ -436,7 +591,7 @@ def main(): with open(results_path_backup, 'w') as outfile: json.dump(global_results, outfile, indent=4) else: - global_results = evaluator.evaluate(adaptor, task_dict, False, 0, None, bootstrap_iters=args.bootstrap_iters) + global_results = evaluator.evaluate(adaptor, task_dict, False, num_fewshot, None, bootstrap_iters=args.bootstrap_iters) if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0: print(json.dumps(global_results, indent=2)) with open(args.results_path, 'w') as outfile: diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 74e5ca7c9..385c22fd3 100755 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -4,7 +4,7 @@ import os import torch from collections import OrderedDict -from .deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint +from deepspeed.checkpoint.deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint MODEL_KEY = 'model' ARGS_KEY = 'args'