Skip to content
59 changes: 45 additions & 14 deletions finetune_t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import torch

from pretrain_gpt import get_batch_pipe as get_batch_pipe_gpt
from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.gpt_dataset import build_dataset_group as build_dataset_group_gpt
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
Expand Down Expand Up @@ -48,6 +50,14 @@ def model_provider(pre_process=True, post_process=True):
return model


def fast_normalize(loss_mask: torch.Tensor):
"""
Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
"""
_, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
counts = torch.gather(dim=0, index=inverse_indices, input=counts)
return loss_mask / counts

def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
Expand All @@ -57,6 +67,9 @@ def get_batch_pipe(data):
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
"""
if 'text' in data:
return get_batch_pipe_gpt(data)

args = get_args()
tokenizer = get_tokenizer()

Expand Down Expand Up @@ -95,6 +108,10 @@ def get_batch_pipe(data):
segment_ids=segment_ids.long(),
)

if args.norm_target_loss:
loss_mask = loss_mask.view(-1)
loss_mask = fast_normalize(loss_mask)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

Expand Down Expand Up @@ -142,20 +159,34 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
if "merged-meg-ds_v3_pii" in paths[0]:
d = build_dataset_group_gpt(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
else:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
Expand Down
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ def __call__(self, parser, args, values, option_string=None):
help='Mask loss for the end of document tokens.')
group.add_argument('--loss-on-targets-only', action='store_true',
help='Mask loss on input sequence.')
group.add_argument('--norm-target-loss', action='store_true',
help='Normalize the loss per target. Used for multi-task finetuning with packing.')
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true",
help='Some objectives require us to sample loss_mask. This might introduce bias towards '
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""GPT-2 model."""

from functools import partial
import torch

from megatron import get_args
Expand Down Expand Up @@ -186,6 +185,10 @@ def CrossEntropy(output, labels):
else:
average_tokens_per_sample = sequence_length
expected_number_of_tokens = average_tokens_per_sample * micro_batch_size
elif args.norm_target_loss and (loss_mask.dim() == 1):
expected_num_of_target_seqs = loss_mask.sum()
loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs
return loss
else:
expected_number_of_tokens = loss_mask.sum()

Expand Down
5 changes: 3 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def pretrain(train_valid_test_dataset_provider,
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')

iteration = 0
iteration = args.iteration
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
Expand All @@ -199,7 +199,8 @@ def pretrain(train_valid_test_dataset_provider,
iterator, model,
iteration, False, data_group_name=name)

if args.save and iteration != 0:
# Do not save if the iteration has not changed
if args.save and iteration != args.iteration:
save_checkpoint(iteration, model, optimizer, lr_scheduler)

if args.do_test:
Expand Down