-
Notifications
You must be signed in to change notification settings - Fork 19
Speed up nvfp4 pack/unpack w/ torch.compile #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! results look good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please test compressed model in vllm
@dsikka Is there a particular test I should run? So far I've tested loading a compressed model into vllm and generating text and the output seems normal. I've also done small tests to confirm that the values saved exactly match the output from the previous version. |
@fynnsu You should do at least a quick accuracy evaluation using lm_eval, comparing an old model to a newly packed model. import argparse
import lm_eval
from lm_eval.utils import make_table
def main():
results = lm_eval.simple_evaluate(
model="hf",
model_args={
"pretrained": "MODEL_ID",
"add_bos_token": True,
"dtype": "auto",
},
tasks="arc_challenge_llama",
batch_size=128,
apply_chat_template=True,
fewshot_as_multiturn=True,
)
print(make_table(results)) |
Alternatively you can just check that the safetensor outputs are exactly the same, which is probably faster import sys
import torch
from safetensors.torch import load_file
def compare_safetensors(file1, file2):
data1 = load_file(file1)
data2 = load_file(file2)
keys1 = set(data1.keys())
keys2 = set(data2.keys())
all_keys = sorted(keys1.union(keys2))
differences = []
for key in all_keys:
if key not in data1:
print(f"{key} missing in {file1}")
differences.append(key)
pass
elif key not in data2:
print(f"{key} missing in {file2}")
differences.append(key)
pass
else:
tensor1 = data1[key]
tensor2 = data2[key]
if tensor1.shape != tensor2.shape or not torch.allclose(tensor1, tensor2, rtol=1e-5, atol=1e-8):
print(f"Difference found in key: {key}: {torch.count_nonzero(abs(tensor1) < abs(tensor2))}")
differences.append(key)
else:
print("succ")
return differences
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python compare_safetensors.py <file1.safetensors> <file2.safetensors>")
sys.exit(1)
file1, file2 = sys.argv[1], sys.argv[2]
diff_keys = compare_safetensors(file1, file2)
if not diff_keys:
print("All keys match exactly.")
else:
print(f"{len(diff_keys)} differing keys found.") |
@@ -105,6 +105,7 @@ def decompress_weight( | |||
return decompressed_weight | |||
|
|||
|
|||
@torch.compile(fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to default to dynamic=True to avoid recompilation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
I compared the outputs from the new version with the old version using both your script and Note: I also ran these tests after the most recent change adding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great addition!
Applies
torch.compile
to nvfp compressor as suggested in vllm-project/llm-compressor#1485Speed ups anywhere from 3x to 25x depending on cpu/gpu.
Benchmarks
Benchmark pack/unpack (new)
Benchmark pack/unpack (old)
This also translates to real usage improvements:
time python examples/quantization_w4a16_fp4/llama3_example.py
New
Old