99from typing import List , Optional
1010
1111import torch
12- from torch .utils ._python_dispatch import return_and_correct_aliasing
1312
1413from torchao .utils import (
1514 TORCH_VERSION_AT_LEAST_2_5 ,
1615 TorchAOBaseTensor ,
17- fill_defaults ,
1816)
1917
2018__all__ = [
@@ -42,12 +40,12 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
4240 int4 quantization with preshuffled packing format (for all granularities)
4341
4442 Tensor Attributes:
45- _data : preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
43+ qdata : preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
4644 preshuffling is specific to fbgemm kernels, see Note for motivation, detailed layout doc is WIP
4745 for bf16 activation:
48- group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
46+ group_scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N ) for 3D Tensor, where B is batch size,
4947 dtype is the same as the original Tensor dtype
50- group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size,
48+ group_zero: (K/group_size, N) for 2D Tensor, (B, K/group_size, N ) for 3D Tensor, where B is batch size,
5149 dtype is the same as the original Tensor dtype
5250 for float8 activation:
5351 group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor
@@ -57,9 +55,6 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
5755
5856 Non-Tensor Attributes:
5957 block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
60- shape_multiplier: is the multipler from _data to the real weight, since
61- we pack the weight for int4, for example, when we pack the last dimension for
62- a 2D tensor, the shape_multiplier will be [1, 2]
6358 shape: shape of the original Tensor
6459
6560 Note on Details for preshuffle for fbgemm kernel:
@@ -80,104 +75,48 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
8075 requires symmetric quantization
8176 """
8277
83- tensor_data_attrs = ["_data" , "group_scale" ]
84- tensor_attributes = ["block_size" , "shape_multiplier" , "shape" ]
78+ tensor_data_names = ["qdata" , "group_scale" ]
79+ optional_tensor_data_names = ["group_zero" , "row_scale" ]
80+ tensor_attribute_names = ["block_size" , "shape" ]
8581
8682 def __new__ (
8783 cls ,
88- _data ,
84+ qdata ,
8985 group_scale ,
9086 group_zero ,
9187 row_scale ,
9288 block_size ,
93- shape_multiplier ,
9489 shape ,
9590 ):
9691 kwargs = {}
97- kwargs ["device" ] = _data .device
92+ kwargs ["device" ] = qdata .device
9893 kwargs ["dtype" ] = group_scale .dtype
9994 kwargs ["requires_grad" ] = False
10095 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
10196
10297 def __init__ (
10398 self ,
104- _data : torch .Tensor ,
99+ qdata : torch .Tensor ,
105100 group_scale : torch .Tensor ,
106101 group_zero : Optional [torch .Tensor ],
107102 row_scale : Optional [torch .Tensor ],
108103 block_size : List [int ],
109- shape_multiplier : List [int ],
110104 shape : List [int ],
111105 ):
112106 # one and only one of group_scale and group_zero should be None
113107 assert group_zero is None or row_scale is None
114108 assert not (group_zero is not None and row_scale is not None )
115- self ._data = _data
109+ self .qdata = qdata
116110 self .group_scale = group_scale
117111 self .group_zero = group_zero
118112 self .row_scale = row_scale
119- self .shape_multiplier = shape_multiplier
120113 self .block_size = block_size
121114
122- def __tensor_flatten__ (self ):
123- if getattr (self , "group_zero" ) is None :
124- assert getattr (self , "row_scale" ) is not None
125- return self .tensor_data_attrs + ["row_scale" ], [
126- getattr (self , attr ) for attr in self .tensor_attributes
127- ]
128- else :
129- return self .tensor_data_attrs + ["group_zero" ], [
130- getattr (self , attr ) for attr in self .tensor_attributes
131- ]
132-
133- @classmethod
134- def __tensor_unflatten__ (
135- cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
136- ):
137- tensors = [tensor_data_dict [name ] for name in cls .tensor_data_attrs ]
138- tensors .append (tensor_data_dict .get ("group_zero" , None ))
139- tensors .append (tensor_data_dict .get ("row_scale" , None ))
140- return cls (
141- * tensors ,
142- * tensor_attributes ,
143- )
144-
145- def _apply_fn_to_data (self , fn ):
146- tensors = [fn (getattr (self , name )) for name in self .tensor_data_attrs ]
147- t1 = getattr (self , "group_zero" )
148- tensors .append (fn (t1 ) if t1 is not None else None )
149- t2 = getattr (self , "row_scale" )
150- tensors .append (fn (t2 ) if t2 is not None else None )
151- return self .__class__ (
152- * tensors ,
153- * [getattr (self , attr ) for attr in self .tensor_attributes ],
154- )
155-
156- def __repr__ (self ):
157- return (
158- f"{ self .__class__ .__name__ } (weight={ self ._data } , block_size={ self .block_size } , "
159- f"shape_multiplier={ self .shape_multiplier } , shape={ self .shape } , device={ self .device } , dtype={ self .dtype } , "
160- f"requires_grad={ self .requires_grad } )"
161- )
162-
163115 def _quantization_type (self ):
164116 return f"shape={ self .shape } , block_size={ self .block_size } , device={ self .device } "
165117
166- def to (self , * args , ** kwargs ):
167- kwargs = self ._get_to_kwargs (* args , ** kwargs )
168- device = kwargs .pop ("device" )
169- return self .__class__ (
170- self ._data .to (device ),
171- self .group_scale .to (device ),
172- self .group_zero .to (device ) if self .group_zero is not None else None ,
173- self .row_scale .to (device ) if self .row_scale is not None else None ,
174- self .block_size ,
175- self .shape_multiplier ,
176- self .shape ,
177- )
178-
179118 @classmethod
180- def from_float (
119+ def from_hp (
181120 cls ,
182121 w : torch .Tensor ,
183122 block_size : List [int ],
@@ -237,17 +176,12 @@ def from_float(
237176 group_zero = None
238177 row_scale = group_zero_or_row_scale
239178
240- shape_multiplier = [1 ] * wq .ndim
241- shape_multiplier [- 1 ] = 2
242-
243- del w
244179 return Int4PreshuffledTensor (
245- _data = wq ,
180+ qdata = wq ,
246181 group_scale = group_scale ,
247182 group_zero = group_zero ,
248183 row_scale = row_scale ,
249184 block_size = block_size ,
250- shape_multiplier = shape_multiplier ,
251185 shape = original_shape ,
252186 )
253187
@@ -265,7 +199,7 @@ def _(func, types, args, kwargs):
265199 orig_input_size = input_tensor .size ()
266200 orig_out_features = weight_tensor .shape [- 2 ]
267201
268- wq = weight_tensor ._data .contiguous ()
202+ wq = weight_tensor .qdata .contiguous ()
269203 group_scale = weight_tensor .group_scale .contiguous ()
270204 # bf16 activation
271205 if weight_tensor .group_zero is not None :
@@ -295,16 +229,17 @@ def _(func, types, args, kwargs):
295229 )
296230 orig_input_size = input_tensor .size ()
297231 orig_out_features = weight_tensor .shape [- 2 ]
298- assert weight_tensor .shape_multiplier [- 1 ] == 2
299232
300- wq = weight_tensor ._data .contiguous ()
233+ wq = weight_tensor .qdata .contiguous ()
301234 group_scale = weight_tensor .group_scale .contiguous ()
302235 if weight_tensor .group_zero is not None :
236+ # bfloat16 activation
303237 group_zero = weight_tensor .group_zero .contiguous ()
304238 res = torch .ops .fbgemm .bf16i4bf16_shuffled_batched (
305239 input_tensor , wq , group_scale , group_zero
306240 )
307241 else :
242+ # fp8 activation
308243 assert weight_tensor .row_scale is not None
309244 row_scale = weight_tensor .row_scale .contiguous ()
310245 xq , x_scale = quantize_fp8_row (input_tensor )
@@ -322,125 +257,6 @@ def _(func, types, args, kwargs):
322257 return res
323258
324259
325- @implements ([aten .detach .default , aten .alias .default ])
326- def _ (func , types , args , kwargs ):
327- return return_and_correct_aliasing (
328- func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
329- )
330-
331-
332- @implements (aten .clone .default )
333- def _ (func , types , args , kwargs ):
334- return return_and_correct_aliasing (
335- func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
336- )
337-
338-
339- def _same_metadata (self : "Int4PreshuffledTensor" , src : "Int4PreshuffledTensor" ) -> bool :
340- return (
341- isinstance (self , Int4PreshuffledTensor )
342- and isinstance (src , Int4PreshuffledTensor )
343- and self .shape == src .shape
344- and self ._data .shape == src ._data .shape
345- and self .group_scale .shape == src .group_scale .shape
346- and (
347- self .group_zero .shape == src .group_zero .shape
348- if self .group_zero is not None
349- else src .group_zero is None
350- )
351- and (
352- self .row_scale .shape == src .row_scale .shape
353- if self .row_scale is not None
354- else src .row_scale is None
355- )
356- and self .block_size == src .block_size
357- and self .shape_multiplier == src .shape_multiplier
358- )
359-
360-
361- @implements (aten .copy_ .default )
362- def _ (func , types , args , kwargs ):
363- self = args [0 ]
364- src = args [1 ]
365- if _same_metadata (self , src ):
366- self_tensors = self .__tensor_flatten__ ()[0 ]
367- for tensor_name in self_tensors :
368- getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
369- return
370- raise ValueError (
371- f"Not supported args for copy_ due to metadata mismatch: { args [0 ], args [1 ]} "
372- )
373-
374-
375- @implements (aten .cat .default )
376- def _ (func , types , args , kwargs ):
377- tensors , dim = fill_defaults (args , 2 , [[], 0 ])
378- tensor_0 = tensors [0 ]
379- if dim < 0 :
380- dim = dim + tensor_0 .ndim
381-
382- for i in range (1 , len (tensors )):
383- assert tensor_0 ._data .ndim == tensors [i ]._data .ndim
384- assert tensor_0 .group_scale .ndim == tensors [i ].group_scale .ndim
385- assert tensor_0 .group_zero .ndim == tensors [i ].group_zero .ndim
386- assert tensor_0 .block_size == tensors [i ].block_size
387- assert tensor_0 .shape_multiplier == tensors [i ].shape_multiplier
388-
389- _data = [t ._data for t in tensors ]
390- group_scale = [t .group_scale for t in tensors ]
391- group_zero = [t .group_zero for t in tensors ]
392-
393- # with group wise quantization, dimension of group_scale, _data and
394- # origianl shape will be the same, so original dim argument applies
395- # to both _data and group_scale
396- cat_data = aten .cat .default (_data , dim )
397- if cat_data .ndim == 2 :
398- sz_dim = 1 - dim
399- else :
400- sz_dim = dim
401-
402- cat_group_scale = aten .cat .default (group_scale , sz_dim )
403- cat_group_zero = aten .cat .default (group_zero , sz_dim )
404- new_shape = list (cat_data .shape )
405- for i in range (len (tensor_0 .shape_multiplier )):
406- new_shape [i ] *= tensor_0 .shape_multiplier [i ]
407- new_shape = tuple (new_shape )
408- new = tensor_0 .__class__ (
409- cat_data ,
410- cat_group_scale ,
411- cat_group_zero ,
412- block_size = tensor_0 .block_size ,
413- shape_multiplier = tensor_0 .shape_multiplier ,
414- shape = new_shape ,
415- )
416- return return_and_correct_aliasing (func , args , kwargs , new )
417-
418-
419- @implements (aten .transpose .int )
420- def _ (func , types , args , kwargs ):
421- self , dim0 , dim1 = args
422- _data = self ._data .transpose (dim0 , dim1 ).contiguous ()
423- shape_multiplier = self .shape_multiplier .copy ()
424- shape_multiplier [dim0 ], shape_multiplier [dim1 ] = (
425- shape_multiplier [dim1 ],
426- shape_multiplier [dim0 ],
427- )
428-
429- tensor_shape = list (_data .shape )
430- for i in range (len (shape_multiplier )):
431- tensor_shape [i ] *= shape_multiplier [i ]
432- tensor_shape = tuple (tensor_shape )
433- new = self .__class__ (
434- _data ,
435- self .group_scale ,
436- self .group_zero ,
437- self .block_size ,
438- shape_multiplier ,
439- tensor_shape ,
440- )
441- return return_and_correct_aliasing (func , args , kwargs , new )
442-
443-
444260Int4PreshuffledTensor .__module__ = "torchao.quantization"
445261
446262if TORCH_VERSION_AT_LEAST_2_5 :
0 commit comments