-
Notifications
You must be signed in to change notification settings - Fork 179
[WIP] TransformModifier #1518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kylesayrs
wants to merge
5
commits into
main
Choose a base branch
from
kylesayrs/transform-modifier
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+271
−0
Draft
[WIP] TransformModifier #1518
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
ba617db
wip
kylesayrs 2f5b1c8
use random-hadamard, add correctness tests
kylesayrs 3aa35e7
add correctness test, note that precision makes a large difference
kylesayrs b6c088e
add on lifecycle methods
brian-dellabetta d1eb2a1
Merge branch 'main' into kylesayrs/transform-modifier
brian-dellabetta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.modifiers.transform import TransformModifier | ||
from llmcompressor import oneshot | ||
|
||
# Select model and load it. | ||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, | ||
torch_dtype="auto", | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# * quantize the weights to 4 bit with GPTQ with a group size 128 | ||
recipe = [ | ||
TransformModifier(), | ||
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), | ||
] | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
pipeline="sequential", | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .transform import TransformModifier |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme | ||
|
||
QUIP = TransformConfig( | ||
config_groups={ | ||
"v": TransformScheme( | ||
type="random-hadamard", | ||
apply=[ | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="input", # non-mergable | ||
ignore="lm_head", | ||
), | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="weight_input", | ||
inverse=True, | ||
ignore="lm_head", | ||
), | ||
], | ||
randomize=True, | ||
), | ||
"u": TransformScheme( | ||
type="random-hadamard", | ||
apply=[ | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="weight_output", | ||
ignore="lm_head", | ||
), | ||
TransformArgs( | ||
targets=["Linear"], | ||
location="output", # non-mergable | ||
inverse=True, | ||
ignore="lm_head", | ||
), | ||
], | ||
randomize=True, | ||
), | ||
} | ||
) |
64 changes: 64 additions & 0 deletions
64
src/llmcompressor/modifiers/transform/template/spinquant.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme | ||
|
||
LLAMA_SPINQUANT = TransformConfig( | ||
transform_groups={ | ||
"R1": TransformScheme( | ||
type="hadamard", | ||
apply=[ | ||
TransformArgs( | ||
targets=["embed_tokens", "o_proj", "down_proj"], | ||
location="weight_output", | ||
), | ||
TransformArgs( | ||
targets=[ | ||
"q_proj", | ||
"k_proj", | ||
"v_proj", | ||
"up_proj", | ||
"gate_proj", | ||
"lm_head", | ||
], | ||
location="weight_input", | ||
inverse=True, | ||
), | ||
], | ||
), | ||
"R2": TransformScheme( | ||
type="hadamard", | ||
apply=[ | ||
TransformArgs( | ||
targets=["v_proj"], | ||
location="weight_output", | ||
), | ||
TransformArgs( | ||
targets=["o_proj"], location="weight_input", inverse=True | ||
), | ||
], | ||
), | ||
"R3": TransformScheme( | ||
type="hadamard", | ||
apply=[ | ||
TransformArgs( | ||
targets=["self_attn"], | ||
location="k_cache", | ||
), | ||
TransformArgs( | ||
targets=["self_attn"], | ||
location="q_attn", | ||
), | ||
], | ||
), | ||
"R4": TransformScheme( | ||
type="hadamard", | ||
apply=[ | ||
TransformArgs( | ||
targets=["down_proj"], | ||
location="input", | ||
), | ||
TransformArgs( | ||
targets=["down_proj"], location="weight_input", inverse=True | ||
), | ||
], | ||
), | ||
} | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Dict, Optional | ||
|
||
from compressed_tensors.transform import TransformScheme, apply_transform_config | ||
|
||
from llmcompressor.core import Event, EventType, State | ||
from llmcompressor.modifiers import Modifier | ||
|
||
from .template.quip import QUIP | ||
|
||
|
||
class TransformModifier(Modifier): | ||
preset_config: Optional[str] = None | ||
config_groups: Optional[Dict[str, TransformScheme]] = None | ||
|
||
# model validator to validate both preset and config groups are not provided | ||
|
||
def on_initialize(self, state: State, **kwargs) -> bool: | ||
if self.preset_config is not None: | ||
# import config template and customize to model | ||
pass | ||
|
||
# config = TransformConfig(config_groups=self.config_groups) | ||
config = QUIP | ||
|
||
apply_transform_config(state.model, config) | ||
|
||
return True | ||
|
||
def on_start(self, state: State, event: Event, **kwargs): | ||
self.started_ = True | ||
|
||
def on_event(self, state: State, event: Event, **kwargs): | ||
if event.type_ == EventType.CALIBRATION_EPOCH_START: | ||
if not self.started_: | ||
self.on_start(state, None) | ||
|
||
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: | ||
pass | ||
|
||
elif event.type_ == EventType.CALIBRATION_EPOCH_END: | ||
if not self.ended_: | ||
self.on_end(state, None) | ||
|
||
def on_end(self, state: State, event: Event, **kwargs): | ||
self.ended_ = True | ||
|
||
def on_finalize(self, state: State, **kwargs) -> bool: | ||
if not self.ended_: | ||
self.on_end(state, None) | ||
|
||
return True |
29 changes: 29 additions & 0 deletions
29
tests/llmcompressor/modifiers/transform/test_correctness.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
import torch | ||
from compressed_tensors.transform import apply_transform_config | ||
from transformers import AutoModelForCausalLM | ||
|
||
from llmcompressor.modifiers.transform.template.quip import QUIP | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dtype,exp_max,exp_mse", [ | ||
(torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 | ||
(torch.float32, 4e-4, 2e-9) | ||
] | ||
) | ||
def test_apply_correctness(dtype, exp_max, exp_mse): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
"meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype | ||
) | ||
|
||
input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} | ||
with torch.no_grad(): | ||
true_output = model(**input) | ||
|
||
apply_transform_config(model, QUIP) | ||
with torch.no_grad(): | ||
output = model(**input) | ||
|
||
assert torch.max(true_output.logits - output.logits) <= exp_max | ||
assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This location isn't supported yet, correct?
https://github.com/neuralmagic/compressed-tensors/pull/334/files