Skip to content

Commit a3cd59d

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/transform-merge
2 parents ba85784 + 5478b43 commit a3cd59d

File tree

12 files changed

+291
-23
lines changed

12 files changed

+291
-23
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,10 @@ def compress_model(self, model: Module):
400400

401401
# in the future, support compression on same device
402402
with align_module_device(module, execution_device=exec_device):
403-
state_dict = module.state_dict(prefix=f"{prefix}.")
403+
state_dict = {
404+
f"{prefix}.{name}": param
405+
for name, param in module.named_parameters(recurse=False)
406+
}
404407

405408
# quantization first
406409
if prefix in module_to_scheme:
@@ -421,7 +424,7 @@ def compress_model(self, model: Module):
421424

422425
# remove any existing parameters
423426
offload_device = get_offloaded_device(module)
424-
for name, _ in list(module.named_parameters()):
427+
for name, _ in list(module.named_parameters(recurse=False)):
425428
delete_offload_parameter(module, name)
426429

427430
# replace with compressed parameters
@@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
458461
if prefix in module_to_scheme or prefix in sparse_compression_targets:
459462
# in the future, support decompression on same device
460463
with align_module_device(module, execution_device="cpu"):
461-
state_dict = module.state_dict(prefix=f"{prefix}.")
464+
state_dict = {
465+
f"{prefix}.{name}": param
466+
for name, param in module.named_parameters(recurse=False)
467+
}
462468

463469
# sparsity first
464470
if prefix in sparse_compression_targets:
@@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
483489
# remove any existing parameters
484490
exec_device = get_execution_device(module)
485491
offload_device = get_offloaded_device(module)
486-
for name, _ in list(module.named_parameters()):
492+
for name, _ in list(module.named_parameters(recurse=False)):
487493
delete_offload_parameter(module, name)
488494

