Skip to content

Commit d3f845f

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Add quantized tensor for groupwise lut based quantization (#2676)
Summary: Pull Request resolved: #2676 Reviewed By: metascroy Differential Revision: D79119915
1 parent dc36108 commit d3f845f

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from torch.utils._python_dispatch import return_and_correct_aliasing
12+
13+
from torchao.utils import TorchAOBaseTensor
14+
15+
# --- C++ Op Accessor Functions ---
16+
17+
18+
def get_pack_op(weight_nbit: int):
19+
"""Gets the C++ packing function from the 'torchao' namespace."""
20+
op_name = f"_pack_groupwise_{weight_nbit}bit_weight_with_lut"
21+
if not hasattr(torch.ops.torchao, op_name):
22+
raise NotImplementedError(f"Packing op for {weight_nbit}-bit not found.")
23+
return getattr(torch.ops.torchao, op_name)
24+
25+
26+
def get_linear_op(weight_nbit: int):
27+
"""Gets the C++ fused linear function from the 'torchao' namespace."""
28+
op_name = f"_linear_groupwise_{weight_nbit}bit_weight_with_lut"
29+
if not hasattr(torch.ops.torchao, op_name):
30+
raise NotImplementedError(f"Linear op for {weight_nbit}-bit not found.")
31+
return getattr(torch.ops.torchao, op_name)
32+
33+
34+
aten = torch.ops.aten
35+
36+
37+
class GroupwiseLutQuantizedTensor(TorchAOBaseTensor):
38+
"""
39+
Corrected version that is robust for torch.export.
40+
"""
41+
42+
tensor_data_attrs = [
43+
"packed_weight",
44+
]
45+
tensor_attributes = [
46+
"bit_width",
47+
"lut_group_size",
48+
"scale_group_size",
49+
"shape",
50+
"dtype",
51+
]
52+
53+
@staticmethod
54+
def __new__(
55+
cls,
56+
packed_weight: torch.Tensor,
57+
bit_width: int,
58+
lut_group_size: int,
59+
scale_group_size: int,
60+
shape: torch.Size,
61+
dtype: torch.dtype,
62+
):
63+
kwargs = {
64+
"device": packed_weight.device,
65+
"dtype": dtype,
66+
"layout": packed_weight.layout,
67+
"requires_grad": False,
68+
}
69+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
70+
71+
def __init__(
72+
self,
73+
packed_weight: torch.Tensor,
74+
bit_width: int,
75+
lut_group_size: int,
76+
scale_group_size: int,
77+
shape: torch.Size,
78+
dtype: torch.dtype,
79+
):
80+
self.packed_weight = packed_weight
81+
self.bit_width = bit_width
82+
self.lut_group_size = lut_group_size
83+
self.scale_group_size = scale_group_size
84+
85+
def __repr__(self):
86+
return (
87+
f"{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype}, "
88+
f"bit_width={self.bit_width}, lut_group_size={self.lut_group_size}, "
89+
f"scale_group_size={self.scale_group_size}, device={self.device})"
90+
)
91+
92+
def __tensor_flatten__(self):
93+
metadata = [getattr(self, attr) for attr in self.tensor_attributes]
94+
return self.tensor_data_attrs, metadata
95+
96+
@classmethod
97+
def __tensor_unflatten__(cls, tensors, metadata, outer_size, outer_stride):
98+
return cls(
99+
*[tensors[name] for name in cls.tensor_data_attrs],
100+
*metadata,
101+
)
102+
103+
def _apply_fn_to_data(self, fn):
104+
new_packed_weight = fn(self.packed_weight)
105+
return self.__class__(
106+
new_packed_weight,
107+
self.bit_width,
108+
self.lut_group_size,
109+
self.scale_group_size,
110+
self.shape,
111+
self.dtype,
112+
)
113+
114+
@classmethod
115+
def from_packed_data(
116+
cls,
117+
int_data: torch.Tensor,
118+
luts: torch.Tensor,
119+
scales: torch.Tensor,
120+
bit_width: int,
121+
lut_group_size: int,
122+
scale_group_size: int,
123+
original_shape: torch.Size,
124+
bias: Optional[torch.Tensor] = None,
125+
target: str = "auto",
126+
):
127+
"""
128+
A factory function that uses the C++ packing op to create an instance
129+
of the GroupwiseLutQuantizedTensor.
130+
"""
131+
# 1. Get the correct C++ packing operator based on the bit width
132+
pack_op = get_pack_op(bit_width)
133+
134+
# 2. Call the C++ op to get the single packed weight tensor
135+
packed_weight = pack_op(
136+
int_data,
137+
luts,
138+
scale_group_size,
139+
lut_group_size,
140+
scales,
141+
bias,
142+
target,
143+
)
144+
145+
# 3. Construct and return the custom tensor object
146+
return cls(
147+
packed_weight,
148+
bit_width,
149+
lut_group_size,
150+
scale_group_size,
151+
original_shape,
152+
int_data.dtype,
153+
)
154+
155+
156+
implements = GroupwiseLutQuantizedTensor.implements
157+
158+
159+
@implements([F.linear])
160+
def _(func, types, args, kwargs):
161+
"""
162+
Override for `torch.nn.functional.linear`. This implementation calls the
163+
fused C++ kernel directly, avoiding a separate dequantization step.
164+
"""
165+
input_tensor, weight_tensor, _ = (
166+
args[0],
167+
args[1],
168+
args[2] if len(args) > 2 else None,
169+
)
170+
171+
# Get the correct C++ operator based on the bit width
172+
linear_op = get_linear_op(weight_tensor.bit_width)
173+
174+
# --- Input Reshaping Logic ---
175+
#
176+
# The underlying C++ kernel (`linear_op`) is designed to compute a matrix multiplication on 2D tensors ONLY.
177+
# It assumes a simple (m, k) matrix layout.
178+
# We "flatten" the high-rank input into a 2D matrix that the C++ kernel understands, and then
179+
# "unflatten" the 2D output back to restore the original batch dimensions.
180+
181+
# Store original shape to reshape the output later
182+
original_shape = input_tensor.shape
183+
k = weight_tensor.shape[1]
184+
# If input rank > 2, flatten all batch dimensions into one
185+
if input_tensor.dim() > 2:
186+
input_tensor = input_tensor.reshape(-1, k)
187+
188+
# The 'n' dimension is the output feature dimension from the weight
189+
n = weight_tensor.shape[0]
190+
191+
# Call the fused C++ linear operator
192+
output = linear_op(
193+
input_tensor,
194+
weight_tensor.packed_weight,
195+
weight_tensor.scale_group_size,
196+
weight_tensor.lut_group_size,
197+
n,
198+
k,
199+
)
200+
201+
# Reshape the output to match the original batch dimensions
202+
if len(original_shape) > 2:
203+
output_shape = original_shape[:-1] + (n,)
204+
return output.reshape(output_shape)
205+
206+
return output
207+
208+
209+
@implements([aten.detach.default])
210+
def _(func, types, args, kwargs):
211+
return return_and_correct_aliasing(
212+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
213+
)
214+
215+
216+
@implements(aten.clone.default)
217+
def _(func, types, args, kwargs):
218+
return return_and_correct_aliasing(
219+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
220+
)

0 commit comments

Comments
 (0)