Skip to content

Commit 331cb3f

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/forbid-extra
2 parents aeb4862 + b2df366 commit 331cb3f

File tree

34 files changed

+1604
-248
lines changed

34 files changed

+1604
-248
lines changed

.github/actions/test/action.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ runs:
6969
echo "::endgroup::"
7070
7171
if [[ "${ENABLE_COVERAGE}" == "true" ]]; then
72-
echo "::group::consolidating coverage reports"
73-
mkdir -p coverage-results
74-
mv .coverage coverage-results/ || echo ".coverage file not found"
75-
mv coverage-html coverage-results/ || echo "coverage-html folder not found"
76-
mv coverage.json coverage-results/ || echo "coverage.json file not found"
72+
echo "::group::check coverage reports"
73+
if [ ! -d coverage-html ]; then
74+
echo "ERROR: coverage-html folder not found"
75+
exit 1
76+
fi
7777
echo "::endgroup::"
7878
fi
7979

.github/workflows/build-test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ on:
2525

2626
# test related parameters
2727
test_configs:
28-
description: "python, label, timeout"
28+
description: "python, label, timeout, etc"
2929
type: string
3030
required: true
3131

@@ -53,6 +53,7 @@ jobs:
5353
python: ${{ matrix.test_config.python }}
5454
timeout: ${{ matrix.test_config.timeout }}
5555
whl: ${{ needs.BUILD.outputs.whl }}
56+
code_coverage: ${{ matrix.test_config.code_coverage || false }}
5657
secrets: inherit
5758

5859
UPLOAD:

.github/workflows/test.yml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ jobs:
7070
permissions:
7171
contents: 'read'
7272
id-token: 'write'
73+
pages: 'write'
74+
environment:
75+
name: github-pages
76+
url: ${{ steps.coverage.outputs.page_url }}
7377

7478
steps:
7579

@@ -134,6 +138,11 @@ jobs:
134138
suitename: test-${{ inputs.python }}-${{ inputs.test_label }}
135139
code_coverage: ${{ inputs.code_coverage }}
136140

141+
- name: extra info for summary
142+
if: ${{ inputs.code_coverage }}
143+
run: |
144+
echo "EXTRA='Code Coverage: https://neuralmagic.github.io/compressed-tensors/'" >> $GITHUB_ENV
145+
137146
- name: summary
138147
uses: neuralmagic/nm-actions/actions/[email protected]
139148
if: success() || failure()
@@ -143,6 +152,7 @@ jobs:
143152
python: ${{ inputs.python }}
144153
whl: ${{ inputs.whl }}
145154
test_status: ${{ steps.test.outputs.status }}
155+
extra: ${{ env.EXTRA }}
146156

147157
- name: copy results to GCP
148158
run: |
@@ -157,9 +167,13 @@ jobs:
157167
retention-days: 5
158168

