diff --git a/phi-3/README.md b/phi-3/README.md new file mode 100644 index 00000000..7654b904 --- /dev/null +++ b/phi-3/README.md @@ -0,0 +1,41 @@ +# Phi-3-mini-4k-instruct +[Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/) is a high-quality Language model. This repro is a simple and efficient PyTorch native implementation of Phi-3-mini-4k-instruct. + +## Downloading Weights + +```bash +export MODEL_REPO=microsoft/Phi-3-mini-4k-instruct +python scripts/download.py --repo_id $MODEL_REPO +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO +``` + +## Benchmarks +Benchmarks run on a single 3090. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). + +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | +| -------- | ------- | ------ | ------ | +| Phi-3-mini-4k-instruct | Base | 106.3 | 791 | +| | 8-bit | 160.5 | 598 | + + +## Generate Text + +Model definition in `model.py`, generation code in `generate.py`. + +```bash +python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" +``` + +To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though. + +## Quantization +### Int8 Weight-Only Quantization +To generate this version of the model +```bash +# Spits out model at checkpoints/$MODEL_REPO/model_int8.pth +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 +``` +To run with int8, just pass the int8 checkpoint to generate.py. +```bash +python generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth +``` diff --git a/phi-3/generate.py b/phi-3/generate.py new file mode 100644 index 00000000..0b7b44b0 --- /dev/null +++ b/phi-3/generate.py @@ -0,0 +1,436 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import Transformer +from sentencepiece import SentencePieceProcessor + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +def speculative_decode( + model: Transformer, + draft_model: Transformer, + cur_token: torch.Tensor, + input_pos: int, + speculate_k: int, + **sampling_kwargs +) -> torch.Tensor: + # draft model inference sequentially + device = cur_token.device + orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) + draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) + + draft_tokens = torch.cat(draft_tokens) + # parallel inference on target model using draft tokens + target_logits = model_forward( + model, + torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), + torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) + ) + target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) + draft_probs = torch.stack(draft_probs) + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) + rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() + + if rejected_locations.shape[0] == 0: # All draft tokens have been accepted + accept_length = speculate_k + 1 + last_token = multinomial_sample_one_no_sync(target_probs[-1]) + # fill last token into draft model + model_forward( + draft_model, + draft_tokens[-1].view(1, -1), + orig_input_pos + speculate_k, + ) + return torch.cat([draft_tokens, last_token]) + else: + accept_length = rejected_locations[0].item() + p = draft_probs[accept_length] + q = target_probs[accept_length] + new = q - p + new = torch.where(new > 0, new, 0.0) + new = new / new.sum() + next_token = multinomial_sample_one_no_sync(new) + return torch.cat([draft_tokens[:accept_length], next_token]) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + draft_model: Transformer, + speculate_k: Optional[int] = 8, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + is_speculative = draft_model is not None + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + if interactive: + max_seq_length = 350 + else: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone() + if is_speculative: + prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + accept_counts = [0] * (speculate_k + 1) + + if is_speculative: + input_pos = input_pos.item() # for speculative decoding easier to keep on host + while input_pos < T_new - 1: + cur_token = next_token.view(()) + + next_tokens = speculative_decode( + model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs + ) + + accept_counts[len(next_tokens) - 1] += 1 + num_added = min(T_new - input_pos - 1, len(next_tokens)) + seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] + for i in next_tokens[: num_added,]: + callback(i) + input_pos = input_pos + num_added + next_token = next_tokens[-1] + else: + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + + generate_stats = { + 'accept_counts': accept_counts + } + return seq, generate_stats + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision, use_tp): + use_cuda = 'cuda' in device + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from quantize import WeightOnlyInt8QuantHandler + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 weight-only quantization!") + path_comps = checkpoint_path.name.split(".") + groupsize = int(path_comps[-2][1:]) + from quantize import WeightOnlyInt4QuantHandler + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + if use_tp: + from tp import apply_tp + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + +def _get_model_size(model): + model_size = 0 + for name, child in model.named_children(): + if not isinstance(child, torch.nn.Embedding): + model_size += sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(child.parameters(), child.buffers()) + ] + ) + return model_size + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + draft_checkpoint_path: Optional[Path] = None, + speculate_k: int = 5, + device=default_device, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + global print + from tp import maybe_init_dist + rank = maybe_init_dist() + use_tp = rank is not None + if use_tp: + if rank != 0: + # only print on rank 0 + print = lambda *args, **kwargs: None + + print(f"Using device={device}") + precision = torch.bfloat16 + is_speculative = draft_checkpoint_path is not None + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision, use_tp) + + if is_speculative: + draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) + else: + draft_model = None + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = _get_model_size(model) + if compile: + if is_speculative and use_tp: # and ("cuda" in device): + torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case + + if is_speculative: + global model_forward, logits_to_prob + model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + + global decode_one_token, prefill + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + + # Uncomment to squeeze more perf out of prefill + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + 'accept_counts': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y, metrics = generate( + model, + encoded, + max_new_tokens, + draft_model=draft_model, + speculate_k=speculate_k, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + aggregate_metrics['accept_counts'].append(metrics['accept_counts']) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y.tolist())) + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + print("==========") + if is_speculative: + counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] + acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] + print(f"Acceptance probs: {acceptance_probs}") + print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") + + print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') + parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') + parser.add_argument('--device', type=str, default=default_device, help='Device to use') + + args = parser.parse_args() + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, + args.speculate_k, args.device + ) diff --git a/phi-3/model.py b/phi-3/model.py new file mode 100644 index 00000000..9c28c846 --- /dev/null +++ b/phi-3/model.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Phi-3-mini-4k-instruct": dict(block_size=4096, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, rope_base=10000, vocab_size=32064), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype=self.output.weight.dtype + dtype = self.output.weight.dtype + # For quantized layers, dtype is encoded in scales + if hasattr(self.output, "scales"): + dtype = self.output.scales.dtype + elif hasattr(self.output, "scales_and_zeros"): + dtype = self.output.scales_and_zeros.dtype + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def _init_rope(self): + self.rotary_emb = PhiRotaryEmbedding(self.head_dim) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, 2*config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.activation_fn = F.silu + + def forward(self, x: Tensor) -> Tensor: + up_states = self.w1(x) + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + return self.w2(up_states) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/phi-3/quantize.py b/phi-3/quantize.py new file mode 100644 index 00000000..5f200b66 --- /dev/null +++ b/phi-3/quantize.py @@ -0,0 +1,622 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sentencepiece import SentencePieceProcessor + +try: + from GPTQ import GenericGPTQRunner, InputRecorder + from eval import get_task_dict, evaluate, lm_eval +except: + pass + +from model import Transformer + +##### Quantization Primitives ###### + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + +class GPTQQuantHandler(QuantHandler): + """ + This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + def __init__(self): + assert self.mod is not None + assert self.get_qparams_func is not None + assert self.quantize_func is not None + assert self.dequantize_func is not None + assert self.combine_qparams_list_func is not None + assert self.make_names_and_values_dict_func is not None + + @staticmethod + def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + input_recorder = InputRecorder( + model, + tokenizer, + calibration_seq_length, + pad_calibration_inputs, + ) + + try: + lm_eval.tasks.initialize_tasks() + except: + pass + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + input_recorder, + task_dict, + limit=calibration_limit, + ) + inputs = input_recorder.get_recorded_inputs() + assert inputs is not None, ( + f"No inputs were collected, use a task other than {calibration_tasks}, "+ + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ + f"{calibration_seq_length})" + ) + print(f"Obtained {len(inputs[0].values)} calibration samples") + return inputs + + @torch.no_grad() + def create_quantized_state_dict( + self, + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "StateDict": + inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + self.mod, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, + self.quantize_func, + self.dequantize_func, + self.combine_qparams_list_func, + self.make_names_and_values_dict_func, + self.skip_layer_func + ) + + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def convert_for_runtime(self) -> "nn.Module": + pass + +##### Weight-only int8 per-channel quantized code ###### + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) + else: + replace_linear_weight_only_int8_per_channel(child) + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) + cur_state_dict[f"{fqn}.weight"] = int8_weight + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + +##### weight only int4 per channel groupwise quantized code ###### + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, + )) + elif padding: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self, use_cuda = True): + if use_cuda: + device="cuda" + else: + device="cpu" + + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if self.padding: + from model import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + +class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + from model import find_multiple + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) + self.quantize_func = lambda w, qparams: \ + group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) + self.dequantize_func = lambda q, qparams: \ + group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() + self.combine_qparams_list_func = lambda qparams_list: \ + [torch.cat(x, dim=1) for x in zip(*qparams_list)] + # skip unless padding=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + ) + # we need to do the padding here, both for q and the qparams if necessary + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales_and_zeros = pack_scales_and_zeros(*qparams) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, in_features: int, out_features: int, + bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + +def quantize( + checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + mode: str = 'int8', + # following arguments only available when setting int4 quantization. + groupsize: int = 128, + # following arguments only used for GPTQ + calibration_tasks: list = ["hellaswag"], + calibration_limit: int = 1000, + calibration_seq_length: int = 100, + pad_calibration_inputs: bool = False, + percdamp: float = .01, + blocksize: int = 128, + label: str = '', +) -> None: + assert checkpoint_path.is_file(), checkpoint_path + + device = 'cpu' + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + + if mode == 'int8': + print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f'{label}int8.pth') + + elif mode == 'int4': + print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") + + elif mode == 'int4-gptq': + print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") + quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + + quantized_state_dict = quant_handler.create_quantized_state_dict( + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs + ) + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") + else: + raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") + + quantize_path = dir_name / new_base_name + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink(missing_ok=True) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + return + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Quantize a model.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') + parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') + parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') + parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') + parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') + parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') + parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') + parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') + parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') + parser.add_argument('--label', type=str, default='_', help='label to add to output filename') + + args = parser.parse_args() + quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) diff --git a/phi-3/scripts/convert_hf_checkpoint.py b/phi-3/scripts/convert_hf_checkpoint.py new file mode 100644 index 00000000..051437a2 --- /dev/null +++ b/phi-3/scripts/convert_hf_checkpoint.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import json +import re +import sys +from pathlib import Path +from typing import Optional + +import torch +from safetensors.torch import load_file + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import ModelArgs + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path("checkpoints/micorosoft/Phi-3-mini-4k-instruct"), + model_name: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + + config = ModelArgs.from_name(model_name) + print(f"Model config {config.__dict__}") + + # Load the json file containing weight mapping + model_map_json = checkpoint_dir / "model.safetensors.index.json" + + assert model_map_json.is_file() + + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + + weight_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attention.wqkv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": "", + "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight" + } + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + + def permute(w, n_head): + dim = config.dim + return ( + w.view(n_head, 2, config.head_dim // 2, dim) + .transpose(1, 2) + .reshape(config.head_dim * n_head, dim) + ) + + merged_result = {} + for file in sorted(bin_files): + state_dict = load_file(str(file)) + merged_result.update(state_dict) + final_result = {} + for key, value in merged_result.items(): + print(key) + if "layers" in key: + abstract_key = re.sub(r'(\d+)', '{}', key, count=1) + layer_num = re.search(r'\d+', key).group(0) + + new_key = weight_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wqkv" in key: + qkv = final_result[key] + q, k, v = qkv.chunk(3, dim=0) + q = permute(q, config.n_head) + k = permute(k, config.n_head) + final_result[key] = torch.cat([q, k, v]) + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') + parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/microsoft/Phi-3-mini-4k-instruct")) + parser.add_argument('--model_name', type=str, default=None) + + args = parser.parse_args() + convert_hf_checkpoint( + checkpoint_dir=args.checkpoint_dir, + model_name=args.model_name, + ) diff --git a/phi-3/scripts/download.py b/phi-3/scripts/download.py new file mode 100644 index 00000000..0f516996 --- /dev/null +++ b/phi-3/scripts/download.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: + from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + try: + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) + except HTTPError as e: + if e.response.status_code == 401: + print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + else: + raise e + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') + parser.add_argument('--repo_id', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Repository ID to download from.') + parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + + args = parser.parse_args() + hf_download(args.repo_id, args.hf_token)