diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index ffe71131..362a2eab 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -12,6 +12,7 @@ import torch import torch._dynamo.config import torch._inductor.config +torch._inductor.config.cpp.enable_kernel_profile = True def device_sync(device): if "cuda" in device: @@ -132,7 +133,7 @@ def encode_tokens(tokenizer, string, bos=True, device='cuda'): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) + return torch.tensor(tokens, dtype=torch.int, device=args.device) def _load_model(checkpoint_path, device, precision, use_tp): with torch.device('meta'): @@ -248,8 +249,13 @@ def callback(x): if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() + if device == 'cuda': + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], use_cuda=True) + profile_sort = 'self_cuda_time_total' + elif device == 'cpu': + prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) + profile_sort = 'self_cpu_time_total' with prof: y = generate( model, @@ -263,6 +269,8 @@ def callback(x): if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue + if hasattr(prof, "key_averages"): + print(prof.key_averages().table(sort_by=profile_sort, row_limit=-1)) if hasattr(prof, "export_chrome_trace"): if use_tp: prof.export_chrome_trace(f"{profile}_rank_{rank}.json") diff --git a/mixtral-moe/quantize.py b/mixtral-moe/quantize.py index 6312863c..f4857907 100644 --- a/mixtral-moe/quantize.py +++ b/mixtral-moe/quantize.py @@ -98,6 +98,20 @@ def convert_for_runtime(self): return self.mod +# TODO: This is a workaround to speedup int8 woq performance. Will remove this when +# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. +def linear_forward_int8(x, weight_int8pack, scales, out_features): + if x.is_cuda: + return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales + + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + class WeightOnlyBit8Linear(torch.nn.Module): __constants__ = ['in_features', 'out_features'] in_features: int @@ -115,7 +129,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # TODO: This is a workaround to speedup int8 woq performance. Will remove this when + # https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. + return linear_forward_int8( + input, + self.weight, self.scales, self.out_features) class ConditionalFeedForwardBit8(nn.Module): diff --git a/quantize.py b/quantize.py index db477754..e99fadcb 100644 --- a/quantize.py +++ b/quantize.py @@ -337,6 +337,18 @@ def convert_for_runtime(self): replace_linear_weight_only_int8_per_channel(self.mod) return self.mod +# TODO: This is a workaround to speedup int8 woq performance. Will remove this when +# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. +def linear_forward_int8(x, weight_int8pack, scales, out_features): + if x.is_cuda: + return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales + + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c class WeightOnlyInt8Linear(torch.nn.Module): __constants__ = ['in_features', 'out_features'] @@ -354,7 +366,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # TODO: This is a workaround to speedup int8 woq performance. Will remove this when + # https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. + return linear_forward_int8( + input, + self.weight, self.scales, self.out_features) ##### weight only int4 per channel groupwise quantized code ###### @@ -502,16 +519,10 @@ def __init__( assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" - if use_cuda: - self.register_buffer( - "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) - ) - else: - self.register_buffer( - "weight", - torch.empty((out_features, in_features // 2), dtype=torch.uint8) - ) + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) self.register_buffer( "scales_and_zeros", torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)