Skip to content

Speed up nvfp4 pack/unpack w/ torch.compile #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

fynnsu
Copy link

@fynnsu fynnsu commented Jul 22, 2025

Applies torch.compile to nvfp compressor as suggested in vllm-project/llm-compressor#1485

Speed ups anywhere from 3x to 25x depending on cpu/gpu.

Benchmarks

Benchmark pack/unpack (new)
./benchmark_fp4_packing.sh 

Running benchmark on CPU...
Creating tensor of shape (8192, 8192) on cpu...
Benchmarking on tensor of shape torch.Size([8192, 8192]) (67,108,864 elements) on cpu
Iter 1/3 - Pack: 1.8949s
Iter 2/3 - Pack: 0.0454s
Iter 3/3 - Pack: 0.0464s
Iter 1/3 - Unpack: 0.0784s
Iter 2/3 - Unpack: 0.0374s
Iter 3/3 - Unpack: 0.0377s

Benchmark Results:
Device: cpu
Tensor shape: 8192x8192 (67,108,864 elements)
Average pack time: 0.6622s (101.34M elements/s)
Average unpack time: 0.0512s (1311.30M elements/s)
Compression ratio: 8.00x
--------------------------------
Running benchmark on GPU...
Creating tensor of shape (8192, 8192) on cuda...
Benchmarking on tensor of shape torch.Size([8192, 8192]) (67,108,864 elements) on cuda:0
Iter 1/3 - Pack: 0.4240s
Iter 2/3 - Pack: 0.0027s
Iter 3/3 - Pack: 0.0027s
Iter 1/3 - Unpack: 0.0459s
Iter 2/3 - Unpack: 0.0005s
Iter 3/3 - Unpack: 0.0005s

Benchmark Results:
Device: cuda:0
Tensor shape: 8192x8192 (67,108,864 elements)
Average pack time: 0.1431s (468.85M elements/s)
Average unpack time: 0.0156s (4299.47M elements/s)
Compression ratio: 8.00x
Benchmark pack/unpack (old)
./benchmark_fp4_packing.sh 

Running benchmark on CPU...
Creating tensor of shape (8192, 8192) on cpu...
Benchmarking on tensor of shape torch.Size([8192, 8192]) (67,108,864 elements) on cpu
Iter 1/3 - Pack: 1.1415s
Iter 2/3 - Pack: 1.0510s
Iter 3/3 - Pack: 0.9702s
Iter 1/3 - Unpack: 0.1212s
Iter 2/3 - Unpack: 0.0985s
Iter 3/3 - Unpack: 0.0991s

Benchmark Results:
Device: cpu
Tensor shape: 8192x8192 (67,108,864 elements)
Average pack time: 1.0542s (63.66M elements/s)
Average unpack time: 0.1062s (631.62M elements/s)
Compression ratio: 8.00x
--------------------------------
Running benchmark on GPU...
Creating tensor of shape (8192, 8192) on cuda...
Benchmarking on tensor of shape torch.Size([8192, 8192]) (67,108,864 elements) on cuda:0
Iter 1/3 - Pack: 0.1073s
Iter 2/3 - Pack: 0.0649s
Iter 3/3 - Pack: 0.0650s
Iter 1/3 - Unpack: 0.0081s
Iter 2/3 - Unpack: 0.0050s
Iter 3/3 - Unpack: 0.0050s

Benchmark Results:
Device: cuda:0
Tensor shape: 8192x8192 (67,108,864 elements)
Average pack time: 0.0791s (848.65M elements/s)
Average unpack time: 0.0060s (11118.17M elements/s)
Compression ratio: 8.00x

This also translates to real usage improvements:

time python examples/quantization_w4a16_fp4/llama3_example.py

New

...
( Model preperation and test generation output excluded )
...
Compressing model: 423it [00:20, 20.30it/s]

real    2m37.888s
user    10m28.019s
sys     1m19.335s

Old

...
( Model preperation and test generation output excluded )
...
Compressing model: 423it [01:27,  4.83it/s]

real    3m59.430s
user    22m39.828s
sys     6m37.413s

Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! results look good

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please test compressed model in vllm

@fynnsu
Copy link
Author

fynnsu commented Jul 24, 2025

please test compressed model in vllm

@dsikka Is there a particular test I should run?

So far I've tested loading a compressed model into vllm and generating text and the output seems normal. I've also done small tests to confirm that the values saved exactly match the output from the previous version.

@fynnsu fynnsu requested a review from dsikka July 25, 2025 13:46
@kylesayrs
Copy link
Contributor

@fynnsu You should do at least a quick accuracy evaluation using lm_eval, comparing an old model to a newly packed model.

import argparse
import lm_eval
from lm_eval.utils import make_table

def main():
    results = lm_eval.simple_evaluate(
        model="hf",
        model_args={
            "pretrained": "MODEL_ID",
            "add_bos_token": True,
            "dtype": "auto",
        },
        tasks="arc_challenge_llama",
        batch_size=128,
        apply_chat_template=True,
        fewshot_as_multiturn=True,
    )

    print(make_table(results))

@kylesayrs
Copy link
Contributor

kylesayrs commented Jul 29, 2025

Alternatively you can just check that the safetensor outputs are exactly the same, which is probably faster

import sys
import torch
from safetensors.torch import load_file

def compare_safetensors(file1, file2):
    data1 = load_file(file1)
    data2 = load_file(file2)

    keys1 = set(data1.keys())
    keys2 = set(data2.keys())

    all_keys = sorted(keys1.union(keys2))
    differences = []

    for key in all_keys:
        if key not in data1:
            print(f"{key} missing in {file1}")
            differences.append(key)
            pass
        elif key not in data2:
            print(f"{key} missing in {file2}")
            differences.append(key)
            pass
        else:
            tensor1 = data1[key]
            tensor2 = data2[key]
            if tensor1.shape != tensor2.shape or not torch.allclose(tensor1, tensor2, rtol=1e-5, atol=1e-8):
                print(f"Difference found in key: {key}: {torch.count_nonzero(abs(tensor1) < abs(tensor2))}")
                differences.append(key)
            else:
                print("succ")

    return differences

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python compare_safetensors.py <file1.safetensors> <file2.safetensors>")
        sys.exit(1)

    file1, file2 = sys.argv[1], sys.argv[2]
    diff_keys = compare_safetensors(file1, file2)

    if not diff_keys:
        print("All keys match exactly.")
    else:
        print(f"{len(diff_keys)} differing keys found.")

@@ -105,6 +105,7 @@ def decompress_weight(
return decompressed_weight


@torch.compile(fullgraph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to default to dynamic=True to avoid recompilation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@fynnsu
Copy link
Author

fynnsu commented Jul 30, 2025

Alternatively you can just check that the safetensor outputs are exactly the same, which is probably faster

I compared the outputs from the new version with the old version using both your script and diff -r on saved compressed directories. The outputs are exactly the same (including the safetensors files).

Note: I also ran these tests after the most recent change adding dynamic=True to the torch.compile calls as suggested by @mgoin.

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants