Skip to content

[CPU] Fix AWQ on CPU after refactoring #2688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 60 additions & 37 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
# LICENSE file in the root directory of this source tree.
import copy
import tempfile
import unittest

import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.dtypes import Int4CPULayout
from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
from torchao.utils import (
Expand Down Expand Up @@ -45,15 +46,15 @@ def forward(self, x):
return x


@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
)
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_6,
reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig",
)
devices = ["cpu"]
if (
torch.cuda.is_available()
and _is_fbgemm_genai_gpu_available()
and TORCH_VERSION_AT_LEAST_2_6
):
devices.append("cuda")


class TestAWQ(TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
Expand All @@ -68,8 +69,8 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
device = "cuda"
@parameterized.expand([(device,) for device in devices])
def test_awq_functionality(self, device):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -80,13 +81,21 @@ def test_awq_functionality(self):
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
if device == "cuda":
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
)
torch.manual_seed(1234)
else:
assert False, "Unsupported device: {}".format(device)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -117,8 +126,8 @@ def test_awq_functionality(self):
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq < loss_base

def test_awq_loading(self):
device = "cuda"
@parameterized.expand([(device,) for device in devices])
def test_awq_loading(self, device):
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -136,13 +145,20 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
if device == "cuda":
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
)
else:
assert False, "Unsupported device: {}".format(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -171,14 +187,14 @@ def test_awq_loading(self):
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

def test_awq_loading_vllm(self):
@parameterized.expand([(device,) for device in devices])
def test_awq_loading_vllm(self, device):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint

There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
"""
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
Expand All @@ -196,13 +212,20 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
if device == "cuda":
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
)
else:
assert False, "Unsupported device: {}".format(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down
23 changes: 23 additions & 0 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
aten = torch.ops.aten


def _same_metadata(self: "Int4CPUAQTTensorImpl", src: "Int4CPUAQTTensorImpl") -> bool:
return (
isinstance(self, Int4CPUAQTTensorImpl)
and isinstance(src, Int4CPUAQTTensorImpl)
and self.packed_weight.shape == src.packed_weight.shape
and self.scale_and_zero.shape == src.scale_and_zero.shape
and self.transposed == src.transposed
and type(self._layout) == type(src._layout)
)


@dataclass(frozen=True)
class Int4CPULayout(Layout):
"""Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`.
Expand Down Expand Up @@ -208,6 +219,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

if func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

raise NotImplementedError(
f"{cls.__name__} dispatch: attempting to run {func}, this is not supported"
)
Expand Down
94 changes: 71 additions & 23 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def wiki2_eval(


# adapted from Hicham Badri (@mobicham)
def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
def benchmark(
model, tokenizer, max_length, tasks=None, evaluation_limit=None, device="cuda"
):
import lm_eval
import numpy as np

Expand Down Expand Up @@ -126,21 +128,33 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("truthfulqa_mc2", 0)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
if "winogrande" in tasks:
for task in [("winogrande", 5)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
if "arc_challenge" in tasks:
for task in [("arc_challenge", 25)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])

Expand All @@ -149,14 +163,22 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("hellaswag", 10)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
if "gsm8k" in tasks:
for task in [("gsm8k", 5)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
# ############################################
Expand All @@ -167,7 +189,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("mmlu", 5)]:
tag, fewshot = task
results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results_mmlu[tag])

Expand All @@ -188,7 +214,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
for task in [("leaderboard_bbh", 3)]:
tag, fewshot = task
results[tag] = lm_eval.evaluator.simple_evaluate(
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
model_eval,
tasks=[tag],
num_fewshot=fewshot,
batch_size=eval_batch_size,
limit=evaluation_limit,
)["results"]
print(tag, results[tag])
results["bbh"] = results[tag]
Expand All @@ -202,7 +232,7 @@ def quantize_and_eval(
tasks: list[str],
max_seq_length: int,
calibration_limit: int,
validation_size: int,
evaluation_limit: int,
device: str,
precision: torch.dtype,
compile: bool,
Expand All @@ -223,18 +253,26 @@ def quantize_and_eval(
if quant.startswith("awq-int4wo"):
group_size = int(quant.split("-")[2])
print(f"running {quant} quantization with group size {group_size}")
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
from torchao.quantization import FbgemmConfig
from torchao.dtypes import Int4CPULayout
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig

# use_hqq = True
# base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
if device == "cuda":
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig for CUDA soon
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
)
else:
assert False, "Unsupported device: {}".format(device)
print(f"running {quant} prepare and calibrate")
t0 = time.time()
quant_config = AWQConfig(base_config, step="prepare")
Expand Down Expand Up @@ -291,7 +329,14 @@ def quantize_and_eval(
if compile:
model = torch.compile(model)

return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device)
return benchmark(
model,
tokenizer,
max_seq_length,
tasks=tasks,
evaluation_limit=evaluation_limit,
device=device,
)


if __name__ == "__main__":
Expand All @@ -310,8 +355,8 @@ def quantize_and_eval(
"--tasks",
nargs="+",
type=str,
help="Task to benchmark model on. Either PPL or QA",
default=["PPL"],
help="Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md",
default=["hellaswag"],
)
parser.add_argument(
"--calibration_limit",
Expand All @@ -320,7 +365,10 @@ def quantize_and_eval(
help="Number of samples to use for calibration. Default is 10.",
)
parser.add_argument(
"--validation_size", type=int, default=1, help="Validation size. Default is 1."
"--evaluation_limit",
type=int,
default=None,
help="Number of samples to use for evaluation. Default is None (all).",
)
parser.add_argument(
"--device",
Expand Down Expand Up @@ -368,7 +416,7 @@ def quantize_and_eval(
args.tasks,
args.max_seq_length,
args.calibration_limit,
args.validation_size,
args.evaluation_limit,
args.device,
args.precision,
args.compile,
Expand Down
Loading