Skip to content

Commit 3885950

Browse files
post-merge fixes
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 1876c3b commit 3885950

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,8 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
754754
fix_fsdp_module_name(name): module.quantization_scheme
755755
for name, module in model.named_modules()
756756
if (
757-
hasattr(module, "quantization_scheme") and
758-
module.quantization_scheme.weights is not None
757+
hasattr(module, "quantization_scheme")
758+
and module.quantization_scheme.weights is not None
759759
)
760760
}
761761

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import math
1616
from typing import Optional
1717

18-
import math
1918
import torch
2019
from compressed_tensors.transform import TransformArgs, TransformScheme
2120
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -24,10 +23,7 @@
2423
apply_transform_weight,
2524
get_transform_size,
2625
)
27-
from compressed_tensors.utils import (
28-
get_execution_device,
29-
get_offloaded_device,
30-
)
26+
from compressed_tensors.utils import get_execution_device, get_offloaded_device
3127
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
3228
from torch import Tensor, device, dtype
3329
from torch.nn import Module, Parameter
@@ -107,7 +103,8 @@ def forward(self, value: Tensor) -> Tensor:
107103

108104
if self.args.inverse:
109105
weight = weight.T
110-
111-
return apply_transform_weight(
112-
weight, value, self.args.location, self.module_type
113-
) / self._scale
106+
107+
return (
108+
apply_transform_weight(weight, value, self.args.location, self.module_type)
109+
/ self._scale
110+
)

tests/test_transform/factory/test_correctness.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,17 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
9191

9292
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
9393
@pytest.mark.parametrize("randomize", (True, False))
94-
def test_correctness_model(type, randomize, model_apply, offload=False):
94+
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
95+
def test_correctness_model(
96+
type, randomize, input_batch_size, model_apply, offload=False
97+
):
9598
# load model
9699
model = model_apply[0]
97100
if offload:
98101
model = offloaded_dispatch(model, torch.device("cuda"))
99102

100103
# get output
101-
input = torch.rand((17, 5, model.fcs[0].in_features))
104+
input = torch.rand((input_bathc_size, 5, model.fcs[0].in_features))
102105
if offload:
103106
input = input.to(torch.device("cuda"))
104107
true_output = model(input)

0 commit comments

Comments
 (0)