diff --git a/experiment_plan.md b/experiment_plan.md new file mode 100644 index 0000000000..3232152a71 --- /dev/null +++ b/experiment_plan.md @@ -0,0 +1,34 @@ +# Experiment Plan + - [x] CDF图 **DDL: 9月5日** + XPUTimer要使用CUPTI版本的 + - [x] FSDP: + - [x] Megatron: + - [x] DLRM: + - [ ] 使用FSDP、Megatron、DLRM构造任务100个 **DDL: 9月8日前构造完脚本,9月12日前跑完50个,9月16日前跑完100个** + 20个有问题的,80个没有问题的,80个没问题的可以直接跑,然后看CDF、空泡率有没有不正常的,这种就是假阳性。 + - [ ] FSDP + - [ ] 模型大小, 7B, 13B, 70B + - [ ] GPU规模, 8, 16, 32, 48, 64 + - [ ] Megatron + - [ ] 模型大小, 7B, 13B, 70B + - [ ] GPU规模, 8, 16, 32, 48, 64 + - [ ] DLRM + - [ ] 模型大小, + - [ ] GPU规模, + 10 + +| GPU | 8 | 16 | 32 | 64 | +| :--------------: | :--: | :--: | :--: | :--: | +| Megatron | 8 | 8 | 8 | 4 | +| FSDP | 8 | 8 | 8 | 4 | +| DLRM | 8 | 8 | 8 | 0 | +| Megatron-badsync | 1 | 1 | 1 | 1 | +| FSDP-Mem | 1 | 1 | 1 | 1 | +| FSDP-OPshape | 1 | 1 | 1 | 1 | +| FSDP-Dataloader | 1 | 1 | 1 | 1 | +| DLRM-badsync | 1 | 1 | 1 | 1 | + + - [x] Greyhound: [https://github.com/wutianyuan1/Greyhound](https://github.com/wutianyuan1/Greyhound) **DDL: 9月12日前** + - [x] 首先是overhead,这里的XPU timer使用CUPTI版本 + - [x] 其次是能不能检测慢 + - [x] 以及FSDP、megatron、DLRM的适配性 diff --git a/xpu_timer/experiments/dlrm/dlrm_s_pytorch.py b/xpu_timer/experiments/dlrm/dlrm_s_pytorch.py new file mode 100644 index 0000000000..e1b42b7466 --- /dev/null +++ b/xpu_timer/experiments/dlrm/dlrm_s_pytorch.py @@ -0,0 +1,1908 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Description: an implementation of a deep learning recommendation model (DLRM) +# The model input consists of dense and sparse features. The former is a vector +# of floating point values. The latter is a list of sparse indices into +# embedding tables, which consist of vectors of floating point values. +# The selected vectors are passed to mlp networks denoted by triangles, +# in some cases the vectors are interacted through operators (Ops). +# +# output: +# vector of values +# model: | +# /\ +# /__\ +# | +# _____________________> Op <___________________ +# / | \ +# /\ /\ /\ +# /__\ /__\ ... /__\ +# | | | +# | Op Op +# | ____/__\_____ ____/__\____ +# | |_Emb_|____|__| ... |_Emb_|__|___| +# input: +# [ dense features ] [sparse indices] , ..., [sparse indices] +# +# More precise definition of model layers: +# 1) fully connected layers of an mlp +# z = f(y) +# y = Wx + b +# +# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk]) +# z = Op(e1,...,ek) +# obtain vectors e1=E[:,p1], ..., ek=E[:,pk] +# +# 3) Operator Op can be one of the following +# Sum(e1,...,ek) = e1 + ... + ek +# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek] +# Cat(e1,...,ek) = [e1', ..., ek']' +# where ' denotes transpose operation +# +# References: +# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang, +# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu, +# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii, +# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko, +# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong, +# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and +# Recommendation Systems", CoRR, arXiv:1906.00091, 2019 + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse + +# miscellaneous +import builtins +import datetime +import json +import sys +import time + +# onnx +# The onnx import causes deprecation warnings every time workers +# are spawned during testing. So, we filter out those warnings. +import warnings + +# data generation +import dlrm_data_pytorch as dp + +# For distributed run +import extend_distributed as ext_dist +import mlperf_logger + +# numpy +import numpy as np +import optim.rwsadagrad as RowWiseSparseAdagrad +import sklearn.metrics + +# pytorch +import torch +import torch.nn as nn + +# dataloader +try: + from internals import fbDataLoader, fbInputBatchFormatter + + has_internal_libs = True +except ImportError: + has_internal_libs = False + +from torch._ops import ops +from torch.autograd.profiler import record_function +from torch.nn.parallel.parallel_apply import parallel_apply +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import gather, scatter +from torch.nn.parameter import Parameter +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.tensorboard import SummaryWriter + +# mixed-dimension trick +from tricks.md_embedding_bag import md_solver, PrEmbeddingBag + +# quotient-remainder trick +from tricks.qr_embedding_bag import QREmbeddingBag + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + try: + import onnx + except ImportError as error: + print("Unable to import onnx. ", error) + +# from torchviz import make_dot +# import torch.nn.functional as Functional +# from torch.nn.parameter import Parameter + +exc = getattr(builtins, "IOError", "FileNotFoundError") + + +def time_wrap(use_gpu): + if use_gpu: + torch.cuda.synchronize() + return time.time() + + +def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1): + with record_function("DLRM forward"): + if use_gpu: # .cuda() + # lS_i can be either a list of tensors or a stacked tensor. + # Handle each case below: + if ndevices == 1: + lS_i = ( + [S_i.to(device) for S_i in lS_i] + if isinstance(lS_i, list) + else lS_i.to(device) + ) + lS_o = ( + [S_o.to(device) for S_o in lS_o] + if isinstance(lS_o, list) + else lS_o.to(device) + ) + return dlrm(X.to(device), lS_o, lS_i) + + +def loss_fn_wrap(Z, T, use_gpu, device): + with record_function("DLRM loss compute"): + if args.loss_function == "mse" or args.loss_function == "bce": + return dlrm.loss_fn(Z, T.to(device)) + elif args.loss_function == "wbce": + loss_ws_ = dlrm.loss_ws[T.data.view(-1).long()].view_as(T).to(device) + loss_fn_ = dlrm.loss_fn(Z, T.to(device)) + loss_sc_ = loss_ws_ * loss_fn_ + return loss_sc_.mean() + + +# The following function is a wrapper to avoid checking this multiple times in th +# loop below. +def unpack_batch(b): + if args.data_generation == "internal": + return fbInputBatchFormatter(b, args.data_size) + else: + # Experiment with unweighted samples + return b[0], b[1], b[2], b[3], torch.ones(b[3].size()), None + + +class LRPolicyScheduler(_LRScheduler): + def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): + self.num_warmup_steps = num_warmup_steps + self.decay_start_step = decay_start_step + self.decay_end_step = decay_start_step + num_decay_steps + self.num_decay_steps = num_decay_steps + + if self.decay_start_step < self.num_warmup_steps: + sys.exit("Learning rate warmup must finish before the decay starts") + + super(LRPolicyScheduler, self).__init__(optimizer) + + def get_lr(self): + step_count = self._step_count + if step_count < self.num_warmup_steps: + # warmup + scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps + lr = [base_lr * scale for base_lr in self.base_lrs] + self.last_lr = lr + elif self.decay_start_step <= step_count and step_count < self.decay_end_step: + # decay + decayed_steps = step_count - self.decay_start_step + scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2 + min_lr = 0.0000001 + lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs] + self.last_lr = lr + else: + if self.num_decay_steps > 0: + # freeze at last, either because we're after decay + # or because we're between warmup and decay + lr = self.last_lr + else: + # do not adjust + lr = self.base_lrs + return lr + + +### define dlrm in PyTorch ### +class DLRM_Net(nn.Module): + def create_mlp(self, ln, sigmoid_layer): + # build MLP layer by layer + layers = nn.ModuleList() + for i in range(0, ln.size - 1): + n = ln[i] + m = ln[i + 1] + + # construct fully connected operator + LL = nn.Linear(int(n), int(m), bias=True) + + # initialize the weights + # with torch.no_grad(): + # custom Xavier input, output or two-sided fill + mean = 0.0 # std_dev = np.sqrt(variance) + std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) + W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) + std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) + bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) + # approach 1 + LL.weight.data = torch.tensor(W, requires_grad=True) + LL.bias.data = torch.tensor(bt, requires_grad=True) + # approach 2 + # LL.weight.data.copy_(torch.tensor(W)) + # LL.bias.data.copy_(torch.tensor(bt)) + # approach 3 + # LL.weight = Parameter(torch.tensor(W),requires_grad=True) + # LL.bias = Parameter(torch.tensor(bt),requires_grad=True) + layers.append(LL) + + # construct sigmoid or relu operator + if i == sigmoid_layer: + layers.append(nn.Sigmoid()) + else: + layers.append(nn.ReLU()) + + # approach 1: use ModuleList + # return layers + # approach 2: use Sequential container to wrap all layers + return torch.nn.Sequential(*layers) + + def create_emb(self, m, ln, weighted_pooling=None): + emb_l = nn.ModuleList() + v_W_l = [] + for i in range(0, ln.size): + if ext_dist.my_size > 1: + if i not in self.local_emb_indices: + continue + n = ln[i] + + # construct embedding operator + if self.qr_flag and n > self.qr_threshold: + EE = QREmbeddingBag( + n, + m, + self.qr_collisions, + operation=self.qr_operation, + mode="sum", + sparse=True, + ) + elif self.md_flag and n > self.md_threshold: + base = max(m) + _m = m[i] if n > self.md_threshold else base + EE = PrEmbeddingBag(n, _m, base) + # use np initialization as below for consistency... + W = np.random.uniform( + low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, _m) + ).astype(np.float32) + EE.embs.weight.data = torch.tensor(W, requires_grad=True) + else: + EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) + # initialize embeddings + # nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) + W = np.random.uniform( + low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) + ).astype(np.float32) + # approach 1 + EE.weight.data = torch.tensor(W, requires_grad=True) + # approach 2 + # EE.weight.data.copy_(torch.tensor(W)) + # approach 3 + # EE.weight = Parameter(torch.tensor(W),requires_grad=True) + if weighted_pooling is None: + v_W_l.append(None) + else: + v_W_l.append(torch.ones(n, dtype=torch.float32)) + emb_l.append(EE) + return emb_l, v_W_l + + def __init__( + self, + m_spa=None, + ln_emb=None, + ln_bot=None, + ln_top=None, + arch_interaction_op=None, + arch_interaction_itself=False, + sigmoid_bot=-1, + sigmoid_top=-1, + sync_dense_params=True, + loss_threshold=0.0, + ndevices=-1, + qr_flag=False, + qr_operation="mult", + qr_collisions=0, + qr_threshold=200, + md_flag=False, + md_threshold=200, + weighted_pooling=None, + loss_function="bce", + ): + super(DLRM_Net, self).__init__() + + if ( + (m_spa is not None) + and (ln_emb is not None) + and (ln_bot is not None) + and (ln_top is not None) + and (arch_interaction_op is not None) + ): + # save arguments + self.ndevices = ndevices + self.output_d = 0 + self.parallel_model_batch_size = -1 + self.parallel_model_is_not_prepared = True + self.arch_interaction_op = arch_interaction_op + self.arch_interaction_itself = arch_interaction_itself + self.sync_dense_params = sync_dense_params + self.loss_threshold = loss_threshold + self.loss_function = loss_function + if weighted_pooling is not None and weighted_pooling != "fixed": + self.weighted_pooling = "learned" + else: + self.weighted_pooling = weighted_pooling + # create variables for QR embedding if applicable + self.qr_flag = qr_flag + if self.qr_flag: + self.qr_collisions = qr_collisions + self.qr_operation = qr_operation + self.qr_threshold = qr_threshold + # create variables for MD embedding if applicable + self.md_flag = md_flag + if self.md_flag: + self.md_threshold = md_threshold + + # If running distributed, get local slice of embedding tables + if ext_dist.my_size > 1: + n_emb = len(ln_emb) + if n_emb < ext_dist.my_size: + sys.exit( + "only (%d) sparse features for (%d) devices, table partitions will fail" + % (n_emb, ext_dist.my_size) + ) + self.n_global_emb = n_emb + self.n_local_emb, self.n_emb_per_rank = ext_dist.get_split_lengths( + n_emb + ) + self.local_emb_slice = ext_dist.get_my_slice(n_emb) + self.local_emb_indices = list(range(n_emb))[self.local_emb_slice] + + # create operators + if ndevices <= 1: + self.emb_l, w_list = self.create_emb(m_spa, ln_emb, weighted_pooling) + if self.weighted_pooling == "learned": + self.v_W_l = nn.ParameterList() + for w in w_list: + self.v_W_l.append(Parameter(w)) + else: + self.v_W_l = w_list + self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) + self.top_l = self.create_mlp(ln_top, sigmoid_top) + + # quantization + self.quantize_emb = False + self.emb_l_q = [] + self.quantize_bits = 32 + + # specify the loss function + if self.loss_function == "mse": + self.loss_fn = torch.nn.MSELoss(reduction="mean") + elif self.loss_function == "bce": + self.loss_fn = torch.nn.BCELoss(reduction="mean") + elif self.loss_function == "wbce": + self.loss_ws = torch.tensor( + np.fromstring(args.loss_weights, dtype=float, sep="-") + ) + self.loss_fn = torch.nn.BCELoss(reduction="none") + else: + sys.exit( + "ERROR: --loss-function=" + self.loss_function + " is not supported" + ) + + def apply_mlp(self, x, layers): + # approach 1: use ModuleList + # for layer in layers: + # x = layer(x) + # return x + # approach 2: use Sequential container to wrap all layers + return layers(x) + + def apply_emb(self, lS_o, lS_i, emb_l, v_W_l): + # WARNING: notice that we are processing the batch at once. We implicitly + # assume that the data is laid out such that: + # 1. each embedding is indexed with a group of sparse indices, + # corresponding to a single lookup + # 2. for each embedding the lookups are further organized into a batch + # 3. for a list of embedding tables there is a list of batched lookups + + ly = [] + for k, sparse_index_group_batch in enumerate(lS_i): + sparse_offset_group_batch = lS_o[k] + + # embedding lookup + # We are using EmbeddingBag, which implicitly uses sum operator. + # The embeddings are represented as tall matrices, with sum + # happening vertically across 0 axis, resulting in a row vector + # E = emb_l[k] + + if v_W_l[k] is not None: + per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch) + else: + per_sample_weights = None + + if self.quantize_emb: + s1 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement() + s2 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement() + print("quantized emb sizes:", s1, s2) + + if self.quantize_bits == 4: + QV = ops.quantized.embedding_bag_4bit_rowwise_offsets( + self.emb_l_q[k], + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + elif self.quantize_bits == 8: + QV = ops.quantized.embedding_bag_byte_rowwise_offsets( + self.emb_l_q[k], + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + + ly.append(QV) + else: + E = emb_l[k] + V = E( + sparse_index_group_batch, + sparse_offset_group_batch, + per_sample_weights=per_sample_weights, + ) + + ly.append(V) + + # print(ly) + return ly + + # using quantizing functions from caffe2/aten/src/ATen/native/quantized/cpu + def quantize_embedding(self, bits): + n = len(self.emb_l) + self.emb_l_q = [None] * n + for k in range(n): + if bits == 4: + self.emb_l_q[k] = ops.quantized.embedding_bag_4bit_prepack( + self.emb_l[k].weight + ) + elif bits == 8: + self.emb_l_q[k] = ops.quantized.embedding_bag_byte_prepack( + self.emb_l[k].weight + ) + else: + return + self.emb_l = None + self.quantize_emb = True + self.quantize_bits = bits + + def interact_features(self, x, ly): + if self.arch_interaction_op == "dot": + # concatenate dense and sparse features + (batch_size, d) = x.shape + T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) + # perform a dot product + Z = torch.bmm(T, torch.transpose(T, 1, 2)) + # append dense feature with the interactions (into a row vector) + # approach 1: all + # Zflat = Z.view((batch_size, -1)) + # approach 2: unique + _, ni, nj = Z.shape + # approach 1: tril_indices + # offset = 0 if self.arch_interaction_itself else -1 + # li, lj = torch.tril_indices(ni, nj, offset=offset) + # approach 2: custom + offset = 1 if self.arch_interaction_itself else 0 + li = torch.tensor([i for i in range(ni) for j in range(i + offset)]) + lj = torch.tensor([j for i in range(nj) for j in range(i + offset)]) + Zflat = Z[:, li, lj] + # concatenate dense features and interactions + R = torch.cat([x] + [Zflat], dim=1) + elif self.arch_interaction_op == "cat": + # concatenation features (into a row vector) + R = torch.cat([x] + ly, dim=1) + else: + sys.exit( + "ERROR: --arch-interaction-op=" + + self.arch_interaction_op + + " is not supported" + ) + + return R + + def forward(self, dense_x, lS_o, lS_i): + if ext_dist.my_size > 1: + # multi-node multi-device run + return self.distributed_forward(dense_x, lS_o, lS_i) + elif self.ndevices <= 1: + # single device run + return self.sequential_forward(dense_x, lS_o, lS_i) + else: + # single-node multi-device run + return self.parallel_forward(dense_x, lS_o, lS_i) + + def distributed_forward(self, dense_x, lS_o, lS_i): + batch_size = dense_x.size()[0] + # WARNING: # of ranks must be <= batch size in distributed_forward call + if batch_size < ext_dist.my_size: + sys.exit( + "ERROR: batch_size (%d) must be larger than number of ranks (%d)" + % (batch_size, ext_dist.my_size) + ) + if batch_size % ext_dist.my_size != 0: + sys.exit( + "ERROR: batch_size %d can not split across %d ranks evenly" + % (batch_size, ext_dist.my_size) + ) + + dense_x = dense_x[ext_dist.get_my_slice(batch_size)] + lS_o = lS_o[self.local_emb_slice] + lS_i = lS_i[self.local_emb_slice] + + if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)): + sys.exit( + "ERROR: corrupted model input detected in distributed_forward call" + ) + + # embeddings + with record_function("DLRM embedding forward"): + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) + + # WARNING: Note that at this point we have the result of the embedding lookup + # for the entire batch on each rank. We would like to obtain partial results + # corresponding to all embedding lookups, but part of the batch on each rank. + # Therefore, matching the distribution of output of bottom mlp, so that both + # could be used for subsequent interactions on each device. + if len(self.emb_l) != len(ly): + sys.exit("ERROR: corrupted intermediate result in distributed_forward call") + + a2a_req = ext_dist.alltoall(ly, self.n_emb_per_rank) + + with record_function("DLRM bottom nlp forward"): + x = self.apply_mlp(dense_x, self.bot_l) + + ly = a2a_req.wait() + ly = list(ly) + + # interactions + with record_function("DLRM interaction forward"): + z = self.interact_features(x, ly) + + # top mlp + with record_function("DLRM top nlp forward"): + p = self.apply_mlp(z, self.top_l) + + # clamp output if needed + if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: + z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) + else: + z = p + + return z + + def sequential_forward(self, dense_x, lS_o, lS_i): + # process dense features (using bottom mlp), resulting in a row vector + x = self.apply_mlp(dense_x, self.bot_l) + # debug prints + # print("intermediate") + # print(x.detach().cpu().numpy()) + + # process sparse features(using embeddings), resulting in a list of row vectors + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) + # for y in ly: + # print(y.detach().cpu().numpy()) + + # interact features (dense and sparse) + z = self.interact_features(x, ly) + # print(z.detach().cpu().numpy()) + + # obtain probability of a click (using top mlp) + p = self.apply_mlp(z, self.top_l) + + # clamp output if needed + if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: + z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) + else: + z = p + + return z + + def parallel_forward(self, dense_x, lS_o, lS_i): + ### prepare model (overwrite) ### + # WARNING: # of devices must be >= batch size in parallel_forward call + batch_size = dense_x.size()[0] + ndevices = min(self.ndevices, batch_size, len(self.emb_l)) + device_ids = range(ndevices) + # WARNING: must redistribute the model if mini-batch size changes(this is common + # for last mini-batch, when # of elements in the dataset/batch size is not even + if self.parallel_model_batch_size != batch_size: + self.parallel_model_is_not_prepared = True + + if self.parallel_model_is_not_prepared or self.sync_dense_params: + # replicate mlp (data parallelism) + self.bot_l_replicas = replicate(self.bot_l, device_ids) + self.top_l_replicas = replicate(self.top_l, device_ids) + self.parallel_model_batch_size = batch_size + + if self.parallel_model_is_not_prepared: + # distribute embeddings (model parallelism) + t_list = [] + w_list = [] + for k, emb in enumerate(self.emb_l): + d = torch.device("cuda:" + str(k % ndevices)) + t_list.append(emb.to(d)) + if self.weighted_pooling == "learned": + w_list.append(Parameter(self.v_W_l[k].to(d))) + elif self.weighted_pooling == "fixed": + w_list.append(self.v_W_l[k].to(d)) + else: + w_list.append(None) + self.emb_l = nn.ModuleList(t_list) + if self.weighted_pooling == "learned": + self.v_W_l = nn.ParameterList(w_list) + else: + self.v_W_l = w_list + self.parallel_model_is_not_prepared = False + + ### prepare input (overwrite) ### + # scatter dense features (data parallelism) + # print(dense_x.device) + dense_x = scatter(dense_x, device_ids, dim=0) + # distribute sparse features (model parallelism) + if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)): + sys.exit("ERROR: corrupted model input detected in parallel_forward call") + + t_list = [] + i_list = [] + for k, _ in enumerate(self.emb_l): + d = torch.device("cuda:" + str(k % ndevices)) + t_list.append(lS_o[k].to(d)) + i_list.append(lS_i[k].to(d)) + lS_o = t_list + lS_i = i_list + + ### compute results in parallel ### + # bottom mlp + # WARNING: Note that the self.bot_l is a list of bottom mlp modules + # that have been replicated across devices, while dense_x is a tuple of dense + # inputs that has been scattered across devices on the first (batch) dimension. + # The output is a list of tensors scattered across devices according to the + # distribution of dense_x. + x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids) + # debug prints + # print(x) + + # embeddings + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) + # debug prints + # print(ly) + + # butterfly shuffle (implemented inefficiently for now) + # WARNING: Note that at this point we have the result of the embedding lookup + # for the entire batch on each device. We would like to obtain partial results + # corresponding to all embedding lookups, but part of the batch on each device. + # Therefore, matching the distribution of output of bottom mlp, so that both + # could be used for subsequent interactions on each device. + if len(self.emb_l) != len(ly): + sys.exit("ERROR: corrupted intermediate result in parallel_forward call") + + t_list = [] + for k, _ in enumerate(self.emb_l): + d = torch.device("cuda:" + str(k % ndevices)) + y = scatter(ly[k], device_ids, dim=0) + t_list.append(y) + # adjust the list to be ordered per device + ly = list(map(lambda y: list(y), zip(*t_list))) + # debug prints + # print(ly) + + # interactions + z = [] + for k in range(ndevices): + zk = self.interact_features(x[k], ly[k]) + z.append(zk) + # debug prints + # print(z) + + # top mlp + # WARNING: Note that the self.top_l is a list of top mlp modules that + # have been replicated across devices, while z is a list of interaction results + # that by construction are scattered across devices on the first (batch) dim. + # The output is a list of tensors scattered across devices according to the + # distribution of z. + p = parallel_apply(self.top_l_replicas, z, None, device_ids) + + ### gather the distributed results ### + p0 = gather(p, self.output_d, dim=0) + + # clamp output if needed + if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: + z0 = torch.clamp( + p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold) + ) + else: + z0 = p0 + + return z0 + + +def dash_separated_ints(value): + vals = value.split("-") + for val in vals: + try: + int(val) + except ValueError: + raise argparse.ArgumentTypeError( + "%s is not a valid dash separated list of ints" % value + ) + + return value + + +def dash_separated_floats(value): + vals = value.split("-") + for val in vals: + try: + float(val) + except ValueError: + raise argparse.ArgumentTypeError( + "%s is not a valid dash separated list of floats" % value + ) + + return value + + +def inference( + args, + dlrm, + best_acc_test, + best_auc_test, + test_ld, + device, + use_gpu, + log_iter=-1, +): + test_accu = 0 + test_samp = 0 + + if args.mlperf_logging: + scores = [] + targets = [] + + for i, testBatch in enumerate(test_ld): + # early exit if nbatches was set by the user and was exceeded + if nbatches > 0 and i >= nbatches: + break + + X_test, lS_o_test, lS_i_test, T_test, W_test, CBPP_test = unpack_batch( + testBatch + ) + + # Skip the batch if batch size not multiple of total ranks + if ext_dist.my_size > 1 and X_test.size(0) % ext_dist.my_size != 0: + print("Warning: Skiping the batch %d with size %d" % (i, X_test.size(0))) + continue + + # forward pass + Z_test = dlrm_wrap( + X_test, + lS_o_test, + lS_i_test, + use_gpu, + device, + ndevices=ndevices, + ) + ### gather the distributed results on each rank ### + # For some reason it requires explicit sync before all_gather call if + # tensor is on GPU memory + if Z_test.is_cuda: + torch.cuda.synchronize() + (_, batch_split_lengths) = ext_dist.get_split_lengths(X_test.size(0)) + if ext_dist.my_size > 1: + Z_test = ext_dist.all_gather(Z_test, batch_split_lengths) + + if args.mlperf_logging: + S_test = Z_test.detach().cpu().numpy() # numpy array + T_test = T_test.detach().cpu().numpy() # numpy array + scores.append(S_test) + targets.append(T_test) + else: + with record_function("DLRM accuracy compute"): + # compute loss and accuracy + S_test = Z_test.detach().cpu().numpy() # numpy array + T_test = T_test.detach().cpu().numpy() # numpy array + + mbs_test = T_test.shape[0] # = mini_batch_size except last + A_test = np.sum((np.round(S_test, 0) == T_test).astype(np.uint8)) + + test_accu += A_test + test_samp += mbs_test + + if args.mlperf_logging: + with record_function("DLRM mlperf sklearn metrics compute"): + scores = np.concatenate(scores, axis=0) + targets = np.concatenate(targets, axis=0) + + metrics = { + "recall": lambda y_true, y_score: sklearn.metrics.recall_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "precision": lambda y_true, y_score: sklearn.metrics.precision_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "f1": lambda y_true, y_score: sklearn.metrics.f1_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "ap": sklearn.metrics.average_precision_score, + "roc_auc": sklearn.metrics.roc_auc_score, + "accuracy": lambda y_true, y_score: sklearn.metrics.accuracy_score( + y_true=y_true, y_pred=np.round(y_score) + ), + } + + validation_results = {} + for metric_name, metric_function in metrics.items(): + validation_results[metric_name] = metric_function(targets, scores) + writer.add_scalar( + "mlperf-metrics-test/" + metric_name, + validation_results[metric_name], + log_iter, + ) + acc_test = validation_results["accuracy"] + else: + acc_test = test_accu / test_samp + writer.add_scalar("Test/Acc", acc_test, log_iter) + + model_metrics_dict = { + "nepochs": args.nepochs, + "nbatches": nbatches, + "nbatches_test": nbatches_test, + "state_dict": dlrm.state_dict(), + "test_acc": acc_test, + } + + if args.mlperf_logging: + is_best = validation_results["roc_auc"] > best_auc_test + if is_best: + best_auc_test = validation_results["roc_auc"] + model_metrics_dict["test_auc"] = best_auc_test + print( + "recall {:.4f}, precision {:.4f},".format( + validation_results["recall"], + validation_results["precision"], + ) + + " f1 {:.4f}, ap {:.4f},".format( + validation_results["f1"], validation_results["ap"] + ) + + " auc {:.4f}, best auc {:.4f},".format( + validation_results["roc_auc"], best_auc_test + ) + + " accuracy {:3.3f} %, best accuracy {:3.3f} %".format( + validation_results["accuracy"] * 100, best_acc_test * 100 + ), + flush=True, + ) + else: + is_best = acc_test > best_acc_test + if is_best: + best_acc_test = acc_test + print( + " accuracy {:3.3f} %, best {:3.3f} %".format( + acc_test * 100, best_acc_test * 100 + ), + flush=True, + ) + return model_metrics_dict, is_best + + +def run(): + ### parse arguments ### + parser = argparse.ArgumentParser( + description="Train Deep Learning Recommendation Model (DLRM)" + ) + # model related parameters + parser.add_argument("--arch-sparse-feature-size", type=int, default=2) + parser.add_argument( + "--arch-embedding-size", type=dash_separated_ints, default="4-3-2" + ) + # j will be replaced with the table number + parser.add_argument("--arch-mlp-bot", type=dash_separated_ints, default="4-3-2") + parser.add_argument("--arch-mlp-top", type=dash_separated_ints, default="4-2-1") + parser.add_argument( + "--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot" + ) + parser.add_argument("--arch-interaction-itself", action="store_true", default=False) + parser.add_argument("--weighted-pooling", type=str, default=None) + # embedding table options + parser.add_argument("--md-flag", action="store_true", default=False) + parser.add_argument("--md-threshold", type=int, default=200) + parser.add_argument("--md-temperature", type=float, default=0.3) + parser.add_argument("--md-round-dims", action="store_true", default=False) + parser.add_argument("--qr-flag", action="store_true", default=False) + parser.add_argument("--qr-threshold", type=int, default=200) + parser.add_argument("--qr-operation", type=str, default="mult") + parser.add_argument("--qr-collisions", type=int, default=4) + # activations and loss + parser.add_argument("--activation-function", type=str, default="relu") + parser.add_argument("--loss-function", type=str, default="mse") # or bce or wbce + parser.add_argument( + "--loss-weights", type=dash_separated_floats, default="1.0-1.0" + ) # for wbce + parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7 + parser.add_argument("--round-targets", type=bool, default=False) + # data + parser.add_argument("--data-size", type=int, default=1) + parser.add_argument("--num-batches", type=int, default=0) + parser.add_argument( + "--data-generation", + type=str, + choices=["random", "dataset", "internal"], + default="random", + ) # synthetic, dataset or internal + parser.add_argument( + "--rand-data-dist", type=str, default="uniform" + ) # uniform or gaussian + parser.add_argument("--rand-data-min", type=float, default=0) + parser.add_argument("--rand-data-max", type=float, default=1) + parser.add_argument("--rand-data-mu", type=float, default=-1) + parser.add_argument("--rand-data-sigma", type=float, default=1) + parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log") + parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte + parser.add_argument("--raw-data-file", type=str, default="") + parser.add_argument("--processed-data-file", type=str, default="") + parser.add_argument("--data-randomize", type=str, default="total") # or day or none + parser.add_argument("--data-trace-enable-padding", type=bool, default=False) + parser.add_argument("--max-ind-range", type=int, default=-1) + parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1] + parser.add_argument("--num-indices-per-lookup", type=int, default=10) + parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False) + parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument("--memory-map", action="store_true", default=False) + # training + parser.add_argument("--mini-batch-size", type=int, default=1) + parser.add_argument("--nepochs", type=int, default=1) + parser.add_argument("--learning-rate", type=float, default=0.01) + parser.add_argument("--print-precision", type=int, default=5) + parser.add_argument("--numpy-rand-seed", type=int, default=123) + parser.add_argument("--sync-dense-params", type=bool, default=True) + parser.add_argument("--optimizer", type=str, default="sgd") + parser.add_argument( + "--dataset-multiprocessing", + action="store_true", + default=False, + help="The Kaggle dataset can be multiprocessed in an environment \ + with more than 7 CPU cores and more than 20 GB of memory. \n \ + The Terabyte dataset can be multiprocessed in an environment \ + with more than 24 CPU cores and at least 1 TB of memory.", + ) + # inference + parser.add_argument("--inference-only", action="store_true", default=False) + # quantize + parser.add_argument("--quantize-mlp-with-bit", type=int, default=32) + parser.add_argument("--quantize-emb-with-bit", type=int, default=32) + # onnx + parser.add_argument("--save-onnx", action="store_true", default=False) + # gpu + parser.add_argument("--use-gpu", action="store_true", default=False) + # distributed + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dist-backend", type=str, default="") + # debugging and profiling + parser.add_argument("--print-freq", type=int, default=1) + parser.add_argument("--test-freq", type=int, default=-1) + parser.add_argument("--test-mini-batch-size", type=int, default=-1) + parser.add_argument("--test-num-workers", type=int, default=-1) + parser.add_argument("--print-time", action="store_true", default=False) + parser.add_argument("--print-wall-time", action="store_true", default=False) + parser.add_argument("--debug-mode", action="store_true", default=False) + parser.add_argument("--enable-profiling", action="store_true", default=False) + parser.add_argument("--plot-compute-graph", action="store_true", default=False) + parser.add_argument("--tensor-board-filename", type=str, default="run_kaggle_pt") + # store/load model + parser.add_argument("--save-model", type=str, default="") + parser.add_argument("--load-model", type=str, default="") + # mlperf logging (disables other output and stops early) + parser.add_argument("--mlperf-logging", action="store_true", default=False) + # stop at target accuracy Kaggle 0.789, Terabyte (sub-sampled=0.875) 0.8107 + parser.add_argument("--mlperf-acc-threshold", type=float, default=0.0) + # stop at target AUC Terabyte (no subsampling) 0.8025 + parser.add_argument("--mlperf-auc-threshold", type=float, default=0.0) + parser.add_argument("--mlperf-bin-loader", action="store_true", default=False) + parser.add_argument("--mlperf-bin-shuffle", action="store_true", default=False) + # mlperf gradient accumulation iterations + parser.add_argument("--mlperf-grad-accum-iter", type=int, default=1) + # LR policy + parser.add_argument("--lr-num-warmup-steps", type=int, default=0) + parser.add_argument("--lr-decay-start-step", type=int, default=0) + parser.add_argument("--lr-num-decay-steps", type=int, default=0) + + global args + global nbatches + global nbatches_test + global writer + args = parser.parse_args() + + if args.dataset_multiprocessing: + assert sys.version_info[0] >= 3 and sys.version_info[1] > 7, ( + "The dataset_multiprocessing " + + "flag is susceptible to a bug in Python 3.7 and under. " + + "https://github.com/facebookresearch/dlrm/issues/172" + ) + + if args.mlperf_logging: + mlperf_logger.log_event(key=mlperf_logger.constants.CACHE_CLEAR, value=True) + mlperf_logger.log_start( + key=mlperf_logger.constants.INIT_START, log_all_ranks=True + ) + + if args.weighted_pooling is not None: + if args.qr_flag: + sys.exit("ERROR: quotient remainder with weighted pooling is not supported") + if args.md_flag: + sys.exit("ERROR: mixed dimensions with weighted pooling is not supported") + if args.quantize_emb_with_bit in [4, 8]: + if args.qr_flag: + sys.exit( + "ERROR: 4 and 8-bit quantization with quotient remainder is not supported" + ) + if args.md_flag: + sys.exit( + "ERROR: 4 and 8-bit quantization with mixed dimensions is not supported" + ) + if args.use_gpu: + sys.exit("ERROR: 4 and 8-bit quantization on GPU is not supported") + + ### some basic setup ### + np.random.seed(args.numpy_rand_seed) + np.set_printoptions(precision=args.print_precision) + torch.set_printoptions(precision=args.print_precision) + torch.manual_seed(args.numpy_rand_seed) + + if args.test_mini_batch_size < 0: + # if the parameter is not set, use the training batch size + args.test_mini_batch_size = args.mini_batch_size + if args.test_num_workers < 0: + # if the parameter is not set, use the same parameter for training + args.test_num_workers = args.num_workers + + use_gpu = args.use_gpu and torch.cuda.is_available() + + if not args.debug_mode: + ext_dist.init_distributed( + local_rank=args.local_rank, use_gpu=use_gpu, backend=args.dist_backend + ) + + if use_gpu: + torch.cuda.manual_seed_all(args.numpy_rand_seed) + torch.backends.cudnn.deterministic = True + if ext_dist.my_size > 1: + ngpus = 1 + device = torch.device("cuda", ext_dist.my_local_rank) + else: + ngpus = torch.cuda.device_count() + device = torch.device("cuda", 0) + print("Using {} GPU(s)...".format(ngpus)) + else: + device = torch.device("cpu") + print("Using CPU...") + + ### prepare training data ### + ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") + # input data + + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP) + mlperf_logger.barrier() + mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START) + mlperf_logger.barrier() + + if args.data_generation == "dataset": + train_data, train_ld, test_data, test_ld = dp.make_criteo_data_and_loaders(args) + table_feature_map = {idx: idx for idx in range(len(train_data.counts))} + nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) + nbatches_test = len(test_ld) + + ln_emb = train_data.counts + # enforce maximum limit on number of vectors per embedding + if args.max_ind_range > 0: + ln_emb = np.array( + list( + map( + lambda x: x if x < args.max_ind_range else args.max_ind_range, + ln_emb, + ) + ) + ) + else: + ln_emb = np.array(ln_emb) + m_den = train_data.m_den + ln_bot[0] = m_den + elif args.data_generation == "internal": + if not has_internal_libs: + raise Exception("Internal libraries are not available.") + NUM_BATCHES = 5000 + nbatches = args.num_batches if args.num_batches > 0 else NUM_BATCHES + train_ld, feature_to_num_embeddings = fbDataLoader(args.data_size, nbatches) + ln_emb = np.array(list(feature_to_num_embeddings.values())) + m_den = ln_bot[0] + else: + # input and target at random + ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") + m_den = ln_bot[0] + train_data, train_ld, test_data, test_ld = dp.make_random_data_and_loader( + args, ln_emb, m_den + ) + nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) + nbatches_test = len(test_ld) + + args.ln_emb = ln_emb.tolist() + if args.mlperf_logging: + print("command line args: ", json.dumps(vars(args))) + + ### parse command line arguments ### + m_spa = args.arch_sparse_feature_size + ln_emb = np.asarray(ln_emb) + num_fea = ln_emb.size + 1 # num sparse + num dense features + + m_den_out = ln_bot[ln_bot.size - 1] + if args.arch_interaction_op == "dot": + # approach 1: all + # num_int = num_fea * num_fea + m_den_out + # approach 2: unique + if args.arch_interaction_itself: + num_int = (num_fea * (num_fea + 1)) // 2 + m_den_out + else: + num_int = (num_fea * (num_fea - 1)) // 2 + m_den_out + elif args.arch_interaction_op == "cat": + num_int = num_fea * m_den_out + else: + sys.exit( + "ERROR: --arch-interaction-op=" + + args.arch_interaction_op + + " is not supported" + ) + arch_mlp_top_adjusted = str(num_int) + "-" + args.arch_mlp_top + ln_top = np.fromstring(arch_mlp_top_adjusted, dtype=int, sep="-") + + # sanity check: feature sizes and mlp dimensions must match + if m_den != ln_bot[0]: + sys.exit( + "ERROR: arch-dense-feature-size " + + str(m_den) + + " does not match first dim of bottom mlp " + + str(ln_bot[0]) + ) + if args.qr_flag: + if args.qr_operation == "concat" and 2 * m_spa != m_den_out: + sys.exit( + "ERROR: 2 arch-sparse-feature-size " + + str(2 * m_spa) + + " does not match last dim of bottom mlp " + + str(m_den_out) + + " (note that the last dim of bottom mlp must be 2x the embedding dim)" + ) + if args.qr_operation != "concat" and m_spa != m_den_out: + sys.exit( + "ERROR: arch-sparse-feature-size " + + str(m_spa) + + " does not match last dim of bottom mlp " + + str(m_den_out) + ) + else: + if m_spa != m_den_out: + sys.exit( + "ERROR: arch-sparse-feature-size " + + str(m_spa) + + " does not match last dim of bottom mlp " + + str(m_den_out) + ) + if num_int != ln_top[0]: + sys.exit( + "ERROR: # of feature interactions " + + str(num_int) + + " does not match first dimension of top mlp " + + str(ln_top[0]) + ) + + # assign mixed dimensions if applicable + if args.md_flag: + m_spa = md_solver( + torch.tensor(ln_emb), + args.md_temperature, # alpha + d0=m_spa, + round_dim=args.md_round_dims, + ).tolist() + + # test prints (model arch) + if args.debug_mode: + print("model arch:") + print( + "mlp top arch " + + str(ln_top.size - 1) + + " layers, with input to output dimensions:" + ) + print(ln_top) + print("# of interactions") + print(num_int) + print( + "mlp bot arch " + + str(ln_bot.size - 1) + + " layers, with input to output dimensions:" + ) + print(ln_bot) + print("# of features (sparse and dense)") + print(num_fea) + print("dense feature size") + print(m_den) + print("sparse feature size") + print(m_spa) + print( + "# of embeddings (= # of sparse features) " + + str(ln_emb.size) + + ", with dimensions " + + str(m_spa) + + "x:" + ) + print(ln_emb) + + print("data (inputs and targets):") + for j, inputBatch in enumerate(train_ld): + X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch) + + torch.set_printoptions(precision=4) + # early exit if nbatches was set by the user and has been exceeded + if nbatches > 0 and j >= nbatches: + break + print("mini-batch: %d" % j) + print(X.detach().cpu()) + # transform offsets to lengths when printing + print( + torch.IntTensor( + [ + np.diff( + S_o.detach().cpu().tolist() + list(lS_i[i].shape) + ).tolist() + for i, S_o in enumerate(lS_o) + ] + ) + ) + print([S_i.detach().cpu() for S_i in lS_i]) + print(T.detach().cpu()) + + global ndevices + ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1 + + ### construct the neural network specified above ### + # WARNING: to obtain exactly the same initialization for + # the weights we need to start from the same random seed. + # np.random.seed(args.numpy_rand_seed) + global dlrm + dlrm = DLRM_Net( + m_spa, + ln_emb, + ln_bot, + ln_top, + arch_interaction_op=args.arch_interaction_op, + arch_interaction_itself=args.arch_interaction_itself, + sigmoid_bot=-1, + sigmoid_top=ln_top.size - 2, + sync_dense_params=args.sync_dense_params, + loss_threshold=args.loss_threshold, + ndevices=ndevices, + qr_flag=args.qr_flag, + qr_operation=args.qr_operation, + qr_collisions=args.qr_collisions, + qr_threshold=args.qr_threshold, + md_flag=args.md_flag, + md_threshold=args.md_threshold, + weighted_pooling=args.weighted_pooling, + loss_function=args.loss_function, + ) + + # test prints + if args.debug_mode: + print("initial parameters (weights and bias):") + for param in dlrm.parameters(): + print(param.detach().cpu().numpy()) + # print(dlrm) + + if use_gpu: + # Custom Model-Data Parallel + # the mlps are replicated and use data parallelism, while + # the embeddings are distributed and use model parallelism + dlrm = dlrm.to(device) # .cuda() + if dlrm.ndevices > 1: + dlrm.emb_l, dlrm.v_W_l = dlrm.create_emb( + m_spa, ln_emb, args.weighted_pooling + ) + else: + if dlrm.weighted_pooling == "fixed": + for k, w in enumerate(dlrm.v_W_l): + dlrm.v_W_l[k] = w.cuda() + + # distribute data parallel mlps + if ext_dist.my_size > 1: + if use_gpu: + device_ids = [ext_dist.my_local_rank] + dlrm.bot_l = ext_dist.DDP(dlrm.bot_l, device_ids=device_ids) + dlrm.top_l = ext_dist.DDP(dlrm.top_l, device_ids=device_ids) + else: + dlrm.bot_l = ext_dist.DDP(dlrm.bot_l) + dlrm.top_l = ext_dist.DDP(dlrm.top_l) + + if not args.inference_only: + if use_gpu and args.optimizer in ["rwsadagrad", "adagrad"]: + sys.exit("GPU version of Adagrad is not supported by PyTorch.") + # specify the optimizer algorithm + opts = { + "sgd": torch.optim.SGD, + "rwsadagrad": RowWiseSparseAdagrad.RWSAdagrad, + "adagrad": torch.optim.Adagrad, + } + + parameters = ( + dlrm.parameters() + if ext_dist.my_size == 1 + else [ + { + "params": [p for emb in dlrm.emb_l for p in emb.parameters()], + "lr": args.learning_rate, + }, + # TODO check this lr setup + # bottom mlp has no data parallelism + # need to check how do we deal with top mlp + { + "params": dlrm.bot_l.parameters(), + "lr": args.learning_rate, + }, + { + "params": dlrm.top_l.parameters(), + "lr": args.learning_rate, + }, + ] + ) + optimizer = opts[args.optimizer](parameters, lr=args.learning_rate) + lr_scheduler = LRPolicyScheduler( + optimizer, + args.lr_num_warmup_steps, + args.lr_decay_start_step, + args.lr_num_decay_steps, + ) + + ### main loop ### + + # training or inference + best_acc_test = 0 + best_auc_test = 0 + skip_upto_epoch = 0 + skip_upto_batch = 0 + total_time = 0 + total_loss = 0 + total_iter = 0 + total_samp = 0 + + if args.mlperf_logging: + mlperf_logger.mlperf_submission_log("dlrm") + mlperf_logger.log_event( + key=mlperf_logger.constants.SEED, value=args.numpy_rand_seed + ) + mlperf_logger.log_event( + key=mlperf_logger.constants.GLOBAL_BATCH_SIZE, value=args.mini_batch_size + ) + + # Load model is specified + if not (args.load_model == ""): + print("Loading saved model {}".format(args.load_model)) + if use_gpu: + if dlrm.ndevices > 1: + # NOTE: when targeting inference on multiple GPUs, + # load the model as is on CPU or GPU, with the move + # to multiple GPUs to be done in parallel_forward + ld_model = torch.load(args.load_model) + else: + # NOTE: when targeting inference on single GPU, + # note that the call to .to(device) has already happened + ld_model = torch.load( + args.load_model, + map_location=torch.device("cuda"), + # map_location=lambda storage, loc: storage.cuda(0) + ) + else: + # when targeting inference on CPU + ld_model = torch.load(args.load_model, map_location=torch.device("cpu")) + dlrm.load_state_dict(ld_model["state_dict"]) + ld_j = ld_model["iter"] + ld_k = ld_model["epoch"] + ld_nepochs = ld_model["nepochs"] + ld_nbatches = ld_model["nbatches"] + ld_nbatches_test = ld_model["nbatches_test"] + ld_train_loss = ld_model["train_loss"] + ld_total_loss = ld_model["total_loss"] + if args.mlperf_logging: + ld_gAUC_test = ld_model["test_auc"] + ld_acc_test = ld_model["test_acc"] + if not args.inference_only: + optimizer.load_state_dict(ld_model["opt_state_dict"]) + best_acc_test = ld_acc_test + total_loss = ld_total_loss + skip_upto_epoch = ld_k # epochs + skip_upto_batch = ld_j # batches + else: + args.print_freq = ld_nbatches + args.test_freq = 0 + + print( + "Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format( + ld_k, ld_nepochs, ld_j, ld_nbatches, ld_nbatches_test + ) + ) + print( + "Training state: loss = {:.6f}".format( + ld_train_loss, + ) + ) + if args.mlperf_logging: + print( + "Testing state: accuracy = {:3.3f} %, auc = {:.3f}".format( + ld_acc_test * 100, ld_gAUC_test + ) + ) + else: + print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100)) + + if args.inference_only: + # Currently only dynamic quantization with INT8 and FP16 weights are + # supported for MLPs and INT4 and INT8 weights for EmbeddingBag + # post-training quantization during the inference. + # By default we don't do the quantization: quantize_{mlp,emb}_with_bit == 32 (FP32) + assert args.quantize_mlp_with_bit in [ + 8, + 16, + 32, + ], "only support 8/16/32-bit but got {}".format(args.quantize_mlp_with_bit) + assert args.quantize_emb_with_bit in [ + 4, + 8, + 32, + ], "only support 4/8/32-bit but got {}".format(args.quantize_emb_with_bit) + if args.quantize_mlp_with_bit != 32: + if args.quantize_mlp_with_bit in [8]: + quantize_dtype = torch.qint8 + else: + quantize_dtype = torch.float16 + dlrm = torch.quantization.quantize_dynamic( + dlrm, {torch.nn.Linear}, quantize_dtype + ) + if args.quantize_emb_with_bit != 32: + dlrm.quantize_embedding(args.quantize_emb_with_bit) + # print(dlrm) + + print("time/loss/accuracy (if enabled):") + + if args.mlperf_logging: + # LR is logged twice for now because of a compliance checker bug + mlperf_logger.log_event( + key=mlperf_logger.constants.OPT_BASE_LR, value=args.learning_rate + ) + mlperf_logger.log_event( + key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS, + value=args.lr_num_warmup_steps, + ) + + # use logging keys from the official HP table and not from the logging library + mlperf_logger.log_event( + key="sgd_opt_base_learning_rate", value=args.learning_rate + ) + mlperf_logger.log_event( + key="lr_decay_start_steps", value=args.lr_decay_start_step + ) + mlperf_logger.log_event( + key="sgd_opt_learning_rate_decay_steps", value=args.lr_num_decay_steps + ) + mlperf_logger.log_event(key="sgd_opt_learning_rate_decay_poly_power", value=2) + + tb_file = "./" + args.tensor_board_filename + writer = SummaryWriter(tb_file) + + ext_dist.barrier() + with torch.autograd.profiler.profile( + args.enable_profiling, use_cuda=use_gpu, record_shapes=True + ) as prof: + if not args.inference_only: + k = 0 + total_time_begin = 0 + while k < args.nepochs: + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_start( + key=mlperf_logger.constants.BLOCK_START, + metadata={ + mlperf_logger.constants.FIRST_EPOCH_NUM: (k + 1), + mlperf_logger.constants.EPOCH_COUNT: 1, + }, + ) + mlperf_logger.barrier() + mlperf_logger.log_start( + key=mlperf_logger.constants.EPOCH_START, + metadata={mlperf_logger.constants.EPOCH_NUM: (k + 1)}, + ) + + if k < skip_upto_epoch: + continue + + if args.mlperf_logging: + previous_iteration_time = None + + for j, inputBatch in enumerate(train_ld): + if j == 0 and args.save_onnx: + X_onnx, lS_o_onnx, lS_i_onnx, _, _, _ = unpack_batch(inputBatch) + + if j < skip_upto_batch: + continue + + X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch) + + if args.mlperf_logging: + current_time = time_wrap(use_gpu) + if previous_iteration_time: + iteration_time = current_time - previous_iteration_time + else: + iteration_time = 0 + previous_iteration_time = current_time + else: + t1 = time_wrap(use_gpu) + + # early exit if nbatches was set by the user and has been exceeded + if nbatches > 0 and j >= nbatches: + break + + # Skip the batch if batch size not multiple of total ranks + if ext_dist.my_size > 1 and X.size(0) % ext_dist.my_size != 0: + print( + "Warning: Skiping the batch %d with size %d" + % (j, X.size(0)) + ) + continue + + mbs = T.shape[0] # = args.mini_batch_size except maybe for last + + # forward pass + Z = dlrm_wrap( + X, + lS_o, + lS_i, + use_gpu, + device, + ndevices=ndevices, + ) + + if ext_dist.my_size > 1: + T = T[ext_dist.get_my_slice(mbs)] + W = W[ext_dist.get_my_slice(mbs)] + + # loss + E = loss_fn_wrap(Z, T, use_gpu, device) + + # compute loss and accuracy + L = E.detach().cpu().numpy() # numpy array + # training accuracy is not disabled + # S = Z.detach().cpu().numpy() # numpy array + # T = T.detach().cpu().numpy() # numpy array + + # # print("res: ", S) + + # # print("j, train: BCE ", j, L) + + # mbs = T.shape[0] # = args.mini_batch_size except maybe for last + # A = np.sum((np.round(S, 0) == T).astype(np.uint8)) + + with record_function("DLRM backward"): + # scaled error gradient propagation + # (where we do not accumulate gradients across mini-batches) + if ( + args.mlperf_logging + and (j + 1) % args.mlperf_grad_accum_iter == 0 + ) or not args.mlperf_logging: + optimizer.zero_grad() + # backward pass + E.backward() + + # optimizer + if ( + args.mlperf_logging + and (j + 1) % args.mlperf_grad_accum_iter == 0 + ) or not args.mlperf_logging: + optimizer.step() + lr_scheduler.step() + + if args.mlperf_logging: + total_time += iteration_time + else: + t2 = time_wrap(use_gpu) + total_time += t2 - t1 + + total_loss += L * mbs + total_iter += 1 + total_samp += mbs + + should_print = ((j + 1) % args.print_freq == 0) or ( + j + 1 == nbatches + ) + should_test = ( + (args.test_freq > 0) + and (args.data_generation in ["dataset", "random"]) + and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches)) + ) + + # print time, loss and accuracy + if should_print or should_test: + gT = 1000.0 * total_time / total_iter if args.print_time else -1 + total_time = 0 + + train_loss = total_loss / total_samp + total_loss = 0 + + str_run_type = ( + "inference" if args.inference_only else "training" + ) + + wall_time = "" + if args.print_wall_time: + wall_time = " ({})".format(time.strftime("%H:%M")) + + print( + "Finished {} it {}/{} of epoch {}, {:.2f} ms/it,".format( + str_run_type, j + 1, nbatches, k, gT + ) + + " loss {:.6f}".format(train_loss) + + wall_time, + flush=True, + ) + + log_iter = nbatches * k + j + 1 + writer.add_scalar("Train/Loss", train_loss, log_iter) + + total_iter = 0 + total_samp = 0 + + # testing + if should_test: + epoch_num_float = (j + 1) / len(train_ld) + k + 1 + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_start( + key=mlperf_logger.constants.EVAL_START, + metadata={ + mlperf_logger.constants.EPOCH_NUM: epoch_num_float + }, + ) + + # don't measure training iter time in a test iteration + if args.mlperf_logging: + previous_iteration_time = None + print( + "Testing at - {}/{} of epoch {},".format(j + 1, nbatches, k) + ) + model_metrics_dict, is_best = inference( + args, + dlrm, + best_acc_test, + best_auc_test, + test_ld, + device, + use_gpu, + log_iter, + ) + + if ( + is_best + and not (args.save_model == "") + and not args.inference_only + ): + model_metrics_dict["epoch"] = k + model_metrics_dict["iter"] = j + 1 + model_metrics_dict["train_loss"] = train_loss + model_metrics_dict["total_loss"] = total_loss + model_metrics_dict["opt_state_dict"] = ( + optimizer.state_dict() + ) + print("Saving model to {}".format(args.save_model)) + torch.save(model_metrics_dict, args.save_model) + + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.EVAL_STOP, + metadata={ + mlperf_logger.constants.EPOCH_NUM: epoch_num_float + }, + ) + + # Uncomment the line below to print out the total time with overhead + # print("Total test time for this group: {}" \ + # .format(time_wrap(use_gpu) - accum_test_time_begin)) + + if ( + args.mlperf_logging + and (args.mlperf_acc_threshold > 0) + and (best_acc_test > args.mlperf_acc_threshold) + ): + print( + "MLPerf testing accuracy threshold " + + str(args.mlperf_acc_threshold) + + " reached, stop training" + ) + break + + if ( + args.mlperf_logging + and (args.mlperf_auc_threshold > 0) + and (best_auc_test > args.mlperf_auc_threshold) + ): + print( + "MLPerf testing auc threshold " + + str(args.mlperf_auc_threshold) + + " reached, stop training" + ) + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.RUN_STOP, + metadata={ + mlperf_logger.constants.STATUS: mlperf_logger.constants.SUCCESS + }, + ) + break + + if args.mlperf_logging: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.EPOCH_STOP, + metadata={mlperf_logger.constants.EPOCH_NUM: (k + 1)}, + ) + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.BLOCK_STOP, + metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: (k + 1)}, + ) + k += 1 # nepochs + if args.mlperf_logging and best_auc_test <= args.mlperf_auc_threshold: + mlperf_logger.barrier() + mlperf_logger.log_end( + key=mlperf_logger.constants.RUN_STOP, + metadata={ + mlperf_logger.constants.STATUS: mlperf_logger.constants.ABORTED + }, + ) + else: + print("Testing for inference only") + inference( + args, + dlrm, + best_acc_test, + best_auc_test, + test_ld, + device, + use_gpu, + ) + + # profiling + if args.enable_profiling: + time_stamp = str(datetime.datetime.now()).replace(" ", "_") + with open("dlrm_s_pytorch" + time_stamp + "_shape.prof", "w") as prof_f: + prof_f.write( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cpu_time_total" + ) + ) + with open("dlrm_s_pytorch" + time_stamp + "_total.prof", "w") as prof_f: + prof_f.write(prof.key_averages().table(sort_by="self_cpu_time_total")) + prof.export_chrome_trace("dlrm_s_pytorch" + time_stamp + ".json") + # print(prof.key_averages().table(sort_by="cpu_time_total")) + + # plot compute graph + if args.plot_compute_graph: + sys.exit( + "ERROR: Please install pytorchviz package in order to use the" + + " visualization. Then, uncomment its import above as well as" + + " three lines below and run the code again." + ) + # V = Z.mean() if args.inference_only else E + # dot = make_dot(V, params=dict(dlrm.named_parameters())) + # dot.render('dlrm_s_pytorch_graph') # write .pdf file + + # test prints + if not args.inference_only and args.debug_mode: + print("updated parameters (weights and bias):") + for param in dlrm.parameters(): + print(param.detach().cpu().numpy()) + + # export the model in onnx + if args.save_onnx: + """ + # workaround 1: tensor -> list + if torch.is_tensor(lS_i_onnx): + lS_i_onnx = [lS_i_onnx[j] for j in range(len(lS_i_onnx))] + # workaound 2: list -> tensor + lS_i_onnx = torch.stack(lS_i_onnx) + """ + # debug prints + # print("inputs", X_onnx, lS_o_onnx, lS_i_onnx) + # print("output", dlrm_wrap(X_onnx, lS_o_onnx, lS_i_onnx, use_gpu, device)) + dlrm_pytorch_onnx_file = "dlrm_s_pytorch.onnx" + batch_size = X_onnx.shape[0] + print("X_onnx.shape", X_onnx.shape) + if torch.is_tensor(lS_o_onnx): + print("lS_o_onnx.shape", lS_o_onnx.shape) + else: + for oo in lS_o_onnx: + print("oo.shape", oo.shape) + if torch.is_tensor(lS_i_onnx): + print("lS_i_onnx.shape", lS_i_onnx.shape) + else: + for ii in lS_i_onnx: + print("ii.shape", ii.shape) + + # name inputs and outputs + o_inputs = ( + ["offsets"] + if torch.is_tensor(lS_o_onnx) + else ["offsets_" + str(i) for i in range(len(lS_o_onnx))] + ) + i_inputs = ( + ["indices"] + if torch.is_tensor(lS_i_onnx) + else ["indices_" + str(i) for i in range(len(lS_i_onnx))] + ) + all_inputs = ["dense_x"] + o_inputs + i_inputs + # debug prints + print("inputs", all_inputs) + + # create dynamic_axis dictionaries + do_inputs = ( + [{"offsets": {1: "batch_size"}}] + if torch.is_tensor(lS_o_onnx) + else [ + {"offsets_" + str(i): {0: "batch_size"}} for i in range(len(lS_o_onnx)) + ] + ) + di_inputs = ( + [{"indices": {1: "batch_size"}}] + if torch.is_tensor(lS_i_onnx) + else [ + {"indices_" + str(i): {0: "batch_size"}} for i in range(len(lS_i_onnx)) + ] + ) + dynamic_axes = {"dense_x": {0: "batch_size"}, "pred": {0: "batch_size"}} + for do in do_inputs: + dynamic_axes.update(do) + for di in di_inputs: + dynamic_axes.update(di) + # debug prints + print(dynamic_axes) + # export model + torch.onnx.export( + dlrm, + (X_onnx, lS_o_onnx, lS_i_onnx), + dlrm_pytorch_onnx_file, + verbose=True, + opset_version=11, + input_names=all_inputs, + output_names=["pred"], + dynamic_axes=dynamic_axes, + dynamo=False, + ) + # recover the model back + dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx") + # check the onnx model + onnx.checker.check_model(dlrm_pytorch_onnx) + total_time_end = time_wrap(use_gpu) + + +if __name__ == "__main__": + run() diff --git a/xpu_timer/experiments/dlrm/train_dlrm.sh b/xpu_timer/experiments/dlrm/train_dlrm.sh new file mode 100644 index 0000000000..43d7028036 --- /dev/null +++ b/xpu_timer/experiments/dlrm/train_dlrm.sh @@ -0,0 +1,17 @@ +if [[ "install" == $1 ]]; then + pip config set global.index-url https://pypi.antfin-inc.com/artifact/repositories/simple + pip install -e ../logging + pip install tensorboard + pip install --force-reinstall /prs/xpu_timer_whl/py_xpu_timer-1.1+cu124-cp311-cp311-linux_x86_64.whl +fi + +#export CUDA_VISIBLE_DEVICES=4,5,6,7 +export XPU_TIMER_DEBUG_MODE=1 +export XPU_TIMER_BASEPORT=28888 +export NCCL_DEBUG=WARN +export WORLD_SIZE=8 +export LOCAL_WORLD_SIZE=8 + +#xpu_timer_launch python dlrm_s_pytorch.py --mini-batch-size=16 --data-size=1000000 --use-gpu +#python -m torch.distributed.launch --nproc_per_node=8 dlrm_s_pytorch.py --mini-batch-size=16 --data-size=1000000 --use-gpu +xpu_timer_launch python -m torch.distributed.launch --nproc_per_node=8 dlrm_s_pytorch.py --arch-embedding-size="80000-80000-80000-80000-80000-80000-80000-80000" --arch-sparse-feature-size=128 --arch-mlp-bot="128-128-128-128" --arch-mlp-top="512-512-512-256-1" --max-ind-range=40000000 --data-generation=random --loss-function=bce --round-targets=True --learning-rate=1.0 --mini-batch-size=2048 --print-freq=2 --print-time --test-freq=2 --test-mini-batch-size=2048 --memory-map --use-gpu --num-batches=100 --dist-backend=nccl diff --git a/xpu_timer/experiments/figs/__init__.py b/xpu_timer/experiments/figs/__init__.py new file mode 100644 index 0000000000..3537800dce --- /dev/null +++ b/xpu_timer/experiments/figs/__init__.py @@ -0,0 +1,13 @@ +# version is genreated by bazel +import io +import sys + +from .version import * + +if isinstance(sys.stdout, io.TextIOWrapper) and sys.version_info >= (3, 7): + sys.stdout.reconfigure(encoding="utf-8") # type ignore[attr-defined] +print(f"git commit is {__version__}") # type: ignore[name-defined] +print(f"build time is {__build_time__}") # type: ignore[name-defined] +print(f"build type is {__build_type__}") # type: ignore[name-defined] +print(f"build platform is {__build_platform__}") # type: ignore[name-defined] +print(f"build platform version is {__build_platform_version__}") # type: ignore[name-defined] diff --git a/xpu_timer/experiments/figs/cuda_gdb_script.py b/xpu_timer/experiments/figs/cuda_gdb_script.py new file mode 100644 index 0000000000..2665679784 --- /dev/null +++ b/xpu_timer/experiments/figs/cuda_gdb_script.py @@ -0,0 +1,339 @@ +# flake8: noqa: E501,E722,F841,E401 +import argparse +import os +import pickle +import re +import shlex +import sys +import time +import traceback +from pathlib import Path + +import gdb + + +def get_argparser(): + parser = argparse.ArgumentParser(description="Parser for cuda-gdb scripts") + + parser.add_argument("--hang-window-size", type=int, help="Window size of hang detection", default=5) + parser.add_argument("--dump-path", type=str, help="Pipe which dumps bytes stream of pickle", required=True) + parser.add_argument("--fifo", action="store_true", help="Dumps to fifo or file, defaults to file") + parser.add_argument("--rank", type=int, required=True) + parser.add_argument("--world-size", type=int, required=True) + return parser + + +class BaseNcclSassParser: + def __init__(self): + self.pc_addresses = [] + + def detect_cycle(self): + tortoise_idx = 0 + hare_idx = 0 + while hare_idx < len(self.pc_addresses) - 1: + tortoise_idx += 1 + hare_idx += 2 + if ( + hare_idx < len(self.pc_addresses) + and self.pc_addresses[tortoise_idx][0] == self.pc_addresses[hare_idx][0] + ): + loop_start = tortoise_idx + loop_end = hare_idx + return self.pc_addresses[loop_start:loop_end] + return None + + def get_register_value(self): + self.register = {} + norm_register = gdb.execute("i r", to_string=True) + # UR54 0x0 0 + for index, reg in enumerate(norm_register.split("\n")): + # nccl max register count is 96 + if index == 96: + break + r, hex_value, int_value = reg.split() + self.register[r] = int(int_value) + + system_register = gdb.execute("i r system", to_string=True) + for index, reg in enumerate(system_register.split("\n")): + # nvidia has 8 predicate register + if index == 8: + break + r, hex_value, int_value = reg.split() + self.register[r] = int(int_value) + + def start_detection(self, window_size, thread_state): + block_range = thread_state["block_range"] + thread_range = thread_state["thread_range"] + block = block_range.split("-")[0] + thread = thread_range.split("-")[0] + gdb.execute(f"cuda block {block}") + gdb.execute(f"cuda thread {thread}") + time.sleep(1) + self.pc_addresses = [] + self.register_value = {} + loop_count = 0 + step = 0 + try: + while True: + # nccl is __forceinline__, ni almost equals si + # FIXME(zhangji.zhang) running start_detection will hang in next time + gdb.execute("ni") + + frame = gdb.selected_frame() + arch = frame.architecture() + pc = frame.pc() + + # Format PC address + pc_addr = int(pc) + insns = arch.disassemble(pc_addr, pc_addr + 16) + asm = insns[0]["asm"] + self.pc_addresses.append((pc_addr, asm)) + + loop_count += 1 + if loop_count > 1000: + # not found + break + if loop_count % window_size == 0: + loop = self.detect_cycle() + if loop: + self.get_register_value() + self.pc_addresses = loop + return self.get_step_reg(loop) + + except gdb.error as e: + print(f"GDB Error: {e}") + except KeyboardInterrupt: + print("Loop detection interrupted.") + return step + + +class LL128Parser(BaseNcclSassParser): + def get_step_reg(self, loop): + operand_pattern = r"([^,\s][^,]*)" + + for addr, line in loop: + if not line.startswith("ISETP"): + continue + parts = line.strip().split(None, 1) + if len(parts) < 2: + continue + + instruction = parts[0] + operands_str = parts[1] + + operands = [] + positions = [] + + # +1 for the instruction and the following space + offset = len(instruction) + 1 + for match in re.finditer(operand_pattern, operands_str): + operand = match.group().strip() + start_pos = offset + match.start() + operands.append(operand) + positions.append(start_pos) + + # we need 4th arg, check if there are at least four operands + if len(operands) >= 4: + operand = operands[3] + position = positions[3] + # ISETP.NE.U32.AND P3, PT, R6, R89, PT + if not operand.startswith("R") or "U32" not in instruction: + continue + # nvcc maybe compiled as R89.reuse + if "." not in operand: + return gdb.parse_and_eval(f"${operand}") + + +class LLParser(BaseNcclSassParser): + def get_step_reg(self, loop): + operand_pattern = r"([^,\s][^,]*)" + register_pattern = r"R\d+" + + for addr, line in loop: + if not line.startswith("ISETP"): + continue + parts = line.strip().split(None, 1) + if len(parts) < 2: + continue + + instruction = parts[0] + operands_str = parts[1] + + operands = [] + positions = [] + + # +1 for the instruction and the following space + offset = len(instruction) + 1 + for match in re.finditer(operand_pattern, operands_str): + operand = match.group().strip() + start_pos = offset + match.start() + operands.append(operand) + positions.append(start_pos) + + # we need 4th arg, check if there are at least four operands + if len(operands) >= 4: + operand = operands[3] + position = positions[3] + # ISETP.NE.AND P0, PT, R17, R0, PT + if not operand.startswith("R"): + continue + # nvcc maybe compiled as R89.reuse + if "." not in operand and re.search(register_pattern, operand): + return gdb.parse_and_eval(f"${operand}") + + +class SimpleParser(BaseNcclSassParser): + def get_step_reg(self, loop): + # detect + # @!P0 LDG.E.64.STRONG.SYS R16, desc[UR4][R36.64] + # @P0 LDGMC.E.MIN.64.STRONG.SYS R16, [R36.64+URZ] + operand_pattern = re.compile(r".*(LDG.*\sR\d+),.*") + regs = set() + + for addr, line in loop: + if "LDG" not in line: + continue + + matched = operand_pattern.search(line) + if not matched: + continue + # we need 4th arg, check if there are at least four operands + ldg = matched.group(1) + regs.add(ldg.split()[-1]) + for reg in regs: + return gdb.parse_and_eval(f"${reg}") + + +class FindNcclHang(gdb.Command): + def __init__(self): + super(FindNcclHang, self).__init__("find_nccl_hang", gdb.COMMAND_USER) + self.nccl_proto_parser = { + "SIMPLE": SimpleParser(), + "LL": LLParser(), + "LL128": LL128Parser(), + } + self.arg_parser = get_argparser() + + def parse_arg(self, arg_str): + arg_list = shlex.split(arg_str) + args, unknown = self.arg_parser.parse_known_args(arg_list) + return args + + def dump_value_to_disk_of_fifo(self, value): + fifo_path = self.args.dump_path + is_fifo = self.args.fifo + if is_fifo: + while not os.path.exists(fifo_path): + print(f"waiting for {fifo_path}", file=sys.stdout) + time.sleep(1) + with open(fifo_path, "wb") as fifo: + fifo.write(pickle.dumps(value)) + return + d = Path(fifo_path) + d.mkdir(parents=True, exist_ok=True) + dump_name = f"{fifo_path}/{str(self.args.rank).zfill(5)}-{str(self.args.world_size).zfill(5)}.nccl.status" + with open(dump_name, "wb") as f: + pickle.dump(value, f) + + def find_nccl_proto(self): + bt = gdb.execute("bt", to_string=True) + for i in bt.split("\n"): + if "nccl" not in i: + continue + if "SIMPLE" in i: + return "SIMPLE" + elif "LL128" in i: + return "LL128" + elif "LL" in i: + return "LL" + raise ValueError("No proto found") + + def invoke(self, arg, from_tty): + try: + self.args = self.parse_arg(arg) + nccl_parser = self.nccl_proto_parser[self.find_nccl_proto()] + all_threads_state = self.list_cuda_threads() + running_threads_state = [i for i in all_threads_state if i["has_next"]] + for running_state in running_threads_state: + running_state["hang_step"] = int(nccl_parser.start_detection(self.args.hang_window_size, running_state)) + running_state["hang_sass"] = [i[1] for i in nccl_parser.pc_addresses] + running_state["registers"] = nccl_parser.register + self.dump_value_to_disk_of_fifo(all_threads_state) + except Exception as e: + tb_exception = traceback.TracebackException.from_exception(e) + errs = [] + for line in tb_exception.format(): + errs.append(line) + self.result.pstack_stderr = "".join(errs) + print("EOF", file=sys.stdout) + print("EOF", file=sys.stderr) + + def can_run_instruction(self, sass, block, thread): + # refers https://arxiv.org/html/2407.02944v1 + gdb.execute(f"cuda block {block}") + gdb.execute(f"cuda thread {thread}") + if "WARPSYNC" in sass or "BSYNC" in sass: + return False + if "BRA" in sass: + address = sass.split()[-1] + # => 0x7f349f290630 <_Z33ncclDevFunc_AllGather_RING_SIMPLEv+39984>: @P0 BRA 0x9cb0 + # when nccl is hanging, we assuming that all condition of all predicate registers are true + # so we just find the jump destination sass code + sass = gdb.execute("x/i $pc", to_string=True) + match = re.search("_Z[_\w\d]+", sass) + if not match: + # no match, for safe, we do not run next instruction + return False + fn = match.group() + new_sass = gdb.execute(f"x/i {fn}+{address}", to_string=True) + try: + sass_code = new_sass.split(":")[-1].replace("{", "").strip() + except: + sass_code = new_sass + return self.can_run_instruction(sass_code, block, thread) + return True + + def list_cuda_threads(self): + thread_info = gdb.execute("info cuda threads", to_string=True) + lines = thread_info.splitlines()[2:] + bts = [] + for line in lines: + parts = line.split() + if parts[0] == "*": + parts = parts[1:] + start_block_idx = eval(parts[0]) + start_thread_idx = eval(parts[1]) + # nccl kernel are 1 d + start_block_idx_x = start_block_idx[0] + start_thread_idx_x = start_thread_idx[0] + end_block_idx = eval(parts[2]) + end_thread_idx = eval(parts[3]) + end_block_idx_x = end_block_idx[0] + end_thread_idx_x = end_thread_idx[0] + gdb.execute(f"cuda block {start_block_idx_x}") + gdb.execute(f"cuda thread {start_thread_idx_x}") + sass = gdb.execute("x/i $pc", to_string=True) + try: + sass_code = sass.split(":")[-1].replace("{", "").strip() + except: + sass_code = sass + # output of bt + """ + #0 0x00000100007e0c68 in ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t () at /data/nccl/src/collectives/device/./prims_simple.h:54 in _ZN10PrimitivesIa7FuncSumIaE12FanSymmetricILi1EELi1E11ProtoSimpleILi2ELi2ELi4ELi0ELi0EELi0EE7barrierEv inlined from prims_simple.h:277 + #1 0x0000010002ba5e18 in ncclKernel_AllGather_RING_LL_Sum_int8_t<<<(2,1,1),(288,1,1)>>> () at /data/nccl/src/collectives/device/./common.h:84 in _Z10ncclKernelIL10ncclFunc_t2Ea7FuncSumIaELi1ELi0ELi2174EEvP11ncclDevCommmP8ncclWork inlined from all_gather_sum_i8.cu:11 + """ + bt = [i for i in gdb.execute("backtrace", to_string=True).split("\n") if i] + bts.append( + { + "block_range": f"{start_block_idx_x}-{end_block_idx_x}", + "thread_range": f"{start_thread_idx_x}-{end_thread_idx_x}", + "stacks": bt, + "sass": sass_code, + "has_next": self.can_run_instruction(sass_code, start_block_idx_x, start_thread_idx_x), + "hang_step": 0, + } + ) + return bts + + +FindNcclHang() diff --git a/xpu_timer/experiments/figs/dlrover_parse_exception.py b/xpu_timer/experiments/figs/dlrover_parse_exception.py new file mode 100644 index 0000000000..861113ad03 --- /dev/null +++ b/xpu_timer/experiments/figs/dlrover_parse_exception.py @@ -0,0 +1,6 @@ +def xpu_timer_parse_python_exception(exc_type, exc_value, exc_traceback, infos): + print("not impl") + + +def xpu_timer_parse_cpp_exception(stack_infos): + print("not impl") diff --git a/xpu_timer/experiments/figs/dump_driver.py b/xpu_timer/experiments/figs/dump_driver.py new file mode 100644 index 0000000000..6b514228fe --- /dev/null +++ b/xpu_timer/experiments/figs/dump_driver.py @@ -0,0 +1,575 @@ +# flake8: noqa: E501,E722,F841,E401 +import argparse +import copy +import ctypes +import fcntl +import itertools +import json +import os +import pickle +import re +import selectors +import shutil +import subprocess +import sys +import sysconfig +import time +import traceback +from pathlib import Path +from typing import Dict, List, Union + +from py_xpu_timer import hosting_service_pb2 # type: ignore[attr-defined] +from py_xpu_timer.util import parallel_job + + +class HashableNamespace(argparse.Namespace): + def __hash__(self): + return hash(tuple(sorted(vars(self).items()))) + + def __eq__(self, other): + if not isinstance(other, HashableNamespace): + return NotImplemented + return vars(self) == vars(other) + + +class PstackStacktrace: + # Thread 5 (Thread 0x7efc0bfff700 (LWP 4127) "python"): + thread_parser_1 = re.compile(r'Thread (\d+).*\(LWP (\d+)\) "(.*)"\):') + # Thread 183 (Thread 0x7f1b892a8700 (LWP 122226)): + thread_parser_2 = re.compile(r"Thread (\d+).*\(LWP (\d+)\)\):") + # remove args ' (throwflag=0, f=0x1bfdb60) at ' + # 27 0x0000000000586c20 in PyEval_EvalFrameEx (throwflag=0, f=0x1bfdb60) at /usr/local/src/conda/python-3.8.18/Python/ceval.c:741 # noqa: E501 + remove_arg_pattern = re.compile(r" \(.*\) (from|at)") + + def __init__(self, thread_line, result): + self.thread_line = thread_line + self.frames = [] + self.pstack_trace = result.stacktrace.add() + + def __str__(self): + frames_str = "\n".join(self.frames) + return f"{self.thread_line}\n{frames_str}" + + def __repr__(self): + return str(self) + + def parse(self): + self._parse_thread() + self._parse_frame() + + def _parse_frame(self): + for orig_line in self.frames: + if not orig_line: + continue + frame = self.pstack_trace.frames.add() + line = PstackStacktrace.remove_arg_pattern.sub("@@@@@", orig_line) + sp = line.split("@@@@@") + frame.origin = orig_line + if len(sp) == 1: # no replace + # 6 0x000000000041c8ee in main () + sp = sp[0].split() + frame.func_name = sp[3].strip() + frame.file_name = "??" + elif " in " in line: + # 21 0x00000000004e81a6 in PyEval_EvalFrameEx /usr/local/src/conda/python-3.8.18/Python/ceval.c:741 + # 3 0x00007f69dfa71429 in Monitor::wait(bool, long, bool) () from /root/jdk/lib/server/libjvm.so + frame.func_name = sp[0].split(" in ")[-1].strip() + frame.file_name = sp[-1].strip() + else: + # 7 call_function /usr/local/src/conda/python-3.8.18/Python/ceval.c:4963 + frame.func_name = sp[0].split()[1].strip() + frame.file_name = sp[-1].strip() + + def _parse_thread(self): + parse_1 = PstackStacktrace.thread_parser_1.search(self.thread_line) + if parse_1 is not None: + _, lwp, thread_name = parse_1.groups() + lwp = int(lwp) + self.pstack_trace.thread_name = thread_name + self.pstack_trace.pid = lwp + return + parse_2 = PstackStacktrace.thread_parser_2.search(self.thread_line) + if parse_2 is not None: + thread_no, lwp = parse_2.groups() + lwp = int(lwp) + thread_name = "Unknown" + self.pstack_trace.thread_name = thread_name + self.pstack_trace.pid = lwp + return + self.pstack_trace.thread_name = "Unknown" + self.pstack_trace.pid = 0 + + +class CudaStacktrace: + cuda_stack_pattern = re.compile( + # r"#(\d+)\s+(0x[\d\w]+\sin\s)?([\d\w_]+)(<<<.*>>>)?\s(\(.*\))\sat\s([/?\w\d\.]+:\d+)\sin\s(_Z[\w\d]+)\s(inlined\s)?from\s([/?\w\d\.]+:\d+)" + r"#(\d+)\s+(0x[\d\w]+\sin\s)?([\/?\d\w_]+)(<<<.*>>>)?\s(\(.*\))(\sat\s([\/\w\d\._-]+:\d+)\sin\s([\w\d_]+)\s(inlined\s)?from\s([\/?\w\d\.]+:\d+))?" + ) + + def __init__(self, cuda_stacks: List[Dict[str, Union[str, List[str]]]], pb_message): + # structure of cuda_stacks + # {'block': '1-1', + # 'stacks': ['#0 0x00000100007def68 in ' + # 'ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t () at ' + # '/data/nccl/src/collectives/device/./prims_simple.h:228 in ' + # '_ZN10PrimitivesIa7FuncSumIaE12FanSymmetricILi1EELi1E11ProtoSimpleILi2ELi2ELi4ELi0ELi0EELi0EE9genericOpILi1ELi0ELi1ELi0ELin1ELi1EEEvllib ' + # 'inlined from prims_simple.h:595', + # '#1 0x0000010002ba5e18 in ' + # 'ncclKernel_AllGather_RING_LL_Sum_int8_t<<<(2,1,1),(288,1,1)>>> ' + # '() at /data/nccl/src/collectives/device/./common.h:84 in ' + # '_Z10ncclKernelIL10ncclFunc_t2Ea7FuncSumIaELi1ELi0ELi2174EEvP11ncclDevCommmP8ncclWork ' + # 'inlined from all_gather_sum_i8.cu:11'], + # 'thread': '32-255'}, + libc = ctypes.CDLL("libstdc++.so.6") + self.cxa_demangle = getattr(libc, "__cxa_demangle") + self.cxa_demangle.restype = ctypes.c_char_p + self.result = pb_message + for stacks in cuda_stacks: + device_stacktrace = self.result.device_stacktrace.add() + self._parse_each_thread(stacks, device_stacktrace) + + def _demangle(self, symbol): + status = ctypes.c_int(0) + demangled = self.cxa_demangle(symbol.encode("utf-8"), None, None, ctypes.byref(status)) + if status.value == 0: + return demangled.decode("utf-8") + return None + + def _parse_each_thread(self, stack_dict, device_stacktrace): + stacks = stack_dict["stacks"] + + for stack in stacks: + frame = device_stacktrace.devices_frames.add() + cuda_frame = frame.cuda_frame + match = CudaStacktrace.cuda_stack_pattern.search(stack) + if match is None: + frame.stderr = "PARSING_REGEX_ERROR" + frame.origin = stack + continue + ( + frame_id, + addr, + func_name, + kernel_args, + func_args, + curr_location, + mangled_symbol, + inline, + inlined_location, + ) = match.groups() + curr_symbol = self._demangle(mangled_symbol) + cuda_frame.device_func = func_name + cuda_frame.curr_location = curr_location + cuda_frame.curr_symbol = curr_symbol + cuda_frame.inlined_location = inlined_location + cuda_frame.block = stack_dict["block"] + cuda_frame.thread = stack_dict["thread"] + cuda_frame.sass = stack_dict["sass"] + if kernel_args is not None: + cuda_frame.kernel_args = kernel_args + + +class StacktraceDriver: + def __init__(self, args): + self.result = hosting_service_pb2.Stacktrace() + self.result.pid = args.pid + self.result.rank = args.rank + self.result.process_state = args.state + self.fifo_path = f"/tmp/xpu_timer_gdb_pipe_{args.rank}" + self.cuda_gdb_script_path = f"{str(Path(__file__).parent)}/cuda_gdb_script.py" + + self.gdb_bin = args.gdb_bin + self.cuda_gdb_bin = args.cuda_gdb_bin + self.pstack_bin = shutil.which(args.pstack_bin) or shutil.which("pstack") or "pstack NOT_FOUND" + self.pyspy_bin = shutil.which(args.pyspy_bin) or shutil.which("py-spy") or "py-spy NOT_FOUND" + self.pid = str(args.pid) + self.rank = str(args.rank) + self.world_size = str(args.world_size) + self.dump_path = args.dump_path + self.do_gdb = args.gdb + self.do_cuda_gdb = args.cuda_gdb + self.do_pyspy = args.pyspy + if args.state.startswith("D"): + print(f"Process {args.pid} is {args.state}", file=sys.stderr) + with open( + f"{self.dump_path}/{self.rank.zfill(5)}-{self.world_size.zfill(5)}.stacktrace", + "wb", + ) as f: + f.write(self.result.SerializeToString()) + exit(0) + + def dump(self): + d = Path(f"/proc/{self.result.pid}/environ") + if not d.exists(): + print("The process is not found...", file=sys.stderr) + return + ret = [] + if self.do_gdb: + ret.append(self._dump_pstack()) + if self.do_pyspy: + ret.append(self._dump_pyspy()) + if self.do_cuda_gdb and os.path.exists(self.cuda_gdb_bin) and os.path.exists(self.cuda_gdb_script_path): + ret.append(self._dump_cuda_gdb()) + + d = Path(self.dump_path) + d.mkdir(parents=True, exist_ok=True) + with open( + f"{self.dump_path}/{self.rank.zfill(5)}-{self.world_size.zfill(5)}.stacktrace", + "wb", + ) as f: + f.write(self.result.SerializeToString()) + exit(sum(ret)) + + def _non_blocking_fd(self, fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + def _ensure_pipe(self): + if os.path.exists(self.fifo_path): + os.remove(self.fifo_path) + while not os.path.exists(self.fifo_path): + os.mkfifo(self.fifo_path) + time.sleep(1) + + def _dump_pstack(self): + if shutil.which(self.pstack_bin) is None: + print(f"{self.pstack_bin} in path", file=sys.stderr) + return 0 + command = [self.pstack_bin, self.pid] + proc = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + print(stderr, file=sys.stderr) + return proc.returncode + threads = [] + frames = [] + try: + lines = stdout.split("\n") + this_line_iter, next_line_iter = itertools.tee(iter(lines), 2) + next(next_line_iter) + for this_line, next_line in zip(this_line_iter, next_line_iter): + if this_line.startswith("Thread"): + threads.append(PstackStacktrace(this_line.strip(), self.result)) + else: + frames.append(this_line.strip()) + if next_line.startswith("Thread"): + threads[-1].frames = frames[:] + frames = [] + frames.append(next_line) + threads[-1].frames = frames[:] + for t in threads: + t.parse() + except Exception as e: + self.result.pstack_stdout = stdout + tb_exception = traceback.TracebackException.from_exception(e) + errs = [] + for line in tb_exception.format(): + errs.append(line) + self.result.pstack_stderr = "".join(errs) + return 0 + + def _dump_cuda_gdb(self): + if shutil.which(self.cuda_gdb_bin) is None: + print(f"{self.cuda_gdb_bin} not found in path", file=sys.stderr) + return 1 + self._ensure_pipe() + # gdb -batch -ex "attach 1142" -ex "source cuda_gdb_script.py" -ex "find_nccl_hang" -ex quit + command = [ + self.cuda_gdb_bin, + "--batch", + "-ex", + f"attach {self.pid}", + "-ex", + f"source {self.cuda_gdb_script_path}", + "-ex", + f"find_nccl_hang --dump-path {self.fifo_path} --fifo --rank {self.rank} --world-size {self.world_size}", + "-ex", + "detach", + "-ex", + "quit", + ] + + env = os.environ.copy() + env.pop("LD_PRELOAD", "") + env["LD_LIBRARY_PATH"] = f"{sysconfig.get_config_var('LIBDIR')}:{env.get('LD_LIBRARY_PATH', '')}" + proc = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + env=env, + ) + + fd = os.open(self.fifo_path, os.O_RDONLY | os.O_NONBLOCK) + fifo = os.fdopen(fd, "rb") + self._non_blocking_fd(proc.stdout.fileno()) + self._non_blocking_fd(proc.stderr.fileno()) + + sel = selectors.DefaultSelector() + sel.register(proc.stdout, selectors.EVENT_READ, data=(proc, sel)) + sel.register(proc.stderr, selectors.EVENT_READ, data=(proc, sel)) + sel.register(fifo, selectors.EVENT_READ, data=(proc, sel)) + + def safe_unregister(selector, fileobj): + try: + selector.unregister(fileobj) + except Exception: + pass + + def parse_to_proto(message, data): + for each in data: + device_status_message = message.device_status.add() + device_status_message.block_range = each["block_range"] + device_status_message.thread_range = each["thread_range"] + device_status_message.sass = each["sass"] + device_status_message.has_next = each["has_next"] + device_status_message.hang_step = each["hang_step"] + if "hang_sass" in each: + device_status_message.hang_sass.extend(each["hang_sass"]) + device_status_message.stack_trace.extend(each["stacks"]) + if "registers" in each: + for key, value in each["registers"].items(): + device_status_message.registers[key] = value + + has_err = False + err = [] + while sel.get_map(): + events = sel.select(timeout=None) + for key, mask in events: + if mask & selectors.EVENT_READ == 0: + # not read event + continue + if key.fileobj is fifo: + # key is named pipe + parse_to_proto(self.result, pickle.loads(key.fileobj.read())) + sel.unregister(key.fileobj) + key.fileobj.close() + continue + # others, stdout and stderr + for line in key.fileobj.readlines(): + line = line.strip() + if key.fileobj is proc.stderr: + err.append(line) + if line == "EOF": + sel.unregister(key.fileobj) + continue + if proc.poll() is not None: + safe_unregister(sel, proc.stderr) + safe_unregister(sel, proc.stdout) + safe_unregister(sel, fifo) + fifo.close() + has_err = proc.returncode != 0 + time.sleep(1) + if has_err: + print("\n".join(err), file=sys.stderr) + return 1 + return 0 + + def _dump_pyspy(self): + if shutil.which(self.pyspy_bin) is None: + print(f"{self.pyspy_bin} in path", file=sys.stderr) + return 1 + command = [self.pyspy_bin, "dump", "-j", "-p", self.pid] + proc = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + print(stderr, file=sys.stderr) + return proc.returncode + try: + stacktrace = json.loads(stdout) + for thread in stacktrace: + thread_trace = self.result.py_stacktrace.add() + thread_trace.pid = thread["pid"] or 0 + thread_trace.owns_gil = thread["owns_gil"] or False + thread_trace.thread_name = thread["thread_name"] or "unknown" + thread_trace.os_thread_id = thread["os_thread_id"] or 0 + thread_trace.thread_id = thread["thread_id"] or 0 + thread_trace.active = thread["active"] or False + for f in thread["frames"]: + frame = thread_trace.frames.add() + frame.func_name = f["name"] or "unknown" + frame.file_name = f"{f['filename']}:{f['line']}" + frame.module = f["module"] or "unknown" + except Exception as e: + self.result.pyspy_stdout = stdout + tb_exception = traceback.TracebackException.from_exception(e) + errs = [] + for line in tb_exception.format(): + errs.append(line) + self.result.pyspy_stderr = "".join(errs) + + # print human readable stack + command = [self.pyspy_bin, "dump", "-p", self.pid] + proc = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + print(stderr, file=sys.stderr) + return proc.returncode + return 0 + + +def parse_env(args): + pid = args.pid + result = {} + with open(f"/proc/{pid}/environ", "r") as f: + environ = f.read() + env_vars = environ.split("\0") + for var in env_vars: + if not var: + continue + try: + k, v = var.split("=") + except: + # maybe wrong env, a="b=c" + continue + result[k] = v + + if args.rank != -1: + result["RANK"] = args.rank + if args.world_size != -1: + result["WORLD_SIZE"] = args.world_size + return result + + +def parse_one_pid_file(pid_path): + sched_file = pid_path / "sched" + if not sched_file.exists(): + return {} + container_pid = pid_path.name + pid_pattern = re.compile(r"\d+") + + with open(sched_file) as f: + first_line = f.readline().strip() + host_pid = pid_pattern.search(first_line).group() + return {host_pid: container_pid} + + +def find_gpu_pid_in_container(): + # refers https://stackoverflow.com/a/74575469 + pid_host_to_container = {} + + proc_path = Path("/proc") + + pid_dir_pattern = re.compile(r"^\d+$") + + pid_dirs = [p for p in proc_path.iterdir() if p.is_dir() and pid_dir_pattern.match(p.name)] + + pid_dict = parallel_job(parse_one_pid_file, tuple(pid_dirs), f"Parsing pids", concurrency=16) + for i in pid_dict: + pid_host_to_container.update(i) + if not pid_host_to_container: + return [] + + nvidia_smi = ["nvidia-smi", "--query-compute-apps=pid", "--format=csv,noheader,nounits"] + process = subprocess.Popen(nvidia_smi, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + stdout, stderr = process.communicate() + gpu_pid_on_host = stdout.splitlines() + + return [int(pid_host_to_container[i]) for i in gpu_pid_on_host if i in pid_host_to_container] + + +def get_process_state(pid): + status_file = f"/proc/{pid}/status" + + if not os.path.exists(status_file): + print(f"Process with PID {pid} not found.", file=sys.stderr, flush=True) + return None + + with open(status_file, "r") as file: + for line in file: + if line.startswith("State:"): + return line.split(":", 1)[1].strip() + return None + + +def run_by_pid(args_tuple): + (args,) = args_tuple + + d = Path(f"/proc/{args.pid}/environ") + if not d.exists(): + print("The process is found...", file=sys.stderr) + return + envs = parse_env(args) + if "RANK" not in envs or "WORLD_SIZE" not in envs: + print("The RANK or WORLD_SIZE is not set, exit...") + return + if args.rank == -1: + args.rank = int(envs["RANK"]) + if args.world_size == -1: + args.world_size = int(envs["WORLD_SIZE"]) + + state = get_process_state(args.pid) + args.state = state + StacktraceDriver(args).dump() + + +def run_auto_detect_mode(args): + pids = find_gpu_pid_in_container() + if not pids: + print( + "We do not find any gpus process in containers, maybe kernel is >4.14 or revert https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=74dc3384fc7983b78cc46ebb1824968a3db85eb1" + ) + return + items = [] + for pid in pids: + args_copy = copy.deepcopy(args) + args_copy.pid = pid + items.append((args_copy,)) + + parallel_job( + run_by_pid, + tuple(items), + f"Dumping on all gpu process", + len(pids), + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dump-path", type=str, required=True) + + parser.add_argument("--pid", type=int, required=False, default=-1) + parser.add_argument("--gdb-bin", type=str, default="/opt/conda/bin/gdb") + parser.add_argument("--cuda-gdb-bin", type=str, default="/usr/local/cuda/bin/cuda-gdb") + parser.add_argument("--pyspy-bin", type=str, default="/opt/conda/bin/py-spy") + parser.add_argument("--pstack-bin", type=str, default="/usr/bin/pstack") + parser.add_argument("--rank", type=int, default=-1) + parser.add_argument("--world-size", type=int, default=-1) + parser.add_argument("--gdb", action="store_true") + parser.add_argument("--cuda-gdb", action="store_true") + parser.add_argument("--pyspy", action="store_true") + + args = HashableNamespace() + parser.parse_args(namespace=args) + if not any((args.gdb, args.cuda_gdb, args.pyspy)): + print(f'You should open at least one switch, "--gdb", "--cuda-gdb", "--pyspy", exit...') + return + d = Path(args.dump_path) + if d.exists() and not d.is_dir(): + print(f"dump path {args.dump_path} is file already exists, exit...") + return + if args.pid == -1: + run_auto_detect_mode(args) + return + run_by_pid((args,)) + + +if __name__ == "__main__": + # StacktraceDriver(1142, 0,1, "/root/").dump() + main() diff --git a/xpu_timer/experiments/figs/dump_timeline.py b/xpu_timer/experiments/figs/dump_timeline.py new file mode 100644 index 0000000000..13b3e066a4 --- /dev/null +++ b/xpu_timer/experiments/figs/dump_timeline.py @@ -0,0 +1,123 @@ +import argparse +import asyncio +import time + +import aiohttp + + +async def fetch(session, url, data): + try: + async with session.get(url, json=data) as response: + response.raise_for_status() + return await response.json() + except aiohttp.ClientConnectionError as e: + print(e) + + +async def request(urls, data): + async with aiohttp.ClientSession() as session: + tasks = [fetch(session, url, data) for url in urls] + return await asyncio.gather(*tasks) + + +def parse_host_ranks(host_list, rank_list, port, dry_run=False): + combined_hosts = [] + for host, rank in zip(host_list, rank_list): + if not rank: + combined_hosts.append(f"{host}:{port}") + continue + if "-" not in rank: + combined_hosts.append(f"{host}-{rank}:{port}") + continue + start, end = map(int, rank.split("-")) + for r in range(start, end + 1): + combined_hosts.append(f"{host}-{r}:{port}") + if dry_run: + return combined_hosts + return [f"http://{host}/HostingService/DumpKernelTrace" for host in combined_hosts] + + +def AES128_CBC(text): + from base64 import b64encode + + from Crypto.Cipher import AES + from Crypto.Util.Padding import pad + + iv = "xpu_timer" + password = "xpu_timer" + iv = iv.ljust(16, "\0").encode("utf-8") + password = password.ljust(16, "\0").encode("utf-8") + text = pad(text.encode("utf-8"), AES.block_size) + + cipher = AES.new(password, AES.MODE_CBC, iv) + cipher_text = cipher.encrypt(text) + return b64encode(cipher_text).decode("utf-8") + + +def main(): + # curr=`date +%s` + # curl -H 'Content-Type: application/json' \ + # -d "{\"dump_path\":\"/root/cc/dd/ee\", \"dump_count\": 110, \"dump_time\": $((curr+3))}" \ + # 127.0.0.1:18888/HostingService/DumpKernelTrace + parser = argparse.ArgumentParser() + parser.add_argument("--host", action="append", required=True, help="Specify the host") + parser.add_argument("--rank", action="append", required=True, help="Specify the host rank range") + parser.add_argument("--port", type=int, default=18888, help="Specify the port on host") + parser.add_argument("--dump-path", type=str, default="/root/timeline", help="Specify dump path") + parser.add_argument("--dump-count", type=int, default=1000, help="Specify how many events to dump") + parser.add_argument("--delay", type=int, default=5, help="Specify when dump after request") + parser.add_argument("--reset", action="store_true", help="Specify reset dump flag") + parser.add_argument("--dry-run", action="store_true", help="Dry run") + parser.add_argument("--no-nccl", action="store_true", help="Disable nccl trace") + parser.add_argument("--no-matmul", action="store_true", help="Disable matmul(fa) trace") + parser.add_argument("--no-memory", action="store_true", help="Disable memory trace") + parser.add_argument("--oss-path", type=str, default="", help="Specify oss dump path") + parser.add_argument("--oss-ak", type=str, default="", help="Specify oss ak") + parser.add_argument("--oss-sk", type=str, default="", help="Specify oss sk") + parser.add_argument("--oss-endpoint", type=str, default="", help="Specify oss endpoint") + + args = parser.parse_args() + + if len(args.host) != len(args.rank): + parser.error("--host and --rank must be provided in pairs") + + dump_kernel_type = 7 # [00][11] // first bits is matmul, second bits is nccl + if args.no_nccl: + dump_kernel_type -= 2 + print("Disable nccl traces") + if args.no_matmul: + dump_kernel_type -= 1 + print("Disable matmul traces") + if args.no_memory: + dump_kernel_type -= 4 + print("Disable memory traces") + if dump_kernel_type == 0: + raise ValueError("No Kernel to trace") + combined_hosts = parse_host_ranks(args.host, args.rank, args.port, args.dry_run) + now = int(time.time()) + print(f"dumping to {args.dump_path}, with count {args.dump_count}") + data = { + "dump_path": args.dump_path, + "dump_time": now + args.delay, + "dump_count": args.dump_count, + "reset": args.reset, + "dump_kernel_type": dump_kernel_type, + } + if args.oss_path and args.oss_ak and args.oss_sk and args.oss_endpoint: + data["oss_args"] = { + "oss_ak": AES128_CBC(args.oss_ak), + "oss_sk": AES128_CBC(args.oss_sk), + "oss_endpoint": args.oss_endpoint, + "oss_path": args.oss_path, + } + print(data) + if args.dry_run: + print(f"dump host {combined_hosts}") + print(f"other data {data}") + return + for i in asyncio.run(request(combined_hosts, data)): + print(i) + + +if __name__ == "__main__": + main() diff --git a/xpu_timer/experiments/figs/gdb_script.py b/xpu_timer/experiments/figs/gdb_script.py new file mode 100644 index 0000000000..2baedc03b3 --- /dev/null +++ b/xpu_timer/experiments/figs/gdb_script.py @@ -0,0 +1,61 @@ +# flake8: noqa: E402 +import os +import re +import sys +import time + +import gdb +from py_xpu_timer import hosting_service_pb2 # type: ignore[attr-defined] + + +class AllThreadsBacktrace(gdb.Command): + def __init__(self): + super(AllThreadsBacktrace, self).__init__("sbt", gdb.COMMAND_USER) + self.remove_arg_pattern = re.compile(r" \(.*\) (from|at)") + self.line_buffer = [] + + def invoke(self, arg, from_tty): + fifo_path = arg + result = hosting_service_pb2.Stacktrace() + for thread in gdb.inferiors()[0].threads(): + pstack_trace = result.stacktrace.add() + thread.switch() + pid, tid, _ = thread.ptid + self.line_buffer.append(f"Thread {tid}, name {thread.name}") + pstack_trace.thread_name = thread.name + pstack_trace.pid = tid + self.backtrace_and_parse(pstack_trace) + print("\n".join(self.line_buffer)) + while not os.path.exists(fifo_path): + print(f"waiting for {fifo_path}", file=sys.stdout) + time.sleep(1) + with open(fifo_path, "wb") as fifo: + fifo.write(result.SerializeToString()) + print("EOF", file=sys.stdout) + print("EOF", file=sys.stderr) + + def backtrace_and_parse(self, pstack_trace): + bt_output = gdb.execute("bt", to_string=True) + self.line_buffer.append(bt_output) + for orig_line in bt_output.splitlines(): + frame = pstack_trace.frames.add() + line = self.remove_arg_pattern.sub("@@@@@", orig_line) + sp = line.split("@@@@@") + frame.origin = orig_line + if len(sp) == 1: # no replace + # 6 0x000000000041c8ee in main () + sp = sp[0].split() + frame.func_name = sp[3].strip() + frame.file_name = "??" + elif " in " in line: + # 21 0x00000000004e81a6 in PyEval_EvalFrameEx /usr/local/src/conda/python-3.8.18/Python/ceval.c:741 + # 3 0x00007f69dfa71429 in Monitor::wait(bool, long, bool) () from /root/jdk/lib/server/libjvm.so + frame.func_name = sp[0].split(" in ")[-1].strip() + frame.file_name = sp[-1].strip() + else: + # 7 call_function /usr/local/src/conda/python-3.8.18/Python/ceval.c:4963 + frame.func_name = sp[0].split()[1].strip() + frame.file_name = sp[-1].strip() + + +AllThreadsBacktrace() diff --git a/xpu_timer/experiments/figs/gen_trace_timeline.py b/xpu_timer/experiments/figs/gen_trace_timeline.py new file mode 100644 index 0000000000..d73f14300f --- /dev/null +++ b/xpu_timer/experiments/figs/gen_trace_timeline.py @@ -0,0 +1,508 @@ +import os +import sys +from argparse import ArgumentParser +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import Dict, List + +from py_xpu_timer import hosting_service_pb2 # type: ignore[attr-defined] +from py_xpu_timer import hook_pb2 +from py_xpu_timer.perfetto_trace_pb2 import Trace, TrackEvent # type: ignore[attr-defined] +from py_xpu_timer.util import GetRankHelper, parallel_job +from tqdm import tqdm + +TRUSTED_PACKET_SEQUENCE_ID = 1 +HOST_TRACE_META = "xpu_timer_host_trace" + + +def format_function_name(s): + if HOST_TRACE_META in s: + return s.replace("@", ".") + return s + + +class TracingData: + tracing_code_to_interned_data: Dict[str, int] = {} + COMPUTE_LINE = 1 + IO_LINE = 2 + HOST_LINE = 0 + global_line_offset: Dict[str, int] = {"compute": COMPUTE_LINE, "io": IO_LINE, "host": HOST_LINE} + local_line_offset: Dict[str, int] = { + "matmul": COMPUTE_LINE, + "nccl": IO_LINE, + "device_memory": IO_LINE, + "host_memory": HOST_LINE, + } + kernel_type: List[str] = ["matmul", "nccl", "device_memory"] + debug_annotation = { + "count": 1, + "delay(us)": 2, + "comm_hash": 3, + "nranks": 4, + "nodes": 5, + "input_size(bytes)": 6, + "blocks": 7, + "grids": 8, + "dtype": 9, + "TFLOPS": 10, + "seq": 11, + "rank": 12, + "cublas_api": 13, + "ldA": 15, + "ldB": 16, + "ldC": 17, + "strideA": 18, + "strideB": 19, + "strideC": 20, + "cublas_algo": 21, + "trans": 22, + "direction": 23, + "Bandwidth(GiB/s)": 24, + "bytes": 25, + "collected": 26, + "uncollectable": 27, + } + debug_annotation_string_values = { + "cublasGemmEx": 1, + "cublasGemmStridedBatchedEx": 2, + "cublasSgemm": 3, + "cublasSgemmStridedBatched": 4, + "cublasLtMatmul": 5, + "aclnnMatmul": 6, + } + timeline_version = 2 + + @staticmethod + def parse_one_host_trace_data(packet, rank, host_traces, header_uuid, base_uuid, group_index): + host = packet.add() + host.track_descriptor.uuid = base_uuid + TracingData.global_line_offset["host"] + host.track_descriptor.parent_uuid = header_uuid + host.track_descriptor.thread.pid = group_index + host.track_descriptor.thread.tid = host.track_descriptor.uuid + host.track_descriptor.thread.thread_name = f"rank {rank} host" + for host_trace in host_traces: + uuid = host.track_descriptor.uuid + start = packet.add() + iid = TracingData.tracing_code_to_interned_data[host_trace.name.replace("@", ".")] + start_ns = host_trace.start_us * 1000 + dur_in_ns = host_trace.dur_us * 1000 + + start.timestamp = start_ns + start.track_event.type = TrackEvent.TYPE_SLICE_BEGIN + start.track_event.track_uuid = uuid + start.track_event.name_iid = iid + start.trusted_packet_sequence_id = TRUSTED_PACKET_SEQUENCE_ID + start.track_event.categories.append("host") + + count = start.track_event.debug_annotations.add() + count.name_iid = 1 # count + count.uint_value = host_trace.count + + if host_trace.name == "GC": + collected = start.track_event.debug_annotations.add() + collected.name_iid = 26 # collected + collected.int_value = host_trace.gc_debug.collected + + uncollectable = start.track_event.debug_annotations.add() + uncollectable.name_iid = 27 # uncollectable + uncollectable.int_value = host_trace.gc_debug.uncollectable + + end = packet.add() + end.trusted_packet_sequence_id = TRUSTED_PACKET_SEQUENCE_ID + end.timestamp = start_ns + dur_in_ns + end.track_event.type = TrackEvent.TYPE_SLICE_END + end.track_event.track_uuid = uuid + + @staticmethod + def parse_one_trace_data(args): + path, header_uuid, group_index = args + name = path.name + # name: 00000-00001.timeline + rank = int(name[:5]) + trace = Trace() + packet = trace.packet + base_uuid = header_uuid + (rank + 1) * 10000 + compute = packet.add() + compute.track_descriptor.uuid = base_uuid + TracingData.global_line_offset["compute"] + compute.track_descriptor.parent_uuid = header_uuid + compute.track_descriptor.thread.pid = group_index + compute.track_descriptor.thread.tid = compute.track_descriptor.uuid + compute.track_descriptor.thread.thread_name = f"rank {rank} compute" + + coll = packet.add() + coll.track_descriptor.uuid = base_uuid + TracingData.global_line_offset["io"] + coll.track_descriptor.parent_uuid = header_uuid + coll.track_descriptor.thread.pid = group_index + coll.track_descriptor.thread.tid = coll.track_descriptor.uuid + coll.track_descriptor.thread.thread_name = f"rank {rank} io" + + with path.open("rb") as f: + timeline_traces = hook_pb2.KernelTraces() + timeline_traces.ParseFromString(f.read()) + matmul_debug_info = defaultdict(list) + + for each_trace in timeline_traces.traces: + debug_data_field_name = each_trace.WhichOneof("debug_data") + debug_data = getattr(each_trace, debug_data_field_name) + tracing = TracingData( + path, + each_trace, + debug_data, + timeline_traces.rank, + ) + tracing.parse_pb(packet, base_uuid, matmul_debug_info) + if hasattr(timeline_traces, "host_traces"): + TracingData.parse_one_host_trace_data( + packet, timeline_traces.rank, timeline_traces.host_traces, header_uuid, base_uuid, group_index + ) + return trace.SerializeToString(), (timeline_traces.rank, matmul_debug_info) + + def __init__(self, path, each_trace, debug_data=None, rank=None): + abs_path = str(path.absolute()) + naming_dict: Dict[int, str] = {} + with open(abs_path + ".meta") as f: + for line in f: + # skip extra traces + if HOST_TRACE_META in line: + continue + k, v = line.strip().split(",") + naming_dict[int(k)] = v + self.name_for_encode = naming_dict[each_trace.trace_code] + self.name = self.name_for_encode.replace("xpu_timer_", "") + self.kernel_type = TracingData.kernel_type[each_trace.kernel_type] + if each_trace.is_host: + self.kernel_type = "host_memory" + self.delay = each_trace.delay_us + + self.kernel_code = each_trace.kernel_type + self.start = each_trace.start_us + self.dur = each_trace.dur_us + self.trace_id = each_trace.trace_id + self.debug_data = debug_data + self.rank = rank + + def add_annotation(self, debug_annotations, matmul_debug_info, packet): + if isinstance(self.debug_data, hook_pb2.MatmulDebugData): + flop = 2 + # bmnk + for field, value in zip("bmnk", self.debug_data.shapes): + annotation = debug_annotations.add() + annotation.name = field + annotation.uint_value = value + flop = flop * value + + # lds + for field, value in zip(range(15, 18), self.debug_data.lds): + annotation = debug_annotations.add() + annotation.name_iid = field + annotation.uint_value = value + + # strides + for field, value in zip(range(18, 21), self.debug_data.strides): + annotation = debug_annotations.add() + annotation.name_iid = field + annotation.uint_value = value + cublas_api = debug_annotations.add() + cublas_api.name_iid = 13 + cublas_api.string_value_iid = TracingData.debug_annotation_string_values[self.debug_data.api] + + cublas_algo = debug_annotations.add() + cublas_algo.name_iid = 21 # cublas algo + cublas_algo.int_value = self.debug_data.algo + + trans = debug_annotations.add() + trans.name_iid = 22 # trans + trans.string_value = self.debug_data.trans + + dtype = debug_annotations.add() + dtype.name_iid = 9 # dtype + dtype.string_value = self.debug_data.dtype + + tflops = debug_annotations.add() + tflops.name_iid = 10 # tflops + tflops.double_value = round(flop / self.dur / 1e6, 2) + matmul_debug_info[self.debug_data.SerializeToString()].append(tflops.double_value) + + elif isinstance(self.debug_data, hook_pb2.FaDebugData): + for field, value in zip("bssh", self.debug_data.shapes): + annotation = debug_annotations.add() + annotation.name = field + annotation.uint_value = value + + elif isinstance(self.debug_data, hook_pb2.GroupedMatmulDebugData): + tflops = debug_annotations.add() + tflops.name_iid = 10 + tflops.double_value = round(self.debug_data.tflops / self.dur / 1e6, 2) + + elif isinstance(self.debug_data, hook_pb2.NcclDebugData): + grids = debug_annotations.add() + grids.name_iid = 8 # grids + grids.string_value = f"[{','.join(map(str,self.debug_data.grids))}]" + + blocks = debug_annotations.add() + blocks.name_iid = 7 # block + blocks.string_value = f"[{','.join(map(str,self.debug_data.blocks))}]" + + comm_hash = debug_annotations.add() + comm_hash.name_iid = 3 # hash + comm_hash.uint_value = self.debug_data.comm_hash + + input_size = debug_annotations.add() + input_size.name_iid = 6 # bytes + input_size.uint_value = self.debug_data.input_size_in_bytes + + dtype = debug_annotations.add() + dtype.name_iid = 9 # dtype + dtype.string_value = self.debug_data.dtype + + nranks = debug_annotations.add() + nranks.name_iid = 4 # nranks + nranks.uint_value = self.debug_data.ranks + + nodes = debug_annotations.add() + nodes.name_iid = 5 # nodes + nodes.uint_value = self.debug_data.nodes + + seq = debug_annotations.add() + seq.name_iid = 11 # seq num + seq.uint_value = self.debug_data.seq + + bandwidth = debug_annotations.add() + bandwidth.name_iid = 24 # GiB/s + bandwidth.double_value = round(self.debug_data.problem_size / (1 << 30) / self.dur * 1e6, 2) + + elif isinstance(self.debug_data, hook_pb2.MemoryDebugData): + direction = debug_annotations.add() + direction.name_iid = 23 # direction + direction.string_value = self.debug_data.direction + bandwidth = debug_annotations.add() + bandwidth.name_iid = 24 # GiB/s + bandwidth.double_value = round(self.debug_data.size / (1 << 30) / self.dur * 1e6, 2) + copy_bytes = debug_annotations.add() + copy_bytes.name_iid = 25 # bytes + copy_bytes.uint_value = self.debug_data.size + else: + raise ValueError("Debug data shoule be FA/Matmul/Nccl/Memory") + + def parse_pb(self, packet, uuid, matmul_debug_info): + uuid = uuid + TracingData.local_line_offset[self.kernel_type] + start = packet.add() + self.iid = TracingData.tracing_code_to_interned_data[self.name] + self.start = self.start * 1000 + dur_in_ns = (self.dur - 10) * 1000 if self.dur > 100 else self.dur * 1000 + + start.timestamp = self.start + start.track_event.type = TrackEvent.TYPE_SLICE_BEGIN + start.track_event.track_uuid = uuid + start.track_event.name_iid = self.iid + start.trusted_packet_sequence_id = TRUSTED_PACKET_SEQUENCE_ID + start.track_event.categories.append(self.kernel_type) + + count_annotation = start.track_event.debug_annotations.add() + count_annotation.name_iid = 1 # count + count_annotation.uint_value = self.trace_id + + delay_annotation = start.track_event.debug_annotations.add() + delay_annotation.name_iid = 2 # delay(us) + delay_annotation.uint_value = self.delay + + if self.rank is None: + raise ValueError("Rank is not set") + rank_annotation = start.track_event.debug_annotations.add() + rank_annotation.name_iid = 12 # rank + rank_annotation.uint_value = self.rank + + if self.debug_data is not None: + self.add_annotation(start.track_event.debug_annotations, matmul_debug_info, packet) + + end = packet.add() + end.trusted_packet_sequence_id = TRUSTED_PACKET_SEQUENCE_ID + + end.timestamp = self.start + dur_in_ns + end.track_event.type = TrackEvent.TYPE_SLICE_END + end.track_event.track_uuid = uuid + + +def add_interned_data(trace): + trace_header = None + for packet in trace.packet: + if packet.HasField("track_event"): + trace_header = packet + break + + if trace_header is None: + raise ValueError("No track events") + interned_data = trace_header.interned_data + for name, iid in TracingData.tracing_code_to_interned_data.items(): + data = interned_data.event_names.add() + data.iid = iid + data.name = name + for name, iid in TracingData.debug_annotation.items(): + data = interned_data.debug_annotation_names.add() + data.iid = iid + data.name = name + for name, iid in TracingData.debug_annotation_string_values.items(): + data = interned_data.debug_annotation_string_values.add() + data.iid = iid + data.str = name.encode() + + trace_header.first_packet_on_sequence = True + trace_header.previous_packet_dropped = True + trace_header.sequence_flags = 3 + + +def serialize_to_file_in_chunks(protobuf_message, file_path, chunk_size=1024 * 1024): + print("Serizlize tarce to bytes, it's slow...") + serialized_data = protobuf_message.SerializeToString() + total_size = len(serialized_data) + with open(file_path, "wb") as f: + with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Serializing to {file_path}") as progress_bar: + for i in range(0, total_size, chunk_size): + end = min(i + chunk_size, total_size) + f.write(serialized_data[i:end]) + progress_bar.update(end - i) + + +def parse_group(group_name, timelines, group_index, add_iternal, fd, concurrency=32): + header_uuid = group_index * int(1e8) + trace = Trace() + header = trace.packet.add() + header.track_descriptor.uuid = header_uuid + header.track_descriptor.process.pid = group_index + header.track_descriptor.process.process_name = group_name + + args = [(path, header_uuid, group_index) for path in timelines] + TracingData.parse_one_trace_data(args[0]) + sub_traces_matmul_info = parallel_job( + TracingData.parse_one_trace_data, + args, + f"generate perfetto timeline for {group_name} with parallel{concurrency}", + concurrency, + ) + all_matmul_info = hook_pb2.RankMatmulInfo() + + def parse_matmul_info(matmul_info): + rank, info = matmul_info + mm_infos = hook_pb2.MatmulInfos() + for mm_debug_pb, tflops in info.items(): + each_info = mm_infos.infos.add() + each_info.mm_debug.ParseFromString(mm_debug_pb) + each_info.tflops.extend(tflops) + all_matmul_info.mm_infos[rank].CopyFrom(mm_infos) + + fd.write(trace.SerializeToString()) + for sub_trace, matmul_info in tqdm(sub_traces_matmul_info, desc=f"Write {group_name}"): + if matmul_info is not None: + parse_matmul_info(matmul_info) + if add_iternal: + trace = Trace() + trace.ParseFromString(sub_trace) + add_interned_data(trace) + fd.write(trace.SerializeToString()) + continue + fd.write(sub_trace) + + return all_matmul_info + + +def parse_timeline_stack(tracestack_path): + path = Path(tracestack_path) + + def parse_single(timeline_path): + timeline = hosting_service_pb2.PythonStackInTimeline() + if not timeline_path.exists(): + print(f"tracing_kernel_callstack not found in {tracestack_path}") + return timeline + + with open(timeline_path, "rb") as f: + timeline.ParseFromString(f.read()) + return timeline + + timelines = [parse_single(p) for p in path.glob("*tracing_kernel_callstack")] + result = {} + for timeline in timelines: + for kernel_name, frames in timeline.named_frames.items(): + if kernel_name in result: + continue + stack_list = [] + for frame in frames.frames: + stack_list.append(f"{frame.func_name}@{frame.file_name}") + result[kernel_name] = ";".join(stack_list) + + merged_path = f"{tracestack_path}/merged_tracing_kernel_stack" + with open(merged_path, "w") as f: + for k, v in result.items(): + f.write(f"{v};{k} 1\n") + + os.system( + "flamegraph.pl --color tracing_kernel --width 1600 --title " + f"'callstack of tracing kernels' < {merged_path} " + f"> {tracestack_path}/tracing_kernel_stack.svg" + ) + + +def generate_perfetto_trace(args): + timeline_dir = args.path + files = Path(timeline_dir) + timeline_dict = {i: f for i, f in enumerate(sorted(list(files.glob("*timeline"))))} + if not timeline_dict: + print("There are no timeline files, exit...", file=sys.stderr, flush=True) + exit(1) + groups_dict = {} + if args.groups: + # args.groups: tp4-cp2-dp4-pp2 + groups_dict = OrderedDict((pair[:2], int(pair[2:])) for pair in args.groups.split("-")) + else: + groups_dict["dp"] = len(timeline_dict) + rank_helper = GetRankHelper(groups_dict) + timelines = {group: [timeline_dict[i] for i in rank_helper.get_ranks(group, group_0=True)] for group in groups_dict} + # perpare interned data + # https://perfetto.dev/docs/reference/synthetic-track-event#interning + all_kernel_names = set() + for timeline in timeline_dict.values(): + meta_path = str(timeline.absolute()) + ".meta" + with open(meta_path) as f: + for line in f: + line = line.strip() + iid, name = format_function_name(line).split(",") + all_kernel_names.add(name) + for iid, name in enumerate(all_kernel_names): + # perfetto's internal id is start at 1 + name = name.replace("xpu_timer_", "") + TracingData.tracing_code_to_interned_data[name] = iid + 1 + + all_matmul_info = hook_pb2.RankMatmulInfo() + trace_name = args.output if args.output else "_".join([f"{k}{v}" for k, v in groups_dict.items()]) + fd = open(f"{timeline_dir}/trace_{trace_name}.bin", "wb") + add_internel = True + for index, (name, files) in enumerate(timelines.items()): + all_matmul_info.MergeFrom(parse_group(name, files, index, add_internel, fd, args.c)) + add_internel = False + serialize_to_file_in_chunks(all_matmul_info, f"{timeline_dir}/matmul_{trace_name}.bin", chunk_size=1024 * 1024) + + +def main(): + parser = ArgumentParser(usage="""python gen_timeline.py dump_dir""") + parser.add_argument("--path", "-p", type=str, default=".") + parser.add_argument("--no-matmul", action="store_true") + parser.add_argument("--no-nccl", action="store_true") + parser.add_argument("-c", type=int, default=16, required=False) + parser.add_argument("--timeline-version", type=int, default=2, required=False) + parser.add_argument( + "--groups", type=str, default="", required=False, help='Group configurations like "tp4-cp2-dp4-pp2"' + ) + parser.add_argument("--output", type=str, default="", required=False, help="Output name for timeline file") + args = parser.parse_args() + TracingData.timeline_version = args.timeline_version + + timeline_dir = args.path + files = Path(timeline_dir) + if not files.exists(): + print("path {timeline_dir} not exists") + return + generate_perfetto_trace(args) + parse_timeline_stack(timeline_dir) + + +if __name__ == "__main__": + main() diff --git a/xpu_timer/experiments/figs/parse_matmul.py b/xpu_timer/experiments/figs/parse_matmul.py new file mode 100644 index 0000000000..7cff35e1f6 --- /dev/null +++ b/xpu_timer/experiments/figs/parse_matmul.py @@ -0,0 +1,165 @@ +# flake8: noqa: E402 +import json +import subprocess +from argparse import ArgumentParser + +import matplotlib.pyplot as plt +import pandas as pd + +plt.rcParams["figure.dpi"] = 300 +from py_xpu_timer import hook_pb2 + + +class MatmulPlayBack: + def __init__(self, path): + self.matmul_debug = hook_pb2.RankMatmulInfo() + with open(path, "rb") as f: + self.matmul_debug.ParseFromString(f.read()) + self.matmul_database = {} + self.rank_mean_tflops = {} + + def run(self, path="relative_matmul_performance.pkl", plot=False, p=0.9): + if plot: + try: + import seaborn as sns + except ImportError: + print("seaborn is not found, run `pip install seaborn` if you need plot") + return + for rank, all_debug in self.matmul_debug.mm_infos.items(): + mean_tflops = {} + for each_debug in all_debug.infos: + self.parse_gemm(each_debug, mean_tflops) + self.rank_mean_tflops[rank] = mean_tflops + + relative_performance = {} + for rank, ops in self.rank_mean_tflops.items(): + relative_performance[rank] = { + op: ops[op] / self.matmul_database[op] if op in self.matmul_database else 1.0 for op in ops + } + + df = pd.DataFrame.from_dict(relative_performance) + df.to_pickle(path) + percentile_90 = df.apply(lambda row: row.quantile(p), axis=1) + comparison_p90 = df.apply(lambda row: row < percentile_90[row.name], axis=1) + comparison_abs = df.apply(lambda row: row < p, axis=1) + for row in comparison_p90.index: + for col in comparison_p90.columns: + if comparison_p90.loc[row, col]: + print( + f"{row}, rank: {col} is slow than p{int(p * 100)}, {self.rank_mean_tflops[col][row]} vs {self.matmul_database[row]}" + ) + for row in comparison_abs.index: + for col in comparison_abs.columns: + if comparison_abs.loc[row, col]: + print( + f"{row}, rank: {col} is slow than {p}, {self.rank_mean_tflops[col][row]} vs {self.matmul_database[row]}" + ) + + if not plot: + return + + def get_figsize(df, aspect_ratio=1.0): + n_rows, n_cols = df.shape + width = max(n_cols / aspect_ratio, 1) + height = max(n_rows / aspect_ratio, 1) + return (width, height) + + plt.figure(figsize=get_figsize(df, 1.5), dpi=300) + + sns.heatmap(df, annot=True, cmap=sns.color_palette("GnBu", as_cmap=True), linewidths=0.5, cbar=True) + plt.title("Relative Performance Heatmap") + plt.xlabel("Rank") + plt.ylabel("Matrix Multiplication Operation") + plt.xticks(rotation=0) + plt.yticks(rotation=0) + plt.show() + + def parse_gemm(self, debug_data_tflops, mean_tflops): + debug_data = debug_data_tflops.mm_debug + all_tflops = debug_data_tflops.tflops + name = f"{debug_data.api}_[{','.join(map(str, debug_data.shapes))}]_{debug_data.trans}" + mean_tflops[name] = round(sum(all_tflops) / len(all_tflops), 2) + if name in self.matmul_database: + return + if debug_data.api == "cublasLtMatmul": + command, base_tflops = self.parse_cublaslt_gemm(debug_data) + else: + command, base_tflops = self.parse_cublas_gemm(debug_data) + if base_tflops == -1: + print(f"{command} error") + return + self.matmul_database[name] = base_tflops + + def parse_cublaslt_gemm(self, debug): + # ./cublaslt_gemm -m 512 -n 512 -k 8192 -b 10 -w 5 -i 100 -t fp16 + commands = ["cublaslt_gemm", "-w", "50", "-i", "10"] + for arg, value in zip("bmnk", debug.shapes): + commands.append(f"-{arg}") + commands.append(str(value)) + commands.append("-t") + commands.append(debug.dtype) + transa, transb = debug.trans + if transa == "T": + commands.append("--trans_a") + if transb == "T": + commands.append("--trans_b") + p = subprocess.Popen(commands, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.returncode: + return " ".join(commands), -1 + out, err = p.communicate() + out = out.decode() + # A:T B:T m:1024 n:1024 k:2048 batch:1 time_us:790.169617 tflops:5.434173 + tflops = out.split(":")[-1] + return " ".join(commands), float(tflops) + + def parse_cublas_gemm(self, debug): + # cublas_benchmark --num_test 10 --config_json '{"name":"cublasGemmEx","m":4096,"n":4096,"k":4096,"transa":0,"transb":0,"datatype":"float"}' + commands = ["cublas_benchmark", "--num_test", "10", "--warm_up", "50"] + config_json = { + "name": "cublasGemmEx", + "m": 4096, + "n": 4096, + "k": 4096, + "transa": 0, + "transb": 0, + "datatype": "float", + } + config_json["name"] = debug.api + for arg, value in zip(["batchCount", "m", "n", "k"], debug.shapes): + config_json[arg] = value + if "16" in debug.dtype: + config_json["datatype"] = "half" + elif "32" in debug.dtype: + config_json["datatype"] = "float" + transa, transb = debug.trans + if transa == "T": + config_json["transa"] = 1 + if transb == "T": + config_json["transb"] = 1 + commands.append("--config_json") + commands.append(json.dumps(config_json)) + p = subprocess.Popen(commands, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + if p.returncode: + return " ".join(commands), -1 + out = out.decode() + # A:T B:T m:1024 n:1024 k:2048 batch:1 time_us:790.169617 tflops:5.434173 + tflops = out.split(":")[-1] + return " ".join(commands), float(tflops) + + +def main(): + parser = ArgumentParser(usage="""python parse_perfetto.py""") + parser.add_argument("--path", default="matmul.bin", help="Trace path for diff") + parser.add_argument("--output-path", default="relative_matmul_performance.pkl") + parser.add_argument("--plot", action="store_true") + parser.add_argument("--p", default=0.9, type=float) + + args = parser.parse_args() + if args.p >= 1.0: + raise ValueError("p should less than 1") + MatmulPlayBack(args.path).run(args.output_path, args.plot, args.p) + + +if __name__ == "__main__": + main() diff --git a/xpu_timer/experiments/figs/parse_nv_timeline.ipynb b/xpu_timer/experiments/figs/parse_nv_timeline.ipynb new file mode 100644 index 0000000000..6aa56ce9cd --- /dev/null +++ b/xpu_timer/experiments/figs/parse_nv_timeline.ipynb @@ -0,0 +1,705 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "id": "8c23b94b", + "metadata": {}, + "outputs": [], + "source": [ + "from parse_perfetto import plot_tflops_grouped_by_operation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8bbcd696", + "metadata": {}, + "outputs": [], + "source": [ + "from parse_perfetto import *" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f0ed0dd4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "timeline_path = \"/Users/sangbo/local_data/projects/xpu_timer/paper/experiments/figs/trace_dp4.bin\"\n", + "save_image = \"./fsdp.svg\"\n", + "plot_tflops_grouped_by_operation(timeline_path, save_image)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d0820c64", + "metadata": {}, + "outputs": [], + "source": [ + "timeline_path = \"/Users/sangbo/local_data/projects/xpu_timer/paper/experiments/figs/trace_dp4.bin\"\n", + "trace = PerfettoParser(trace=timeline_path)\n", + "filter_rank =[ i for i in range(8)]\n", + "matmul = trace.parse(tflops_sql)\n", + "# Loop over each name and rank to plot them separately\n", + "matmul = matmul[matmul['rank'].isin(filter_rank)]\n", + "grouped = matmul.groupby(['name', 'rank'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "33ed1294", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from pylab import mpl\n", + "import warnings\n", + "\n", + "# Suppress all warnings\n", + "warnings.filterwarnings('ignore')\n", + "plt.rcParams['axes.unicode_minus'] = False\n", + "def plot_sss(data):\n", + " if(len(data) < 10):\n", + " return\n", + " y = np.array(data)\n", + " x = np.arange(1, (len(y)+1))\n", + "\n", + " z1=np.polyfit(x,y,deg=1)#deg=100,100次多项式,返回值为系数\n", + " p1=np.poly1d(z1)#通过多项式系数,返回方程\n", + " print(p1)#输出方程\n", + " print(np.polyval(p1,12))#进行预测\n", + " print(np.polyval(z1,13))#这两种方法都可以,可以p1,或者z1作为参数,12,13为x,输入x得到预测值\n", + "\n", + " y_pred=p1(x)#预测值\n", + " plt.plot(x,y,'*',label='origin value')\n", + " plt.plot(x,y_pred,'r',label='pred value')\n", + "# plt.title('多项式拟合')\n", + " plt.xlabel('xlable')\n", + " plt.ylabel('ylabel')\n", + " plt.legend(loc=3, borderaxespad=0., bbox_to_anchor=(0, 0))#画出图例\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1eef78b1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "===============mm_bmnk_1_126464_8192_4096_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_126464_8192_4096_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_126464_8192_4096_------------2====================\n", + "===============mm_bmnk_1_126464_8192_4096_------------3====================\n", + "===============mm_bmnk_1_4096_126464_8192_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_4096_126464_8192_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_126464_8192_------------2====================\n", + "===============mm_bmnk_1_4096_126464_8192_------------3====================\n", + "===============mm_bmnk_1_4096_4096_8192_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_4096_4096_8192_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_4096_8192_------------2====================\n", + "===============mm_bmnk_1_4096_4096_8192_------------3====================\n", + "===============mm_bmnk_1_4096_512_8192_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_4096_512_8192_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_512_8192_------------2====================\n", + "===============mm_bmnk_1_4096_512_8192_------------3====================\n", + "===============mm_bmnk_1_4096_8192_126464_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_126464_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_126464_------------2====================\n", + "===============mm_bmnk_1_4096_8192_126464_------------3====================\n", + "===============mm_bmnk_1_4096_8192_4096_------------0====================\n", + "[1801, 1829]\n", + "===============mm_bmnk_1_4096_8192_4096_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_4096_------------2====================\n", + "===============mm_bmnk_1_4096_8192_4096_------------3====================\n", + "===============mm_bmnk_1_4096_8192_512_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_512_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_512_------------2====================\n", + "===============mm_bmnk_1_4096_8192_512_------------3====================\n", + "===============mm_bmnk_1_4096_8192_8192_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_8192_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_4096_8192_8192_------------2====================\n", + "===============mm_bmnk_1_4096_8192_8192_------------3====================\n", + "===============mm_bmnk_1_512_8192_4096_------------0====================\n", + " \n", + "1 x + 1200\n", + "1212.0\n", + "1213.0\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228]\n", + "===============mm_bmnk_1_512_8192_4096_------------1====================\n", + " \n", + "1 x + 1200\n", + "1212.0\n", + "1213.0\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228]\n", + "===============mm_bmnk_1_512_8192_4096_------------2====================\n", + "===============mm_bmnk_1_512_8192_4096_------------3====================\n", + "===============mm_bmnk_1_8192_4096_8192_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_8192_4096_8192_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_8192_4096_8192_------------2====================\n", + "===============mm_bmnk_1_8192_4096_8192_------------3====================\n", + "===============mm_bmnk_1_8192_8192_4096_------------0====================\n", + "[]\n", + "===============mm_bmnk_1_8192_8192_4096_------------1====================\n", + "[]\n", + "===============mm_bmnk_1_8192_8192_4096_------------2====================\n", + "===============mm_bmnk_1_8192_8192_4096_------------3====================\n" + ] + } + ], + "source": [ + "for (name, rank), df in grouped:\n", + " print(f\"==============={name}------------{rank}====================\")\n", + " filter_df = df[df['TFLOPS'] < 200]\n", + " if rank in [0, 1, 7]:\n", + " plot_sss(filter_df['count'].tolist())\n", + " print(filter_df['count'].tolist())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "52876a13", + "metadata": {}, + "outputs": [], + "source": [ + "xx = [6021, 6022, 6025, 6026, 6029, 6030, 6033, 6034, 6037, 6038, 6041, 6042, 6045, 6046, 6049, 6050, 6053, 6054, 6057, 6058, 6061, 6062]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6cf13ce1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "22" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(xx)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9706fd52", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "yy = np.array(xx)\n", + "tt = yy[1:] - yy[0:-1]\n", + "tt" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "68604e64", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.sum(tt == 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e6cee559", + "metadata": {}, + "outputs": [], + "source": [ + "from parse_perfetto import analysis_host_issue" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "22660827", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gc_timeline_path = \"/Users/sangbo/local_data/projects/xpu_timer/paper/experiments/figs/trace_dp4.bin\"\n", + "output_dir = \"./bad_gc\"\n", + "group_str = \"tp2-cp1-dp32-pp4\"\n", + "\n", + "\n", + "bad_gc = analysis_host_issue(gc_timeline_path, output_dir, group_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "6ceb77cc", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "no_sync_timeline_path = \"/Users/sangbo/local_data/projects/xpu_timer/paper/experiments/figs/trace_dp4.bin\"\n", + "output_dir = \"./no_sync\"\n", + "group_str = \"tp2-cp1-dp32-pp4\"\n", + "\n", + "no_sync = analysis_host_issue(no_sync_timeline_path, output_dir, group_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "5cc9094f", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "bad_sync_timeline_path = \"/Users/sangbo/local_data/projects/xpu_timer/paper/experiments/figs/trace_dp4.bin\"\n", + "output_dir = \"./bad_sync\"\n", + "group_str = \"tp2-cp1-dp32-pp4\"\n", + "\n", + "bad_sync = analysis_host_issue(bad_sync_timeline_path, output_dir, group_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "752a3fd6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['AllGather', 'ReduceScatter'])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bad_gc.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "429b2860", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['AllGather', 'ReduceScatter'])" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "no_sync.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "505f9e3f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['AllGather', 'ReduceScatter'])" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bad_sync.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "975ac4e9", + "metadata": {}, + "outputs": [], + "source": [ + "# flake8: noqa: E402\n", + "import os\n", + "import random\n", + "from argparse import ArgumentParser\n", + "from collections import OrderedDict\n", + "from pathlib import Path\n", + "from typing import Dict\n", + "\n", + "import matplotlib.colors as mcolors\n", + "import matplotlib.gridspec as gridspec\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from perfetto.trace_processor import TraceProcessor\n", + "from util import GetRankHelper\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "4739001c", + "metadata": {}, + "outputs": [], + "source": [ + "dist_strategy = \"tp2-cp1-dp32-pp4\"\n", + "groups_dict = OrderedDict((pair[:2], int(pair[2:])) for pair in dist_strategy.split(\"-\"))\n", + "rank_helper = GetRankHelper(groups_dict)\n", + "group_ranks = {group: rank_helper.get_ranks(group) for group in groups_dict}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "87f054f4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0,\n", + " 1,\n", + " 2,\n", + " 4,\n", + " 6,\n", + " 8,\n", + " 10,\n", + " 12,\n", + " 14,\n", + " 16,\n", + " 18,\n", + " 20,\n", + " 22,\n", + " 24,\n", + " 26,\n", + " 28,\n", + " 30,\n", + " 32,\n", + " 34,\n", + " 36,\n", + " 38,\n", + " 40,\n", + " 42,\n", + " 44,\n", + " 46,\n", + " 48,\n", + " 50,\n", + " 52,\n", + " 54,\n", + " 56,\n", + " 58,\n", + " 60,\n", + " 62,\n", + " 64,\n", + " 128,\n", + " 192}" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ranks = []\n", + "ranks.extend(group_ranks['tp'][0])\n", + "ranks.extend(group_ranks['dp'][0])\n", + "ranks.extend(group_ranks['pp'][0])\n", + "ranks.sort()\n", + "ranks = set(ranks)\n", + "ranks" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "c0ff862b", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_nccl_host_issue_delay_hist(dfs_list, labels, filter_ranks):\n", + "\n", + " unique_ops = dfs_list[0].keys()\n", + " # Determine the number of subplots needed\n", + " num_plots = len(unique_ops)\n", + " num_cols = 2 # Define the number of columns in the subplot grid\n", + " num_rows = (num_plots + num_cols - 1) // num_cols + 1# Calculate the number of rows needed\n", + " \n", + " fig = plt.figure(figsize=(30, num_rows * 5))\n", + " gs = fig.add_gridspec(num_rows,num_cols)\n", + " axe = fig.add_subplot(gs[0, :])\n", + " \n", + " for i, dfs in enumerate(dfs_list):\n", + " delay = []\n", + " for k,v in dfs.items():\n", + " v = v[v['rank'].isin(filter_ranks)]\n", + " delay.extend(v['delay'].tolist())\n", + " # Step 1: Sort the data\n", + " sorted_delay = np.sort(delay)\n", + "\n", + " # Step 2: Calculate the CDF\n", + " cdf_delay = np.arange(1, len(sorted_delay) + 1) / len(sorted_delay)\n", + "\n", + " # Step 3: Plot the CDF\n", + " axe.plot(sorted_delay, cdf_delay, marker='.', linestyle='-', label=f'{labels[i]}')\n", + "\n", + " axe.set_xlabel('host issue (ms)')\n", + " axe.set_ylabel('Ratio Comulative Distribution (%)')\n", + " axe.set_title(\"CDF of Host Issue\")\n", + " axe.legend(loc=2)\n", + "\n", + " # Create histograms for each unique op\n", + " for i, unique_op in enumerate(unique_ops):\n", + " axe = fig.add_subplot(gs[(i+2)//2, (i+2)%2])\n", + " \n", + " for j, dfs in enumerate(dfs_list):\n", + " delay = []\n", + " for k,v in dfs.items():\n", + " v = v[v['rank'].isin(filter_ranks)]\n", + " v = v[v['op'] == unique_op]\n", + " delay.extend(v['delay'].tolist())\n", + " # Step 1: Sort the data\n", + " sorted_delay = np.sort(delay)\n", + "\n", + " # Step 2: Calculate the CDF\n", + " cdf_delay = np.arange(1, len(sorted_delay) + 1) / len(sorted_delay)\n", + "\n", + " # Step 3: Plot the CDF\n", + " axe.plot(sorted_delay, cdf_delay, marker='.', linestyle='-', label=f'{labels[j]}')\n", + " \n", + " axe.set_xlabel('host issue (ms)')\n", + " axe.set_ylabel('Ratio Comulative Distribution (%)')\n", + " axe.set_title(f\"CDF of {unique_op} Host Issue\")\n", + " axe.legend(loc=2)\n", + "\n", + " \n", + "\n", + " plt.tight_layout()\n", + " plt.grid(True)\n", + " plt.savefig(\"cdf_of_host_issue.svg\", dpi=300)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "ca3e3b1e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_nccl_host_issue_delay_hist([no_sync, bad_sync, bad_gc], labels = [\"best\", \"bad sync\", \"bad gc\"], filter_ranks=ranks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a481096", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8bb12e7-441a-407c-bb5b-fb14ca7af1b6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d87e6f4-b611-48c4-bfba-0bca72dfcf0f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebb138b1-a8a1-43fe-9d88-479cbc260677", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b67030c4-12e2-4f54-ae78-5545e8efffcf", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/xpu_timer/experiments/figs/parse_perfetto.py b/xpu_timer/experiments/figs/parse_perfetto.py new file mode 100644 index 0000000000..5b3e3ad502 --- /dev/null +++ b/xpu_timer/experiments/figs/parse_perfetto.py @@ -0,0 +1,687 @@ +# flake8: noqa: E402 +import os +import random +from argparse import ArgumentParser +from collections import OrderedDict +from pathlib import Path +from typing import Dict + +import matplotlib.colors as mcolors +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from perfetto.trace_processor import TraceProcessor, TraceProcessorConfig +from util import GetRankHelper + +plt.rcParams["figure.dpi"] = 300 + +from types import SimpleNamespace + +nccl_sql = SimpleNamespace( + sql=""" +include perfetto module slices.slices; + +WITH comm_hash + AS (SELECT t1.arg_set_id, + t1.int_value AS hash, + t2.int_value AS seq, + t2.RANK AS rank, + t3.delay as delay + FROM (SELECT DISTINCT arg_set_id, + int_value + FROM args + WHERE KEY = 'debug.comm_hash') t1 + JOIN ( (SELECT DISTINCT arg_set_id, + int_value + FROM args + WHERE KEY = 'debug.seq') seq + JOIN (SELECT DISTINCT arg_set_id, + int_value AS RANK + FROM args + WHERE KEY = 'debug.rank') RANK + ON seq.arg_set_id = RANK.arg_set_id ) t2 + ON t1.arg_set_id = t2.arg_set_id + JOIN (SELECT DISTINCT arg_set_id, int_value AS delay FROM args WHERE key = 'debug.delay(us)') t3 + ON t1.arg_set_id = t3.arg_set_id) +SELECT CAST(id as INT) AS id, + CAST(ts as INT) AS ts, + CAST(dur as INT) AS dur, + name AS name, + CAST(comm_hash.hash as INT) as hash, + CAST(comm_hash.seq as INT) as seq, + CAST(comm_hash.rank as INT) as rank, + CAST(comm_hash.delay as INT) as delay +FROM _slice_with_thread_and_process_info + JOIN comm_hash + ON comm_hash.arg_set_id = +_slice_with_thread_and_process_info.arg_set_id +""", + table_dtype={ + "name": str, + "rank": int, + "ts": int, + "dur": int, + "hash": int, + "id": int, + "seq": int, + "delay": int, + }, +) + +tflops_sql = SimpleNamespace( + sql=""" + include perfetto module slices.slices; + +WITH tflops + AS (SELECT t1.arg_set_id, + t2.rank AS rank, + t1.TFLOPS as TFLOPS, + t3.count as count + FROM (SELECT DISTINCT arg_set_id, real_value as TFLOPS FROM args where key ='debug.TFLOPS') t1 + JOIN (SELECT DISTINCT arg_set_id, int_value as rank FROM args where key ='debug.rank') t2 + ON t1.arg_set_id = t2.arg_set_id + JOIN (SELECT DISTINCT arg_set_id, int_value as count FROM args where key = 'debug.count') t3 + ON t1.arg_set_id = t3.arg_set_id) +SELECT CAST(id as INT) AS id, + CAST(ts as INT) AS ts, + CAST(dur as INT) AS dur, + name AS name, + CAST(tflops.rank as INT) as rank, + CAST(tflops.TFLOPS as DOUBLE) as TFLOPS, + CAST(tflops.count as INT) as count +FROM _slice_with_thread_and_process_info + JOIN tflops + ON tflops.arg_set_id = +_slice_with_thread_and_process_info.arg_set_id +""", + table_dtype={ + "name": str, + "rank": int, + "ts": int, + "dur": int, + "TFLOPS": float, + "id": int, + "count": int, + }, +) + +xccl_sql = SimpleNamespace( + sql=""" +include perfetto module slices.slices; +WITH bandwidth AS ( + SELECT + t1.arg_set_id, + t2.rank AS rank, + t1.Bandwidth AS Bandwidth, + t3.seq AS seq, + MAX(t1.Bandwidth) OVER (PARTITION BY t3.seq) AS max_bandwidth + FROM + (SELECT DISTINCT arg_set_id, real_value AS Bandwidth FROM args WHERE key = 'debug.Bandwidth(GiB/s)') t1 + JOIN + (SELECT DISTINCT arg_set_id, int_value AS rank FROM args WHERE key = 'debug.rank') t2 + ON + t1.arg_set_id = t2.arg_set_id + JOIN + (SELECT DISTINCT arg_set_id, int_value AS seq FROM args WHERE key = 'debug.seq') t3 + ON + t1.arg_set_id = t3.arg_set_id +) +SELECT + CAST(id AS INT) AS id, + CAST(ts AS INT) AS ts, + CAST(dur AS INT) AS dur, + name AS name, + CAST(bandwidth.rank AS INT) AS rank, + CAST(bandwidth.max_bandwidth AS DOUBLE) AS 'Bandwidth', + CAST(bandwidth.seq AS INT) AS seq +FROM + _slice_with_thread_and_process_info +JOIN + bandwidth +ON + bandwidth.arg_set_id = _slice_with_thread_and_process_info.arg_set_id; + """, + table_dtype={ + "name": str, + "rank": int, + "ts": int, + "dur": int, + "Bandwidth": float, + "id": int, + "seq": int, + }, +) + + +class PerfettoParser: + def __init__(self, trace): + conf = TraceProcessorConfig( + bin_path="/Users/sangbo/Downloads/software/mac-arm64/trace_processor_shell", + ) + self.trace = TraceProcessor(trace=trace, config=conf) + + def parse(self, sql_spec): + qr = self.trace.query(sql_spec.sql) + df = qr.as_pandas_dataframe() + for key in df.keys(): + df[key] = df[key].astype(sql_spec.table_dtype[key]) + return df + + +def diff_performance(diff_data: Dict[str, pd.DataFrame], key: str, ax, colors, density=True, title=None): + + df = pd.concat([df[[key]].add_suffix(f"({name})") for name, df in diff_data.items()], axis=1) + df.plot( + ax=ax, + kind="hist", + bins=100, + alpha=0.5, + color=colors, + edgecolor="black", + density=density, + ) + ax.set_xlabel(key) + ax.set_ylabel("Density" if density else "Frequency") + ax.set_title(f"{title} Distribution Comparison") + ax.legend() + ax.grid(True) + + +def analysis_launch_time(trace): + new_data = trace[["name", "hash", "seq", "ts"]] + group_data = {k: v.drop(columns=["hash"]) for k, v in new_data.groupby(["hash", "seq"])} + + def parse_one(frame): + first_launch = frame.ts.min() + last_launch = frame.ts.max() + frame["relative_ts_ms"] = (frame.ts - first_launch) / 1e6 + frame = frame.loc[frame["relative_ts_ms"] != 0] + + return frame + + launch_time_diff = {k: parse_one(v) for k, v in group_data.items()} + df = pd.concat(list(launch_time_diff.values()), ignore_index=True) + return df + + +def analysis_dur(trace): + new_data = trace[["name", "hash", "seq", "dur"]] + group_data = {k: v.drop(columns=["hash"]) for k, v in new_data.groupby(["hash", "seq"])} + + def parse_one(frame): + frame["dur_ms"] = frame.dur / 1e6 + return frame + + launch_time_diff = {k: parse_one(v) for k, v in group_data.items()} + df = pd.concat(list(launch_time_diff.values()), ignore_index=True) + return df + + +def plot_diff(named_path, image_path): + colors = list(random.sample(mcolors.TABLEAU_COLORS.keys(), len(named_path))) + + fig = plt.figure(figsize=(24, 16), dpi=300) + gs = gridspec.GridSpec(3, 2, height_ratios=[5, 2, 2]) + + nccl_launch_time_ax = fig.add_subplot(gs[0, 0]) + nccl_dur_time_ax = fig.add_subplot(gs[0, 1]) + matmul_ax = fig.add_subplot(gs[1:, :]) + + nccl_launch_diff_time = {} + nccl_dur_time = {} + matmul_dir_time = {} + for name, path in named_path.items(): + trace = PerfettoParser(trace=path) + nccl_info = trace.parse(nccl_sql) + matmul_dir_time[name] = trace.parse(tflops_sql) + nccl_dur_time[name] = analysis_dur(nccl_info) + nccl_launch_diff_time[name] = analysis_launch_time(nccl_info) + + diff_performance(nccl_dur_time, "dur_ms", nccl_dur_time_ax, colors, title="NCCL dur(ms)") + diff_performance( + nccl_launch_diff_time, + "relative_ts_ms", + nccl_launch_time_ax, + colors, + title="NCCL launch diff(ms)", + ) + diff_performance(matmul_dir_time, "TFLOPS", matmul_ax, colors, title="Matmul TFLOPS ") + + plt.tight_layout() + + plt.show() + fig.savefig(f"{image_path}/performance_diff.svg", dpi=300) + + +def plot_tflops_box(path, image_path): + trace = PerfettoParser(trace=path) + matmul = trace.parse(tflops_sql) + matmul = matmul[["rank", "TFLOPS"]] + + matmul.boxplot( + column="TFLOPS", + by="rank", + color=dict(boxes="r", whiskers="r", medians="r", caps="r"), + boxprops=dict(linestyle="-", linewidth=1.5), + flierprops=dict(linestyle="-", linewidth=1.5), + medianprops=dict(linestyle="-", linewidth=1.5), + whiskerprops=dict(linestyle="-", linewidth=1.5), + capprops=dict(linestyle="-", linewidth=1.5), + showfliers=False, + grid=True, + rot=0, + ) + plt.title("Box Plot of TFLOPS by Rank") + plt.suptitle("") + plt.xlabel("Rank") + plt.ylabel("TFLOPS") + plt.show() + plt.savefig(image_path, dpi=300) + + +def plot_xccl_box(path, image_path): + trace = PerfettoParser(trace=path) + xccl = trace.parse(xccl_sql) + xccl = xccl[["rank", "Bandwidth"]] + + xccl.boxplot( + column="Bandwidth", + by="rank", + color=dict(boxes="r", whiskers="r", medians="r", caps="r"), + boxprops=dict(linestyle="-", linewidth=1.5), + flierprops=dict(linestyle="-", linewidth=1.5), + medianprops=dict(linestyle="-", linewidth=1.5), + whiskerprops=dict(linestyle="-", linewidth=1.5), + capprops=dict(linestyle="-", linewidth=1.5), + showfliers=False, + grid=True, + rot=0, + ) + plt.title("Box Plot of Bandwidth by Rank") + plt.suptitle("") + plt.xlabel("Rank") + plt.ylabel("Bandwidth(GiB/s)") + plt.show() + plt.savefig(image_path, dpi=300) + + +def plot_nccl_host_issue_delay_hist(path, save_file_name, max_delay=1e5, min_duration=0.005): + trace = PerfettoParser(trace=path) + nccl = trace.parse(nccl_sql) + nccl = nccl[["name", "rank", "delay", "seq", "hash", "dur"]] + nccl = nccl[nccl["delay"] < max_delay] + nccl = nccl[nccl["dur"] > min_duration] + nccl["op"] = nccl["name"].apply(lambda x: x.split("_")[0]) + nccl["delay"] = nccl["delay"].apply(lambda x: x / 1000) + unique_ops = nccl["op"].unique() + # Determine the number of subplots needed + num_plots = len(unique_ops) + num_cols = 2 # Define the number of columns in the subplot grid + num_rows = (num_plots + num_cols - 1) // num_cols # Calculate the number of rows needed + +# fig, axes = plt.subplots(num_rows+1, num_cols, figsize=(30, num_rows * 5)) + fig = plt.figure(figsize=(30, num_rows * 5)) + gs = fig.add_gridspec(num_rows+1,num_cols) + axe = fig.add_subplot(gs[0, :]) + + data = nccl["delay"] + # Create the histogram and get the counts and bin edges + counts, bins = np.histogram(data, bins=10, density=False) + + # Calculate the bin widths + bin_widths = np.diff(bins) + + # Convert counts to percentages + percentages = counts / counts.sum() * 100 + + # Plot the bar chart with percentages + axe.bar(bins[:-1], percentages, width=bin_widths, edgecolor="black", alpha=0.75, align="edge") + + axe.set_xlabel("delay(ms)") + axe.set_ylabel("Percentage") + axe.set_title(f"Histogram All Operators with Percentages") + + dfs = {} + # Create histograms for each unique op + for i, unique_op in enumerate(unique_ops): + axe = fig.add_subplot(gs[(i+2)//2, (i+2)%2]) + subset = nccl[nccl["op"] == unique_op] + dfs[unique_op] = subset + data = subset["delay"] + # Create the histogram and get the counts and bin edges + counts, bins = np.histogram(data, bins=10, density=False) + + # Calculate the bin widths + bin_widths = np.diff(bins) + + # Convert counts to percentages + percentages = counts / counts.sum() * 100 + + # Plot the bar chart with percentages + axe.bar(bins[:-1], percentages, width=bin_widths, edgecolor="black", alpha=0.75, align="edge") + + axe.set_xlabel("delay(ms)") + axe.set_ylabel("Percentage") + axe.set_title(f"Histogram {unique_op} with Percentages") + + + dfs_dir = f"./{save_file_name}/host_issue" + if not os.path.exists(dfs_dir): + os.makedirs(dfs_dir) + nccl.to_csv(f"{dfs_dir}/host_issue_all-op.csv") + for op_name, op_df in dfs.items(): + op_df.to_csv(f"{dfs_dir}/host_issue-{op_name}.csv") + + image_path = f"{dfs_dir}/host_issue.svg" + plt.tight_layout() + plt.savefig(image_path, dpi=300) + plt.show() + return dfs + + +def analysis_nccl_kernel_run_duration_ratio(trace): + new_data = trace[ + [ + "name", + "hash", + "seq", + "dur", + "ts", + "rank", + ] + ] + group_data = {k: v for k, v in new_data.groupby(["hash", "seq"])} + + def parse_one(frame): + frame["dur_ms"] = frame.dur / 1e6 + first_launch = frame.ts.min() + last_launch = frame.ts.max() + frame["relative_ts_ms"] = (frame.ts - first_launch) / 1e6 + kernel_time = frame.dur_ms.min() + ratio_time = frame["relative_ts_ms"] / kernel_time + frame["ratio_dur_diff"] = ratio_time + frame["kernel_time"] = kernel_time + return frame + + launch_time_diff = {k: parse_one(v) for k, v in group_data.items()} + df = pd.concat(list(launch_time_diff.values()), ignore_index=True) + return df + +def plot_nccl_host_issue_delay_seq_in_same_communicator( + df_dict, image_dir, delay=10, filter_op=["HcclAllGather"], filter_rank=[0, 1] +): + if not os.path.exists(image_dir): + os.makedirs(image_dir) + + def aggregate_ranks(x): + return np.array(set(x)) + + def is_in_filter_rank(x): + cur_rank = x['rank_set'] + if "Send" in x['name'] or "Recv" in x['name']: + return set(cur_rank).issubset(set(filter_rank)) + else: + return set(cur_rank) == set(filter_rank) + + + for i, (op, df) in enumerate(df_dict.items()): + if op not in filter_op: + continue + + df['rank_set'] = df['rank'] + df['rank_set'] = df.groupby(['name', 'hash', 'seq'])['rank_set'].transform(aggregate_ranks) + + df['is_in_filter_rank'] = df.apply(is_in_filter_rank, axis=1) + grouped = df[df['is_in_filter_rank'] == True] + grouped = grouped[grouped["delay"] <= delay] + grouped = grouped.groupby(["name", "hash", "rank"]) + + + num_plots = len(grouped) + if num_plots == 0: + continue + num_cols = 2 # Define the number of columns in the subplot grid + num_rows = (num_plots + num_cols - 1) // num_cols # Calculate the number of rows needed + + fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 5)) + axes = axes.flatten() # Flatten in case axes is a 2D array + for ax in axes: + ax.axis('off') + axe_num = 0 + for (name, hash_value, rank), group in grouped: + ax = axes[axe_num] + ax.axis('on') + group = group.sort_values(by="seq") + group = group.reset_index(drop=True) + y_label = group["seq"] + x_label = [i for i in range(len(y_label))] + + # Plot the sequences + ax.plot(x_label, y_label, label=f"{name} (rank {rank}) (hash {hash_value})", marker="o", linestyle="-") + + axe_num += 1 + # Add a legend + ax.legend() + + # Add labels and title + ax.set_xlabel("Index") + ax.set_ylabel("Sequence Value") + ax.set_title(f"Sequences by Name and Rank of {op}") + group.to_csv(f"{image_dir}/{name}-rank_{rank}-hash_{hash_value}.csv") + + # Show the plot + # Adjust layout + plt.tight_layout() + + # Save the image + plt.savefig(os.path.join(image_dir, f"{name}-op_{op}.svg")) + + # Show the figure with all subplots + plt.show() + +def plot_nccl_kernel_run_duration_ratio(path, image_path, threshold=1e5): + # threshold is for hpu, because api is not stable + trace = PerfettoParser(trace=path) + nccl = trace.parse(nccl_sql) + nccl = analysis_nccl_kernel_run_duration_ratio(nccl) + nccl["op"] = nccl["name"].apply(lambda x: x.split("_")[0]) + unique_ops = nccl["op"].unique() + unique_ops = [item for item in unique_ops if "Send" not in item and "Recv" not in item] + + # Determine the number of subplots needed + num_plots = len(unique_ops) + num_cols = 2 # Define the number of columns in the subplot grid + num_rows = (num_plots + num_cols - 1) // num_cols # Calculate the number of rows needed + + fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 5)) + axes = axes.flatten() # Flatten in case axes is a 2D array + + dfs = {} + for i, op in enumerate(unique_ops): + axe = axes[i] + df = nccl[nccl["op"] == op] + + # Create 10 bins based on 'dur_ms' + bins = pd.cut(df["dur_ms"], bins=10) + + # Get the midpoint of each bin + bin_midpoints = bins.apply(lambda x: int(x.mid)) + + # Assign these midpoints back to the DataFrame + df.loc[:, "dur_bins_mid"] = bin_midpoints + dfs[op] = df + + df.boxplot( + column="ratio_dur_diff", + by="dur_bins_mid", + ax=axes[i], + color=dict(boxes="r", whiskers="r", medians="r", caps="r"), + boxprops=dict(linestyle="-", linewidth=1.5), + flierprops=dict(linestyle="-", linewidth=1.5), + medianprops=dict(linestyle="-", linewidth=1.5), + whiskerprops=dict(linestyle="-", linewidth=1.5), + capprops=dict(linestyle="-", linewidth=1.5), + ) + axe.set_title(f"Box {op} Plot") + axe.set_xlabel("kernel duration(ms)") + axe.set_ylabel("lanuch_time_diff/kernel_duration") + axe.legend() + axe.grid(True) + plt.tight_layout() + fig.suptitle("") + plt.show() + return dfs + + +def plot_nccl_kernel_run_duration_ratio_longtail( + df_dict, image_dir, filter_kernel_duration=100, filter_ratio=10, filter_op=["HcclAllGather"], filter_rank=[0, 1] +): + if not os.path.exists(image_dir): + os.makedirs(image_dir) + + for i, (op, df) in enumerate(df_dict.items()): + if op not in filter_op: + continue + df = df[(df["ratio_dur_diff"] >= filter_ratio)] + df = df[df["rank"].isin(filter_rank)] + grouped = df.groupby(["name", "rank", "hash"]) + + num_plots = len(grouped) + num_cols = 2 # Define the number of columns in the subplot grid + num_rows = (num_plots + num_cols - 1) // num_cols # Calculate the number of rows needed + + fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 5)) + axes = axes.flatten() # Flatten in case axes is a 2D array + axe_num = 0 + for (name, rank, hash_value), group in grouped: + ax = axes[axe_num] + group = group.reset_index(drop=True) + y_label = group["seq"] + x_label = [i for i in range(len(y_label))] + + # Plot the sequences + ax.plot(x_label, y_label, label=f"{name} (rank {rank}) (hash {hash_value})", marker="o", linestyle="-") + + axe_num += 1 + # Add a legend + ax.legend() + + # Add labels and title + ax.set_xlabel("Index") + ax.set_ylabel("Sequence Value") + ax.set_title(f"Sequences by Name and Rank of {op}") + group.to_csv(f"{name}-rank_{rank}-hash_{hash_value}.csv") + + # Show the plot + # Adjust layout + plt.tight_layout() + + # Save the image + plt.savefig(os.path.join(image_dir, f"{name}-op_{op}.svg")) + + # Show the figure with all subplots + plt.show() + +def analysis_host_issue(timeline_path, output_dir, dist_strategy): + groups_dict = OrderedDict((pair[:2], int(pair[2:])) for pair in dist_strategy.split("-")) + rank_helper = GetRankHelper(groups_dict) + group_ranks = {group: rank_helper.get_ranks(group) for group in groups_dict} + dfs = plot_nccl_host_issue_delay_hist(timeline_path, + output_dir, + ) + return dfs + for parallel_method, parallel_size in group_ranks.items: + if len(parallel_size[0]) == 1: + continue + + print(f"Analysis {parallel_method} Group: {parallel_size[0]}") + plot_nccl_host_issue_delay_seq_in_same_communicator(dfs, + delay=0.05, + image_dir=f"{output_dir}/{parallel_method}", + filter_op=dfs.keys(), + filter_rank=parallel_size[0]) + +def plot_tflops_grouped_by_operation(timeline_path, image_path, filter_rank=[0]): + trace = PerfettoParser(trace=timeline_path) + matmul = trace.parse(tflops_sql) + # Loop over each name and rank to plot them separately + matmul = matmul[matmul['rank'].isin(filter_rank)] + grouped = matmul.groupby(['name', 'rank']) + num_plots = len(grouped) + + + num_cols = 2 # Define the number of columns in the subplot grid + num_rows = (num_plots + num_cols - 1) // num_cols # Calculate the number of rows needed + + fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 5)) + axes = axes.flatten() # Flatten in case axes is a 2D array + for ax in axes: + ax.axis('off') + + axe_num = 0 + for (name, rank), group in grouped: + ax = axes[axe_num] + ax.axis('on') + group = group.sort_values(by="count") + group = group.reset_index(drop=True) + y_label = group["TFLOPS"] + x_label = [i for i in range(len(y_label))] + + # Plot the sequences + ax.plot(x_label, y_label, label=f"{name} (rank {rank})", marker="o", linestyle="-") + + axe_num += 1 + # Add a legend + ax.legend() + + # Add labels and title + ax.set_xlabel("Count") + ax.set_ylabel("TFLOPS") + ax.set_title(f"Sequences by Name and Rank of {name}") + + plt.savefig(image_path, dpi=300) + plt.tight_layout() + plt.show() + + +def main(): + parser = ArgumentParser(usage="""python parse_perfetto.py""") + parser.add_argument("--path", action="append", required=True, help="Trace path for diff") + parser.add_argument("--name", action="append", required=False, help="Trace name for diff") + parser.add_argument( + "--type", + default="tflops-box", + choices=["tflops-box", "xccl-box", "performance-diff"], + ) + parser.add_argument("--output", default="./") + + args = parser.parse_args() + dir_path = Path(args.output) + dir_path.mkdir(parents=True, exist_ok=True) + if args.type == "tflops-box": + for p in args.path: + plot_tflops_box(p, f"{args.output}/{Path(p).name.strip('.bin')}_tflops_boxplot.svg") + elif args.type == "xccl-box": + for p in args.path: + plot_xccl_box(p, f"{args.output}/{Path(p).name.strip('.bin')}_xccl_boxplot.svg") + elif args.type == "performance-diff": + diff_data = {} + if args.name: + if len(args.name) != len(args.path): + raise ValueError( + f"When set --name, length of --name and --path must be same, {len(args.name)} vs {len(args.path)}" + ) + diff_data = dict(zip(args.name, args.path)) + else: + diff_data = {Path(p).name.strip(".bin"): p for p in args.path} + plot_diff(diff_data, args.output) + elif args.type == "host-issue": + timeline_path = args.path[0] + output_dir = args.output + dist_strategy = args.dist + analysis_host_issue(timeline_path, output_dir, dist_strategy) + + +if __name__ == "__main__": + main() diff --git a/xpu_timer/experiments/figs/run_plugin.py b/xpu_timer/experiments/figs/run_plugin.py new file mode 100644 index 0000000000..ad74080566 --- /dev/null +++ b/xpu_timer/experiments/figs/run_plugin.py @@ -0,0 +1,67 @@ +# flake8: noqa: E501,E722,F841,E401 +import os +import runpy +import sys +import time +from pathlib import Path + +PY_EXCEPTION_FN = "xpu_timer_parse_python_exception" +CPP_EXCEPTION_FN = "xpu_timer_parse_cpp_exception" + + +def load_plugins(plugin_paths, pattern): + fns = [] + for plugin in plugin_paths: + try: + plugin_namespace = runpy.run_path(plugin) + if pattern in plugin_namespace: + fns.append(plugin_namespace[pattern]) + except: + print(f"[XPU_TIMER] error when load {plugin}", file=sys.stderr) + # maybe file not found or import error + continue + return fns + + +def xpu_timer_parse_python_exception(exc_type, exc_value, exc_traceback): + + if exc_type is KeyboardInterrupt: + return + job_infos = {} + job_infos["time"] = int(time.time()) + job_infos["pod_name"] = os.environ.get("POD_NAME", "UNKNOWN") + job_infos["job_name"] = os.environ.get("ENV_ARGO_WORKFLOW_NAME", "UNKNOWN") + job_infos["ip"] = os.environ.get("POD_IP", "UNKNOWN") + job_infos["rank"] = int(os.environ.get("RANK", "-1")) + + plugin_paths = [] + path_dir = Path(__file__).parent + plugin_from_env = os.environ.get("XPU_TIMER_EXIT_HOOK_PLUGIN", None) + plugin_paths.append(path_dir / "dlrover_parse_exception.py") + if plugin_from_env is not None: + plugin_paths.extend(plugin_from_env.split(",")) + fns = load_plugins(plugin_paths, PY_EXCEPTION_FN) + for fn, plugin in zip(fns, plugin_paths): + try: + fn(exc_type, exc_value, exc_traceback, job_infos) + except: + # ignore all exceptions + print(f"[XPU_TIMER] error when running {plugin}", file=sys.stderr) + continue + + +def xpu_timer_parse_cpp_exception(stack_infos): + plugin_paths = [] + path_dir = Path(__file__).parent + plugin_from_env = os.environ.get("XPU_TIMER_EXIT_HOOK_PLUGIN", None) + plugin_paths.append(path_dir / "dlrover_parse_exception.py") + if plugin_from_env is not None: + plugin_paths.extend(plugin_from_env.split(",")) + fns = load_plugins(plugin_paths, CPP_EXCEPTION_FN) + for fn, plugin in zip(fns, plugin_paths): + try: + fn(stack_infos) + except: + # ignore all exceptions + print(f"[XPU_TIMER] error when running {plugin}", file=sys.stderr) + continue diff --git a/xpu_timer/experiments/figs/stack_viewer.py b/xpu_timer/experiments/figs/stack_viewer.py new file mode 100644 index 0000000000..5ce596d0fe --- /dev/null +++ b/xpu_timer/experiments/figs/stack_viewer.py @@ -0,0 +1,140 @@ +import argparse +import os +from pathlib import Path + +from py_xpu_timer import hosting_service_pb2 # type: ignore[attr-defined] + + +class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_stack = False + self.ranks = set() + + def add_rank(self, rank): + self.ranks.add(rank) + + +class StackTrie: + def __init__(self, all_ranks): + self.root = TrieNode() + self.all_ranks = all_ranks + + def insert(self, words, rank): + node = self.root + for word in words: + if word not in node.children: + node.children[word] = TrieNode() + node = node.children[word] + node.ranks.add(rank) + node.is_end_of_stack = True + node.add_rank(rank) + + def _format_rank_str(self, ranks): + + leak_ranks = list(self.all_ranks - set(ranks)) + ranks = list(ranks) + + def _inner_format(ranks): + """fold continuous ranks, [0,1,2,5,6,7]->[0-2,5-7] + return has stack and leak stack, suppose we have 8 ranks(0-7) + [0,1,2,5,6,7]->0-2/5-7|3-4, means rank 0-2,5-7 has this stacktrace, + while rank 3-4 do not have this stacktrace + """ + ranks = sorted(ranks) + str_buf = [] + low = 0 + high = 0 + total = len(ranks) + while high < total - 1: + low_value = ranks[low] + high_value = ranks[high] + while high < total - 1 and high_value + 1 == ranks[high + 1]: + high += 1 + high_value = ranks[high] + low = high + 1 + high += 1 + if low_value != high_value: + str_buf.append(f"{low_value}-{high_value}") + else: + str_buf.append(str(low_value)) + if high == total - 1: + str_buf.append(str(ranks[high])) + return "/".join(str_buf) + + has_stack_ranks = _inner_format(ranks) + leak_stack_ranks = _inner_format(leak_ranks) + return f"@{'|'.join([has_stack_ranks, leak_stack_ranks])}" + + def _traverse_with_all_stack(self, node, path): + for word, child in node.children.items(): + rank_str = self._format_rank_str(child.ranks) + if child.is_end_of_stack: + yield ";".join(path + [word]) + rank_str + word += rank_str + yield from self._traverse_with_all_stack(child, path + [word]) + + def __iter__(self): + yield from self._traverse_with_all_stack(self.root, []) + + +class StackViewer: + def __init__(self, path): + p = Path(path) + self.path = path + self.files = sorted(p.glob("*stacktrace")) + if not self.files: + print(f"no stacktrace files in {path}") + exit(1) + # files format is 00003-00008.stacktrace + self.world_size = int(self.files[0].name[6:11]) + self.all_ranks = set(range(self.world_size)) + + self._parse("cpp") + self._parse("py") + + def _parse(self, mode): + self.stack_trie = StackTrie(self.all_ranks) + for f in self.files: + self._parse_one(f, mode) + with open(f"{self.path}/{mode}_stack", "w") as f: + for stack in self.stack_trie: + f.write(f"{stack} 1\n") + os.system( + "flamegraph.pl --color python --width 1600 --title " + f"'merge stack in {mode}' < {self.path}/{mode}_stack " + f"> {self.path}/{mode}_stack.svg" + ) + + def _frame_hash(self, stracetrace, rank): + for i in stracetrace: + buf = [] + for index, frame in enumerate(i.frames[::-1]): + func_file_name = f"{frame.func_name}@{frame.file_name}" + buf.append(func_file_name) + self.stack_trie.insert(buf, rank) + + def _parse_one(self, path, mode): + st = hosting_service_pb2.Stacktrace() + # 00003-00008.stacktrace + rank = int(path.name[:5]) + with open(path, "rb") as f: + st.ParseFromString(f.read()) + if st.pstack_stderr: + print(st.pstack_stderr) + self.stack_trie.insert([f"State@{st.process_state}"], rank) + if mode == "cpp": + self._frame_hash(st.stacktrace, rank) + else: + self._frame_hash(st.py_stacktrace, rank) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--path", "-p", type=str, default="./") + args = parser.parse_args() + StackViewer(args.path) + + +if __name__ == "__main__": + main() diff --git a/xpu_timer/experiments/figs/util.py b/xpu_timer/experiments/figs/util.py new file mode 100644 index 0000000000..9c68e419ec --- /dev/null +++ b/xpu_timer/experiments/figs/util.py @@ -0,0 +1,64 @@ +import concurrent.futures +import time +import typing +from typing import Dict, Set + +import numpy as np +from tqdm import tqdm + + +def parallel_job(fn, items, desc, concurrency=32): + SLEEP = 0.2 + futures: Set[concurrent.futures.Future] = set() + items: Set[concurrent.futures.Future] = set(items) + objs = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as e: + with tqdm(total=len(items), desc=desc) as bar: + while futures or items: + done = set() + added = set() + for item in items: + futures.add(e.submit(fn, item)) + added.add(item) + if len(futures) > concurrency: + break + for future in futures: + if future.done(): + obj = future.result() + objs.append(obj) + done.add(future) + bar.update(1) + futures -= done + items -= added + time.sleep(SLEEP) + return objs + + +class GetRankHelper: + def __init__(self, groups: typing.OrderedDict[str, int]): + self.world_size = 1 + self.name_to_axis: Dict[str, int] = {} + order = [] + total_dim = len(groups) + for index, (group_name, value) in enumerate(groups.items()): + self.world_size = self.world_size * value + self.name_to_axis[group_name] = total_dim - index - 1 + order.append(value) + order = order[::-1] + self.ranks = np.arange(self.world_size) + self.ranks = self.ranks.reshape(order) + + def get_ranks(self, group, group_0=False): + axis = self.name_to_axis[group] + result = [] + strides = np.array(self.ranks.strides) // self.ranks.itemsize + shape = self.ranks.shape[axis] + skip = strides[axis] + index = [slice(None)] * self.ranks.ndim + index[axis] = 0 + first = self.ranks[tuple(index)].reshape(-1) + if group_0: + return np.array(range(first[0], first[0] + skip * shape, skip)) + for start in first: + result.append(list(range(start, start + skip * shape, skip))) + return np.array(result) diff --git a/xpu_timer/experiments/fsdp/train.sh b/xpu_timer/experiments/fsdp/train.sh new file mode 100644 index 0000000000..5b6baa91fc --- /dev/null +++ b/xpu_timer/experiments/fsdp/train.sh @@ -0,0 +1,35 @@ +#!/bin/sh +#****************************************************************# +# ScriptName: train.sh +# Author: $SHTERM_REAL_USER@alibaba-inc.com +# Create Date: 2024-12-26 10:27 +# Modify Author: $SHTERM_REAL_USER@alibaba-inc.com +# Modify Date: 2025-08-07 14:48 +# Function: +#***************************************************************# +# export XPU_TIMER_SM_COUNT=20 +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 +# export NCCL_MAX_NCHANNELS=20 +#nsys profile --stats true -w true -t cuda,nvtx,osrt,cudnn,cublas xpu_timer_launch python -m torch.distributed.launch --nproc_per_node=6 train_llama.py +# xpu_timer_launch +export CUDA_VISIBLE_DEVICES=4,5,6,7 +export XPU_TIMER_DEBUG_MODE=1 +export XPU_TIMER_BASEPORT=28888 +export NCCL_DEBUG=WARN +export WORLD_SIZE=4 +export LOCAL_WORLD_SIZE=4 + +# export GLOG_v=5 + +# CUDA_DEVICE_MAX_CONNECTIONS=1 TORCH_NCCL_ENABLE_TIMING=1 +xpu_timer_launch python -m torch.distributed.launch --nnodes=1 --nproc_per_node=4 train_llama.py + + +# WORLD_SIZE=${WORLD_SIZE:-$WORKER_NUM} +# +# pip show atorch +# +# NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) +# +# python -m torch.distributed.run --nnodes=2 --nproc_per_node=$NUM_GPUS_PER_NODE --master-addr aistudio-zwdx9wmm-edljob-worker-0 --master-port 24444 --node-rank $RANK train_llama.py +# diff --git a/xpu_timer/experiments/fsdp/train_llama.py b/xpu_timer/experiments/fsdp/train_llama.py new file mode 100644 index 0000000000..75a7a929f5 --- /dev/null +++ b/xpu_timer/experiments/fsdp/train_llama.py @@ -0,0 +1,912 @@ +import os +import time +import functools + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Dataset +from transformers import GPTNeoXConfig, GPTNeoXForCausalLM +from transformers import LlamaConfig, LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + checkpoint_wrapper +) +from contextlib import nullcontext + + +def human_readable_flops(num): + for unit in [ + "", + "KFLOPS", + "MFLOPS", + "GFLOPS", + "TFLOPS", + "PFLOPS", + "EFLOPS", + "ZFLOPS", + ]: + if abs(num) < 1000.0: + return "%3.3f%s" % (num, unit) + num /= 1000.0 + return "%.3f%s" % (num, "Yi") + + +def compute_training_flops( + batch_size, + sequence_length, + hidden_size, + vocab_size, + intermediate_size, + num_layers, + use_gradient_checkpointing=False, + use_peft=False, + use_gqa=False, + kv_head_ratio=1, +): + """Returns: + hardware flops + model flops + + The source of formula: + Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM's + (APPENDIX: FLOATING-POINT OPERATIONS) + + Assuming that backward pass has twice FLOPs as many as forward pass. Only matrix multiplication FLOPs are computed. + For use_peft, backward pass FLOPS is a little more than the forward pass. Assuming equal for simplicity here. + """ + # [b,s,n] -> [b,s,n] + query_proj_flops = batch_size * 2 * sequence_length * hidden_size**2 + if use_gqa: + key_value_proj_flops = ( + 2 + * batch_size + * 2 + * sequence_length + * hidden_size + * hidden_size + / kv_head_ratio + ) + else: + key_value_proj_flops = 2 * query_proj_flops + attention_proj_flops = query_proj_flops + key_value_proj_flops + attention_flops = ( + 2 * batch_size * hidden_size * sequence_length**2 + + 4 * batch_size * sequence_length * hidden_size**2 + ) + attention_forward_flops = attention_proj_flops + attention_flops + # llama2 use gate_proj, has 3 Linears + two_mlps_forward_flops = ( + 3 * 2 * batch_size * sequence_length * hidden_size * intermediate_size + ) + logits_forward_flops = 2 * batch_size * sequence_length * hidden_size * vocab_size + decoder_layer_forward_flops = attention_forward_flops + two_mlps_forward_flops + # forward FLOPs without gradient checkpointing + forward_flops_wo_gc = ( + num_layers * decoder_layer_forward_flops + logits_forward_flops + ) + factor = 2 if use_peft else 3 + if not use_gradient_checkpointing: + return forward_flops_wo_gc * factor, forward_flops_wo_gc * factor + else: + return ( + num_layers * decoder_layer_forward_flops * (factor + 1) + + logits_forward_flops * factor, + forward_flops_wo_gc * factor, + ) + + + + +def apply_fsdp_checkpointing(model, blocks): + wrapper = lambda m: checkpoint_wrapper(m, + checkpoint_fn=torch.utils.checkpoint.checkpoint, + use_reentrant=False, + preserve_rng_state=True) + check_fn = lambda submodule: isinstance(submodule, blocks) + apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) + + +class DummyDataset(Dataset): + def __init__(self, vocab_size=1000, max_length=128, data_size=100000): + self.vocab_size = vocab_size + self.max_length = max_length + self.data_size = data_size + + def __len__(self): + return self.data_size + + def __getitem__(self, idx): + text = torch.randint(low=0, high=self.vocab_size, size=(self.max_length,)) + return text, text + + +def main(): + # Initialize the process group + dist.init_process_group(backend="nccl") + + # Get local rank and world size + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + num_layers = 10 + hidden_size = 4096 + intermediate_size = 8192 + vocab_size = 126464 + num_head = 64 + num_kv_head = 8 + batch_size = 2 + seq_length = 4096 + kv_head_ratio = num_head // num_kv_head + torch.cuda.set_device(local_rank) + + + config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_head, + num_key_value_heads=num_kv_head, + intermediate_size=intermediate_size, + max_position_embeddings=seq_length, + initializer_range=0.02, + layer_norm_eps=1e-5, + # attn_implementation="flash_attention_2", + use_cache=False, + use_bfloat16=True + ) + + #init_device = "cpu" if local_rank == 0 else "meta" + init_device = "meta" + + # from liger_kernel.transformers import apply_liger_kernel_to_llama + # apply_liger_kernel_to_llama( + # rope=True, + # swiglu=True, + # cross_entropy=True, + # fused_linear_cross_entropy=False, + # rms_norm=True + # ) + + with torch.device(init_device): + model = LlamaForCausalLM(config) + + + flop, _ = compute_training_flops( + batch_size, + seq_length, + hidden_size, + vocab_size, + intermediate_size, + num_layers, + use_gradient_checkpointing=True, + use_gqa=True, + kv_head_ratio=kv_head_ratio, + ) + + + dataset = DummyDataset(vocab_size=vocab_size, max_length=seq_length) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler) + + #param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) if local_rank != 0 else None + param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) + wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer,},) + #model = model.to(dtype=torch.bfloat16) + + model = FSDP(model, device_id=local_rank, auto_wrap_policy=wrap_policy, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + sync_module_states=False, param_init_fn=param_init_fn, + forward_prefetch=True, limit_all_gathers=True, use_orig_params=True) + + apply_fsdp_checkpointing(model, LlamaDecoderLayer) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + # Training Loop + def save_profile(prof): + prof.export_chrome_trace(f"fsdp_trace_{rank}.json") + + epoch = 0 + iters = 0 + prof = torch.profiler.profile( + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=1000000, + repeat=1), + on_trace_ready=save_profile, + record_shapes=False, + with_stack=False) + + #prof = nullcontext() + # prof.start() + + dur = [] + model.train() + for input_ids, labels in dataloader: + input_ids, labels = input_ids.to(local_rank), labels.to(local_rank) + start = time.time() + optimizer.zero_grad() + loss = model(input_ids=input_ids, labels=labels).loss + loss.backward() + optimizer.step() + torch.cuda.synchronize() + if rank == 0: + dur = time.time() - start + tflops = flop / dur / 1e12 + print(f"Epoch {epoch}, Loss: {loss.item()} time {dur} tflops {tflops}") + iters += 1 + # if iters > 10: + # break + # prof.step() + # prof.stop() + + print("Training Complete") + dist.destroy_process_group() + +def main_ds(): + import deepspeed + dist.init_process_group(backend="nccl") + + # Get local rank and world size + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + num_layers = 20 + hidden_size = 8192//4 + intermediate_size = 28672 + vocab_size = 126464 + num_head = 64 + num_kv_head = 8 + batch_size = 2 + seq_length = 4096 + kv_head_ratio = num_head // num_kv_head + torch.cuda.set_device(local_rank) + + flop, _ = compute_training_flops( + batch_size, + seq_length, + hidden_size, + vocab_size, + intermediate_size, + num_layers, + use_gradient_checkpointing=True, + use_gqa=True, + kv_head_ratio=kv_head_ratio, + ) + + + config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_head, + num_key_value_heads=num_kv_head, + intermediate_size=intermediate_size, + max_position_embeddings=seq_length, + initializer_range=0.02, + layer_norm_eps=1e-5, + attn_implementation="flash_attention_2", + use_cache=False, + use_bfloat16=True + ) + + #init_device = "cpu" if local_rank == 0 else "meta" + init_device = "meta" + + from liger_kernel.transformers import apply_liger_kernel_to_llama + apply_liger_kernel_to_llama( + rope=True, + swiglu=True, + cross_entropy=True, + fused_linear_cross_entropy=False, + rms_norm=True + ) + + ds_config = { + "train_batch_size": batch_size * world_size, + "train_micro_batch_size_per_gpu": batch_size, + #"steps_per_print": 10, + "zero_optimization": { + "stage": 3, + "overlap_comm": True, + }, + "bf16": { + "enabled": True, + }, + "activation_checkpointing": { + "partition_activations": True, # Partition activations across GPUs + #"contiguous_memory_optimization": True, # Optimize contiguous memory usage + }, + } + + kwargs = {} + kwargs["config"] = ds_config + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = LlamaForCausalLM(config) + kwargs["model"] = model + + from deepspeed.ops.adam import FusedAdam + #optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + optimizer = FusedAdam(model.parameters(), lr=1e-4) + kwargs["optimizer"] = optimizer + model_engine, optimizer, _, _ = deepspeed.initialize(**kwargs) + #from remote_pdb import RemotePdb + #RemotePdb("127.0.0.1", 16666+rank).set_trace() + + dataset = DummyDataset(vocab_size=vocab_size, max_length=seq_length) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler) + model_engine.train() + start = end = 0 + dur = [] + epoch = 0 + # Training Loop + def save_profile(prof): + prof.export_chrome_trace(f"ds_trace_{rank}.json") + + epoch = 0 + iters = 0 + # prof = torch.profiler.profile( + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=3, + # repeat=1), + # on_trace_ready=save_profile, + # record_shapes=True, + # with_stack=True) + + # #prof = nullcontext() + # prof.start() + + for step, (input_ids, labels) in enumerate(dataloader): + start = time.time() + input_ids, labels = input_ids.to(local_rank), labels.to(local_rank) + optimizer.zero_grad() + loss = model_engine(input_ids=input_ids, labels=labels).loss + model_engine.backward(loss) + model_engine.step() + torch.cuda.synchronize() + dur = time.time() - start + tflops = flop / dur / 1e12 + if rank == 0: + print(f"Epoch {epoch}, Step {step}, Loss {loss.item()}, time {dur}, {tflops}") + # if step > 10: + # break + # prof.step() + # prof.stop() + +def main_qwen_vl(): + #dist.init_process_group(backend="nccl") + + #local_rank = int(os.environ["LOCAL_RANK"]) + #rank = int(os.environ["RANK"]) + #world_size = int(os.environ["WORLD_SIZE"]) + world_size = 1 + local_rank = rank = 0 + + torch.cuda.set_device(local_rank) + config ={ + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "vision_start_token_id": 151652, + "vision_end_token_id": 151653, + "vision_token_id": 151654, + "image_token_id": 151655, + "video_token_id": 151656, + "hidden_act": "silu", + "hidden_size": 8192 // 4, + "initializer_range": 0.02, + "intermediate_size": 29568 // 4, + "max_position_embeddings": 32768, + "max_window_layers": 80, + "model_type": "qwen2_vl", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.41.2", + "use_cache": False, + "use_sliding_window": False, + "vision_config": { + "depth": 32, + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 8192, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2 + }, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16 // 4, + 24 // 4, + 24 // 4 + ] + }, + "vocab_size": 152064 + } + + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel, Qwen2VisionTransformerPretrainedModel, Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + + qwen_config = Qwen2VLConfig(**config) + preprocess_config = { + "min_pixels": 3136, + "max_pixels": 12845056, + "patch_size": 14, + "temporal_patch_size": 2, + "merge_size": 2, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "image_processor_type": "Qwen2VLImageProcessor", + "processor_class": "Qwen2VLProcessor" + } + preprocess = Qwen2VLImageProcessor(**preprocess_config) + with torch.device('cuda'): + #model = Qwen2VLForConditionalGeneration(qwen_config) + text_model = Qwen2VLModel(qwen_config) + vision_model = Qwen2VisionTransformerPretrainedModel(qwen_config.vision_config) + image = [torch.ones(1280, 1280, 3, dtype=torch.uint8) for _ in range(10)] + if rank == 0: + text = torch.randint(low=0, high=qwen_config.vocab_size, size=(1, 4096,)).cuda() + #text[-32:] = 1516545 + #print((text == 151655).sum().item()) + data = preprocess(image, return_tensors="pt") + image_grid_thw = data['image_grid_thw'].cuda() + image_hidden = data['pixel_values'].cuda() + # + t = text_model(text) + v = vision_model(image_hidden, image_grid_thw) + breakpoint() + #m = model(input_ids=text,pixel_values=image_hidden,image_grid_thw=image_grid_thw) + print(1) + + #num_layers = 10 + #hidden_size = 8192 + #intermediate_size = 32768 + #vocab_size = 126464 + #num_head = 128 + #num_kv_head = 16 + #batch_size = 2 + #seq_length = 4096 + #kv_head_ratio = num_head // num_kv_head + #torch.cuda.set_device(local_rank) + + + #config = LlamaConfig( + # vocab_size=vocab_size, + # hidden_size=hidden_size, + # num_hidden_layers=num_layers, + # num_attention_heads=num_head, + # num_key_value_heads=num_kv_head, + # intermediate_size=intermediate_size, + # max_position_embeddings=seq_length, + # initializer_range=0.02, + # layer_norm_eps=1e-5, + # attn_implementation="flash_attention_2", + # use_cache=False, + # use_bfloat16=True + #) + + ##init_device = "cpu" if local_rank == 0 else "meta" + #init_device = "meta" + + #from liger_kernel.transformers import apply_liger_kernel_to_llama + #apply_liger_kernel_to_llama( + # rope=True, + # swiglu=True, + # cross_entropy=True, + # fused_linear_cross_entropy=False, + # rms_norm=True + #) + + #with torch.device(init_device): + # model = LlamaForCausalLM(config) + + # + #flop, _ = compute_training_flops( + # batch_size, + # seq_length, + # hidden_size, + # vocab_size, + # intermediate_size, + # num_layers, + # use_gradient_checkpointing=True, + # use_gqa=True, + # kv_head_ratio=kv_head_ratio, + #) + + + #dataset = DummyDataset(vocab_size=vocab_size, max_length=seq_length) + #sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + #dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler) + + ##param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) if local_rank != 0 else None + #param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) + #wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer,},) + ##model = model.to(dtype=torch.bfloat16) + + #model = FSDP(model, device_id=local_rank, auto_wrap_policy=wrap_policy, + # mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + # sync_module_states=False, param_init_fn=param_init_fn, + # forward_prefetch=True, limit_all_gathers=True, use_orig_params=True) + + #apply_fsdp_checkpointing(model, LlamaDecoderLayer) + #optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + ## Training Loop + #def save_profile(prof): + # prof.export_chrome_trace(f"fsdp_trace_{rank}.json") + + #epoch = 0 + #iters = 0 + #prof = torch.profiler.profile( + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=3, + # repeat=1), + # on_trace_ready=save_profile, + # record_shapes=True, + # with_stack=True) + + ##prof = nullcontext() + ##prof.start() + + #dur = [] + #with prof: + # model.train() + # for input_ids, labels in dataloader: + # start = time.time() + # input_ids, labels = input_ids.to(local_rank), labels.to(local_rank) + # optimizer.zero_grad() + # loss = model(input_ids=input_ids, labels=labels).loss + # loss.backward() + # optimizer.step() + # torch.cuda.synchronize() + # if rank == 0: + # dur = time.time() - start + # tflops = flop / dur / 1e12 + # print(f"Epoch {epoch}, Loss: {loss.item()} time {dur} tflops {tflops}") + # iters += 1 + # if iters > 10: + # break + # prof.step() + # epoch += 1 + + #print("Training Complete") + #dist.destroy_process_group() +def mllama(): + from transformers import MllamaForConditionalGeneration, AutoProcessor + from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaVisionConfig, MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaVisionEncoderLayer, MllamaSelfAttentionDecoderLayer + from liger_kernel.transformers import apply_liger_kernel_to_mllama + + dist.init_process_group(backend="nccl") + + # Get local rank and world size + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + + apply_liger_kernel_to_mllama( + rope=True, + swiglu=True, + cross_entropy=False, + fused_linear_cross_entropy=True, + rms_norm=True + ) + config = { + "architectures": [ + "MllamaForConditionalGeneration" + ], + "image_token_index": 128256, + "model_type": "mllama", + "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": 128000, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "cross_attention_layers": [ + 3, + 8, + 13, + 18, + 23, + 28, + 33, + 38, + 43, + 48, + 53, + 58, + 63, + 68, + 73, + 78, + 83, + 88, + 93, + 98 + ], + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "silu", + "hidden_size": 4096, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "initializer_range": 0.02, + "intermediate_size": 28672, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 131072, + "min_length": 0, + "model_type": "mllama_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 64, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 128004, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "sep_token_id": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": False, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "bfloat16", + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": False, + "vocab_size": 128256 + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.45.0.dev0", + "vision_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_heads": 16, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "gelu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "image_size": 560, + "intermediate_layers_indices": [ + 3, + 7, + 15, + 23, + 30 + ], + "intermediate_size": 5120, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_num_tiles": 4, + "min_length": 0, + "model_type": "mllama_vision_model", + "no_repeat_ngram_size": 0, + "norm_eps": 1e-05, + "num_beam_groups": 1, + "num_beams": 1, + "num_channels": 3, + "num_global_layers": 8, + "num_hidden_layers": 32, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "patch_size": 14, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "supported_aspect_ratios": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ], + [ + 1, + 3 + ], + [ + 1, + 4 + ], + [ + 2, + 1 + ], + [ + 2, + 2 + ], + [ + 3, + 1 + ], + [ + 4, + 1 + ] + ], + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "bfloat16", + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "vision_output_dim": 7680 + } + } + vision_config = MllamaVisionConfig(**config['vision_config']) + text_config = MllamaTextConfig(**config['text_config']) + model_config = MllamaConfig(vision_config, text_config, torch_dtype="bfloat16") + data = torch.load('dummy.pth', map_location='cuda') + label = torch.randint(low=0, high=config['text_config']['vocab_size'], size=data['input_ids'].shape) + with torch.device('meta'): + model = MllamaForConditionalGeneration(model_config) + param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) + wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={ MllamaVisionEncoderLayer, MllamaSelfAttentionDecoderLayer},) + #model = model.to(dtype=torch.bfloat16) + + model = FSDP(model, device_id=local_rank, auto_wrap_policy=wrap_policy, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + sync_module_states=False, param_init_fn=param_init_fn, + forward_prefetch=True, limit_all_gathers=True, use_orig_params=True) + + apply_fsdp_checkpointing(model, (MllamaVisionEncoderLayer, MllamaSelfAttentionDecoderLayer)) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + def save_profile(prof): + prof.export_chrome_trace(f"mllama_trace_{rank}.json") + + epoch = 0 + iters = 0 + # prof = torch.profiler.profile( + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=2, + # repeat=1), + # on_trace_ready=save_profile, + # record_shapes=True, + # with_stack=True) + + # prof.start() + model.train() + for i in range(100000000): + start = time.time() + optimizer.zero_grad() + loss= model(**data, labels=label).loss + loss.backward() + optimizer.step() + torch.cuda.synchronize() + dur = time.time() - start + if rank == 0: + print(f"Step {i}, Loss {loss.item()}, time {dur}") + # prof.step() + # prof.stop() + + +if __name__ == "__main__": + # main_ds() + main() + # mllama() + diff --git a/xpu_timer/experiments/scripts/cc.py b/xpu_timer/experiments/scripts/cc.py new file mode 100644 index 0000000000..291ac172bd --- /dev/null +++ b/xpu_timer/experiments/scripts/cc.py @@ -0,0 +1,26 @@ + +from enum import Enum, auto, unique + +ANT_PATCH_ENV_NOT_SET = "ANT_PATCH_ENV_NOT_SET" + +@unique +class PatchStatus(Enum): + PATCH_OK_VERSION_CHECK = auto() + PATCH_OK_ENV_CHECK = auto() + PATCH_FAIL_VERSION_CHECK = auto() + PATCH_FAIL_ENV_CHECK = auto() + UNKNOWN = auto() + + def __str__(self): + return PATCH_STATUS_MSG[self] + +PATCH_STATUS_MSG = { + PatchStatus.PATCH_OK_VERSION_CHECK: "Y,Patch OK(VERSION)", + PatchStatus.PATCH_OK_ENV_CHECK: "Y,Patch OK(ENV)", + PatchStatus.PATCH_FAIL_VERSION_CHECK: "N,Patch Fail(VERSION)", + PatchStatus.PATCH_FAIL_ENV_CHECK: "N,Patch Fail(ENV)", + PatchStatus.UNKNOWN: "unknown", +} + +a = PatchStatus.PATCH_OK_VERSION_CHECK +print(a) diff --git a/xpu_timer/experiments/scripts/layernorm_linear.py b/xpu_timer/experiments/scripts/layernorm_linear.py new file mode 100644 index 0000000000..da88b7b0a9 --- /dev/null +++ b/xpu_timer/experiments/scripts/layernorm_linear.py @@ -0,0 +1,1253 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""LayerNormLinear API""" +import os +import warnings +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch.nn import init + +from .. import cpp_extensions as tex + +from .base import ( + get_workspace, + get_ub, + TransformerEngineBaseModule, + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..utils import ( + divide, + get_default_init_method, + init_method_constant, + cast_if_needed, + assert_dim_for_fp8_exec, + clear_tensor_data, + requires_grad, +) +from ..distributed import ( + set_tensor_model_parallel_attributes, + get_distributed_world_size, + allreduce, + reduce_scatter_along_first_dim, + gather_along_first_dim, + _fsdp_scatter_tensors, + _fsdp_gather_tensors, +) +from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..jit import no_torch_dynamo +from ..graph import is_graph_capturing +from ._common import _apply_normalization, _noop_cat +from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor + +__all__ = ["LayerNormLinear"] + + +class _LayerNormLinear(torch.autograd.Function): + """LayerNormLinear semi-top level module + Calls custom cuda extensions. + """ + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: Union[torch.Tensor, None], + weight: torch.Tensor, + weight_fp8: Optional[torch.Tensor], + bias: torch.Tensor, + use_bias: bool, + eps: float, + is_first_microbatch: Union[bool, None], + fp8: bool, + fp8_calibration: bool, + fp8_meta: Dict[str, Any], + fuse_wgrad_accumulation: bool, + cpu_offloading: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, + sequence_parallel: bool, + tensor_parallel: bool, + activation_dtype: torch.dtype, + parallel_mode: Union[str, None], + return_layernorm_output: bool, + return_layernorm_output_gathered: bool, + is_grad_enabled: bool, + fwd_ln_sm_margin: int, + bwd_ln_sm_margin: int, + zero_centered_gamma: bool, + normalization: str, + ub_bulk_wgrad: bool, + ub_bulk_dgrad: bool, + ub_overlap_rs_dgrad: bool, + ub_overlap_ag: bool, + ub_name: str, + fp8_output: bool, + fsdp_group: Union[dist_group_type, None], + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # Make sure input dimensions are compatible + in_features = ln_weight.numel() + assert inp.shape[-1] == in_features, "GEMM not possible" + inputmat = inp.view((-1, in_features)) + if fp8: + assert_dim_for_fp8_exec(inputmat) + assert_dim_for_fp8_exec(weight) + + # Cast for native AMP + inputmat = cast_if_needed(inputmat, activation_dtype) + ln_weight = cast_if_needed(ln_weight, activation_dtype) + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) + + if ub_overlap_ag: + tp_world_size = get_distributed_world_size(tp_group) + if tp_world_size == 1 or (not is_grad_enabled): + ub_overlap_ag = False + if ub_overlap_ag: + dim_size = list(inputmat.size()) + dim_size[0] = dim_size[0] * tp_world_size + ub_obj_lnout = get_ub(ub_name + "_fprop") + if return_layernorm_output: + # First prepare LN output in higher precision, + # which will be later copied to a FP8 UB + ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format) + else: + ln_out = ub_obj_lnout.get_ubuf_output(0) + else: + ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype + ln_out = torch.empty_like( + inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + ) + + # Objects for FP8 cast + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + ln_out_scale_inv = None + if fp8: + ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) + + # Launch normalization kernel + ln_out, mu, rsigma = _apply_normalization( + inputmat, + ln_out, + ln_weight, + ln_bias, + eps, + fp8 and not return_layernorm_output, + fp8_meta, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + is_grad_enabled, + fp8_scale_inv=ln_out_scale_inv, + ) + + # Column Parallel Linear + ln_out_gathered = False + if ub_overlap_ag: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) + if not return_layernorm_output: + ln_out = torch.empty_like(ln_out) + if ub_obj_lnout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif parallel_mode == "column" and sequence_parallel: + ln_out_gathered = True + ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) + else: + ln_out_total = ln_out + + # If residual connection is after LN, we need `ln_out_return` + # tensor in higher precision, this comes at the cost + # of an extra fp8 cast. + if return_layernorm_output: + ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out + if fp8: + if ub_overlap_ag: + ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0) + tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + out=ln_out_fp8, + scale_inv=ln_out_scale_inv, + ) + ln_out = torch.empty_like(ln_out_fp8) + else: + ln_out_total = tex.cast_to_fp8( + ln_out_total, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + scale_inv=ln_out_scale_inv, + ) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[slice_start:slice_end, ...] + else: + ln_out = ln_out_total + + if fp8: + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + bias = cast_if_needed(bias, bias_dtype) if use_bias else bias + + # Use FP8 weights + if weight_fp8 is None: + weight_fp8 = weight + + assert isinstance(weight_fp8, Float8Tensor) + + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + ln_out_scale_inv.fill_(ln_out_scale_inv.item()) + + if fp8_output: + out_index, meta_tensor, output_te_dtype, output_dtype = ( + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_meta["scaling_fwd"], + fp8_dtype_forward, + torch.uint8, + ) + else: + out_index, meta_tensor, output_te_dtype, output_dtype = ( + None, + None, + None, + activation_dtype, + ) + out, _ = tex.fp8_gemm( + weight_fp8._data, + weight_fp8._scale_inv, + 0, + weight_fp8._fp8_dtype, + ln_out_total, + ln_out_scale_inv, + 0, + fp8_dtype_forward, + output_dtype, + get_workspace(), + bias=bias, + use_bias=use_bias, + use_split_accumulator=_2X_ACC_FPROP, + ub_algo=ub_algo if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, + out_index=out_index, + fp8_meta_tensor=meta_tensor, + D_dtype=output_te_dtype, + ) + if output_dtype == torch.uint8: + out = Float8Tensor( + data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype=fp8_dtype_forward, + dtype=activation_dtype, + ) + else: + # Cast for native AMP + weight = cast_if_needed(weight, activation_dtype) + bias = cast_if_needed(bias, activation_dtype) if use_bias else bias + + if fp8_calibration: + # amax of input + amin, amax = ln_out_total.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( + -amin, amax + ).float() + # amax of weight + amin, amax = weight.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( + -amin, amax + ).float() + + out, _, _ = tex.gemm( + weight, + ln_out_total, + activation_dtype, + get_workspace(), + bias=bias, + use_bias=use_bias, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, + ) + + if is_grad_enabled: + if cpu_offloading: + if fp8 and weight_fp8 is not None: + weight_fp8.weight_offloading = True + ln_weight.weight_offloading = True + weight.weight_offloading = True + + inputmat.activation_offloading = True + if normalization == "LayerNorm": + mu.activation_offloading = True + rsigma.activation_offloading = True + ln_out.activation_offloading = True + + # Scatter intermediate/activation tensors saved for the backward pass + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + ctx.fsdp_group = fsdp_group + ctx.fsdp_shapes = _fsdp_scatter_tensors( + fsdp_group, + mu, + rsigma, + weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + ln_out if weight.requires_grad else None, + ) + + ctx.save_for_backward( + inputmat, + ln_weight, + mu, + rsigma, + weight, + weight_fp8, + weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, + ln_out if weight.requires_grad else None, + ln_out_scale_inv, + ) + + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_meta = fp8_meta + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = use_bias + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp.shape + ctx.parallel_mode = parallel_mode + ctx.tp_group = tp_group + ctx.tp_size = tp_size + ctx.return_layernorm_output = return_layernorm_output + ctx.return_layernorm_output_gathered = ( + return_layernorm_output_gathered and ln_out_gathered + ) + ctx.bwd_ln_sm_margin = bwd_ln_sm_margin + ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_bulk_wgrad = ub_bulk_wgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_name = ub_name + ctx.requires_dgrad = inp.requires_grad + ctx.normalization = normalization + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + ctx.reduce_and_update_bwd_fp8_tensors = ( + ctx.reduce_and_update_bwd_fp8_tensors + or FP8GlobalStateManager.is_first_fp8_module() + ) + + # Row Parallel Linear + if parallel_mode == "row" and sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif parallel_mode == "row" and tensor_parallel: + out, _ = allreduce(out, tp_group) + + # [*, in_features] -> [*, out_features] except first dimension changes for SP + out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) + + if return_layernorm_output: + if return_layernorm_output_gathered: + shape = list(inp.shape) + shape[0] *= tp_size + return out, ln_out_return.view(shape) + return out, ln_out_return.view_as(inp) + return out + + @staticmethod + def backward( + ctx, *grad_outputs: Tuple[torch.Tensor, ...] + ) -> Tuple[Union[torch.Tensor, None], ...]: + if isinstance(grad_outputs[0], Float8Tensor): + ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ + 0 + ]._scale_inv + + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): + ( + inputmat, + ln_weight, + mu, + rsigma, + weight, + weight_fp8, + main_grad, + ln_out, + ln_out_scale_inv, + ) = ctx.saved_tensors + + # Gather intermediate/activation tensors if needed + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + _fsdp_gather_tensors( + ctx.fsdp_group, + ctx.fsdp_shapes, + mu, + rsigma, + weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + ln_out, + ) + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + weight = torch.nn.Parameter(weight, weight.requires_grad) + weight.main_grad = main_grad + + if ctx.ub_overlap_rs_dgrad: + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_overlap_rs_dgrad = False + if ctx.ub_bulk_dgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1 or not weight.requires_grad: + ctx.ub_bulk_dgrad = False + if ctx.ub_bulk_dgrad: + dim_size = list(ln_out.size()) + dim_size[0] = dim_size[0] * tp_world_size + ub_obj_lnout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) + ( + grad_output, + grad_output_c, + grad_output_t, + grad_bias, + ) = TransformerEngineBaseModule.grad_output_preprocess( + ctx, grad_outputs[0], ctx.parallel_mode == "row" + ) + + if ctx.ub_bulk_wgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1 or not weight.requires_grad: + ctx.ub_bulk_wgrad = False + + # Column Parallel Linear + # Overlap input AG with dgrad + if ( + weight.requires_grad + and (not ctx.ub_bulk_dgrad) + and ctx.parallel_mode == "column" + and ctx.sequence_parallel + ): + ln_out_total, _ = gather_along_first_dim(ln_out, ctx.tp_group, async_op=False) + handle = None + else: + ln_out_total = ln_out + handle = None + + if ctx.is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + dgrad_size = list(grad_output.size()) + dgrad_size[1] = weight.size(1) + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub(ctx.ub_name + "_wgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub(ctx.ub_name + "_dgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + else: + dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(grad_output.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = weight.size(1) + rs_out = torch.empty( + dim_size, dtype=ctx.activation_dtype, device=grad_output.device + ) + if ub_obj_dgrad.is_p2p_overlap(): + if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None + + if ctx.fp8: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + out_index, meta_tensor, out_te_type, out_type = ( + None, + None, + None, + ctx.activation_dtype, + ) + if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): + out_index = tex.FP8BwdTensors.GRAD_INPUT1 + meta_tensor = ctx.fp8_meta["scaling_bwd"] + out_te_type = fp8_dtype_backward + out_type = torch.uint8 + ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + # DGRAD: Evaluated unconditionally to feed into Linear backward + _ = tex.fp8_gemm( + weight_fp8.transpose_2d(), + weight_fp8._scale_inv, + 0, + weight_fp8._fp8_dtype, + ( + grad_output_c._data + if isinstance(grad_output_c, Float8Tensor) + else grad_output_c + ), + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + out_type, + get_workspace(), + out=dgrad, + use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, + out_index=out_index, + fp8_meta_tensor=meta_tensor, + D_dtype=out_te_type, + ) + clear_tensor_data(grad_output_c) + else: + # DGRAD: Evaluated unconditionally to feed into Linear backward + _, _, _ = tex.gemm( + weight, + grad_output, + ctx.activation_dtype, + get_workspace(), + out=dgrad, + layout="NN", + grad=True, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, + ) + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) + + # Overlap dgrad-RS/AR with wgrad + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + if not ctx.ub_bulk_dgrad and handle is not None: + handle.wait() + if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: + dgrad = dgrad + grad_outputs[1].view_as(dgrad) + dgrad, _ = reduce_scatter_along_first_dim( + dgrad, ctx.tp_group, async_op=False + ) + handle = None + elif ctx.parallel_mode == "column" and ctx.tensor_parallel: + dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + + if weight.requires_grad: + if ctx.fp8: + # WGRAD + extra_output_tensor = None + if ctx.ub_bulk_wgrad: + if ub_obj_dgrad.is_fp8_ubuf(): + dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output + extra_output_tensor = torch.empty( + dim_size, dtype=ctx.activation_dtype, device=dgrad.device + ) + dgrad = extra_output_tensor + else: + dgrad = ub_obj_dgrad.get_ubuf_output(0) + if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: + ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) + wgrad, _ = tex.fp8_gemm( + ln_out_total_t, + ln_out_scale_inv, + 0, + fp8_dtype_forward, + ( + grad_output_t._data + if isinstance(grad_output_t, Float8Tensor) + else grad_output_t + ), + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + accumulate=accumulate_wgrad_into_param_main_grad, + out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor, + ) + clear_tensor_data(ln_out_total_t, grad_output_t) + else: + ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( + ln_out_total, + ln_out_scale_inv, + 0, + fp8_dtype_forward, + TE_DType[ctx.activation_dtype], + ) + wgrad, _, _ = tex.gemm( + ln_out_total_c, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NT", + grad=True, + accumulate=accumulate_wgrad_into_param_main_grad, + out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor, + ) + clear_tensor_data(ln_out_total_c) + else: + # WGRAD + wgrad, grad_bias, _ = tex.gemm( + ln_out_total, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NT", + grad=True, + use_bias=ctx.use_bias, + accumulate=accumulate_wgrad_into_param_main_grad, + out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + ) + clear_tensor_data(ln_out_total) + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + + # Column Parallel Linear + if ( + (not ctx.ub_bulk_wgrad) + and ctx.parallel_mode == "column" + and ctx.tensor_parallel + and handle is not None + ): + handle.wait() + + # LayerNorm gradient + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out.view(inputmat.shape) + else: + dgrad = dgrad.view(inputmat.shape) + + # Residual gradient + if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + dgrad = dgrad + grad_outputs[1].view_as(dgrad) + + if ctx.normalization == "LayerNorm": + dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, + inputmat, + mu, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, + ) + elif ctx.normalization == "RMSNorm": + dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, + inputmat, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, + ) + dbeta = None + clear_tensor_data(mu) + clear_tensor_data(rsigma) + + if not ctx.use_bias: + grad_bias = None + + if weight.requires_grad: + # Handle custom DDP from mcore. + if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): + weight.grad_added_to_main_grad = True + if getattr(weight, "zero_out_wgrad", False): + wgrad = torch.zeros( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + wgrad = torch.empty( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + elif ctx.fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + # Scatter fp8 weight buffers + if ctx.fp8 and not isinstance(weight, Float8Tensor): + _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + + return ( + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgamma, + dbeta, + wgrad, + None, # weight_fp8 + grad_bias, + None, # use_bias + None, # eps + None, # is_first_microbatch + None, # fp8 + None, # fp8_calibration + None, # fp8_meta + None, # fuse_wgrad_accumulation + None, # cpu_offloading + None, # tp_group + None, # tp_size + None, # sequence_parallel + None, # tensor_parallel + None, # activation_dtype + None, # parallel_mode + None, # return_layernorm_output + None, # return_layernorm_output_gathered + None, # is_grad_enabled + None, # fwd_ln_sm_margin + None, # bwd_ln_sm_margin + None, # zero_centered_gamma + None, # normalization + None, # ub_bulk_wgrad + None, # ub_bulk_dgrad + None, # ub_overlap_rs_dgrad + None, # ub_overlap_ag + None, # ub_name + None, # fp8_output + None, # fsdp_group + ) + + +class LayerNormLinear(TransformerEngineBaseModule): + r""" + Applies layer normalization followed by linear transformation to the incoming data. + + Parameters + ---------- + in_features : int + size of each input sample. + out_features : int + size of each output sample. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + bias : bool, default = `True` + if set to `False`, the layer will not learn an additive bias. + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. + init_method : Callable, default = `None` + used for initializing weights in the following way: `init_method(weight)`. + When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module is + taken post layernorm. + return_layernorm_output_gathered : bool, default = `False` + if set to `True`, output of layernorm is returned after the all + gather operation. Ignored if return_layernorm_output is False. + Example use case: with sequence parallel, input to residual connection + for transformer module (e.g. LoRA) will need to be gathered. + Returning layernorm output gathered will prevent a redundant gather. + parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None + Configuration for splitting the weight and bias tensors along dim 0 into + multiple PyTorch parameters. If a list or tuple of strings is provided, + they are used to make the names of equally-sized parameters. If a dict + (preferably an OrderedDict) is provided, the keys are used as names and + values as split sizes along dim 0. The resulting parameters will have + names that end in `_weight` or `_bias`, so trailing underscores are + stripped from any provided names. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + device : Union[torch.device, str], default = "cuda" + The device on which the parameters of the model will be allocated. It is the user's + responsibility to ensure all parameters are moved to the GPU before running the + forward pass. + + Parallelism parameters + ---------------------- + sequence_parallel : bool, default = `False` + if set to `True`, uses sequence parallelism. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + tp_size : int, default = 1 + used as TP (tensor parallel) world size when TP groups are not formed during + initialization. In this case, users must call the + `set_tensor_parallel_group(tp_group)` method on the initialized module before the + forward pass to supply the tensor parallel group needed for tensor and sequence + parallel collectives. + parallel_mode : {None, 'column', 'row'}, default = `None` + used to decide whether this Linear layer is Column Parallel Linear or Row + Parallel Linear as described `here `_. + When set to `None`, no communication is performed. + + Optimization parameters + ----------------------- + fuse_wgrad_accumulation : bool, default = 'False' + if set to `True`, enables fusing of creation and accumulation of + the weight gradient. When enabled, it is assumed that the weights + have an additional `main_grad` attribute (used instead of the + regular `grad`) which is a pre-allocated buffer of the correct + size to accumulate gradients in. + return_bias : bool, default = `False` + when set to `True`, this module will not apply the additive bias itself, but + instead return the bias value during the forward pass together with the + output of the linear transformation :math:`y = xA^T`. This is useful when + the bias addition can be fused to subsequent operations. + params_dtype : torch.dtype, default = `torch.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + """ + + def __init__( + self, + in_features: int, + out_features: int, + eps: float = 1e-5, + sequence_parallel: bool = False, + fuse_wgrad_accumulation: bool = False, + tp_group: Optional[dist_group_type] = None, + tp_size: int = 1, + get_rng_state_tracker: Optional[Callable] = None, + init_method: Optional[Callable] = None, + bias: bool = True, + normalization: str = "LayerNorm", + return_bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + parallel_mode: Optional[str] = None, + return_layernorm_output: bool = False, + return_layernorm_output_gathered: bool = False, + parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, + zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_overlap_ag: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_name: Optional[str] = None, + ) -> None: + super().__init__() + + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.in_features = in_features + self.out_features = out_features + self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + self.normalization = normalization + assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" + self.use_bias = bias + self.return_bias = return_bias + self.apply_bias = self.use_bias and not return_bias + self.return_layernorm_output = return_layernorm_output + self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.zero_centered_gamma = zero_centered_gamma + self.ub_bulk_wgrad = ub_bulk_wgrad + self.ub_bulk_dgrad = ub_bulk_dgrad + self.ub_overlap_ag = ub_overlap_ag + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + + if tp_group is None: + self.tp_size = tp_size + if tp_size == 1: + self.set_tensor_parallel_group(tp_group) + else: + self.tp_size = get_distributed_world_size(tp_group) + self.set_tensor_parallel_group(tp_group) + self.set_nccl_overlap_warning_if_tp() + + self.parallel_mode = parallel_mode + assert ( + self.parallel_mode in GemmParallelModes + ), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + + if init_method is None: + init_method = get_default_init_method() + + self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + + self.eps = eps + layer_norm_weight = torch.nn.Parameter( + torch.empty(in_features, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma)), + ) + if self.normalization != "RMSNorm": + layer_norm_bias = torch.nn.Parameter( + torch.empty(in_features, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) + ) + else: + self.layer_norm_bias = None + + # Initialize params in FP8 + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None + if self.use_bias: + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=params_dtype, + ) + + # Configure parameter splits + self.weight_names = [] + self.bias_names = [] + self.parameter_split_sizes = [] + if parameters_split is None: + # Split into a single parameter by default + self.weight_names = ["weight"] + self.bias_names = ["bias"] + self.parameter_split_sizes = [out_features] + elif not parameters_split: + raise ValueError("Cannot split weight buffer into 0 parameters") + elif isinstance(parameters_split, dict): + # Split parameters with provided sizes + for name, split_size in parameters_split.items(): + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + elif all(isinstance(name, str) for name in parameters_split): + # Split parameters evenly + split_size = out_features // len(parameters_split) + for name in parameters_split: + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + else: + raise TypeError("Invalid configuration for parameters split") + + # Make sure parameter splits are valid + if sum(self.parameter_split_sizes) != out_features: + raise ValueError( + f"Trying to split weight buffer ({out_features=}) " + f"with split sizes {self.parameter_split_sizes}" + ) + + # Adjust parameter splits for tensor-parallel distribution + if self.parallel_mode == "column": + for i, size in enumerate(self.parameter_split_sizes): + if size % self.tp_size != 0: + raise RuntimeError( + f"Attempting to distribute a parameter with out_features={size} " + f"between {self.tp_size} tensor-parallel processes" + ) + self.parameter_split_sizes[i] = size // self.tp_size + + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in LayerNormLinear.parameters(). This makes it + # more likely that they will stay contiguous if the weights + # are manipulated externally, e.g. by FSDP. + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + + # Check if parameters are subviews of buffers + is_subview = (split_start, split_end) != (0, self.out_features) + if is_subview and with_fp8_params: + raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + + # Construct weight parameter + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) + + if with_fp8_params: + self.init_fp8_metadata() + + self.reset_parameters(defer_init=(device == "meta")) + + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.apply_bias: + self.gemm_bias_unfused_add = True + else: + self.gemm_bias_unfused_add = False + + # These many SMs are subtracted from the total SM count when calling forward + # and backward LayerNorm C APIs. These envvars can be used to prevent the LN + # kernels from using all SMs in the device. This is useful for cases such as + # communication overlap with LN. + self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + + def reset_layer_norm_parameters(self) -> None: + """Init LN params""" + warnings.warn( + "This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormLinear.reset_parameters() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not self.zero_centered_gamma: + init.ones_(self.layer_norm_weight) + else: + init.zeros_(self.layer_norm_weight) + if self.layer_norm_bias is not None: + init.zeros_(self.layer_norm_bias) + + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallelism attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, + fp8_output: Optional[bool] = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + Apply layer normalization to the input followed by a linear transformation. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + is_first_microbatch : {True, False, None}, default = None + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) + """ + + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + + with self.prepare_forward(inp, is_first_microbatch) as inp: + + # Get concatenated weight and bias tensors + unfused_weights = [getattr(self, name) for name in self.weight_names] + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + unfused_weights = [w.dequantize() for w in unfused_weights] + weight_tensor = _noop_cat(unfused_weights) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], + ) + else: + bias_tensor = getattr(self, self.bias_names[0]) # Unused + + # Initialize FP8 weights if needed + weight_fp8 = None + if self.fp8: + if isinstance(weight_tensor, Float8Tensor): + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: + weight_tensor.transpose_2d( + fill_cache=True, + noop_flag=skip_fp8_weight_update, + ) + else: + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weight_fp8 = self.get_fp8_workspace( + tensor=weight_tensor, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + + from ..cpu_offload import CPUOffloadEnabled + + if torch.is_grad_enabled(): + fwd_fn = _LayerNormLinear.apply + args = [] + else: + fwd_fn = _LayerNormLinear.forward + args = [None] + args += ( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + weight_tensor, + weight_fp8, + bias_tensor, + self.apply_bias and not self.gemm_bias_unfused_add, + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.fp8_meta, + self.fuse_wgrad_accumulation, + CPUOffloadEnabled, + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + torch.is_grad_enabled(), + self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.normalization, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_overlap_ag, + self.ub_name, + fp8_output, + self.fsdp_group, + ) + out = fwd_fn(*args) + + if self.return_layernorm_output: + out, ln_out = out + + if self.gemm_bias_unfused_add: + out = out + cast_if_needed(bias_tensor, self.activation_dtype) + + if self.return_bias: + if self.return_layernorm_output: + return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out + return out, cast_if_needed(bias_tensor, self.activation_dtype) + if self.return_layernorm_output: + return out, ln_out + return out diff --git a/xpu_timer/experiments/scripts/layernorm_mlp.py b/xpu_timer/experiments/scripts/layernorm_mlp.py new file mode 100644 index 0000000000..c456549b86 --- /dev/null +++ b/xpu_timer/experiments/scripts/layernorm_mlp.py @@ -0,0 +1,1593 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""LayerNormMLP API""" +import os +import warnings +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch.nn.parameter import Parameter +from torch.nn import init + +from .base import ( + get_workspace, + _ub_communicators, + get_ub, + TransformerEngineBaseModule, + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..jit import ( + bias_gelu_fused, + bgrad_dgelu_fused, + set_jit_fusion_options, + warmup_jit_bias_gelu_all_dtypes, +) +from ..utils import ( + divide, + get_default_init_method, + init_method_constant, + cast_if_needed, + assert_dim_for_fp8_exec, + clear_tensor_data, + requires_grad, +) +from ..distributed import ( + set_tensor_model_parallel_attributes, + get_distributed_world_size, + allreduce, + reduce_scatter_along_first_dim, + gather_along_first_dim, + use_reentrant_activation_recompute, + _fsdp_scatter_tensors, + _fsdp_gather_tensors, +) + +from .. import cpp_extensions as tex + +from ..constants import dist_group_type, TE_DType +from ..jit import no_torch_dynamo +from ..graph import is_graph_capturing +from ..float8_tensor import Float8Tensor +from ._common import _apply_normalization + +__all__ = ["LayerNormMLP"] + + +def _act_func(activation: str): + funcs = { + "gelu": (tex.gelu, tex.dgelu), + "relu": (tex.relu, tex.drelu), + "geglu": (tex.geglu, tex.dgeglu), + "reglu": (tex.reglu, tex.dreglu), + "swiglu": (tex.swiglu, tex.dswiglu), + "qgelu": (tex.qgelu, tex.dqgelu), + "srelu": (tex.srelu, tex.dsrelu), + } + if activation not in funcs: + raise NotImplementedError("Activation type " + activation + " is not supported!") + return funcs[activation] + + +class _LayerNormMLP(torch.autograd.Function): + """LayerNormMLP semi-top level module + Calls custom cuda extensions. + """ + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: torch.Tensor, + fc1_weight: torch.Tensor, + fc1_weight_fp8: Optional[torch.Tensor], + fc1_bias: torch.Tensor, + use_fc1_bias: bool, + fc2_weight: torch.Tensor, + fc2_weight_fp8: Optional[torch.Tensor], + fc2_bias: torch.Tensor, + use_fc2_bias: bool, + eps: float, + is_first_microbatch: Union[bool, None], + fp8: bool, + fp8_calibration: bool, + fp8_meta: Dict[str, Any], + fuse_wgrad_accumulation: bool, + cpu_offloading: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, + sequence_parallel: bool, + tensor_parallel: bool, + activation_dtype: torch.dtype, + return_layernorm_output: bool, + return_layernorm_output_gathered: bool, + bias_gelu_nvfusion: bool, + set_parallel_mode: bool, + is_grad_enabled: bool, + fwd_ln_sm_margin: int, + bwd_ln_sm_margin: int, + zero_centered_gamma: bool, + activation: str, + normalization: str, + ub_bulk_wgrad: bool, + ub_bulk_dgrad: bool, + ub_overlap_rs_dgrad: bool, + ub_overlap_rs: bool, + ub_overlap_ag: bool, + gemm_gelu_fusion: bool, + fsdp_group: Union[dist_group_type, None], + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # Make sure input dimensions are compatible + in_features = ln_weight.numel() + assert inp.shape[-1] == in_features, "GEMM not possible" + inputmat = inp.view((-1, in_features)) + if fp8: + assert_dim_for_fp8_exec(inputmat) + assert_dim_for_fp8_exec(fc1_weight) + assert_dim_for_fp8_exec(fc2_weight) + + activation_func = _act_func(activation)[0] + + # Cast for native AMP + inputmat = cast_if_needed(inputmat, activation_dtype) + ln_weight = cast_if_needed(ln_weight, activation_dtype) + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) + + tp_world_size = get_distributed_world_size(tp_group) + if ub_overlap_ag: + if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: + ub_overlap_ag = False + if ub_overlap_ag: + ub_obj_lnout = get_ub("fc1_fprop") + ln_out = ub_obj_lnout.get_ubuf_output(0) + else: + ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype + ln_out = torch.empty_like( + inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + ) + ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs + + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + + ln_out, mu, rsigma = _apply_normalization( + inputmat, + ln_out, + ln_weight, + ln_bias, + eps, + fp8 and not return_layernorm_output, + fp8_meta, + normalization, + fwd_ln_sm_margin, + zero_centered_gamma, + is_grad_enabled, + ) + + # Column Parallel Linear + ln_out_gathered = False + if ub_overlap_ag: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) + ln_out = torch.empty_like(ln_out) + if ub_obj_lnout.is_atomic_gemm(): + ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif set_parallel_mode and sequence_parallel: + ln_out_gathered = True + ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) + else: + ln_out_total = ln_out + + # If residual connection is after LN, we need `ln_out` + # tensor in higher precision, this comes at the cost + # of an extra fp8 cast. + if return_layernorm_output: + ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out + if fp8: + if ub_overlap_ag: + ln_out = tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + else: + ln_out_total = tex.cast_to_fp8( + ln_out_total, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[slice_start:slice_end, ...] + else: + ln_out = ln_out_total + + if fp8: + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias + fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias + + # Use FP8 weights + if fc1_weight_fp8 is None: + fc1_weight_fp8 = fc1_weight + if fc2_weight_fp8 is None: + fc2_weight_fp8 = fc2_weight + + assert isinstance(fc1_weight_fp8, Float8Tensor) + assert isinstance(fc2_weight_fp8, Float8Tensor) + + # Perform FP8 GEMM + fp8_gemm_args = [ + fc1_weight_fp8._data, + fc1_weight_fp8._scale_inv, + 0, + fc1_weight_fp8._fp8_dtype, + ln_out_total, + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + activation_dtype, + get_workspace(), + ] + fp8_gemm_kwargs = dict( + bias=fc1_bias, + use_bias=use_fc1_bias, + use_split_accumulator=_2X_ACC_FPROP, + ub_algo=ub_algo_ag if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, + ) + if gemm_gelu_fusion: + fp8_gemm_args[8] = torch.uint8 # out_dtype + fp8_gemm_kwargs.update( + dict( + gelu=True, + out_index=tex.FP8FwdTensors.GEMM2_INPUT, + fp8_meta_tensor=fp8_meta["scaling_fwd"], + D_dtype=fp8_dtype_forward, + ) + ) + fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) + if not is_grad_enabled: + clear_tensor_data(ln_out_total) + + # Perform activation + if gemm_gelu_fusion: + gelu_out, fc1_out = fp8_gemm_out + else: + fc1_out, _ = fp8_gemm_out + gelu_out = activation_func( + fc1_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + ) + if not is_grad_enabled: + clear_tensor_data(fc1_out) + + fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( + None, + None, + None, + activation_dtype, + ) + if ub_overlap_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + fc2_out = ub_obj_fc2out.get_ubuf_output(1) + dim_size = list(gelu_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight_fp8.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + if ub_obj_fc2out.is_p2p_overlap(): + if ub_obj_fc2out.is_atomic_gemm(): + ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_fc2out.is_atomic_gemm(): + ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + + if ub_obj_fc2out.is_fp8_ubuf(): + fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT + fc2_meta_tensor = fp8_meta["scaling_fwd"] + fc2_te_type = fp8_dtype_forward + out_type = torch.uint8 + ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index]) + else: + dim_size = list(gelu_out.size()) + dim_size[1] = fc2_weight_fp8.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + + _ = tex.fp8_gemm( + fc2_weight_fp8._data, + fc2_weight_fp8._scale_inv, + 0, + fc2_weight_fp8._fp8_dtype, + gelu_out, + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + out_type, + get_workspace(), + bias=fc2_bias, + use_bias=use_fc2_bias, + use_split_accumulator=_2X_ACC_FPROP, + out=fc2_out, + ub_algo=ub_algo_rs if ub_overlap_rs else None, + ub=ub_obj_fc2out if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, + out_index=fc2_out_index, + fp8_meta_tensor=fc2_meta_tensor, + D_dtype=fc2_te_type, + ) + if not is_grad_enabled: + clear_tensor_data(gelu_out) + else: + # Cast for native AMP + fc1_weight = cast_if_needed(fc1_weight, activation_dtype) + fc2_weight = cast_if_needed(fc2_weight, activation_dtype) + fc1_bias = cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias + fc2_bias = cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias + + if fp8_calibration: + # amax of fc1 input + amin, amax = ln_out_total.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( + -amin, amax + ).float() + # amax of fc1 weight + amin, amax = fc1_weight.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( + -amin, amax + ).float() + + fc1_outputs = tex.gemm( + fc1_weight, + ln_out_total, + activation_dtype, + get_workspace(), + bias=fc1_bias, + use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, + gelu=not bias_gelu_nvfusion and (activation == "gelu"), + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, + ) + if not is_grad_enabled: + clear_tensor_data(ln_out_total) + + if bias_gelu_nvfusion: + fc1_out, _, _ = fc1_outputs + gelu_out = bias_gelu_fused(fc1_out, fc1_bias) + else: + if activation == "gelu": + gelu_out, _, fc1_out = fc1_outputs + else: + fc1_out, _, _ = fc1_outputs + gelu_out = activation_func( + fc1_out, None, tex.FP8FwdTensors.GEMM2_INPUT, TE_DType[fc1_out.dtype] + ) + if not is_grad_enabled: + clear_tensor_data(fc1_out) + + if fp8_calibration: + # amax of fc2 input + amin, amax = gelu_out.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = torch.max( + -amin, amax + ).float() + # amax of fc2 weight + amin, amax = fc2_weight.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = torch.max( + -amin, amax + ).float() + + if ub_overlap_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + fc2_out = ub_obj_fc2out.get_ubuf_output(1) + dim_size = list(gelu_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + if ub_obj_fc2out.is_p2p_overlap(): + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + else: + dim_size = list(gelu_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + _ = tex.gemm( + fc2_weight, + gelu_out, + activation_dtype, + get_workspace(), + bias=fc2_bias, + use_bias=use_fc2_bias, + out=fc2_out, + ub_algo=ub_algo_rs if ub_overlap_rs else None, + ub=ub_obj_fc2out if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, + ) + if not is_grad_enabled: + clear_tensor_data(gelu_out) + + if is_grad_enabled: + if cpu_offloading: + if fp8 and fc1_weight_fp8 is not None: + fc1_weight_fp8.weight_offloading = True + if fp8 and fc2_weight_fp8 is not None: + fc2_weight_fp8.weight_offloading = True + ln_weight.weight_offloading = True + fc1_weight.weight_offloading = True + fc2_weight.weight_offloading = True + fc1_bias.weight_offloading = True + + inputmat.activation_offloading = True + if normalization == "LayerNorm": + mu.activation_offloading = True + rsigma.activation_offloading = True + ln_out.activation_offloading = True + fc1_out.activation_offloading = True + gelu_out.activation_offloading = True + + # Scatter intermediate/activation tensors saved for the backward pass + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + ctx.fsdp_group = fsdp_group + ctx.fsdp_shapes = _fsdp_scatter_tensors( + fsdp_group, + mu, + rsigma, + ln_out, + fc1_out, + gelu_out, + fc1_weight_fp8 if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + fc2_weight_fp8 if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + ) + + ctx.save_for_backward( + inputmat, + ln_weight, + mu, + rsigma, + ln_out if fc1_weight.requires_grad else None, + fc1_out, + gelu_out if fc2_weight.requires_grad else None, + fc1_weight, + fc1_weight_fp8, + fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, + fc2_weight, + fc2_weight_fp8, + fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, + fc1_bias, + fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ) + + ctx.activation_dtype = activation_dtype + ctx.activation = activation + ctx.fp8 = fp8 + ctx.fp8_meta = fp8_meta + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading + ctx.is_first_microbatch = is_first_microbatch + ctx.use_fc1_bias = use_fc1_bias + ctx.use_fc2_bias = use_fc2_bias + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp.shape + ctx.tp_group = tp_group + ctx.tp_size = tp_size + ctx.bias_gelu_nvfusion = bias_gelu_nvfusion + ctx.return_layernorm_output = return_layernorm_output + ctx.return_layernorm_output_gathered = ( + return_layernorm_output_gathered and ln_out_gathered + ) + ctx.set_parallel_mode = set_parallel_mode + ctx.bwd_ln_sm_margin = bwd_ln_sm_margin + ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_bulk_wgrad = ub_bulk_wgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_overlap_ag = ub_overlap_ag + ctx.requires_dgrad = inp.requires_grad + ctx.normalization = normalization + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad( + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + ): + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + + # Row Parallel Linear + if ub_overlap_rs: + fc2_out = rs_out + elif set_parallel_mode and sequence_parallel: + fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) + elif set_parallel_mode and tensor_parallel: + fc2_out, _ = allreduce(fc2_out, tp_group) + + # [*, in_features] -> [*, out_features] except first dimension changes for SP + fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1]) + + if return_layernorm_output: + if return_layernorm_output_gathered: + shape = list(inp.shape) + shape[0] *= tp_size + return fc2_out, ln_out_return.view(shape) + return fc2_out, ln_out_return.view_as(inp) + return fc2_out + + @staticmethod + def backward( + ctx, *grad_outputs: Tuple[torch.Tensor, ...] + ) -> Tuple[Union[torch.Tensor, None], ...]: + with torch.cuda.nvtx.range("_LayerNormMLP_backward"): + ( + inputmat, + ln_weight, + mu, + rsigma, + ln_out, + fc1_out, + gelu_out, + fc1_weight, + fc1_weight_fp8, + fc1_weight_main_grad, + fc2_weight, + fc2_weight_fp8, + fc2_weight_main_grad, + fc1_bias, + fwd_scale_inverses, + ) = ctx.saved_tensors + + # Gather saved autograd context tensors when running with FSDP + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + _fsdp_gather_tensors( + ctx.fsdp_group, + ctx.fsdp_shapes, + mu, + rsigma, + ln_out, + fc1_out, + gelu_out, + fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + ) + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad) + fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad) + + fc1_weight.main_grad = fc1_weight_main_grad + fc2_weight.main_grad = fc2_weight_main_grad + + activation_func = _act_func(ctx.activation)[1] + + if ctx.ub_overlap_rs_dgrad: + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_overlap_rs_dgrad = False + if ctx.ub_bulk_dgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1 or not fc1_weight.requires_grad: + ctx.ub_bulk_dgrad = False + if ctx.ub_bulk_dgrad: + dim_size = list(ln_out.size()) + dim_size[0] = dim_size[0] * tp_world_size + ub_obj_lnout = get_ub("fc1_dgrad") + ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) + if ctx.ub_overlap_ag: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_overlap_ag = False + + if ctx.ub_overlap_ag: + dim_size = list(grad_outputs[0].size()) + dim_size[0] = dim_size[0] * tp_world_size + ctx.ub_obj_gradout = get_ub("fc2_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + + ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess + ( + grad_output, + grad_output_c, + grad_output_t, + fc2_bias_grad, + ) = TransformerEngineBaseModule.grad_output_preprocess(ctx, grad_outputs[0], True) + + if ctx.ub_bulk_wgrad: + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1 or not fc1_weight.requires_grad: + ctx.ub_bulk_wgrad = False + # Column Parallel Linear + # Overlap input AG with dgrad + if ( + fc1_weight.requires_grad + and (not ctx.ub_bulk_dgrad) + and ctx.set_parallel_mode + and ctx.sequence_parallel + ): + ln_out_total, _ = gather_along_first_dim(ln_out, ctx.tp_group, async_op=False) + handle = None + else: + ln_out_total = ln_out + handle = None + + if ctx.is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + if ctx.fp8: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + + # FC2 DGRAD; Unconditional + fc2_dgrad, _ = tex.fp8_gemm( + fc2_weight_fp8.transpose_2d(), + fc2_weight_fp8._scale_inv, + 0, + fc2_weight_fp8._fp8_dtype, + grad_output_c, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=ub_algo if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ) + if ctx.ub_overlap_ag: + grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + clear_tensor_data(grad_output_c) + + # FC2 WGRAD + if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: + if fc2_weight.requires_grad: + gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) + clear_tensor_data(gelu_out) + fc2_wgrad, _ = tex.fp8_gemm( + gelu_out_t, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + grad_output_t, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + ) + clear_tensor_data(gelu_out_t, grad_output_t) + + if ctx.activation == "gelu": + fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( + fc2_dgrad, + fc1_out, + ctx.fp8_meta["scaling_bwd"], + tex.FP8BwdTensors.GRAD_OUTPUT2, + fp8_dtype_backward, + ) + else: + dgelu = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) + fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_fused( + dgelu, + ctx.fp8_meta["scaling_bwd"], + tex.FP8BwdTensors.GRAD_OUTPUT2, + fp8_dtype_backward, + ) + clear_tensor_data(fc1_out) + else: + if fc2_weight.requires_grad: + gelu_out_c = torch.ops.tex_ts.cast_from_fp8_ts( + gelu_out, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + TE_DType[ctx.activation_dtype], + ) + clear_tensor_data(gelu_out) + fc2_wgrad, _, _ = tex.gemm( + gelu_out_c, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NT", + grad=True, + use_bias=False, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ) + clear_tensor_data(gelu_out_c) + + if ctx.activation == "gelu": + fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( + fc2_dgrad, fc1_out, fc1_bias + ) + else: + dgelu_no_fp8 = activation_func( + fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype] + ) + fc1_bias_grad = dgelu_no_fp8.sum(dim=0) + clear_tensor_data(fc1_out) + + dgelu = tex.cast_to_fp8( + dgelu_no_fp8, + ctx.fp8_meta["scaling_bwd"], + tex.FP8BwdTensors.GRAD_OUTPUT2, + fp8_dtype_backward, + ) + dgelu_t = None + + out_index, meta_tensor, out_te_type, out_type = ( + None, + None, + None, + ctx.activation_dtype, + ) + fc1_dgrad_size = list(dgelu.size()) + fc1_dgrad_size[1] = fc1_weight.size(1) + # Get/alloc fc1_dgrad + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub("fc1_wgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub("fc1_dgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + else: + fc1_dgrad = torch.empty( + fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + ) + + # FP8 RS + if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): + out_index = tex.FP8BwdTensors.GRAD_INPUT2 + meta_tensor = ctx.fp8_meta["scaling_bwd"] + out_te_type = fp8_dtype_backward + out_type = torch.uint8 + ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(dgelu.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc1_weight_fp8.size(1) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) + if ub_obj_dgrad.is_p2p_overlap(): + if ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None + # FC1 DGRAD: Unconditional + _ = tex.fp8_gemm( + fc1_weight_fp8.transpose_2d(), + fc1_weight_fp8._scale_inv, + 0, + fc1_weight_fp8._fp8_dtype, + dgelu, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT2, + fp8_dtype_backward, + out_type, + get_workspace(), + out=fc1_dgrad, + use_split_accumulator=_2X_ACC_DGRAD, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, + out_index=out_index, + fp8_meta_tensor=meta_tensor, + D_dtype=out_te_type, + ) + else: + # FC2 DGRAD; Unconditional + fc2_dgrad, _, _ = tex.gemm( + fc2_weight, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NN", + gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == "gelu"), + grad=True, + gelu_input=fc1_out, + ub_algo=( + tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None + ), + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ) + + # FC2 WGRAD + if fc2_weight.requires_grad: + fc2_wgrad, fc2_bias_grad, _ = tex.gemm( + gelu_out, + grad_output, + ctx.activation_dtype, + get_workspace(), + layout="NT", + grad=True, + use_bias=ctx.use_fc2_bias, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ) + clear_tensor_data(gelu_out) + + if ctx.bias_gelu_nvfusion and ctx.activation == "gelu": + fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) + else: + if ctx.activation != "gelu": + fc2_dgrad = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) + + # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM + # and will not be calculated in case wgrad is not required. + if not fc1_weight.requires_grad: + fc1_bias_grad = fc2_dgrad.sum(dim=0) + + # Overwrite data. Deleting the tensor does not release underlying memory. + clear_tensor_data(fc1_out) + dgelu = fc2_dgrad + + fc1_dgrad_size = list(dgelu.size()) + fc1_dgrad_size[1] = fc1_weight.size(1) + if ctx.ub_bulk_wgrad: # allocate dgrad output + ub_obj_dgrad = get_ub("fc1_wgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub("fc1_dgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + else: + fc1_dgrad = torch.empty( + fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + ) + + # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(dgelu.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc1_weight.size(1) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) + if ub_obj_dgrad.is_p2p_overlap(): + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None + # FC1 DGRAD: Unconditional + _ = tex.gemm( + fc1_weight, + dgelu, + ctx.activation_dtype, + get_workspace(), + out=fc1_dgrad, + layout="NN", + grad=True, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, + ) + + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_lnout.get_ubuf_output(1) + # Overlap dgrad-RS/AR with wgrad + if ctx.set_parallel_mode and ctx.sequence_parallel: + if not ctx.ub_bulk_dgrad and handle is not None: + handle.wait() + if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: + fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) + fc1_dgrad, _ = reduce_scatter_along_first_dim( + fc1_dgrad, ctx.tp_group, async_op=False + ) + handle = None + elif ctx.set_parallel_mode and ctx.tensor_parallel: + fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + + if fc1_weight.requires_grad: + if ctx.fp8: + # FC1 WGRAD + extra_output_tensor = None + if ctx.ub_bulk_wgrad: + if ub_obj_dgrad.is_fp8_ubuf(): + dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output + extra_output_tensor = torch.empty( + dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device + ) + fc1_dgrad = extra_output_tensor + else: + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) + if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: + ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) + fc1_wgrad, _ = tex.fp8_gemm( + ln_out_total_t, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + dgelu_t, + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT2, + fp8_dtype_backward, + ctx.activation_dtype, + get_workspace(), + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor, + ) + clear_tensor_data(ln_out_total_t, dgelu_t) + else: + ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( + ln_out_total, + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + TE_DType[ctx.activation_dtype], + ) + fc1_wgrad, _, _ = tex.gemm( + ln_out_total_c, + dgelu_no_fp8, + ctx.activation_dtype, + get_workspace(), + layout="NT", + grad=True, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=( + tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + ), + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + extra_output_tensor=extra_output_tensor, + ) + clear_tensor_data(ln_out_total_c, dgelu_no_fp8) + else: + # FC1 WGRAD + fc1_wgrad_outputs = tex.gemm( + ln_out_total, + dgelu, + ctx.activation_dtype, + get_workspace(), + layout="NT", + grad=True, + use_bias=not ctx.bias_gelu_nvfusion, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + ) + clear_tensor_data(ln_out_total, dgelu) + + if ctx.bias_gelu_nvfusion: + fc1_wgrad, _, _ = fc1_wgrad_outputs + else: + fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs + if ctx.ub_bulk_wgrad: + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + + # Column Parallel Linear + if ( + (not ctx.ub_bulk_wgrad) + and ctx.set_parallel_mode + and ctx.tensor_parallel + and handle is not None + ): + handle.wait() + + # LayerNorm gradient + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out.view(inputmat.shape) + else: + dgrad = fc1_dgrad.view(inputmat.shape) + + # Residual gradient + if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + dgrad = dgrad + grad_outputs[1].view_as(dgrad) + + if ctx.normalization == "LayerNorm": + dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, + inputmat, + mu, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, + ) + elif ctx.normalization == "RMSNorm": + dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, + inputmat, + rsigma, + ln_weight, + ctx.bwd_ln_sm_margin, + ctx.zero_centered_gamma, + ) + dbeta = None + clear_tensor_data(mu) + clear_tensor_data(rsigma) + + if fc1_weight.requires_grad: + # Handle custom DDP from mcore. + if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): + fc1_weight.grad_added_to_main_grad = True + if getattr(fc1_weight, "zero_out_wgrad", False): + fc1_wgrad = torch.zeros( + fc1_weight.main_grad.shape, + dtype=fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + fc1_wgrad = torch.empty( + fc1_weight.main_grad.shape, + dtype=fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + elif ctx.fuse_wgrad_accumulation: + fc1_wgrad = None + else: + fc1_wgrad = None + + if fc2_weight.requires_grad: + # Handle custom DDP from mcore. + if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): + fc2_weight.grad_added_to_main_grad = True + if getattr(fc2_weight, "zero_out_wgrad", False): + fc2_wgrad = torch.zeros( + fc2_weight.main_grad.shape, + dtype=fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + fc2_wgrad = torch.empty( + fc2_weight.main_grad.shape, + dtype=fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + elif ctx.fuse_wgrad_accumulation: + fc2_wgrad = None + else: + fc2_wgrad = None + + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + # Scatter Fp8 tranposed-weight buffers + if ctx.fp8: + _fsdp_scatter_tensors( + ctx.fsdp_group, + fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, + fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, + ) + + return ( + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgamma, + dbeta, + fc1_wgrad, + None, # fc1_weight_fp8 + fc1_bias_grad if ctx.use_fc1_bias else None, + None, # use_fc1_bias + fc2_wgrad, + None, # fc2_weight_fp8 + fc2_bias_grad if ctx.use_fc2_bias else None, + None, # use_fc2_bias + None, # eps + None, # is_first_microbatch + None, # fp8 + None, # fp8_calibration + None, # fp8_meta + None, # fuse_wgrad_accumulation + None, # cpu_offloading + None, # tp_group + None, # tp_size + None, # sequence_parallel + None, # tensor_parallel + None, # activation_dtype + None, # return_layernorm_output + None, # return_layernorm_output_gathered + None, # bias_gelu_nvfusion + None, # set_parallel_mode + None, # is_grad_enabled + None, # fwd_ln_sm_margin + None, # bwd_ln_sm_margin + None, # zero_centered_gamma + None, # activation + None, # normalization + None, # ub_bulk_wgrad + None, # ub_bulk_dgrad + None, # ub_overlap_rs_dgrad + None, # ub_overlap_rs + None, # ub_overlap_ag + None, # gemm_gelu_fusion + None, # fsdp_group + ) + + +class LayerNormMLP(TransformerEngineBaseModule): + r""" + Applies layer normalization on the input followed by the MLP module, consisting of + 2 successive linear transformations, separated by the GeLU activation. + + Parameters + ---------- + hidden_size : int + size of each input sample. + ffn_hidden_size : int + intermediate size to which input samples are projected. + eps : float, default = 1e-5 + a value added to the denominator of layer normalization for numerical stability. + bias : bool, default = `True` + if set to `False`, the FC1 and FC2 layers will not learn an additive bias. + normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' + type of normalization applied. + activation : str, default = 'gelu' + activation function used. + Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'. + init_method : Callable, default = `None` + used for initializing FC1 weights in the following way: `init_method(weight)`. + When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. + output_layer_init_method : Callable, default = `None` + used for initializing FC2 weights in the following way: + `output_layer_init_method(weight)`. When set to `None`, defaults to + `torch.nn.init.normal_(mean=0.0, std=0.023)`. + return_layernorm_output : bool, default = `False` + if set to `True`, output of layernorm is returned from the forward + together with the output of the linear transformation. + Example use case: residual connection for transformer module + is taken post layernorm. + return_layernorm_output_gathered : bool, default = `False` + if set to `True`, output of layernorm is returned after the all + gather operation. Ignored if return_layernorm_output is False. + Example use case: with sequence parallel, input to residual connection + for transformer module (e.g. LoRA) will need to be gathered. + Returning layernorm output gathered will prevent a redundant gather. + zero_centered_gamma : bool, default = 'False' + if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + the LayerNorm formula changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * + (1 + \gamma) + \beta + device : Union[torch.device, str], default = "cuda" + The device on which the parameters of the model will be allocated. It is the user's + responsibility to ensure all parameters are moved to the GPU before running the + forward pass. + + Parallelism parameters + ---------------------- + set_parallel_mode : bool, default = `False` + if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row + Parallel as described `here `_. + sequence_parallel : bool, default = `False` + if set to `True`, uses sequence parallelism. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + tp_size : int, default = 1 + used as TP (tensor parallel) world size when TP groups are not formed during + initialization. In this case, users must call the + `set_tensor_parallel_group(tp_group)` method on the initialized module before the + forward pass to supply the tensor parallel group needed for tensor and sequence + parallel collectives. + + Optimization parameters + ----------------------- + fuse_wgrad_accumulation : bool, default = 'False' + if set to `True`, enables fusing of creation and accumulation of + the weight gradient. When enabled, it is assumed that the weights + have an additional `main_grad` attribute (used instead of the + regular `grad`) which is a pre-allocated buffer of the correct + size to accumulate gradients in. + return_bias : bool, default = `False` + when set to `True`, this module will not apply the additive bias for FC2, but + instead return the bias value during the forward pass together with the + output of the linear transformation :math:`y = xA^T`. This is useful when + the bias addition can be fused to subsequent operations. + params_dtype : torch.dtype, default = `torch.get_default_dtype()` + it controls the type used to allocate the initial parameters. Useful when + the model is trained with lower precision and the original FP32 parameters + would not fit in GPU memory. + seq_length: int + sequence length of input samples. Needed for JIT Warmup, a technique where jit fused + functions are warmed up before training to ensure same kernels are used for forward + propogation and activation recompute phase. + micro_batch_size: int + batch size per training step. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are + used for forward propogation and activation recompute phase. + """ + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + eps: float = 1e-5, + sequence_parallel: bool = False, + return_bias: bool = False, + get_rng_state_tracker: Optional[Callable] = None, + tp_group: Optional[dist_group_type] = None, + tp_size: int = 1, + init_method: Optional[Callable] = None, + bias: bool = True, + normalization: str = "LayerNorm", + activation: str = "gelu", + output_layer_init_method: Optional[Callable] = None, + fuse_wgrad_accumulation: bool = False, + params_dtype: Optional[torch.dtype] = None, + return_layernorm_output: bool = False, + return_layernorm_output_gathered: bool = False, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + set_parallel_mode: bool = False, + zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, + ) -> None: + super().__init__() + + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + self.normalization = normalization + assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" + self.use_bias = bias + self.activation = activation + self.return_bias = return_bias + self.apply_bias = bias and not return_bias + self.return_layernorm_output = return_layernorm_output + self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.bias_gelu_nvfusion = ( + bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and self.activation == "gelu" + ) + self.set_parallel_mode = set_parallel_mode + self.zero_centered_gamma = zero_centered_gamma + self.ub_bulk_wgrad = ub_bulk_wgrad + self.ub_bulk_dgrad = ub_bulk_dgrad + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + self.ub_overlap_rs = ub_overlap_rs + self.ub_overlap_ag = ub_overlap_ag + # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap + self.gemm_gelu_fusion = ( + bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) + and self.activation == "gelu" + and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) + ) + + if tp_group is None: + self.tp_size = tp_size + if tp_size == 1: + self.set_tensor_parallel_group(tp_group) + else: + self.tp_size = get_distributed_world_size(tp_group) + self.set_tensor_parallel_group(tp_group) + self.set_nccl_overlap_warning_if_tp() + + if init_method is None: + init_method = get_default_init_method() + if output_layer_init_method is None: + output_layer_init_method = get_default_init_method() + + self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + self.size_per_partition = divide(ffn_hidden_size, self.tp_size) + + # Initialize params in FP8 + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + + # LN init + self.eps = eps + layer_norm_weight = Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma)), + ) + if self.normalization != "RMSNorm": + layer_norm_bias = Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)) + self.register_parameter( + "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) + ) + else: + self.layer_norm_bias = None + + # FC1 init + if self.activation in ["reglu", "geglu", "swiglu"]: + fc1_output_features = 2 * self.size_per_partition + else: + fc1_output_features = self.size_per_partition + + fc1_weight = Parameter( + torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) + ) + self.register_parameter( + "fc1_weight", + fc1_weight, + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + if self.use_bias: + fc1_bias = Parameter( + torch.empty(fc1_output_features, device=device, dtype=params_dtype) + ) + self.register_parameter("fc1_bias", fc1_bias, init_fn=init_method_constant(0.0)) + else: + self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) + + # FC2 init + fc2_weight = Parameter( + torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) + ) + self.register_parameter( + "fc2_weight", + fc2_weight, + init_fn=output_layer_init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + + if self.use_bias: + fc2_bias = Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)) + self.register_parameter("fc2_bias", fc2_bias, init_fn=init_method_constant(0.0)) + else: + self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) + + if with_fp8_params: + self.init_fp8_metadata(num_gemms=2) + + self.reset_parameters(defer_init=(device == "meta")) + + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.set_parallel_mode and self.apply_bias: + self.gemm_bias_unfused_add = True + else: + self.gemm_bias_unfused_add = False + + if self.bias_gelu_nvfusion: + set_jit_fusion_options() + if seq_length and micro_batch_size: + warmup_jit_bias_gelu_all_dtypes( + self.size_per_partition, seq_length, micro_batch_size + ) + + # These many SMs are subtracted from the total SM count when calling forward + # and backward LayerNorm C APIs. These envvars can be used to prevent the LN + # kernels from using all SMs in the device. This is useful for cases such as + # communication overlap with LN. + self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) + self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + + def reset_layer_norm_parameters(self) -> None: + """Init LN params""" + warnings.warn( + "This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormMLP.reset_parameters() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not self.zero_centered_gamma: + init.ones_(self.layer_norm_weight) + else: + init.zeros_(self.layer_norm_weight) + if self.layer_norm_bias is not None: + init.zeros_(self.layer_norm_bias) + + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallel attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallel attributes for linear parameters + set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) + set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) + if self.use_bias: + set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) + if self.set_parallel_mode: + setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel) + + @no_torch_dynamo() + def forward( + self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + Apply layer normalization to the input followed by a feedforward network (MLP Block). + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + is_first_microbatch : {True, False, None}, default = None + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) + """ + + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + + with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + + # Get weight tensors + fc1_weight = self.fc1_weight + fc2_weight = self.fc2_weight + if not self.fp8: + if isinstance(fc1_weight, Float8Tensor): + fc1_weight = fc1_weight.from_float8() + if isinstance(fc2_weight, Float8Tensor): + fc2_weight = fc2_weight.from_float8() + + # Cast weights to FP8 if needed + fc1_weight_fp8 = None + fc2_weight_fp8 = None + if self.fp8: + update_workspace = is_first_microbatch is None or is_first_microbatch + if isinstance(fc1_weight, Float8Tensor): + if fc1_weight._transpose is not None: + fc1_weight.transpose_2d( + fill_cache=True, + noop_flag=skip_fp8_weight_update, + ) + else: + cache_name = None + if is_first_microbatch is not None: + cache_name = "fc1_weight" + fc1_weight_fp8 = self.get_fp8_workspace( + tensor=fc1_weight, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + cache_name=cache_name, + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + if isinstance(fc2_weight, Float8Tensor): + if fc2_weight._transpose is not None: + fc2_weight.transpose_2d( + fill_cache=True, + noop_flag=skip_fp8_weight_update, + ) + else: + cache_name = None + if is_first_microbatch is not None: + cache_name = "fc2_weight" + fc2_weight_fp8 = self.get_fp8_workspace( + tensor=fc2_weight, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + cache_name=cache_name, + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + + # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode + if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): + self.bias_gelu_nvfusion = False + + from ..cpu_offload import CPUOffloadEnabled + + if torch.is_grad_enabled(): + fwd_fn = _LayerNormMLP.apply + args = [] + else: + fwd_fn = _LayerNormMLP.forward + args = [None] + args += ( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + fc1_weight, + fc1_weight_fp8, + self.fc1_bias, + self.use_bias, + fc2_weight, + fc2_weight_fp8, + self.fc2_bias, + self.apply_bias and not self.gemm_bias_unfused_add, + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.fp8_meta, + self.fuse_wgrad_accumulation, + CPUOffloadEnabled, + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + self.bias_gelu_nvfusion, + self.set_parallel_mode, + torch.is_grad_enabled(), + self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.activation, + self.normalization, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_overlap_rs, + self.ub_overlap_ag, + self.gemm_gelu_fusion, + self.fsdp_group, + ) + out = fwd_fn(*args) + + if self.return_layernorm_output: + out, ln_out = out + + if self.gemm_bias_unfused_add: + out = out + cast_if_needed(self.fc2_bias, self.activation_dtype) + + if self.return_bias: + if self.return_layernorm_output: + return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out + return out, cast_if_needed(self.fc2_bias, self.activation_dtype) + if self.return_layernorm_output: + return out, ln_out + return out diff --git a/xpu_timer/experiments/scripts/llama3.1/train_llama3.1_8b.sh b/xpu_timer/experiments/scripts/llama3.1/train_llama3.1_8b.sh new file mode 100755 index 0000000000..4530b86995 --- /dev/null +++ b/xpu_timer/experiments/scripts/llama3.1/train_llama3.1_8b.sh @@ -0,0 +1,213 @@ +#!/bin/bash +set -ex + +export PYTHONPATH=$PWD/ant_patches:$PYTHONPATH + +export OMP_NUM_THREADS=1 +#export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_DEBUG_SUBSYS=INIT # disable aistudio default nccl env + +GPUS_PER_NODE=$(nvidia-smi -L | wc -l) +#GPUS_PER_NODE=4 +#export CUDA_VISIBLE_DEVICES=4,5,6,7 + +lines=`echo $POD_NAME | grep edljob | wc -l` +if [ $lines -eq 0 ]; then + WORLD_SIZE=${WORLD_SIZE:-1} + NODE_RANK=${RANK:-0} + MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} + RANDOM_PORT=$[$RANDOM + 20000] + MASTER_PORT=${MASTER_PORT:-$RANDOM_PORT} + GPU_NUM=$((${GPUS_PER_NODE}*${WORLD_SIZE})) + echo "---> from pytorch runtime, WORLD_SIZE: ${WORLD_SIZE}, NODE_RANK: ${NODE_RANK}, MASTER_ADDR: ${MASTER_ADDR}, MASTER_PORT: ${MASTER_PORT}" + LAUNCHER=" \ + torchrun \ + --nproc_per_node ${GPUS_PER_NODE} \ + --nnodes ${WORLD_SIZE} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + " +else + WORLD_SIZE=${WORKER_NUM:-1} + NODE_RANK=${RANK:-0} + MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} + RANDOM_PORT=$[$RANDOM + 20000] + MASTER_PORT=${MASTER_PORT:-$RANDOM_PORT} + GPU_NUM=$((${GPUS_PER_NODE}*${WORLD_SIZE})) + echo "---> from edl runtime, WORLD_SIZE: ${WORLD_SIZE}, NODE_RANK: ${NODE_RANK}" + LAUNCHER=" \ + python -m atorch.distributed.run --fault_tolerant --network-check \ + --max_restarts=1 \ + --nnode=$WORLD_SIZE \ + --nproc_per_node=$GPUS_PER_NODE \ + --rdzv_conf join_timeout=300 \ + " +fi + +MODEL_SIZE=${MODEL_SIZE:-8} +if [[ ${MODEL_SIZE} == 8 ]]; then HIDDEN_SIZE=4096; NUM_HEADS=32; NUM_QUERY_GROUPS=8; NUM_LAYERS=32; FFN_HIDDEN_SIZE=14336; MAX_POSITION_EMBEDDINGS=131072; VOCAB_SIZE=128256; +elif [[ ${MODEL_SIZE} == 70 ]]; then HIDDEN_SIZE=8192; NUM_HEADS=64; NUM_QUERY_GROUPS=8; NUM_LAYERS=80; FFN_HIDDEN_SIZE=28672; MAX_POSITION_EMBEDDINGS=131072; VOCAB_SIZE=128256; +elif [[ ${MODEL_SIZE} == 405 ]]; then HIDDEN_SIZE=16384; NUM_HEADS=128; NUM_QUERY_GROUPS=8; NUM_LAYERS=126; FFN_HIDDEN_SIZE=53248; MAX_POSITION_EMBEDDINGS=131072; VOCAB_SIZE=128256; +elif [[ ${MODEL_SIZE} == 1 ]]; then HIDDEN_SIZE=2048; NUM_HEADS=32; NUM_QUERY_GROUPS=8; NUM_LAYERS=16; FFN_HIDDEN_SIZE=8192; MAX_POSITION_EMBEDDINGS=131072; VOCAB_SIZE=128256; +elif [[ ${MODEL_SIZE} == 3 ]]; then HIDDEN_SIZE=3072; NUM_HEADS=24; NUM_QUERY_GROUPS=8; NUM_LAYERS=28; FFN_HIDDEN_SIZE=8192; MAX_POSITION_EMBEDDINGS=131072; VOCAB_SIZE=128256; +else echo "invalid MODEL_SIZE: ${MODEL_SIZE}"; exit 1 +fi + +DEVICE_MODEL=$(nvidia-smi -i 0 -q | grep "Product Name" | awk -F: '{ print $2 }') +DEVICE_MODEL=$(echo "$DEVICE_MODEL" | xargs) # drop white space +DEVICE_MODEL=NVIDIA + +JOB_DIR="/tmp/llama3_${MODEL_SIZE}B_${DEVICE_MODEL}_${GPU_NUM}p" +echo $JOB_DIR +mkdir -p ${JOB_DIR} +CHECKPOINT_PATH=${JOB_DIR} # +TENSORBOARD_LOGS_PATH=${JOB_DIR} + + +cp -r ${0} ${JOB_DIR} +#pip list > ${JOB_DIR}/pip_list.txt + +DATASET_DIR="/dnn_training_sys/dataset/nlp/fineweb-edu/CC-MAIN-2024-10/" +DATASET0="${DATASET_DIR}/CC-MAIN-2024-10_0000_text_document" +DATASET1="${DATASET_DIR}/CC-MAIN-2024-10_0001_text_document" +DATASET2="${DATASET_DIR}/CC-MAIN-2024-10_0002_text_document" +DATASET4="${DATASET_DIR}/CC-MAIN-2024-10_0004_text_document" + +DATA_PATH="0.25 ${DATASET0} 0.25 ${DATASET1} 0.25 ${DATASET2} 0.25 ${DATASET4}" +LOG_PATH="${JOB_DIR}/debug_llama_${RANK}.txt" + + + +GPT_MODEL_ARGS=( + --num-layers $NUM_LAYERS + --hidden-size $HIDDEN_SIZE + --ffn-hidden-size $FFN_HIDDEN_SIZE + --num-attention-heads $NUM_HEADS + --group-query-attention + --num-query-groups $NUM_QUERY_GROUPS + --max-position-embeddings $MAX_POSITION_EMBEDDINGS + --vocab-size $VOCAB_SIZE + --position-embedding-type "rope" + --rotary-base 500000 + --rotary-percent 1.0 + --swiglu + --untie-embeddings-and-output-weights + --normalization "RMSNorm" + --norm-epsilon "1e-05" + --disable-bias-linear + --transformer-impl "transformer_engine" +) + +TRAINING_ARGS=( + --micro-batch-size 2 + --global-batch-size 128 + --seq-length "4096" + --train-iters 800 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 + --init-method-std 0.006 + --clip-grad 1.0 + --bf16 + --lr "3.0e-4" + --lr-decay-style cosine + --min-lr "3.0e-5" + --lr-warmup-fraction 0.001 + --lr-decay-iters 20000 + --seed 42 +) + + +if [ "$DEVICE_MODEL" = "A800-SXM4-80GB" ] || [ "$DEVICE_MODEL" = "A100-SXM4-80GB" ]; then + # Ampere GPUs do not support multicast. If `--tp-comm-overlap` is set on Ampere-arch GPUs, this env must be set. + export UB_SKIPMC=1 +fi +export NVTE_FLASH_ATTN=1 + +# deterministic computation +#export PYTORCH_JIT=0 +#export NVTE_TORCH_COMPILE=0 +#export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +#export NCCL_ALGO="Ring" +#export CUBLAS_WORKSPACE_CONFIG=":4096:8" + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size 4 + --use-distributed-optimizer + --no-async-tensor-model-parallel-allreduce + --sequence-parallel + --manual-gc + --manual-gc-interval 50 +) + + #--pipeline-model-parallel-size 4 + #--num-layers-per-virtual-pipeline-stage 2 + + #--overlap-param-gather + #--overlap-grad-reduce +# some optional args + +# --use-distributed-optimizer +# --overlap-param-gather +# --overlap-grad-reduce +# --context-parallel-size 2 +# --tp-comm-overlap +# --decoder-first-pipeline-num-layers +# --decoder-last-pipeline-num-layers + + +DATA_ARGS=( + --mock-data + --tokenizer-type "NullTokenizer" +) + + +EVAL_AND_LOGGING_ARGS=( + --save-interval 1000 + --eval-interval 1000 + --save $CHECKPOINT_PATH + --load $CHECKPOINT_PATH + --ckpt-format "torch_dist" + --async-save + --eval-iters 250 + --log-interval 1 + --log-throughput + --tensorboard-dir $TENSORBOARD_LOGS_PATH + --log-timers-to-tensorboard + --log-memory-to-tensorboard + --log-world-size-to-tensorboard + --log-validation-ppl-to-tensorboard +) + +KERNEL_ARGS=( + --use-flash-attn + --no-masked-softmax-fusion + --attention-softmax-in-fp32 +) +# +# --deterministic-mode +# --use-flash-attn +# --cross-entropy-loss-fusion + +PROFILING_ARGS=( + --profile + --use-pytorch-profiler + --profile-ranks 0 + --profile-step-start 10 + --profile-step-end 20 +) + +CMD="${LAUNCHER} pretrain_gpt.py \ + ${GPT_MODEL_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} \ + ${KERNEL_ARGS[@]} \ + " + + #${PROFILING_ARGS[@]} \ +echo ${CMD} +nohup ${CMD} > ${LOG_PATH} 2>&1 & +#${CMD} 2>&1 | tee ${LOG_PATH} diff --git a/xpu_timer/experiments/scripts/pretrain_bert.py b/xpu_timer/experiments/scripts/pretrain_bert.py new file mode 100644 index 0000000000..35884ecdc4 --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_bert.py @@ -0,0 +1,193 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain BERT""" + +from functools import partial + +import torch +import torch.nn.functional as F + +from megatron.training import get_args +from megatron.training import get_tokenizer +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.core import tensor_parallel +from megatron.core.enums import ModelType +import megatron.legacy.model +from megatron.core.models.bert.bert_model import BertModel +from megatron.training import pretrain +from megatron.training.utils import average_losses_across_data_parallel_group +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core.transformer.spec_utils import import_module +from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec, bert_layer_local_spec +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.bert_dataset import BERTMaskedWordPieceDataset, BERTMaskedWordPieceDatasetConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core import mpu, tensor_parallel + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building BERT model ...') + + args = get_args() + config = core_transformer_config_from_args(args) + num_tokentypes = 2 if args.bert_binary_head else 0 + + if args.use_legacy_models: + model = megatron.legacy.model.BertModel( + config=config, + num_tokentypes=num_tokentypes, + add_binary_head=args.bert_binary_head, + parallel_output=True, + pre_process=pre_process, + post_process=post_process) + else: + if args.spec is None: + transformer_layer_spec = bert_layer_with_transformer_engine_spec #default spec + elif args.spec[0] == 'local': + print_rank_0('Using Local spec for transformer layers') + transformer_layer_spec = bert_layer_local_spec + else : + transformer_layer_spec = import_module(args.spec) + + model = BertModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + num_tokentypes=num_tokentypes, + add_binary_head=args.bert_binary_head, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + parallel_output=True, + pre_process=pre_process, + post_process=post_process) + + return model + + +def get_batch(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', + 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b['text'].long() + types = data_b['types'].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'].float() + lm_labels = data_b['labels'].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +def loss_func(loss_mask, sentence_order, output_tensor): + lm_loss_, sop_logits = output_tensor + + lm_loss_ = lm_loss_.float() + loss_mask = loss_mask.float() + lm_loss = torch.sum( + lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + if sop_logits is not None: + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), + sentence_order.view(-1), + ignore_index=-1) + sop_loss = sop_loss.float() + loss = lm_loss + sop_loss + averaged_losses = average_losses_across_data_parallel_group( + [lm_loss, sop_loss]) + return loss, {'lm loss': averaged_losses[0], + 'sop loss': averaged_losses[1]} + else: + loss = lm_loss + averaged_losses = average_losses_across_data_parallel_group( + [lm_loss]) + return loss, {'lm loss': averaged_losses[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch( + data_iterator) + timers('batch-generator').stop() + + if not args.bert_binary_head: + types = None + + # Forward pass through the model. + output_tensor = model(tokens, padding_mask, + tokentype_ids=types, lm_labels=lm_labels) + + return output_tensor, partial(loss_func, loss_mask, sentence_order) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + tokenizer = get_tokenizer() + + config = BERTMaskedWordPieceDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + path_to_cache=args.data_cache_path, + tokenizer=tokenizer, + masking_probability=args.mask_prob, + short_sequence_probability=args.short_seq_prob, + masking_max_ngram=3, + masking_do_full_word=True, + masking_do_permutation=False, + masking_use_longer_ngrams=False, + masking_use_geometric_distribution=False, + classification_head=args.bert_binary_head, + ) + + print_rank_0('> building train, validation, and test datasets ' + 'for BERT ...') + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + BERTMaskedWordPieceDataset, + train_val_test_num_samples, + lambda: mpu.get_tensor_model_parallel_rank() == 0, + config, + ).build() + + print_rank_0("> finished creating BERT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain(train_valid_test_datasets_provider, model_provider, + ModelType.encoder_or_decoder, + forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) diff --git a/xpu_timer/experiments/scripts/pretrain_gpt.py b/xpu_timer/experiments/scripts/pretrain_gpt.py new file mode 100644 index 0000000000..2d1d7c3211 --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_gpt.py @@ -0,0 +1,272 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Pretrain GPT.""" + +import os +import torch +from functools import partial +from contextlib import nullcontext +import inspect + +from typing import Union +import ant_patches + +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +import megatron.legacy.model +from megatron.core.models.gpt import GPTModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + + +stimer = StragglerDetector() + +def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.use_legacy_models: + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: # using core models + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.fp8) + else: + transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + + build_model_context = nullcontext + build_model_context_args = {} + if args.fp8_param_gather: + try: + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Check if fp8_model_init supports preserve_high_precision_init_val + if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters: + build_model_context_args["preserve_high_precision_init_val"] = True + except: + raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") + + with build_model_context(**build_model_context_args): + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base + ) + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss[0].isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: GPTModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + s3_cache_path = args.s3_cache_path + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) diff --git a/xpu_timer/experiments/scripts/pretrain_ict.py b/xpu_timer/experiments/scripts/pretrain_ict.py new file mode 100644 index 0000000000..205588b5e9 --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_ict.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain BERT for Inverse Cloze Task""" + +from functools import partial +import math + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.legacy.data.biencoder_dataset_utils import get_ict_batch +from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets +from megatron.legacy.model.biencoder_model import biencoder_model_provider +from megatron.training import pretrain +from megatron.training.utils import average_losses_across_data_parallel_group + + +def pretrain_ict_model_provider(pre_process=True, post_process=True): + args = get_args() + + model = biencoder_model_provider( + only_context_model=False, + only_query_model=False, + biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + + return model + +def get_group_world_size_rank(): + + group = mpu.get_data_parallel_group() + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + + return group, rank, world_size + + +class AllgatherFromDataParallelRegion(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_): + assert input_.dim() == 2 + group, rank, world_size = get_group_world_size_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=0).contiguous() + + return output + + + @staticmethod + def backward(ctx, grad_output): + group, rank, world_size = get_group_world_size_rank() + + assert grad_output.shape[0] % world_size == 0 + dim_size = grad_output.shape[0] // world_size + output_list = torch.split(grad_output, dim_size, dim=0) + + # get chunk from this rank + output = output_list[rank].contiguous() + return output + +def loss_func(output_tensor): + args = get_args() + query_logits, context_logits = output_tensor + + micro_batch_size = query_logits.shape[0] + # recall we assert that tensor_model_parallel_size == 1 + assert mpu.get_tensor_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + global_batch_size = dist.get_world_size() * micro_batch_size + all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) + all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) + + # scores are inner products between query and context embeddings + retrieval_scores = torch.matmul(all_query_logits, + torch.transpose(all_context_logits, 0, 1)) + # scaling the retriever scores + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) + + softmax_scores = F.log_softmax(retrieval_scores, dim=1) + sorted_vals, sorted_indices = torch.topk(softmax_scores, + k=softmax_scores.shape[1], sorted=True) + + def topk_accuracy(k): + return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \ + for i in range(global_batch_size)]) / global_batch_size]) + + topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies] + + labels = torch.arange(global_batch_size).long().cuda() + loss = F.nll_loss(softmax_scores, labels, reduction='mean') + reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs]) + + # Scale the retrieval loss + loss = loss * mpu.get_data_parallel_world_size() + + # create stats_dict with retrieval loss and all specified top-k accuracies + topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ + zip(args.retriever_report_topk_accuracies, reduced_losses[1:])} + stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict) + return loss, stats_dict + + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + query_tokens, query_mask, \ + context_tokens, context_mask, context_indices = get_ict_batch(data_iterator) + timers('batch-generator').stop() + + # Query and Context Types + query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) + context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0) + + # Forward model. + output_tensor = model(query_tokens, query_mask, query_types, context_tokens, + context_mask, context_types) + + return output_tensor, partial(loss_func) + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid and test datasets.""" + args = get_args() + print_rank_0('> building train, validation, and test datasets ' + 'for BERT ICT...') + + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + max_seq_length=args.seq_length, + masked_lm_prob=args.mask_prob, + short_seq_prob=args.short_seq_prob, + seed=args.seed, + binary_head=False, + dataset_type='ict') + print_rank_0("> finished creating BERT ICT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + print_rank_0("WARNING : This script is DEPRECATED. Will be removed in mcore release 0.9") + pretrain(train_valid_test_datasets_provider, + pretrain_ict_model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) diff --git a/xpu_timer/experiments/scripts/pretrain_mamba.py b/xpu_timer/experiments/scripts/pretrain_mamba.py new file mode 100644 index 0000000000..f8202b6eac --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_mamba.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Pretrain Mamba.""" + +import os +import torch +from functools import partial + +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +# from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +from megatron.core.models.mamba import MambaModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + + +stimer = StragglerDetector() + +def count_parameters_in_layer(model, layer_name): + num_params = 0 + for name, param in model.named_parameters(): + if layer_name in name: + num_params += param.numel() + print_rank_0(f" - {name}: {param.numel()}") + return num_params + + +def model_provider(pre_process=True, post_process=True) -> MambaModel: + """Builds the model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + MambaModel: The returned model + """ + args = get_args() + + print_rank_0('building Mamba model ...') + config = core_transformer_config_from_args(get_args()) + + assert args.use_legacy_models == False, "Mamba only supported in Mcore!" + + if args.spec is not None: + mamba_stack_spec = import_module(args.spec) + else: + raise("You must provide a valid Mamba layer spec!") + + model = MambaModel( + config=config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + hybrid_attention_ratio=args.hybrid_attention_ratio, + hybrid_mlp_ratio=args.hybrid_mlp_ratio, + hybrid_override_pattern=args.hybrid_override_pattern, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base + ) + + for l in range(model.decoder.num_layers_per_pipeline_rank): + layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') + print_rank_0(f" == params layer {l}: {layer_params}") + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss[0].isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: MambaModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (MambaModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain(train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) diff --git a/xpu_timer/experiments/scripts/pretrain_retro.py b/xpu_timer/experiments/scripts/pretrain_retro.py new file mode 100644 index 0000000000..0aecbf14ce --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_retro.py @@ -0,0 +1,245 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain Retro.""" + +from functools import partial +import torch + +from megatron.training import get_args +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.training import print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core import tensor_parallel +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.datasets.retro.query.retro_dataset import get_retro_datasets +from megatron.core.datasets.retro.query.multi_split_gpt_dataset import MultiSplitGPTDataset, MultiSplitGPTDatasetConfig +from megatron.core.enums import ModelType +from megatron.core.models.retro import get_retro_decoder_block_spec, RetroConfig, RetroModel +from megatron.core.models.retro.utils import get_all_true_mask +from megatron.training import pretrain +from megatron.training.utils import get_ltor_masks_and_position_ids +from pretrain_gpt import ( + is_dataset_built_on_rank, + loss_func, + model_provider as default_model_provider, + train_valid_test_datasets_provider as gpt_train_valid_test_datasets_provider, +) + + +def get_retro_config(): + return core_transformer_config_from_args(get_args(), RetroConfig) + + +def core_model_provider(pre_process=True, post_process=True): + """Build the model using Megatron-Core.""" + + args = get_args() + config = get_retro_config() + + # NOTE: Experimental customization feature + if args.spec is not None: + block_spec = import_module(args.spec)() + else: + block_spec = get_retro_decoder_block_spec(config, use_transformer_engine=True) + + print_rank_0('building GPT model ...') + model = RetroModel( + config=config, + transformer_layer_spec=block_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent + ) + return model + + +def model_provider(pre_process=True, post_process=True): + """Build the model. + + Select between two different model classes: + 1. Default model (uses megatron.legacy.models/gpt_model.py). + 2. Core model (uses megatron/core/models/retro/model.py). + """ + + args = get_args() + if not args.use_legacy_models and args.retro_add_retriever: + provider = core_model_provider + else: + provider = default_model_provider + model = provider(pre_process=pre_process, post_process=post_process) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + + args = get_args() + tokenizer = get_tokenizer() + config = get_retro_config() + + # Items and their type. + keys = ['text'] + if args.retro_add_retriever: + keys.append('neighbor_tokens') + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + if args.retro_add_retriever: + # note: [bs * l * k, r] + # note: 2x == neighbor, continuation + neighbor_tokens = data_b['neighbor_tokens'] \ + .view(-1, config.retro_retrieved_length).long() + _, _, neighbor_position_ids = get_ltor_masks_and_position_ids( + neighbor_tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + neighbor_attention_mask = get_all_true_mask( + (1, 1, config.retro_retrieved_length, config.retro_retrieved_length), + neighbor_tokens.device) + return tokens, labels, loss_mask, attention_mask, position_ids, \ + neighbor_tokens, neighbor_attention_mask, neighbor_position_ids + + else: + return tokens, labels, loss_mask, attention_mask, position_ids + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + if args.retro_add_retriever: + tokens, labels, loss_mask, attention_mask, position_ids, \ + neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \ + get_batch(data_iterator) + else: + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \ + None, None, None + timers('batch-generator').stop() + + # Model call. + if args.use_legacy_models: + forward_kwargs = { + "retriever_input_ids" : neighbor_tokens, + "retriever_position_ids" : neighbor_position_ids, + "retriever_attn_mask" : neighbor_attention_mask, + } + else: + if args.retro_add_retriever: + forward_kwargs = { + "context_input_ids" : neighbor_tokens, + "context_position_ids" : neighbor_position_ids, + "context_mask" : neighbor_attention_mask, + } + else: + forward_kwargs = {} + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels, **forward_kwargs) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_valid_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + # Dataset config. + retro_config = get_retro_config() + data_config = MultiSplitGPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + split_preprocessing=retro_config.retro_split_preprocessing, + path_to_cache=args.data_cache_path, + return_document_ids=False, + tokenizer=get_tokenizer(), + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + ) + + # GPT datasets. + print_rank_0(" > multi-split gpt datasets.") + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + MultiSplitGPTDataset, + train_valid_test_num_samples, + is_dataset_built_on_rank, + data_config, + ).build() + + gpt_datasets = { + "train" : (train_ds, train_valid_test_num_samples[0]), + "valid" : (valid_ds, train_valid_test_num_samples[1]), + "test" : (test_ds, train_valid_test_num_samples[2]), + } + + # Retro datasets. + if args.retro_add_retriever: + return get_retro_datasets( + config=retro_config, + gpt_datasets=gpt_datasets, + sample_length=args.seq_length, + eod_token_id=get_tokenizer().eod, + ) + + # Multi-split GPT datasets. + else: + return ( + gpt_datasets["train"][0], + gpt_datasets["valid"][0], + gpt_datasets["test"][0], + ) + + +if __name__ == "__main__": + + # Temporary for transition to core datasets. + train_valid_test_datasets_provider.is_distributed = True + + pretrain(train_valid_test_datasets_provider, + model_provider, + ModelType.retro_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) diff --git a/xpu_timer/experiments/scripts/pretrain_t5.py b/xpu_timer/experiments/scripts/pretrain_t5.py new file mode 100644 index 0000000000..253d4b19c6 --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_t5.py @@ -0,0 +1,299 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain T5""" + +from copy import deepcopy +from functools import partial +from typing import Union + +import torch + +from megatron.training import ( + get_args, + get_timers, + get_tokenizer, + print_rank_0 +) +from megatron.core import mpu, tensor_parallel +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.t5_dataset import ( + T5MaskedWordPieceDataset, + T5MaskedWordPieceDatasetConfig, +) +from megatron.core.enums import ModelType +from megatron.core.models.T5 import T5Model +from megatron.training import pretrain +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset, T5MaskedWordPieceDatasetConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.models.T5.t5_spec import (get_t5_encoder_with_transformer_engine_block_spec, + get_t5_decoder_with_transformer_engine_block_spec, + get_t5_encoder_with_local_block_spec, + get_t5_decoder_with_local_block_spec) +from megatron.legacy.model import T5Model as LegacyT5Model +from pretrain_gpt import loss_func + +""" +Pipeline parallelism for T5 + +T5 is a model architecture with both encoder and decoder blocks. +Consequently, pipeline parallelism is implemented slightly differently +compared to architectures like GPT and BERT. + +In particular, when pipeline_model_parallel_world_size > 1, each stage +either executes an encoder block or a decoder block. The +--pipeline-model-parallel-split-rank argument controls the rank at which +the split happens: all ranks lower than this argument execute the +encoder block, and all ranks equal to or higher than this argument value +execute the decoder block. + +In the encoder section of the model, only one tensor is sent downstream: +the intermediate encoder_hidden_state. In the decoder section of the +model, two tensors are sent downstream in the forward pass: the fully +computed encoder_hidden_state, and the intermediate decoder_hidden_state. + +In particular, these are the shapes of the tensors sent between +different workers: + If rank is in decoder section: + intermediate decoder_hidden_state (pre-transpose), + complete encoder_hidden_state (post-transpose). + If rank is at boundary between encoder and decoder sections: + complete encoder_hidden_state (post-transpose). + If rank is in encoder section: + intermediate encoder_hidden_state (pre-transpose). + +Additionally, we have code in the backward_step function in schedules.py +to accumulate the encoder_hidden_state gradient across skip connections +(encoder_hidden_state fed in as input to each layer in the decoder). +""" + + +def model_provider( + pre_process=True, post_process=True, add_encoder=True, add_decoder=True +) -> Union[LegacyT5Model, T5Model]: + """Builds the model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + add_encoder (bool, optional): Defaults to True + add_decoder (bool, optional): Defaults to True + Returns: + T5Model: The returned T5 model + """ + + args = get_args() + + assert ( + args.encoder_tensor_model_parallel_size == 0 or + args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size + ), f"Because word embeddings are shared between the encoder & decoder, these have to have the same tensor parallel size." + + config = core_transformer_config_from_args(args) + if args.use_legacy_models: + model = LegacyT5Model( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + ) + else: + encoder_config = deepcopy(config) + encoder_config.num_layers = args.encoder_num_layers + + if args.pipeline_model_parallel_size > 1: + assert args.encoder_pipeline_model_parallel_size > 0, "Need to know how to shard the encoder & decoder." + + if args.encoder_pipeline_model_parallel_size > 0: + encoder_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size + + encoder_layers_per_pipeline = encoder_config.num_layers // encoder_config.pipeline_model_parallel_size + decoder_layers_per_pipeline = config.num_layers // config.pipeline_model_parallel_size + + if args.transformer_impl == "local": + en_block_spec = get_t5_encoder_with_local_block_spec(encoder_layers_per_pipeline) + de_block_spec = get_t5_decoder_with_local_block_spec(decoder_layers_per_pipeline) + elif args.transformer_impl == "transformer_engine": + en_block_spec = get_t5_encoder_with_transformer_engine_block_spec( + encoder_layers_per_pipeline + ) + de_block_spec = get_t5_decoder_with_transformer_engine_block_spec( + decoder_layers_per_pipeline + ) + + print_rank_0('building T5 model ...') + model = T5Model( + config=config, + encoder_config=encoder_config, + transformer_encoder_layer_spec=en_block_spec, + transformer_decoder_layer_spec=de_block_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + add_encoder=add_encoder, + add_decoder=add_decoder + ) + + return model + + +def get_batch(data_iterator): + """Build the batch.""" + + keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask', 'enc_dec_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_enc = data_b['text_enc'].long() + tokens_dec = data_b['text_dec'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() + + enc_mask = data_b['enc_mask'] < 0.5 + dec_mask = data_b['dec_mask'] < 0.5 + enc_dec_mask = data_b['enc_dec_mask'] < 0.5 + + return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask + + +def forward_step(data_iterator, model: T5Model): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (T5Model): The T5 Model + """ + + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch generator', log_level=2).start() + tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = get_batch( + data_iterator + ) + timers('batch generator').stop() + + # Forward model lm_labels + output_tensor = model( + tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, lm_labels=lm_labels + ) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples: int): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + tokenizer = get_tokenizer() + + config = T5MaskedWordPieceDatasetConfig( + random_seed=args.seed, + sequence_length=args.encoder_seq_length, + sequence_length_decoder=args.decoder_seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + path_to_cache=args.data_cache_path, + tokenizer=tokenizer, + masking_probability=args.mask_prob, + short_sequence_probability=args.short_seq_prob, + masking_max_ngram=10, + masking_do_full_word=True, + masking_do_permutation=False, + masking_use_longer_ngrams=False, + masking_use_geometric_distribution=True, + ) + + print_rank_0('> building train, validation, and test datasets for T5 ...') + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + T5MaskedWordPieceDataset, + train_val_test_num_samples, + lambda: mpu.get_tensor_model_parallel_rank() == 0, + config, + ).build() + + print_rank_0("> finished creating T5 datasets ...") + + return train_ds, valid_ds, test_ds + + +def t5_embedding_ranks(pp_ranks): + """T5's embedding ranks consist of the encoder's first rank, and the decoder's first & last ranks. + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + first_rank = pp_ranks[0] + last_rank = pp_ranks[-1] + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + if len(pp_ranks) == 1: + return [first_rank] + elif pp_ranks[epp] not in (first_rank, last_rank): + return [first_rank, pp_ranks[epp], last_rank] + else: + return [first_rank, last_rank] + + +def t5_position_embedding_ranks(pp_ranks): + """T5's positional embeddings are the encoder & decoder first rank stages + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + if len(pp_ranks) == 1 or pp_ranks[0] == pp_ranks[epp]: + return [pp_ranks[0]] + else: + return [pp_ranks[0], pp_ranks[epp]] + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_and_decoder, + forward_step, + args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}, + get_embedding_ranks=t5_embedding_ranks, + get_position_embedding_ranks=t5_position_embedding_ranks, + ) diff --git a/xpu_timer/experiments/scripts/pretrain_vision_classify.py b/xpu_timer/experiments/scripts/pretrain_vision_classify.py new file mode 100644 index 0000000000..8d9b28baeb --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_vision_classify.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain VIT""" + +import torch +import torch.nn.functional as F +from functools import partial +from megatron.training import get_args, get_timers, print_rank_0 +from megatron.core.enums import ModelType +from megatron.legacy.data.vit_dataset import build_train_valid_datasets +from megatron.legacy.model.vision.classification import VitClassificationModel +from megatron.legacy.model.vision.classification import MitClassificationModel +from megatron.training import pretrain +from megatron.training.utils import average_losses_across_data_parallel_group +from megatron.training.arguments import core_transformer_config_from_args + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + args = get_args() + config = core_transformer_config_from_args(args) + if args.vision_backbone_type == 'vit': + print_rank_0("building VIT model ...") + model = VitClassificationModel(config=config, + num_classes=args.num_classes, + pre_process=pre_process, + post_process=post_process) + elif args.vision_backbone_type == 'mit': + print_rank_0("building MIT model ...") + model = MitClassificationModel(num_classes=args.num_classes, + pre_process=pre_process, + post_process=post_process) + else: + raise Exception('{} vision backbone is not supported.'.format( + args.vision_backbone_type)) + return model + + +def get_batch(data_iterator): + """Build the batch.""" + data = next(data_iterator) + + # only data parallelism; no need for broadcast + images = data[0].cuda() + labels = data[1].cuda() + + return images, labels + + +def loss_func(labels, output_tensor): + logits = output_tensor.contiguous().float() + loss = F.cross_entropy(logits, labels) + + outputs = torch.argmax(logits, -1) + correct = (outputs == labels).float() + accuracy = torch.mean(correct) + + averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) + + return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} + + +def forward_step(data_iterator, model): + """Forward step.""" + timers = get_timers() + + # Get the batch. + timers("batch-generator", log_level=2).start() + ( + images, + labels, + ) = get_batch(data_iterator) + timers("batch-generator").stop() + + # Forward model. lm_labels + output_tensor = model(images) + + return output_tensor, partial(loss_func, labels) + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0( + "> building train, validation, and test datasets " "for VIT ..." + ) + train_ds, valid_ds = build_train_valid_datasets( + data_path=args.data_path, + image_size=(args.img_h, args.img_w) + ) + print_rank_0("> finished creating VIT datasets ...") + + return train_ds, valid_ds, None + + +if __name__ == "__main__": + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True} + ) diff --git a/xpu_timer/experiments/scripts/pretrain_vision_dino.py b/xpu_timer/experiments/scripts/pretrain_vision_dino.py new file mode 100644 index 0000000000..f75280c42d --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_vision_dino.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +import torch.distributed as dist +from functools import partial +from megatron.training import get_args, get_timers, print_rank_0 +from megatron.core.enums import ModelType +from megatron.legacy.data.vit_dataset import build_train_valid_datasets +from megatron.legacy.model.vision.dino import DINOPretrainModel +from megatron.legacy.model.vision.knn_monitor import knn_predict, get_feature_bank +from megatron.training import pretrain +from megatron.training.utils import average_losses_across_data_parallel_group, unwrap_model +from megatron.training.arguments import core_transformer_config_from_args + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + config = core_transformer_config_from_args(get_args()) + return DINOPretrainModel(config, pre_process=pre_process, post_process=post_process) + +def get_batch(data_iterator): + """Build the batch.""" + data = next(data_iterator) + + # only data parallelism; no need for broadcast + if isinstance(data[0], list): + images = [aug.cuda() for aug in data[0]] + else: + images = data[0].cuda() + labels = data[1].cuda() + + return images, labels + + +def loss_func(model, labels, output_tensor, collect_data=False): + args = get_args() + + model = unwrap_model(model) + if model.training: + student_output, teacher_output = output_tensor + loss = model.dino_loss(student_output, teacher_output, args.curr_iteration) + averaged_loss = average_losses_across_data_parallel_group([loss]) + return loss, {"loss": averaged_loss[0]} + else: + _, teacher_feature = output_tensor + feature_bank, feature_labels, classes = get_feature_bank() + feature = F.normalize(teacher_feature.float(), dim=1) + + knn_accs = [] + for k in [10, 20, 100, 200]: + pred_labels = knn_predict(feature, feature_bank, + feature_labels, classes, k, 0.07) + knn_acc = (pred_labels[:, 0] == labels).float().mean() + knn_accs.append(knn_acc) + + averaged_loss = average_losses_across_data_parallel_group(knn_accs) + return 0, {"knn_acc_10": averaged_loss[0], + "knn_acc_20": averaged_loss[1], + "knn_acc_100": averaged_loss[2], + "knn_acc_200": averaged_loss[3]} + + +def forward_step(data_iterator, model): + """Forward step.""" + timers = get_timers() + + # Get the batch. + timers("batch-generator", log_level=2).start() + ( + images, + labels, + ) = get_batch(data_iterator) + timers("batch-generator").stop() + + return model(images), partial(loss_func, model, labels) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0( + "> building train, validation, and test datasets " "for VIT ..." + ) + train_ds, valid_ds = build_train_valid_datasets( + data_path=args.data_path, + image_size=(args.img_h, args.img_w) + ) + print_rank_0("> finished creating VIT datasets ...") + + return train_ds, valid_ds, None + + +if __name__ == "__main__": + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True} + ) + diff --git a/xpu_timer/experiments/scripts/pretrain_vision_inpaint.py b/xpu_timer/experiments/scripts/pretrain_vision_inpaint.py new file mode 100644 index 0000000000..8570baab5b --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_vision_inpaint.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Pretrain VIT""" + +import torch +import torch.nn.functional as F +from functools import partial +from megatron.training import get_args, get_timers, print_rank_0, print_rank_last +from megatron.core.enums import ModelType +from megatron.legacy.data.vit_dataset import build_train_valid_datasets +from megatron.legacy.model.vision.inpainting import VitInpaintingModel +from megatron.legacy.model.vision.inpainting import MitInpaintingModel +from megatron.training import pretrain +from megatron.training.utils import average_losses_across_data_parallel_group +from tasks.vision.segmentation.metrics import SSIM, PSNR +from megatron.training.arguments import core_transformer_config_from_args + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + config = core_transformer_config_from_args(args) + if args.vision_backbone_type == 'vit': + model = VitInpaintingModel(config=config, + pre_process=pre_process, + post_process=post_process) + elif args.vision_backbone_type == 'mit': + model = MitInpaintingModel(config=config, + pre_process=pre_process, + post_process=post_process) + else: + raise Exception('{} vision backbone is not supported.'.format( + args.vision_backbone_type)) + return model + + +def get_batch(data_iterator): + """Build the batch.""" + data = next(data_iterator) + + # only data parallelism; no need for broadcast + images = data[0][0].cuda() + masks = data[0][1].cuda() + return images, masks + + +def loss_func(images, masks, masked_images, outputs, non_loss_data=False): + outputs = outputs.contiguous().float() + masks_flip = 1-masks + flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0) + flip_masked_images = images.masked_fill(masks_flip.bool(), 0) + + ssim_fun = SSIM() + psnr_fun = PSNR() + + if not non_loss_data: + mask_count = torch.count_nonzero(masks) + loss = F.mse_loss( + flip_masked_outputs, + flip_masked_images.float(), + reduction="sum" + ) + loss = loss/mask_count + ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float()) + psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float()) + + averaged_loss = average_losses_across_data_parallel_group( + [loss, psnr, ssim] + ) + + return loss, {"loss": averaged_loss[0], + "psnr": averaged_loss[1], + 'ssim': averaged_loss[2]} + else: + synth_images = masked_images.float() + flip_masked_outputs + ssim = ssim_fun(synth_images, images.float()) + psnr = psnr_fun(synth_images, images.float()) + return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr + + +def forward_step(data_iterator, model): + """Forward step.""" + timers = get_timers() + + # Get the batch. + timers("batch-generator", log_level=2).start() + ( + images, + masks, + ) = get_batch(data_iterator) + timers("batch-generator").stop() + + masked_images = images.masked_fill(masks.bool(), 0) + outputs = model(masked_images) + + # Forward mode + return outputs, partial(loss_func, images, masks, masked_images) + + +def process_non_loss_data(data, iteration, writer): + psnr_sum = 0 + ssim_sum = 0 + for (output_tb, ssim, psnr) in data: + output_tb[output_tb < 0] = 0 + output_tb[output_tb > 1] = 1 + writer.add_images("gt-input-output-vald", output_tb, + global_step=iteration, walltime=None, + dataformats='NCHW') + psnr_sum = psnr_sum + psnr.item() + ssim_sum = ssim_sum + ssim.item() + psnr = psnr_sum/len(data) + ssim = ssim_sum/len(data) + writer.add_scalar('PSNR generate value-validation', psnr, iteration) + writer.add_scalar('SSIM generate value-validation', ssim, iteration) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0( + "> building train, validation, and test datasets " "for VIT ..." + ) + train_ds, valid_ds = build_train_valid_datasets( + data_path=args.data_path, + image_size=(args.img_h, args.img_w) + ) + print_rank_0("> finished creating VIT datasets ...") + + return train_ds, valid_ds, None + + +if __name__ == "__main__": + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + process_non_loss_data, + args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True} + ) diff --git a/xpu_timer/experiments/scripts/pretrain_vlm.py b/xpu_timer/experiments/scripts/pretrain_vlm.py new file mode 100644 index 0000000000..b7e9aed8c7 --- /dev/null +++ b/xpu_timer/experiments/scripts/pretrain_vlm.py @@ -0,0 +1,318 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Pretrain vision language model.""" +from copy import deepcopy +from functools import partial + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig +from megatron.core.enums import ModelType +from megatron.core.models.multimodal.llava_model import LLaVAModel, IMAGE_TOKEN_INDEX +from megatron.core.models.multimodal.llava_spec import ( + decoder_model_with_transformer_engine_default_spec, + decoder_model_with_local_default_spec, +) +from megatron.core.models.vision.vit_layer_specs import ( + get_vit_layer_with_transformer_engine_spec, + get_vit_layer_with_local_spec, +) +from megatron.core.transformer.spec_utils import import_module +from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args +from pretrain_gpt import loss_func + + +def get_num_image_tokens(): + args = get_args() + add_class_token = not args.disable_vision_class_token + + num_patches_per_dim_h = args.img_h // args.patch_dim + num_patches_per_dim_w = args.img_w // args.patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + num_image_tokens = num_patches + (1 if add_class_token else 0) + return num_image_tokens + + +def model_provider( + pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True +) -> LLaVAModel: + """Builds the model. + + Note: currently, only LLaVA model is supported. Follow-up changes will make this configurable. + + Args: + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + parallel_output (bool): Enable model parallel output. + + Returns: + model (megatron.core.models.multimodal.llava_model.LLaVAModel): A multimodal model + """ + args = get_args() + + num_image_tokens = get_num_image_tokens() + args.decoder_seq_length = args.seq_length + num_image_tokens + args.seq_length = num_image_tokens + args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length) + + print_rank_0('building a multimodal model ...') + language_transformer_config = core_transformer_config_from_args(get_args()) + + if args.spec is not None: + language_transformer_layer_spec = import_module(args.spec) + elif args.transformer_impl == "transformer_engine": + language_transformer_layer_spec = decoder_model_with_transformer_engine_default_spec( + args.num_experts, args.moe_grouped_gemm + ) + else: # transformer_impl == "local" + language_transformer_layer_spec = decoder_model_with_local_default_spec( + args.num_experts, args.moe_grouped_gemm + ) + + if args.transformer_impl == "transformer_engine": + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + else: # transformer_impl == "local" + vision_transformer_layer_spec = get_vit_layer_with_local_spec() + + # TODO: Make these configurable via input .yaml config. + vision_transformer_config = deepcopy(language_transformer_config) + vision_transformer_config.num_layers = args.encoder_num_layers + vision_transformer_config.first_pipeline_num_layers = None + vision_transformer_config.last_pipeline_num_layers = None + + vision_projection_type = "mlp" + vision_projection_config = deepcopy(language_transformer_config) + + if args.encoder_pipeline_model_parallel_size > 0: + assert ( + args.encoder_pipeline_model_parallel_size == 1 + ), "ViT can only live on 1 pipeline stage." + vision_transformer_config.pipeline_model_parallel_size = ( + args.encoder_pipeline_model_parallel_size + ) + vision_projection_config.pipeline_model_parallel_size = ( + args.encoder_pipeline_model_parallel_size + ) + if args.encoder_tensor_model_parallel_size > 0: + vision_transformer_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) + vision_projection_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) + + vision_projection_modules = deepcopy(language_transformer_layer_spec.submodules.mlp.submodules) + + model = LLaVAModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_transformer_layer_spec, + language_vocab_size=args.padded_vocab_size, + language_max_sequence_length=args.max_position_embeddings, + vision_transformer_config=vision_transformer_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + drop_vision_class_token=args.disable_vision_class_token, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_modules, + vision_projection_type=vision_projection_type, + parallel_output=parallel_output, + language_position_embedding_type=args.position_embedding_type, + language_rotary_percent=args.rotary_percent, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + ) + + return model + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train, validation, and test sets. + + Returns: + train_ds, val_ds, test_ds (megatron.core.datasets.multimodal_dataset.MockMultimodalDataset): Train, validation, and test datasets, respectively. + """ + args = get_args() + + config = MultimodalDatasetConfig( + random_seed=args.seed, + split=args.split, + sequence_length=args.decoder_seq_length - args.seq_length, + tokenizer=get_tokenizer(), + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + image_h=args.img_h, + image_w=args.img_w, + preprocess_func=_preprocess_data_for_llava, + ) + + print_rank_0("> building train, validation, and test datasets for multimodal ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + MockMultimodalDataset, + train_val_test_num_samples, + lambda: parallel_state.get_tensor_model_parallel_rank() == 0, + config, + ).build() + + print_rank_0("> finished creating multimodal datasets ...") + + return train_ds, valid_ds, test_ds + + +def _preprocess_data_for_llava(data): + """Preprocess data sample to the format expected by a LLaVA model. + + Note: This doesn't support all the different modes in the official LLaVA repo yet. + + Args: + data (dict): Data sample with keys like 'image', 'tokens', etc. + + Returns: + data (dict): Processed data sample suitable for the model. + """ + # Prepend image token index to tokens. + data["tokens"] = torch.cat( + [ + IMAGE_TOKEN_INDEX + * torch.ones(1, dtype=data["tokens"].dtype, device=data["tokens"].device), + data["tokens"], + ] + ) + # Prepend labels accordingly. + data["labels"] = torch.cat([data["tokens"][1].unsqueeze(0), data["labels"]]) + # Zero loss mask for the image token index. + data["loss_mask"] = torch.cat( + [ + torch.zeros(1, dtype=data["loss_mask"].dtype, device=data["loss_mask"].device), + data["loss_mask"], + ] + ) + # Add one more position id. + data["position_ids"] = torch.cat( + [data["position_ids"], data["position_ids"][-1].unsqueeze(0) + 1] + ) + + return data + + +def get_batch(data_iterator): + """Generate a batch. + + Args: + data_iterator: Iterable dataset. + + Returns: + sample: A data sample with images, tokens, etc. + """ + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + data_i = tensor_parallel.broadcast_data(["tokens", "position_ids", "labels"], data, torch.int64) + data_f = tensor_parallel.broadcast_data(["image", "loss_mask"], data, torch.float32) + + tokens = data_i["tokens"].long() + position_ids = data_i["position_ids"].long() + labels = data_i["labels"].long() + images = data_f["image"].float() + loss_mask = data_f["loss_mask"].float() + attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model. + + return tokens, position_ids, labels, images, loss_mask, attention_mask + + +def forward_step(data_iterator, model: LLaVAModel): + """Forward training step. + + Args: + data_iterator: Iterable dataset. + model (megatron.core.models.multimodal.llava_model.LLaVAModel): Multimodal model + + Returns: + output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. + loss_func (callable): Loss function with a loss mask specified. + """ + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, position_ids, labels, images, loss_mask, attention_mask = get_batch(data_iterator) + timers('batch-generator').stop() + + output_tensor, loss_mask = model( + images, tokens, position_ids, attention_mask, labels, loss_mask + ) + + return output_tensor, partial(loss_func, loss_mask) + + +def add_vlm_extra_args(parser): + """Extra arguments.""" + group = parser.add_argument_group(title='vision language model specific arguments') + group.add_argument("--disable-vision-class-token", action="store_true", default=False) + return parser + + +def llava_embedding_ranks(pp_ranks): + """LLava's embedding ranks consist of the decoder's first and last ranks (ie, the ViT has no embeddings). + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + last_rank = pp_ranks[-1] + if len(pp_ranks) == 1 or pp_ranks[epp] == last_rank: + return [last_rank] + else: + return [pp_ranks[epp], last_rank] + + +def llava_position_embedding_ranks(pp_ranks): + """LLava's embedding ranks consist of the singular rank of the model or the decoder's first rank. + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + last_rank = pp_ranks[-1] + if len(pp_ranks) == 1: + return [last_rank] + else: + return [pp_ranks[epp]] + + +if __name__ == "__main__": + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_and_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + extra_args_provider=add_vlm_extra_args, + get_embedding_ranks=llava_embedding_ranks, + get_position_embedding_ranks=llava_position_embedding_ranks, + ) diff --git a/xpu_timer/experiments/scripts/setup.py b/xpu_timer/experiments/scripts/setup.py new file mode 100644 index 0000000000..adb00629ac --- /dev/null +++ b/xpu_timer/experiments/scripts/setup.py @@ -0,0 +1,109 @@ +"""Setup for pip package.""" + +import importlib.util +import subprocess + +import setuptools +from setuptools import Extension + +spec = importlib.util.spec_from_file_location('package_info', 'megatron/core/package_info.py') +package_info = importlib.util.module_from_spec(spec) +spec.loader.exec_module(package_info) + + +__contact_emails__ = package_info.__contact_emails__ +__contact_names__ = package_info.__contact_names__ +__description__ = package_info.__description__ +__download_url__ = package_info.__download_url__ +__homepage__ = package_info.__homepage__ +__keywords__ = package_info.__keywords__ +__license__ = package_info.__license__ +__package_name__ = package_info.__package_name__ +__repository_url__ = package_info.__repository_url__ +__version__ = package_info.__version__ + + +with open("megatron/core/README.md", "r", encoding='utf-8') as fh: + long_description = fh.read() +long_description_content_type = "text/markdown" + +############################################################################### +# Extension Making # +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # + +extra_compile_args = ( + subprocess.check_output(["python3", "-m", "pybind11", "--includes"]) + .decode("utf-8") + .strip() + .split() +) + +############################################################################### + +setuptools.setup( + name=__package_name__, + # Versions should comply with PEP440. For a discussion on single-sourcing + # the version across setup.py and the project code, see + # https://packaging.python.org/en/latest/single_source_version.html + version=__version__, + description=__description__, + long_description=long_description, + long_description_content_type=long_description_content_type, + # The project's main homepage. + url=__repository_url__, + download_url=__download_url__, + # Author details + author=__contact_names__, + author_email=__contact_emails__, + # maintainer Details + maintainer=__contact_names__, + maintainer_email=__contact_emails__, + # The licence under which the project is released + license=__license__, + classifiers=[ + # How mature is this project? Common values are + # 1 - Planning + # 2 - Pre-Alpha + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + # 6 - Mature + # 7 - Inactive + 'Development Status :: 5 - Production/Stable', + # Indicate who your project is intended for + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'Intended Audience :: Information Technology', + # Indicate what your project relates to + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Image Recognition', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: Utilities', + # Pick your license as you wish (should match "license" above) + 'License :: OSI Approved :: BSD License', + # Supported python versions + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + # Additional Setting + 'Environment :: Console', + 'Natural Language :: English', + 'Operating System :: OS Independent', + ], + packages=setuptools.find_namespace_packages(include=["megatron.core", "megatron.core.*"]), + ext_modules=[ + Extension( + "megatron.core.datasets.helpers", + sources=["megatron/core/datasets/helpers.cpp"], + language="c++", + extra_compile_args=extra_compile_args, + ) + ], + # Add in any packaged data. + include_package_data=True, + # PyPI package information. + keywords=__keywords__, +) diff --git a/xpu_timer/experiments/scripts/train_llama.py b/xpu_timer/experiments/scripts/train_llama.py new file mode 100644 index 0000000000..75a7a929f5 --- /dev/null +++ b/xpu_timer/experiments/scripts/train_llama.py @@ -0,0 +1,912 @@ +import os +import time +import functools + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Dataset +from transformers import GPTNeoXConfig, GPTNeoXForCausalLM +from transformers import LlamaConfig, LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + checkpoint_wrapper +) +from contextlib import nullcontext + + +def human_readable_flops(num): + for unit in [ + "", + "KFLOPS", + "MFLOPS", + "GFLOPS", + "TFLOPS", + "PFLOPS", + "EFLOPS", + "ZFLOPS", + ]: + if abs(num) < 1000.0: + return "%3.3f%s" % (num, unit) + num /= 1000.0 + return "%.3f%s" % (num, "Yi") + + +def compute_training_flops( + batch_size, + sequence_length, + hidden_size, + vocab_size, + intermediate_size, + num_layers, + use_gradient_checkpointing=False, + use_peft=False, + use_gqa=False, + kv_head_ratio=1, +): + """Returns: + hardware flops + model flops + + The source of formula: + Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM's + (APPENDIX: FLOATING-POINT OPERATIONS) + + Assuming that backward pass has twice FLOPs as many as forward pass. Only matrix multiplication FLOPs are computed. + For use_peft, backward pass FLOPS is a little more than the forward pass. Assuming equal for simplicity here. + """ + # [b,s,n] -> [b,s,n] + query_proj_flops = batch_size * 2 * sequence_length * hidden_size**2 + if use_gqa: + key_value_proj_flops = ( + 2 + * batch_size + * 2 + * sequence_length + * hidden_size + * hidden_size + / kv_head_ratio + ) + else: + key_value_proj_flops = 2 * query_proj_flops + attention_proj_flops = query_proj_flops + key_value_proj_flops + attention_flops = ( + 2 * batch_size * hidden_size * sequence_length**2 + + 4 * batch_size * sequence_length * hidden_size**2 + ) + attention_forward_flops = attention_proj_flops + attention_flops + # llama2 use gate_proj, has 3 Linears + two_mlps_forward_flops = ( + 3 * 2 * batch_size * sequence_length * hidden_size * intermediate_size + ) + logits_forward_flops = 2 * batch_size * sequence_length * hidden_size * vocab_size + decoder_layer_forward_flops = attention_forward_flops + two_mlps_forward_flops + # forward FLOPs without gradient checkpointing + forward_flops_wo_gc = ( + num_layers * decoder_layer_forward_flops + logits_forward_flops + ) + factor = 2 if use_peft else 3 + if not use_gradient_checkpointing: + return forward_flops_wo_gc * factor, forward_flops_wo_gc * factor + else: + return ( + num_layers * decoder_layer_forward_flops * (factor + 1) + + logits_forward_flops * factor, + forward_flops_wo_gc * factor, + ) + + + + +def apply_fsdp_checkpointing(model, blocks): + wrapper = lambda m: checkpoint_wrapper(m, + checkpoint_fn=torch.utils.checkpoint.checkpoint, + use_reentrant=False, + preserve_rng_state=True) + check_fn = lambda submodule: isinstance(submodule, blocks) + apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) + + +class DummyDataset(Dataset): + def __init__(self, vocab_size=1000, max_length=128, data_size=100000): + self.vocab_size = vocab_size + self.max_length = max_length + self.data_size = data_size + + def __len__(self): + return self.data_size + + def __getitem__(self, idx): + text = torch.randint(low=0, high=self.vocab_size, size=(self.max_length,)) + return text, text + + +def main(): + # Initialize the process group + dist.init_process_group(backend="nccl") + + # Get local rank and world size + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + num_layers = 10 + hidden_size = 4096 + intermediate_size = 8192 + vocab_size = 126464 + num_head = 64 + num_kv_head = 8 + batch_size = 2 + seq_length = 4096 + kv_head_ratio = num_head // num_kv_head + torch.cuda.set_device(local_rank) + + + config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_head, + num_key_value_heads=num_kv_head, + intermediate_size=intermediate_size, + max_position_embeddings=seq_length, + initializer_range=0.02, + layer_norm_eps=1e-5, + # attn_implementation="flash_attention_2", + use_cache=False, + use_bfloat16=True + ) + + #init_device = "cpu" if local_rank == 0 else "meta" + init_device = "meta" + + # from liger_kernel.transformers import apply_liger_kernel_to_llama + # apply_liger_kernel_to_llama( + # rope=True, + # swiglu=True, + # cross_entropy=True, + # fused_linear_cross_entropy=False, + # rms_norm=True + # ) + + with torch.device(init_device): + model = LlamaForCausalLM(config) + + + flop, _ = compute_training_flops( + batch_size, + seq_length, + hidden_size, + vocab_size, + intermediate_size, + num_layers, + use_gradient_checkpointing=True, + use_gqa=True, + kv_head_ratio=kv_head_ratio, + ) + + + dataset = DummyDataset(vocab_size=vocab_size, max_length=seq_length) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler) + + #param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) if local_rank != 0 else None + param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) + wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer,},) + #model = model.to(dtype=torch.bfloat16) + + model = FSDP(model, device_id=local_rank, auto_wrap_policy=wrap_policy, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + sync_module_states=False, param_init_fn=param_init_fn, + forward_prefetch=True, limit_all_gathers=True, use_orig_params=True) + + apply_fsdp_checkpointing(model, LlamaDecoderLayer) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + # Training Loop + def save_profile(prof): + prof.export_chrome_trace(f"fsdp_trace_{rank}.json") + + epoch = 0 + iters = 0 + prof = torch.profiler.profile( + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=1000000, + repeat=1), + on_trace_ready=save_profile, + record_shapes=False, + with_stack=False) + + #prof = nullcontext() + # prof.start() + + dur = [] + model.train() + for input_ids, labels in dataloader: + input_ids, labels = input_ids.to(local_rank), labels.to(local_rank) + start = time.time() + optimizer.zero_grad() + loss = model(input_ids=input_ids, labels=labels).loss + loss.backward() + optimizer.step() + torch.cuda.synchronize() + if rank == 0: + dur = time.time() - start + tflops = flop / dur / 1e12 + print(f"Epoch {epoch}, Loss: {loss.item()} time {dur} tflops {tflops}") + iters += 1 + # if iters > 10: + # break + # prof.step() + # prof.stop() + + print("Training Complete") + dist.destroy_process_group() + +def main_ds(): + import deepspeed + dist.init_process_group(backend="nccl") + + # Get local rank and world size + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + num_layers = 20 + hidden_size = 8192//4 + intermediate_size = 28672 + vocab_size = 126464 + num_head = 64 + num_kv_head = 8 + batch_size = 2 + seq_length = 4096 + kv_head_ratio = num_head // num_kv_head + torch.cuda.set_device(local_rank) + + flop, _ = compute_training_flops( + batch_size, + seq_length, + hidden_size, + vocab_size, + intermediate_size, + num_layers, + use_gradient_checkpointing=True, + use_gqa=True, + kv_head_ratio=kv_head_ratio, + ) + + + config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_head, + num_key_value_heads=num_kv_head, + intermediate_size=intermediate_size, + max_position_embeddings=seq_length, + initializer_range=0.02, + layer_norm_eps=1e-5, + attn_implementation="flash_attention_2", + use_cache=False, + use_bfloat16=True + ) + + #init_device = "cpu" if local_rank == 0 else "meta" + init_device = "meta" + + from liger_kernel.transformers import apply_liger_kernel_to_llama + apply_liger_kernel_to_llama( + rope=True, + swiglu=True, + cross_entropy=True, + fused_linear_cross_entropy=False, + rms_norm=True + ) + + ds_config = { + "train_batch_size": batch_size * world_size, + "train_micro_batch_size_per_gpu": batch_size, + #"steps_per_print": 10, + "zero_optimization": { + "stage": 3, + "overlap_comm": True, + }, + "bf16": { + "enabled": True, + }, + "activation_checkpointing": { + "partition_activations": True, # Partition activations across GPUs + #"contiguous_memory_optimization": True, # Optimize contiguous memory usage + }, + } + + kwargs = {} + kwargs["config"] = ds_config + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = LlamaForCausalLM(config) + kwargs["model"] = model + + from deepspeed.ops.adam import FusedAdam + #optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + optimizer = FusedAdam(model.parameters(), lr=1e-4) + kwargs["optimizer"] = optimizer + model_engine, optimizer, _, _ = deepspeed.initialize(**kwargs) + #from remote_pdb import RemotePdb + #RemotePdb("127.0.0.1", 16666+rank).set_trace() + + dataset = DummyDataset(vocab_size=vocab_size, max_length=seq_length) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler) + model_engine.train() + start = end = 0 + dur = [] + epoch = 0 + # Training Loop + def save_profile(prof): + prof.export_chrome_trace(f"ds_trace_{rank}.json") + + epoch = 0 + iters = 0 + # prof = torch.profiler.profile( + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=3, + # repeat=1), + # on_trace_ready=save_profile, + # record_shapes=True, + # with_stack=True) + + # #prof = nullcontext() + # prof.start() + + for step, (input_ids, labels) in enumerate(dataloader): + start = time.time() + input_ids, labels = input_ids.to(local_rank), labels.to(local_rank) + optimizer.zero_grad() + loss = model_engine(input_ids=input_ids, labels=labels).loss + model_engine.backward(loss) + model_engine.step() + torch.cuda.synchronize() + dur = time.time() - start + tflops = flop / dur / 1e12 + if rank == 0: + print(f"Epoch {epoch}, Step {step}, Loss {loss.item()}, time {dur}, {tflops}") + # if step > 10: + # break + # prof.step() + # prof.stop() + +def main_qwen_vl(): + #dist.init_process_group(backend="nccl") + + #local_rank = int(os.environ["LOCAL_RANK"]) + #rank = int(os.environ["RANK"]) + #world_size = int(os.environ["WORLD_SIZE"]) + world_size = 1 + local_rank = rank = 0 + + torch.cuda.set_device(local_rank) + config ={ + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "vision_start_token_id": 151652, + "vision_end_token_id": 151653, + "vision_token_id": 151654, + "image_token_id": 151655, + "video_token_id": 151656, + "hidden_act": "silu", + "hidden_size": 8192 // 4, + "initializer_range": 0.02, + "intermediate_size": 29568 // 4, + "max_position_embeddings": 32768, + "max_window_layers": 80, + "model_type": "qwen2_vl", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.41.2", + "use_cache": False, + "use_sliding_window": False, + "vision_config": { + "depth": 32, + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 8192, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2 + }, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16 // 4, + 24 // 4, + 24 // 4 + ] + }, + "vocab_size": 152064 + } + + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel, Qwen2VisionTransformerPretrainedModel, Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + + qwen_config = Qwen2VLConfig(**config) + preprocess_config = { + "min_pixels": 3136, + "max_pixels": 12845056, + "patch_size": 14, + "temporal_patch_size": 2, + "merge_size": 2, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "image_processor_type": "Qwen2VLImageProcessor", + "processor_class": "Qwen2VLProcessor" + } + preprocess = Qwen2VLImageProcessor(**preprocess_config) + with torch.device('cuda'): + #model = Qwen2VLForConditionalGeneration(qwen_config) + text_model = Qwen2VLModel(qwen_config) + vision_model = Qwen2VisionTransformerPretrainedModel(qwen_config.vision_config) + image = [torch.ones(1280, 1280, 3, dtype=torch.uint8) for _ in range(10)] + if rank == 0: + text = torch.randint(low=0, high=qwen_config.vocab_size, size=(1, 4096,)).cuda() + #text[-32:] = 1516545 + #print((text == 151655).sum().item()) + data = preprocess(image, return_tensors="pt") + image_grid_thw = data['image_grid_thw'].cuda() + image_hidden = data['pixel_values'].cuda() + # + t = text_model(text) + v = vision_model(image_hidden, image_grid_thw) + breakpoint() + #m = model(input_ids=text,pixel_values=image_hidden,image_grid_thw=image_grid_thw) + print(1) + + #num_layers = 10 + #hidden_size = 8192 + #intermediate_size = 32768 + #vocab_size = 126464 + #num_head = 128 + #num_kv_head = 16 + #batch_size = 2 + #seq_length = 4096 + #kv_head_ratio = num_head // num_kv_head + #torch.cuda.set_device(local_rank) + + + #config = LlamaConfig( + # vocab_size=vocab_size, + # hidden_size=hidden_size, + # num_hidden_layers=num_layers, + # num_attention_heads=num_head, + # num_key_value_heads=num_kv_head, + # intermediate_size=intermediate_size, + # max_position_embeddings=seq_length, + # initializer_range=0.02, + # layer_norm_eps=1e-5, + # attn_implementation="flash_attention_2", + # use_cache=False, + # use_bfloat16=True + #) + + ##init_device = "cpu" if local_rank == 0 else "meta" + #init_device = "meta" + + #from liger_kernel.transformers import apply_liger_kernel_to_llama + #apply_liger_kernel_to_llama( + # rope=True, + # swiglu=True, + # cross_entropy=True, + # fused_linear_cross_entropy=False, + # rms_norm=True + #) + + #with torch.device(init_device): + # model = LlamaForCausalLM(config) + + # + #flop, _ = compute_training_flops( + # batch_size, + # seq_length, + # hidden_size, + # vocab_size, + # intermediate_size, + # num_layers, + # use_gradient_checkpointing=True, + # use_gqa=True, + # kv_head_ratio=kv_head_ratio, + #) + + + #dataset = DummyDataset(vocab_size=vocab_size, max_length=seq_length) + #sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + #dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler) + + ##param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) if local_rank != 0 else None + #param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) + #wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer,},) + ##model = model.to(dtype=torch.bfloat16) + + #model = FSDP(model, device_id=local_rank, auto_wrap_policy=wrap_policy, + # mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + # sync_module_states=False, param_init_fn=param_init_fn, + # forward_prefetch=True, limit_all_gathers=True, use_orig_params=True) + + #apply_fsdp_checkpointing(model, LlamaDecoderLayer) + #optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + ## Training Loop + #def save_profile(prof): + # prof.export_chrome_trace(f"fsdp_trace_{rank}.json") + + #epoch = 0 + #iters = 0 + #prof = torch.profiler.profile( + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=3, + # repeat=1), + # on_trace_ready=save_profile, + # record_shapes=True, + # with_stack=True) + + ##prof = nullcontext() + ##prof.start() + + #dur = [] + #with prof: + # model.train() + # for input_ids, labels in dataloader: + # start = time.time() + # input_ids, labels = input_ids.to(local_rank), labels.to(local_rank) + # optimizer.zero_grad() + # loss = model(input_ids=input_ids, labels=labels).loss + # loss.backward() + # optimizer.step() + # torch.cuda.synchronize() + # if rank == 0: + # dur = time.time() - start + # tflops = flop / dur / 1e12 + # print(f"Epoch {epoch}, Loss: {loss.item()} time {dur} tflops {tflops}") + # iters += 1 + # if iters > 10: + # break + # prof.step() + # epoch += 1 + + #print("Training Complete") + #dist.destroy_process_group() +def mllama(): + from transformers import MllamaForConditionalGeneration, AutoProcessor + from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaVisionConfig, MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaVisionEncoderLayer, MllamaSelfAttentionDecoderLayer + from liger_kernel.transformers import apply_liger_kernel_to_mllama + + dist.init_process_group(backend="nccl") + + # Get local rank and world size + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + + apply_liger_kernel_to_mllama( + rope=True, + swiglu=True, + cross_entropy=False, + fused_linear_cross_entropy=True, + rms_norm=True + ) + config = { + "architectures": [ + "MllamaForConditionalGeneration" + ], + "image_token_index": 128256, + "model_type": "mllama", + "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": 128000, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "cross_attention_layers": [ + 3, + 8, + 13, + 18, + 23, + 28, + 33, + 38, + 43, + 48, + 53, + 58, + 63, + 68, + 73, + 78, + 83, + 88, + 93, + 98 + ], + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "silu", + "hidden_size": 4096, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "initializer_range": 0.02, + "intermediate_size": 28672, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 131072, + "min_length": 0, + "model_type": "mllama_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 64, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 128004, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "sep_token_id": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": False, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "bfloat16", + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": False, + "vocab_size": 128256 + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.45.0.dev0", + "vision_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_heads": 16, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "gelu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "image_size": 560, + "intermediate_layers_indices": [ + 3, + 7, + 15, + 23, + 30 + ], + "intermediate_size": 5120, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_num_tiles": 4, + "min_length": 0, + "model_type": "mllama_vision_model", + "no_repeat_ngram_size": 0, + "norm_eps": 1e-05, + "num_beam_groups": 1, + "num_beams": 1, + "num_channels": 3, + "num_global_layers": 8, + "num_hidden_layers": 32, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "patch_size": 14, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "supported_aspect_ratios": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ], + [ + 1, + 3 + ], + [ + 1, + 4 + ], + [ + 2, + 1 + ], + [ + 2, + 2 + ], + [ + 3, + 1 + ], + [ + 4, + 1 + ] + ], + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "bfloat16", + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "vision_output_dim": 7680 + } + } + vision_config = MllamaVisionConfig(**config['vision_config']) + text_config = MllamaTextConfig(**config['text_config']) + model_config = MllamaConfig(vision_config, text_config, torch_dtype="bfloat16") + data = torch.load('dummy.pth', map_location='cuda') + label = torch.randint(low=0, high=config['text_config']['vocab_size'], size=data['input_ids'].shape) + with torch.device('meta'): + model = MllamaForConditionalGeneration(model_config) + param_init_fn = lambda m: m.to_empty(device=torch.device("cuda")) + wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={ MllamaVisionEncoderLayer, MllamaSelfAttentionDecoderLayer},) + #model = model.to(dtype=torch.bfloat16) + + model = FSDP(model, device_id=local_rank, auto_wrap_policy=wrap_policy, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + sync_module_states=False, param_init_fn=param_init_fn, + forward_prefetch=True, limit_all_gathers=True, use_orig_params=True) + + apply_fsdp_checkpointing(model, (MllamaVisionEncoderLayer, MllamaSelfAttentionDecoderLayer)) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + def save_profile(prof): + prof.export_chrome_trace(f"mllama_trace_{rank}.json") + + epoch = 0 + iters = 0 + # prof = torch.profiler.profile( + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=2, + # repeat=1), + # on_trace_ready=save_profile, + # record_shapes=True, + # with_stack=True) + + # prof.start() + model.train() + for i in range(100000000): + start = time.time() + optimizer.zero_grad() + loss= model(**data, labels=label).loss + loss.backward() + optimizer.step() + torch.cuda.synchronize() + dur = time.time() - start + if rank == 0: + print(f"Step {i}, Loss {loss.item()}, time {dur}") + # prof.step() + # prof.stop() + + +if __name__ == "__main__": + # main_ds() + main() + # mllama() +