489495
# replace with decompressed parameters

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,22 @@ def dequantize(
111111
elif scale.ndim == 2:
112112
if scale.shape[1] == 1:
113113
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114-
else:
114+
# Scale height matches input or is 1 -> group quantization across columns
115+
#
116+
# Example 1: scale.shape[0] == 1
117+
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118+
#
119+
# Example 2: scale.shape[0] == x_q.shape[0]
120+
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121+
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
115122
group_size = int(x_q.shape[1] / scale.shape[1])
116123
args = QuantizationArgs(
117124
strategy=QuantizationStrategy.GROUP, group_size=group_size
118125
)
126+
else:
127+
args = QuantizationArgs(
128+
strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
129+
)
119130
else:
120131
raise ValueError(
121132
f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -189,7 +200,63 @@ def _process_quantization(
189200
q_min, q_max = calculate_range(args, x.device)
190201
group_size = args.group_size
191202

192-
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
203+
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
204+
if args.strategy == QuantizationStrategy.BLOCK:
205+
original_shape = x.shape
206+
rows, cols = x.shape[-2], x.shape[-1]
207+
block_height, block_width = args.block_structure
208+
209+
# Ensure exact division (tensor dimensions must be divisible by block size)
210+
if rows % block_height != 0:
211+
raise ValueError(
212+
f"Tensor height {rows} is not divisible by block_height {block_height}. "
213+
f"Block quantization requires exact division."
214+
)
215+
if cols % block_width != 0:
216+
raise ValueError(
217+
f"Tensor width {cols} is not divisible by block_width {block_width}. "
218+
f"Block quantization requires exact division."
219+
)
220+
221+
# reshape into blocks and transpose to make each block contiguous
222+
num_rows_blocks = rows // block_height
223+
num_cols_blocks = cols // block_width
224+
x_blocks = x.reshape(
225+
num_rows_blocks,
226+
block_height,
227+
num_cols_blocks,
228+
block_width,
229+
).transpose(1, 2)
230+
231+
# expand scale/zero_point for blocks
232+
sb = scale.unsqueeze(-1).unsqueeze(-1)
233+
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
234+
if do_quantize:
235+
# quantize blocks
236+
x_blocks = _quantize(
237+
x=x_blocks,
238+
scale=sb,
239+
zero_point=zb,
240+
q_min=q_min,
241+
q_max=q_max,
242+
args=args,
243+
dtype=dtype,
244+
global_scale=global_scale,
245+
)
246+
if do_dequantize:
247+
# dequantize blocks
248+
x_blocks = _dequantize(
249+
x_q=x_blocks,
250+
scale=sb,
251+
zero_point=zb,
252+
global_scale=global_scale,
253+
)
254+
# restore original shape
255+
output = x_blocks.transpose(1, 2).reshape(original_shape)
256+
elif args.strategy in (
257+
QuantizationStrategy.GROUP,
258+
QuantizationStrategy.TENSOR_GROUP,
259+
):
193260
n_dims = x.shape
194261
if len(n_dims) > 2:
195262
x = x.squeeze(0)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import math
18+
import warnings
1819
from enum import Enum
1920
from typing import List, Optional
2021

@@ -172,14 +173,43 @@ def _initialize_scale_zero_point(
172173

173174
if base_name == "weight" and weight_shape is not None:
174175
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
175-
# (output_channels, 1)
176+
# (output_channels, 1) - only for weights
176177
expected_shape = (weight_shape[0], 1)
177178
elif quantization_args.strategy in (
178179
QuantizationStrategy.TENSOR_GROUP,
179180
QuantizationStrategy.GROUP,
180181
):
182+
# GROUP/TENSOR_GROUP for both weights and activations
181183
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
182184
expected_shape = (weight_shape[0], max(num_groups, 1))
185+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186+
# For block quantization, scale shape should match number of blocks - only for weights
187+
if quantization_args.block_structure is None:
188+
raise ValueError(
189+
"Block quantization requires block_structure to be specified"
190+
)
191+
block_height, block_width = quantization_args.block_structure
192+
rows, cols = weight_shape[-2], weight_shape[-1]
193+
num_rows_blocks = math.ceil(rows / block_height)
194+
num_cols_blocks = math.ceil(cols / block_width)
195+
196+
# Warn if dimensions don't divide evenly
197+
if rows % block_height != 0 or cols % block_width != 0:
198+
warnings.warn(
199+
f"Block quantization: tensor shape {weight_shape} does not divide evenly "
200+
f"by block structure {quantization_args.block_structure}. "
201+
f"Some blocks will be incomplete which may affect quantization quality.",
202+
UserWarning,
203+
)
204+
205+
expected_shape = (num_rows_blocks, num_cols_blocks)
206+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
207+
warnings.warn(
208+
f"BLOCK quantization not supported for {base_name} activations. "
209+
f"Falling back to tensor-level quantization.",
210+
UserWarning,
211+
)
212+
expected_shape = 1
183213

184214
# 3. Identify quantization scale and zp dtype
185215
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype

src/compressed_tensors/quantization/quant_args.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warnings
1616
from enum import Enum
17-
from typing import Any, Dict, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import torch
2020
from compressed_tensors.utils import Aliasable
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
153153
:param symmetric: whether or not quantization scale is symmetric about zero-point
154154
:param strategy: string id determining the scope of scale/zero-point to apply
155155
:param group_size: group length to use for the group strategy
156-
:param block_structure: 2d block structure to use for the block strategy, must be
157-
of the format "2x4", "8x16", etc.
156+
:param block_structure: 2d block structure to use for the block strategy; must be
157+
a list of two ints [rows, cols] like [128, 128].
158158
:param dynamic: set True to perform dynamic quantization - values will not be
159159
calibrated during calibration phase, instead during inference new quantization
160160
ranges will be observed with every sample. Defaults to False for static
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
169169
symmetric: bool = True
170170
group_size: Optional[int] = None
171171
strategy: Optional[QuantizationStrategy] = None
172-
block_structure: Optional[str] = None
172+
block_structure: Optional[List[int]] = None
173173
dynamic: Union[DynamicType, bool] = False
174174
actorder: Union[ActivationOrdering, bool, None] = None
175175
observer: Optional[str] = Field(
@@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]:
207207

208208
return value
209209

210+
@field_validator("block_structure", mode="before")
211+
def validate_block_structure(cls, value) -> Optional[List[int]]:
212+
if value is None:
213+
return value
214+
# For backward compatibility, allow string format "2x4", "8x16", etc.
215+
if isinstance(value, str):
216+
try:
217+
return [int(x) for x in value.split("x")]
218+
except Exception:
219+
raise ValueError(
220+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
221+
)
222+
if isinstance(value, (list, tuple)):
223+
if len(value) != 2 or not all(isinstance(v, int) for v in value):
224+
raise ValueError(
225+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
226+
)
227+
return list(value)
228+
raise ValueError(
229+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
230+
)
231+
210232
@field_validator("strategy", mode="before")
211233
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
212234
if isinstance(value, str):
@@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
277299

278300
# infer observer w.r.t. dynamic
279301
if dynamic:
280-
if strategy not in (
302+
supported_strategies = (
281303
QuantizationStrategy.TOKEN,
282304
QuantizationStrategy.TENSOR,
283305
QuantizationStrategy.TENSOR_GROUP,
284-
):
306+
QuantizationStrategy.GROUP,
307+
)
308+
if strategy not in supported_strategies:
285309
raise ValueError(
286-
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
287-
"must be used for dynamic quantization",
310+
f"One of {supported_strategies} must be used for dynamic quantization"
288311
)
289312

290313
if (

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from copy import deepcopy
1617
from typing import Any, Dict, List, Optional
1718

@@ -52,6 +53,7 @@ class QuantizationScheme(BaseModel):
5253
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
5354
inputs = model.input_activations
5455
outputs = model.output_activations
56+
weights = model.weights
5557

5658
if inputs is not None:
5759
if inputs.actorder is not None:
@@ -61,6 +63,22 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
6163
if outputs.actorder is not None:
6264
raise ValueError("Cannot apply actorder to output activations")
6365

66+
if (
67+
inputs
68+
and weights
69+
and weights.strategy == QuantizationStrategy.GROUP
70+
and inputs.strategy == QuantizationStrategy.GROUP
71+
and weights.group_size != inputs.group_size
72+
):
73+
warnings.warn(
74+
"Using GROUP strategy for both weights and input_activations "
75+
f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
76+
"may complicate fused kernel implementations. Consider using "
77+
"TENSOR_GROUP strategy for both or matching group sizes.",
78+
UserWarning,
79+
stacklevel=2,
80+
)
81+
6482
return model
6583

6684

@@ -243,6 +261,29 @@ def is_preset_scheme(name: str) -> bool:
243261
),
244262
)
245263

264+
# Block‐wise FP8 (deepseekv3-style quantization):
265+
# static 128x128 per‐block weights and
266+
# dynamic per‐token‐group activations
267+
FP8_BLOCK = dict(
268+
weights=QuantizationArgs(
269+
num_bits=8,
270+
type=QuantizationType.FLOAT,
271+
strategy=QuantizationStrategy.BLOCK,
272+
symmetric=True,
273+
dynamic=False,
274+
block_structure=[128, 128],
275+
),
276+
input_activations=QuantizationArgs(
277+
num_bits=8,
278+
type=QuantizationType.FLOAT,
279+
strategy=QuantizationStrategy.GROUP,
280+
symmetric=True,
281+
dynamic=True,
282+
observer=None,
283+
group_size=128,
284+
),
285+
)
286+
246287
PRESET_SCHEMES = {
247288
# Unquantized (no-op)
248289
"UNQUANTIZED": UNQUANTIZED,
@@ -257,6 +298,7 @@ def is_preset_scheme(name: str) -> bool:
257298
# Float weight and activation schemes
258299
"FP8": FP8,
259300
"FP8_DYNAMIC": FP8_DYNAMIC,
301+
"FP8_BLOCK": FP8_BLOCK,
260302
"NVFP4A16": NVFP4A16,
261303
"NVFP4": NVFP4,
262304
}

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
171171
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
172172
elif args.strategy == QuantizationStrategy.TENSOR:
173173
reduce_dims = None
174-
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174+
elif args.strategy in (
175+
QuantizationStrategy.TENSOR_GROUP,
176+
QuantizationStrategy.GROUP,
177+
):
175178
if len(value.shape) > 2:
176179
value = value.squeeze(0)
177180

@@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
187190
),
188191
)
189192
else:
193+
supported_strategies = (
194+
QuantizationStrategy.TOKEN,
195+
QuantizationStrategy.TENSOR,
196+
QuantizationStrategy.TENSOR_GROUP,
197+
QuantizationStrategy.GROUP,
198+
)
190199
raise ValueError(
191200
"Dynamic quantization is only supported for ",
192-
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
201+
f"{supported_strategies}",
193202
)
194203

195204
if not reduce_dims:

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
align_module_device,
3030
delete_offload_module,
3131
has_offloaded_params,
32+
match_named_modules,
3233
patch_attr,
3334
register_offload_module,
3435
update_offload_parameter,

tests/test_examples/test_bitmask_compression_ipynb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import nbformat
1615
import pytest
16+
17+
18+
nbformat = pytest.importorskip("nbformat")
1719
from nbconvert.preprocessors import ExecutePreprocessor
1820

1921

0 commit comments

Comments
 (0)