Skip to content

Commit 8a4efcc

Browse files
authored
Merge branch 'main' into support_multi_compressor
2 parents 158720a + de945c6 commit 8a4efcc

File tree

18 files changed

+311
-114
lines changed

18 files changed

+311
-114
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta"
55
[tool.black]
66
line-length = 88
77
target-version = ['py36']
8+
9+
[tool.pytest.ini_options]
10+
markers = [
11+
"unit: tests to ensure code correctness and regression test functionality",
12+
"smoke: quick tests to check basic functionality",
13+
"sanity: tests to ensure that new changes do not break existing functionality",
14+
"regression: detailed tests to ensure major functions work correctly",
15+
"integration: tests which integrate with a third party service such as HF",
16+
]

src/compressed_tensors/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
SPARSITY_CONFIG_NAME = "sparsity_config"
15+
# configs
1616
QUANTIZATION_CONFIG_NAME = "quantization_config"
17-
COMPRESSION_CONFIG_NAME = "compression_config"
18-
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17+
SPARSITY_CONFIG_NAME = "sparsity_config"
18+
TRANSFORM_CONFIG_NAME = "transform_config"
19+
20+
# required fields
1921
COMPRESSION_VERSION_NAME = "version"
2022
QUANTIZATION_METHOD_NAME = "quant_method"
23+
24+
# auxillary configs
25+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_CONFIG_NAME,
3030
QUANTIZATION_METHOD_NAME,
3131
SPARSITY_CONFIG_NAME,
32+
TRANSFORM_CONFIG_NAME,
3233
)
3334
from compressed_tensors.compressors.base import BaseCompressor
3435
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@
4344
)
4445
from compressed_tensors.quantization.lifecycle import expand_target_names
4546
from compressed_tensors.quantization.utils import is_module_quantized
47+
from compressed_tensors.transform import TransformConfig
4648
from compressed_tensors.utils import (
4749
align_module_device,
4850
delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105107

106108
sparsity_config: Optional[SparsityCompressionConfig] = None
107109
quantization_config: Optional[QuantizationConfig] = None
110+
transform_config: Optional[TransformConfig] = None
108111

109112
@classmethod
110113
def from_pretrained(
@@ -144,6 +147,8 @@ def from_compression_config(
144147

145148
sparsity_config = cls.parse_sparsity_config(compression_config)
146149
quantization_config = cls.parse_quantization_config(compression_config)
150+
# TODO: transform config is not support by CompressedTensorsConfig yet
151+
147152
if sparsity_config is None and quantization_config is None:
148153
return None
149154

@@ -177,7 +182,6 @@ def from_pretrained_model(
177182
algorithm
178183
:return: compressor for the configs, or None if model is not compressed
179184
"""
180-
181185
if quantization_format is not None:
182186
# llmcompressor incorrectly passes in a CompressionFormat when
183187
# the value string is expected - handle both cases
@@ -194,17 +198,23 @@ def from_pretrained_model(
194198
quantization_config = QuantizationConfig.from_pretrained(
195199
model, format=quantization_format
196200
)
201+
202+
# use config passed as argument
197203
if isinstance(sparsity_config, str): # we passed in a sparsity format
198204
sparsity_config = SparsityCompressionConfig.load_from_registry(
199205
sparsity_config
200206
)
201207

202-
if sparsity_config is None and quantization_config is None:
208+
# use config attached to model
209+
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
210+
211+
if not any((quantization_config, sparsity_config, transform_config)):
203212
return None
204213

205214
return cls(
206215
sparsity_config=sparsity_config,
207216
quantization_config=quantization_config,
217+
transform_config=transform_config,
208218
)
209219

210220
@staticmethod
@@ -283,13 +293,17 @@ def __init__(
283293
self,
284294
sparsity_config: Optional[SparsityCompressionConfig] = None,
285295
quantization_config: Optional[QuantizationConfig] = None,
296+
transform_config: Optional[TransformConfig] = None,
286297
):
287298
self.sparsity_config = sparsity_config
288299
self.quantization_config = quantization_config
300+
self.transform_config = transform_config
301+
289302
self.sparsity_compressor = None
290303
self.quantization_compressor: Optional[
291304
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
292305
] = None
306+
# no transform compressor is required
293307

294308
if sparsity_config is not None:
295309
self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -718,44 +732,49 @@ def update_config(self, save_directory: str):
718732
719733
:param save_directory: path to a folder containing a HF model config
720734
"""
721-
if self.quantization_config is None and self.sparsity_config is None:
722-
return
723-
724-
config_file_path = os.path.join(save_directory, CONFIG_NAME)
725-
if not os.path.exists(config_file_path):
726-
_LOGGER.warning(
727-
f"Could not find a valid model config file in "
728-
f"{save_directory}. Compression config will not be saved."
729-
)
735+
# this check is also done in `from_pretrained_model`,
736+
# but not in `from_pretrained`` or `from_compression_config``
737+
if not any(
738+
(self.quantization_config, self.sparsity_config, self.transform_config)
739+
):
730740
return
731741

732-
with open(config_file_path, "r") as config_file:
733-
config_data = json.load(config_file)
734-
735-
# required metadata whenever a quantization or sparsity config is present
742+
# write to config.json file, regardless of whether it exists already
736743
# overwrite previous config and version if already existing
737-
config_data[QUANTIZATION_CONFIG_NAME] = {}
738-
config_data[QUANTIZATION_CONFIG_NAME][
739-
COMPRESSION_VERSION_NAME
740-
] = compressed_tensors.__version__
741-
742-
if self.quantization_config is not None:
743-
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
744+
config_file_path = os.path.join(save_directory, CONFIG_NAME)
745+
if os.path.exists(config_file_path):
746+
with open(config_file_path, "r") as file:
747+
config_data = json.load(file)
744748
else:
745-
config_data[QUANTIZATION_CONFIG_NAME][
746-
QUANTIZATION_METHOD_NAME
747-
] = DEFAULT_QUANTIZATION_METHOD
748-
749-
# quantization and sparsity configs
750-
if self.quantization_config is not None:
751-
quant_config_data = self.quantization_config.model_dump()
752-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
753-
if self.sparsity_config is not None:
754-
sparsity_config_data = self.sparsity_config.model_dump()
755-
config_data[QUANTIZATION_CONFIG_NAME][
756-
SPARSITY_CONFIG_NAME
757-
] = sparsity_config_data
749+
config_data = {}
750+
751+
# serialize configs into json
752+
qconfig_data = (
753+
self.quantization_config.model_dump(exclude=["quant_method", "format"])
754+
if self.quantization_config is not None
755+
else {}
756+
)
757+
sconfig_data = (
758+
self.sparsity_config.model_dump()
759+
if self.sparsity_config is not None
760+
else {}
761+
)
762+
tconfig_data = (
763+
self.transform_config.model_dump()
764+
if self.transform_config is not None
765+
else {}
766+
)
758767

768+
# construct compression (quantization) config
769+
config_data[QUANTIZATION_CONFIG_NAME] = {
770+
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
771+
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
772+
SPARSITY_CONFIG_NAME: sconfig_data,
773+
TRANSFORM_CONFIG_NAME: tconfig_data,
774+
**qconfig_data,
775+
}
776+
777+
# write results to config.json file
759778
with open(config_file_path, "w") as config_file:
760779
json.dump(config_data, config_file, indent=2, sort_keys=True)
761780

src/compressed_tensors/quantization/quant_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from compressed_tensors.utils import Aliasable
2121
from compressed_tensors.utils.helpers import deprecated
22-
from pydantic import BaseModel, Field, field_validator, model_validator
22+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2323

2424

2525
__all__ = [
@@ -358,6 +358,8 @@ def pytorch_dtype(self) -> torch.dtype:
358358
def get_observer(self) -> str:
359359
return self.observer
360360

361+
model_config = ConfigDict(extra="forbid")
362+
361363

362364
def round_to_quantized_type(
363365
tensor: torch.Tensor, args: QuantizationArgs

src/compressed_tensors/quantization/quant_config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from enum import Enum
16-
from typing import Dict, List, Optional, Union
16+
from typing import Annotated, Any, Dict, List, Optional, Union
1717

1818
from compressed_tensors.config import CompressionFormat
1919
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
@@ -26,7 +26,7 @@
2626
module_type,
2727
parse_out_kv_cache_args,
2828
)
29-
from pydantic import BaseModel, Field
29+
from pydantic import BaseModel, ConfigDict, Field
3030
from torch.nn import Module
3131

3232

@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)
145+
# `run_compressed` is a dummy, unused arg for backwards compatibility
146+
# see: https://github.com/huggingface/transformers/pull/39324
147+
run_compressed: Annotated[Any, Field(exclude=True)] = None
145148

146149
def model_post_init(self, __context):
147150
"""
@@ -254,3 +257,6 @@ def requires_calibration_data(self):
254257
return True
255258

256259
return False
260+
261+
# TODO set `extra="forbid"` when upstream transformers is compatible
262+
model_config = ConfigDict(extra="ignore")

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warnings
1616
from copy import deepcopy
17-
from typing import Any, Dict, List, Optional
17+
from typing import List, Optional
1818

1919
from compressed_tensors.config import CompressionFormat
2020
from compressed_tensors.quantization.quant_args import (
@@ -23,7 +23,7 @@
2323
QuantizationStrategy,
2424
QuantizationType,
2525
)
26-
from pydantic import BaseModel, model_validator
26+
from pydantic import BaseModel, ConfigDict, model_validator
2727

2828

2929
__all__ = [
@@ -83,6 +83,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
8383

8484
return model
8585

86+
model_config = ConfigDict(extra="forbid")
87+
8688

8789
"""
8890
Pre-Set Quantization Scheme Args

src/compressed_tensors/transform/apply.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors import TRANSFORM_CONFIG_NAME
1617
from compressed_tensors.transform import TransformConfig, TransformFactory
1718

1819

@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3031
for name, scheme in config.config_groups.items():
3132
factory = TransformFactory.from_scheme(scheme, name=name)
3233
factory.apply_to_model(model)
34+
35+
# attach config to model for compression/serialization
36+
setattr(model, TRANSFORM_CONFIG_NAME, config)

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21-
from compressed_tensors import InternalModule
2221
from compressed_tensors.registry.registry import RegistryMixin, T
2322
from compressed_tensors.transform import (
2423
TransformArgs,
@@ -34,6 +33,7 @@
3433
register_offload_module,
3534
update_offload_parameter,
3635
)
36+
from compressed_tensors.utils.internal import InternalModule
3737
from torch import Tensor
3838
from torch.nn import Module, Parameter
3939

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
16-
from typing import Optional, Union
15+
from typing import Optional
1716

1817
import torch
1918
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -26,7 +25,7 @@
2625
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2726
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2827
from torch import Tensor, device, dtype
29-
from torch.nn import Linear, Module, Parameter
28+
from torch.nn import Module, Parameter
3029

3130

3231
@TransformFactory.register("hadamard")
@@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5453
"""
5554
assert hasattr(module, "weight")
5655
size = get_transform_size(module, args.location, self.scheme.head_dim)
57-
dtype = module.weight.dtype
56+
dtype = self.scheme.precision
5857
device = get_offloaded_device(module)
5958
exec_device = get_execution_device(module)
6059

6160
factory_kwargs = {"construct_device": exec_device}
6261
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6362
perm = self.perms[weight] if self.scheme.randomize else None
64-
return HadamardTransform(weight, perm, args, type(module))
63+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6564

6665
def _create_weight(
6766
self,
@@ -85,15 +84,18 @@ def __init__(
8584
self,
8685
weight: Parameter,
8786
perm: Optional[Parameter],
87+
scheme: TransformScheme,
8888
args: TransformArgs,
8989
module_type: type[torch.nn.Module],
9090
):
9191
super().__init__()
9292
self.weight = weight
9393
self.perm = perm
94+
self.scheme = scheme
9495
self.args = args
9596
self.module_type = module_type
96-
self._scale = math.sqrt(weight.size(0))
97+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98+
self._precision = scheme.precision if args.is_online() else torch.float64
9799

98100
def forward(self, value: Tensor) -> Tensor:
99101
weight = self.weight
@@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor:
105107
weight = weight.T
106108

107109
return (
108-
apply_transform_weight(weight, value, self.args.location, self.module_type)
110+
apply_transform_weight(
111+
weight.to(self._precision),
112+
value.to(self._precision),
113+
self.args.location,
114+
self.module_type,
115+
)
109116
/ self._scale
110-
)
117+
).to(value.dtype)

0 commit comments

Comments
 (0)