-
Notifications
You must be signed in to change notification settings - Fork 19
[Transform] Serialize with tied weights #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
…s/transform_factory
…s/transform_permutations
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
49e04b9
to
2e362d2
Compare
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like your style checks have different behavior than what's on main. I've seen this elsewhere, not sure what's causing it. Maybe our version pin on flake8>=3.8.3
is too loose and later versions have different behavior?
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and #391 are both ready to merge right? We can send over to team for review if so
@brian-dellabetta This is ready to merge. Afaict the only thing left from the head branch is applying transformers in higher granularity, which is still unknown whether this actually has an effect on accuracy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
## Purpose ## * Enable offline spinquant-style transforms ## Prerequisites ## * neuralmagic/compressed-tensors#370 * neuralmagic/compressed-tensors#412 * neuralmagic/compressed-tensors#414 ## Changes ## * Added `spinquant_example.py` to examples folder * Added `SpinQuantModifier` which handles the construction of a spinquant-style transform config ## Testing ## * Added modifier serialization and correctness tests ## Evaluation ## Using this branch, and [the original SpinQuant code](https://github.com/facebookresearch/SpinQuant), we see very similar results for `meta-llama/Llama-3.2-1B-Instruct` with W4A16 quantization. Results are equivalent in hf (in-memory vs serialized and re-loaded), and very similar in vllm. The symmetric scales calculation in `llm-compressor` is slightly different than original SpinQuant paper, which uses the original GPTQ implementation. When this is swapped in, results are consistent, with hadamard improving results on `gsm8k_llama` and `arc_challenge_llama`: Scheme | Impl | gsm8k | gsm8k_llama | arc_challenge_llama -- | -- | -- | -- | -- Hadamard+W4A16 | LC | 0.2403 | 0.2835 | 0.5262 W4A16 | LC | 0.1964 | 0.1933 | 0.4781 Hadamard+W4A16 | LC+SQscales | 0.1721 | 0.2183 | 0.485 W4A16 | LC+SQscales | 0.207 | 0.1706 | 0.4498 Hadamard+W4A16 | SQ | 0.1736 | 0.2282 | 0.4807 W4A16 | SQ | 0.1986 | 0.1774 | 0.4489 To run LC+SQScales, change [this line in CT](https://github.com/neuralmagic/compressed-tensors/blob/b2df366797b00330ec765f5891dde14e4cc74c9d/src/compressed_tensors/quantization/utils/helpers.py#L111) from ```python scales = max_val_pos / (float(bit_range) / 2) ``` to ```python scales = max_val_pos / (float(bit_max)) ``` <details> <summary>The following python script was used to generate these results</summary> Clone SpinQuant repo and paste this in the top-level directory: ```python # coding=utf-8 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from typing import Literal import os os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" from torch import nn import lm_eval from transformers import LlamaForCausalLM, AutoTokenizer import transformers from train_utils.main import prepare_model from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant from utils.hadamard_utils import random_hadamard_matrix, hadamard_matrix from utils.process_args import process_args_ptq # model_id = "meta-llama/Llama-3.1-8B-Instruct" # model_id = "meta-llama/Llama-3.2-3B-Instruct" model_id = "meta-llama/Llama-3.2-1B-Instruct" dtype = torch.bfloat16 class RotateModule(nn.Module): def __init__(self, R_init): super(RotateModule, self).__init__() self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda"))) def forward(self, x, transpose=False): if transpose: return x @ self.weight else: return self.weight @ x def get_sq_model( r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16], w_clip: bool = False, ) -> LlamaForCausalLMQuant: model_args, training_args, ptq_args = process_args_ptq() model_args.input_model = model_id if w_bits == 4: ptq_args.w_bits = 4 ptq_args.w_groupsize = 128 ptq_args.w_rtn = True # if False, GPTQ is used ptq_args.w_clip = w_clip ptq_args.a_bits = 16 ptq_args.k_bits = 16 ptq_args.v_bits = 16 print("=======ARGS=======", ptq_args) config = transformers.AutoConfig.from_pretrained(model_args.input_model) # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens process_word_embeddings = False if config.tie_word_embeddings: config.tie_word_embeddings = False process_word_embeddings = True model = LlamaForCausalLMQuant.from_pretrained( pretrained_model_name_or_path=model_args.input_model, config=config, torch_dtype=dtype, device_map="cuda", ) if process_word_embeddings: model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone() model = prepare_model(ptq_args, model) for param in model.parameters(): param.requires_grad = False match r1r2: case "eye": R1 = torch.eye(model.config.hidden_size, device="cuda") case "random-hadamard": R1 = random_hadamard_matrix(model.config.hidden_size, "cuda") case _: R1 = hadamard_matrix(model.config.hidden_size, "cuda") model.R1 = RotateModule(R1) for i in range(model.config.num_hidden_layers): # Each head dim = 128 for Llama model match r1r2: case "eye": R2 = torch.eye( model.config.hidden_size // model.config.num_attention_heads, device="cuda", ) case "random-hadamard": R2 = random_hadamard_matrix( model.config.hidden_size // model.config.num_attention_heads, "cuda" ) case _: R2 = hadamard_matrix( model.config.hidden_size // model.config.num_attention_heads, "cuda" ) model.model.layers[i].self_attn.R2 = RotateModule(R2) model.config.use_cache = False return model def get_lc_model( r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16] ) -> LlamaForCausalLM: from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.transform import SpinQuantModifier model = LlamaForCausalLM.from_pretrained( pretrained_model_name_or_path=model_id, torch_dtype=dtype, device_map="cuda", ) recipe = [ SpinQuantModifier( rotations=[] if r1r2 == "eye" else ["R1", "R2"], transform_type="hadamard", ) ] if w_bits == 4: recipe.append( QuantizationModifier( targets="Linear", scheme="W4A16", ignore=["lm_head"], ) ) oneshot( model=model, recipe=recipe, pipeline="datafree", log_dir=None, ) return model if __name__ == "__main__": for scales_impl in ["sq_min_hack", "lc_min_hack"]: for r1r2 in ["eye", "hadamard"]: for sq_lc in ["sq", "lc"]: w_bits = 4 os.environ["SCALES_IMPL"] = scales_impl model = ( get_sq_model(r1r2=r1r2, w_bits=w_bits) if sq_lc == "sq" else get_lc_model(r1r2=r1r2, w_bits=w_bits) ).to("cuda") SAVE_DIR = model_id.split("/")[1] + f"-{scales_impl}-{r1r2}-w4a16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) tokenizer.save_pretrained(SAVE_DIR) del model del tokenizer torch.cuda.empty_cache() results = lm_eval.simple_evaluate( # 1) hf in-memory # model=lm_eval.models.huggingface.HFLM( # pretrained=model, # batch_size=32, # add_bos_token=False, # ), # 1/) # 2) vllm serialized model="vllm", model_args={ "pretrained": SAVE_DIR, "add_bos_token": False, "dtype": "auto", "max_model_len": 4096, "gpu_memory_utilization": 0.5, "enable_chunked_prefill": True, }, # 2/) # 3) hf serialized # model="hf", # model_args={ # "pretrained": SAVE_DIR, # "add_bos_token": False, # "dtype": "auto", # }, # device="cuda", # 3/) tasks=["gsm8k_llama", "gsm8k", "arc_challenge_llama"], num_fewshot=8, batch_size=32, apply_chat_template=True, fewshot_as_multiturn=True, ) print( f"RESULTS, {model_id} {sq_lc} R1R2 {r1r2} W_BITS {w_bits} SCALEIMPL {scales_impl}" ) print(lm_eval.utils.make_table(results)) ``` </details> ## Follow Ups ## * Infer data free pipeline, even if a transform modifier is included * Rotations R3 and R4 * Modify example to use GPTQ once basic evaluation has been performed --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
Purpose
Prerequisites
apply_transform_config
#348Semi-Prerequisites
These changes are required in order to support saving models with offloaded Transforms. Models without offloading do not require these changes.
Changes
_update_tied_weights
which_dynamic_tied_weights_keys
attribute of the transform modules. This property is read bytransformers
during saving and tells transformers to delete duplicates of the tied weights before saving.transformers
Testing