Skip to content

Commit ccd5512

Browse files
committed
Track API usage
1 parent 418593c commit ccd5512

File tree

11 files changed

+96
-8
lines changed

11 files changed

+96
-8
lines changed

torchao/float8/float8_linear_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from typing import Callable, List, Optional, Union
99

10+
import torch
1011
import torch.nn as nn
1112

1213
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
@@ -101,6 +102,7 @@ def convert_to_float8_training(
101102
Returns:
102103
nn.Module: The modified module with swapped linear layers.
103104
"""
105+
torch._C._log_api_usage_once("torchao.quantization.convert_to_float8_training")
104106
if config is None:
105107
config = Float8LinearConfig()
106108

torchao/float8/fsdp_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
3939

4040
from torchao.float8.float8_linear import Float8Linear
4141

42+
torch._C._log_api_usage_once(
43+
"torchao.quantization.precompute_float8_dynamic_scale_for_fsdp"
44+
)
45+
4246
float8_linears: List[Float8Linear] = [
4347
m
4448
for m in module.modules()

torchao/optim/adam.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
bf16_stochastic_round=bf16_stochastic_round,
234234
is_adamw=False,
235235
)
236+
torch._C._log_api_usage_once("torchao.optim.Adam8bit")
236237

237238
@staticmethod
238239
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -263,6 +264,7 @@ def __init__(
263264
bf16_stochastic_round=bf16_stochastic_round,
264265
is_adamw=False,
265266
)
267+
torch._C._log_api_usage_once("torchao.optim.Adam4bit")
266268

267269
@staticmethod
268270
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -293,6 +295,7 @@ def __init__(
293295
bf16_stochastic_round=bf16_stochastic_round,
294296
is_adamw=False,
295297
)
298+
torch._C._log_api_usage_once("torchao.optim.AdamFp8")
296299

297300
@staticmethod
298301
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -323,6 +326,7 @@ def __init__(
323326
bf16_stochastic_round=bf16_stochastic_round,
324327
is_adamw=True,
325328
)
329+
torch._C._log_api_usage_once("torchao.optim.AdamW8bit")
326330

327331
@staticmethod
328332
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -353,6 +357,7 @@ def __init__(
353357
bf16_stochastic_round=bf16_stochastic_round,
354358
is_adamw=True,
355359
)
360+
torch._C._log_api_usage_once("torchao.optim.AdamW4bit")
356361

357362
@staticmethod
358363
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -383,6 +388,7 @@ def __init__(
383388
bf16_stochastic_round=bf16_stochastic_round,
384389
is_adamw=True,
385390
)
391+
torch._C._log_api_usage_once("torchao.optim.AdamWFp8")
386392

387393
@staticmethod
388394
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):

torchao/quantization/pt2e/convert.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,9 +1271,6 @@ def _convert_to_reference_decomposed_fx(
12711271
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
12721272
12731273
"""
1274-
torch._C._log_api_usage_once(
1275-
"quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
1276-
)
12771274
return _convert_fx(
12781275
graph_module,
12791276
is_reference=True,

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def calibrate(model, data_loader):
106106

107107
return torch_prepare_pt2e(model, quantizer)
108108

109-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
109+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_pt2e")
110110
original_graph_meta = model.meta
111111
node_name_to_scope = _get_node_name_to_scope(model)
112112
# TODO: check qconfig_mapping to make sure conv and bn are both configured
@@ -192,7 +192,7 @@ def train_loop(model, train_data):
192192

193193
return torch_prepare_qat_pt2e(model, quantizer)
194194

195-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
195+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_qat_pt2e")
196196
original_graph_meta = model.meta
197197
node_name_to_scope = _get_node_name_to_scope(model)
198198
model = quantizer.transform_for_annotation(model)
@@ -309,7 +309,7 @@ def convert_pt2e(
309309

310310
return torch_convert_pt2e(model, use_reference_representation, fold_quantize)
311311

312-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
312+
torch._C._log_api_usage_once("torchao.quantization.pt2e.convert_pt2e")
313313
if not isinstance(use_reference_representation, bool):
314314
raise ValueError(
315315
"Unexpected argument type for `use_reference_representation`, "

torchao/quantization/qat/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
self.__post_init__()
145145

146146
def __post_init__(self):
147+
torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig")
147148
self.step = self.step.lower()
148149
all_step_values = [s.value for s in QATStep]
149150
if self.step not in all_step_values:
@@ -359,6 +360,7 @@ class ComposableQATQuantizer(TwoStepQuantizer):
359360
"""
360361

361362
def __init__(self, quantizers: List[TwoStepQuantizer]):
363+
torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer")
362364
self.quantizers = quantizers
363365

364366
def prepare(
@@ -385,6 +387,8 @@ def initialize_fake_quantizers(
385387
:class:`~torchao.quantization.qat.fake_quantizer.FakeQuantizer`
386388
in the model based on the provided example inputs.
387389
"""
390+
torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers")
391+
388392
# avoid circular dependencies
389393
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
390394

torchao/quantization/qat/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
*args,
6666
**kwargs,
6767
)
68+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding")
6869
if weight_config is not None:
6970
self.weight_fake_quantizer = FakeQuantizer(weight_config)
7071
else:
@@ -148,6 +149,9 @@ def __init__(
148149
zero_point_precision: torch.dtype = torch.int32,
149150
) -> None:
150151
super().__init__()
152+
torch._C._log_api_usage_once(
153+
"torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer"
154+
)
151155
self.bit_width = 4
152156
self.group_size: int = group_size
153157
self.scale_precision: torch.dtype = scale_precision

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class FakeQuantizer(torch.nn.Module):
4444

4545
def __init__(self, config: FakeQuantizeConfigBase):
4646
super().__init__()
47+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
4748
self.config = config
4849
self.enabled = True
4950
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
*args,
8383
**kwargs,
8484
)
85+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear")
8586
# initialize activation fake quantizer
8687
if activation_config is not None:
8788
self.activation_fake_quantizer = FakeQuantizer(activation_config)
@@ -209,6 +210,9 @@ def __init__(
209210
scales_precision: torch.dtype = torch.float32,
210211
) -> None:
211212
super().__init__()
213+
torch._C._log_api_usage_once(
214+
"torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer"
215+
)
212216
self.groupsize: int = groupsize
213217
self.padding_allowed: bool = padding_allowed
214218
self.precision: torch.dtype = precision
@@ -412,6 +416,9 @@ def __init__(
412416
scales_precision: torch.dtype = torch.bfloat16,
413417
) -> None:
414418
super().__init__()
419+
torch._C._log_api_usage_once(
420+
"torchao.quantization.qat.Int4WeightOnlyQATQuantizer"
421+
)
415422
assert inner_k_tiles in [2, 4, 8]
416423
assert groupsize in [32, 64, 128, 256]
417424
self.inner_k_tiles = inner_k_tiles
@@ -596,6 +603,9 @@ def __init__(
596603
group_size: Optional[int] = 64,
597604
scale_precision: torch.dtype = torch.bfloat16,
598605
):
606+
torch._C._log_api_usage_once(
607+
"torchao.quantization.qat.Float8ActInt4WeightQATQuantizer"
608+
)
599609
if group_size is not None:
600610
weight_granularity = "per_group"
601611
else:

torchao/quantization/quant_api.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125

126126
logger = logging.getLogger(__name__)
127127

128+
# TODO: revisit this list?
128129
__all__ = [
129130
"swap_conv2d_1x1_to_linear",
130131
"Quantizer",
@@ -611,6 +612,8 @@ def quantize_(
611612
quantize_(m, int4_weight_only(group_size=32))
612613
613614
"""
615+
torch._C._log_api_usage_once("torchao.quantization.quantize_")
616+
614617
filter_fn = _is_linear if filter_fn is None else filter_fn
615618

616619
if isinstance(config, ModuleFqnToConfig):
@@ -735,6 +738,11 @@ class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
735738
act_mapping_type: MappingType = MappingType.ASYMMETRIC
736739
set_inductor_config: bool = True
737740

741+
def __post_init__(self):
742+
torch._C._log_api_usage_once(
743+
"torchao.quantization.Int8DynamicActivationInt4WeightConfig"
744+
)
745+
738746

739747
# for BC
740748
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
@@ -846,6 +854,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
846854
layout: Layout = QDQLayout()
847855

848856
def __post_init__(self):
857+
torch._C._log_api_usage_once(
858+
"torchao.quantization.Int8DynamicActivationIntxWeightConfig"
859+
)
849860
assert TORCH_VERSION_AT_LEAST_2_6, (
850861
"Int8DynamicActivationIntxWeightConfig requires torch 2.6+"
851862
)
@@ -996,6 +1007,11 @@ class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
9961007
act_mapping_type: MappingType = MappingType.SYMMETRIC
9971008
set_inductor_config: bool = True
9981009

1010+
def __post_init__(self):
1011+
torch._C._log_api_usage_once(
1012+
"torchao.quantization.Int4DynamicActivationInt4WeightConfig"
1013+
)
1014+
9991015

10001016
# for bc
10011017
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
@@ -1052,6 +1068,11 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
10521068
mode: Optional[str] = "weight_only"
10531069
set_inductor_config: bool = True
10541070

1071+
def __post_init__(self):
1072+
torch._C._log_api_usage_once(
1073+
"torchao.quantization.GemliteUIntXWeightOnlyConfig"
1074+
)
1075+
10551076

10561077
# for BC
10571078
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
@@ -1121,6 +1142,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11211142
set_inductor_config: bool = True
11221143
preserve_zero: Optional[bool] = None
11231144

1145+
def __post_init__(self):
1146+
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
1147+
11241148

11251149
# for BC
11261150
# TODO maybe change other callsites
@@ -1232,6 +1256,9 @@ class Int8WeightOnlyConfig(AOBaseConfig):
12321256
group_size: Optional[int] = None
12331257
set_inductor_config: bool = True
12341258

1259+
def __post_init__(self):
1260+
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
1261+
12351262

12361263
# for BC
12371264
int8_weight_only = Int8WeightOnlyConfig
@@ -1388,6 +1415,11 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
13881415
weight_only_decode: bool = False
13891416
set_inductor_config: bool = True
13901417

1418+
def __post_init__(self):
1419+
torch._C._log_api_usage_once(
1420+
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
1421+
)
1422+
13911423

13921424
# for BC
13931425
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
@@ -1490,6 +1522,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14901522
weight_dtype: torch.dtype = e4m3_dtype
14911523
set_inductor_config: bool = True
14921524

1525+
def __post_init__(self):
1526+
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
1527+
14931528

14941529
# for BC
14951530
float8_weight_only = Float8WeightOnlyConfig
@@ -1620,9 +1655,11 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16201655
set_inductor_config: bool = True
16211656

16221657
def __post_init__(self):
1658+
torch._C._log_api_usage_once(
1659+
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
1660+
)
16231661
if self.mm_config is None:
16241662
self.mm_config = Float8MMConfig(use_fast_accum=True)
1625-
16261663
activation_granularity, weight_granularity = _normalize_granularity(
16271664
self.granularity
16281665
)
@@ -1712,6 +1749,11 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
17121749
activation_dtype: torch.dtype = e5m2_dtype
17131750
weight_dtype: torch.dtype = e4m3_dtype
17141751

1752+
def __post_init__(self):
1753+
torch._C._log_api_usage_once(
1754+
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
1755+
)
1756+
17151757

17161758
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
17171759
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
@@ -1764,6 +1806,9 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
17641806
set_inductor_config: bool = True
17651807

17661808
def __post_init__(self):
1809+
torch._C._log_api_usage_once(
1810+
"torchao.quantization.Float8StaticActivationFloat8WeightConfig"
1811+
)
17671812
if self.mm_config is None:
17681813
self.mm_config = Float8MMConfig(use_fast_accum=True)
17691814

@@ -1847,6 +1892,9 @@ class UIntXWeightOnlyConfig(AOBaseConfig):
18471892
use_hqq: bool = False
18481893
set_inductor_config: bool = True
18491894

1895+
def __post_init__(self):
1896+
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
1897+
18501898

18511899
# for BC
18521900
uintx_weight_only = UIntXWeightOnlyConfig
@@ -1946,6 +1994,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
19461994
layout: Layout = QDQLayout()
19471995

19481996
def __post_init__(self):
1997+
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
19491998
assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+"
19501999
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
19512000
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
@@ -2020,6 +2069,9 @@ class FPXWeightOnlyConfig(AOBaseConfig):
20202069
mbits: int
20212070
set_inductor_config: bool = True
20222071

2072+
def __post_init__(self):
2073+
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
2074+
20232075

20242076
# for BC
20252077
fpx_weight_only = FPXWeightOnlyConfig
@@ -2151,6 +2203,9 @@ class ModuleFqnToConfig(AOBaseConfig):
21512203
default_factory=dict
21522204
)
21532205

2206+
def __post_init__(self):
2207+
torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig")
2208+
21542209

21552210
def _module_fqn_to_config_handler(
21562211
module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig

0 commit comments

Comments
 (0)