159169
- name: upload coverage report
160-
uses: actions/upload-artifact@v4
161-
if: (success() || failure()) && inputs.code_coverage
170+
uses: actions/upload-pages-artifact@v3
171+
if: ${{ inputs.code_coverage }}
162172
with:
163-
name: coverage-results
164-
path: coverage-results/*
173+
path: coverage-html
165174
retention-days: 5
175+
176+
- name: deploy to Github Pages
177+
id: coverage
178+
uses: actions/deploy-pages@v4
179+
if: ${{ inputs.code_coverage }}

.github/workflows/trigger-all.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
wf_category: ${{ inputs.wf_category || 'NIGHTLY' }}
3333
gitref: ${{ inputs.gitref || 'main' }}
3434
push_to_pypi: ${{ (github.event.schedule == '30 0 * * *') || inputs.push_to_pypi || false }}
35-
test_configs: '[{"python":"3.11.4","label":"ubuntu-24.04","timeout":"40"},
35+
test_configs: '[{"python":"3.11.4","label":"ubuntu-24.04","timeout":"40","code_coverage":true},
3636
{"python":"3.10.12","label":"ubuntu-22.04","timeout":"40"},
3737
{"python":"3.9.17","label":"k8s-h100-solo","timeout":"40"},
3838
{"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,15 +392,18 @@ def compress_model(self, model: Module):
392392
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
393393

394394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395-
module_device = get_execution_device(module).type
396-
is_meta = module_device == "meta"
395+
module_device = get_execution_device(module)
396+
is_meta = module_device.type == "meta"
397397

398398
exec_device = "meta" if is_meta else "cpu"
399399
onloading_device = "meta" if is_meta else module_device
400400

401401
# in the future, support compression on same device
402402
with align_module_device(module, execution_device=exec_device):
403-
state_dict = module.state_dict(prefix=f"{prefix}.")
403+
state_dict = {
404+
f"{prefix}.{name}": param
405+
for name, param in module.named_parameters(recurse=False)
406+
}
404407

405408
# quantization first
406409
if prefix in module_to_scheme:
@@ -421,7 +424,7 @@ def compress_model(self, model: Module):
421424

422425
# remove any existing parameters
423426
offload_device = get_offloaded_device(module)
424-
for name, _ in list(module.named_parameters()):
427+
for name, _ in list(module.named_parameters(recurse=False)):
425428
delete_offload_parameter(module, name)
426429

427430
# replace with compressed parameters
@@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
458461
if prefix in module_to_scheme or prefix in sparse_compression_targets:
459462
# in the future, support decompression on same device
460463
with align_module_device(module, execution_device="cpu"):
461-
state_dict = module.state_dict(prefix=f"{prefix}.")
464+
state_dict = {
465+
f"{prefix}.{name}": param
466+
for name, param in module.named_parameters(recurse=False)
467+
}
462468

463469
# sparsity first
464470
if prefix in sparse_compression_targets:
@@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
483489
# remove any existing parameters
484490
exec_device = get_execution_device(module)
485491
offload_device = get_offloaded_device(module)
486-
for name, _ in list(module.named_parameters()):
492+
for name, _ in list(module.named_parameters(recurse=False)):
487493
delete_offload_parameter(module, name)
488494

489495
# replace with decompressed parameters
@@ -747,12 +753,16 @@ def _replace_weights(self, dense_weight_generator, model: Module):
747753

748754
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
749755
"""
750-
Returns a dictionary which maps quantized module names to their quantization schemes
756+
Returns a dictionary which maps quantized module names to their quantization
757+
schemes. Only includes modules with weight quantization
751758
"""
752759
return {
753760
fix_fsdp_module_name(name): module.quantization_scheme
754761
for name, module in model.named_modules()
755-
if is_module_quantized(module)
762+
if (
763+
hasattr(module, "quantization_scheme")
764+
and module.quantization_scheme.weights is not None
765+
)
756766
}
757767

758768

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def compression_param_names(self) -> Tuple[str]:
6161
"weight_global_scale",
6262
)
6363

64+
def compression_param_info(
65+
self,
66+
weight_shape: torch.Size,
67+
quantization_args: Optional[QuantizationArgs] = None,
68+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
69+
"""
70+
Creates a dictionary of expected shapes and dtypes for each compression
71+
parameter used by the compressor
72+
73+
:param weight_shape: uncompressed weight shape
74+
:param quantization_args: quantization parameters for the weight
75+
:return: dictionary mapping compressed parameter names to shape and dtype
76+
"""
77+
output = {
78+
"weight_packed": (
79+
torch.Size((weight_shape[0], weight_shape[1] // 2)),
80+
torch.uint8,
81+
),
82+
}
83+
return output
84+
6485
def compress_weight(
6586
self,
6687
weight: Tensor,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,27 @@ def dequantize(
111111
elif scale.ndim == 2:
112112
if scale.shape[1] == 1:
113113
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114-
else:
114+
# Scale height matches input or is 1 -> group quantization across columns
115+
#
116+
# Example 1: scale.shape[0] == 1
117+
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118+
#
119+
# Example 2: scale.shape[0] == x_q.shape[0]
120+
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121+
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
115122
group_size = int(x_q.shape[1] / scale.shape[1])
116123
args = QuantizationArgs(
117124
strategy=QuantizationStrategy.GROUP, group_size=group_size
118125
)
126+
else:
127+
rows, cols = x_q.shape[-2], x_q.shape[-1]
128+
block_height = rows // scale.shape[0] # Rows per block
129+
block_width = cols // scale.shape[1] # Columns per block
130+
131+
args = QuantizationArgs(
132+
strategy=QuantizationStrategy.BLOCK,
133+
block_structure=[block_height, block_width],
134+
)
119135
else:
120136
raise ValueError(
121137
f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -189,14 +205,67 @@ def _process_quantization(
189205
q_min, q_max = calculate_range(args, x.device)
190206
group_size = args.group_size
191207

192-
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
193-
n_dims = x.shape
194-
if len(n_dims) > 2:
195-
x = x.squeeze(0)
208+
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
209+
if args.strategy == QuantizationStrategy.BLOCK:
210+
original_shape = x.shape
211+
rows, cols = x.shape[-2], x.shape[-1]
212+
block_height, block_width = args.block_structure
213+
214+
# Ensure exact division (tensor dimensions must be divisible by block size)
215+
if rows % block_height != 0:
216+
raise ValueError(
217+
f"Tensor height {rows} is not divisible by block_height {block_height}. "
218+
f"Block quantization requires exact division."
219+
)
220+
if cols % block_width != 0:
221+
raise ValueError(
222+
f"Tensor width {cols} is not divisible by block_width {block_width}. "
223+
f"Block quantization requires exact division."
224+
)
225+
226+
# reshape into blocks and transpose to make each block contiguous
227+
num_rows_blocks = rows // block_height
228+
num_cols_blocks = cols // block_width
229+
x_blocks = x.reshape(
230+
num_rows_blocks,
231+
block_height,
232+
num_cols_blocks,
233+
block_width,
234+
).transpose(1, 2)
235+
236+
# expand scale/zero_point for blocks
237+
sb = scale.unsqueeze(-1).unsqueeze(-1)
238+
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
239+
if do_quantize:
240+
# quantize blocks
241+
x_blocks = _quantize(
242+
x=x_blocks,
243+
scale=sb,
244+
zero_point=zb,
245+
q_min=q_min,
246+
q_max=q_max,
247+
args=args,
248+
dtype=dtype,
249+
global_scale=global_scale,
250+
)
251+
if do_dequantize:
252+
# dequantize blocks
253+
x_blocks = _dequantize(
254+
x_q=x_blocks,
255+
scale=sb,
256+
zero_point=zb,
257+
global_scale=global_scale,
258+
)
259+
# restore original shape
260+
output = x_blocks.transpose(1, 2).reshape(original_shape)
261+
elif args.strategy in (
262+
QuantizationStrategy.GROUP,
263+
QuantizationStrategy.TENSOR_GROUP,
264+
):
196265

197266
output_dtype = dtype if dtype is not None else x.dtype
198267
output = torch.zeros_like(x).to(output_dtype)
199-
columns = output.shape[1]
268+
columns = output.shape[-1]
200269

201270
# TODO: make validation step for inputs
202271

@@ -226,14 +295,12 @@ def _process_quantization(
226295
perm = torch.argsort(g_idx)
227296
x = safe_permute(x, perm, dim=1)
228297

229-
x = torch.reshape(
230-
x,
231-
(
232-
x.shape[0],
233-
ceil(x.shape[1] / group_size),
234-
group_size,
235-
),
298+
# Maintain all dimensions apart from the last dim, which is divided by the group_size
299+
reshaped_dims = (
300+
ceil(x.shape[-1] / group_size),
301+
group_size,
236302
)
303+
x = x.unflatten(-1, reshaped_dims)
237304

238305
if do_quantize:
239306
output = _quantize(
@@ -256,19 +323,12 @@ def _process_quantization(
256323
global_scale=global_scale,
257324
)
258325

259-
output = torch.reshape(
260-
output,
261-
(output.shape[0], output.shape[1] * output.shape[2]),
262-
)
263-
326+
output = output.flatten(start_dim=-2)
264327
output = output.to(output_dtype)
265328

266329
if not is_column_order:
267330
output = safe_permute(output, torch.argsort(perm), dim=1)
268331

269-
if len(n_dims) > 2:
270-
output = output.unsqueeze(0)
271-
272332
else: # covers channel, token and tensor strategies
273333
if do_quantize:
274334
output = _quantize(

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import math
18+
import warnings
1819
from enum import Enum
1920
from typing import List, Optional
2021

@@ -172,14 +173,43 @@ def _initialize_scale_zero_point(
172173

173174
if base_name == "weight" and weight_shape is not None:
174175
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
175-
# (output_channels, 1)
176+
# (output_channels, 1) - only for weights
176177
expected_shape = (weight_shape[0], 1)
177178
elif quantization_args.strategy in (
178179
QuantizationStrategy.TENSOR_GROUP,
179180
QuantizationStrategy.GROUP,
180181
):
182+
# GROUP/TENSOR_GROUP for both weights and activations
181183
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
182184
expected_shape = (weight_shape[0], max(num_groups, 1))
185+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186+
# For block quantization, scale shape should match number of blocks - only for weights
187+
if quantization_args.block_structure is None:
188+
raise ValueError(
189+
"Block quantization requires block_structure to be specified"
190+
)
191+
block_height, block_width = quantization_args.block_structure
192+
rows, cols = weight_shape[-2], weight_shape[-1]
193+
num_rows_blocks = math.ceil(rows / block_height)
194+
num_cols_blocks = math.ceil(cols / block_width)
195+
196+
# Warn if dimensions don't divide evenly
197+
if rows % block_height != 0 or cols % block_width != 0:
198+
warnings.warn(
199+
f"Block quantization: tensor shape {weight_shape} does not divide evenly "
200+
f"by block structure {quantization_args.block_structure}. "
201+
f"Some blocks will be incomplete which may affect quantization quality.",
202+
UserWarning,
203+
)
204+
205+
expected_shape = (num_rows_blocks, num_cols_blocks)
206+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
207+
warnings.warn(
208+
f"BLOCK quantization not supported for {base_name} activations. "
209+
f"Falling back to tensor-level quantization.",
210+
UserWarning,
211+
)
212+
expected_shape = 1
183213

184214
# 3. Identify quantization scale and zp dtype
185215
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype

0 commit comments

Comments
 (0)