Skip to content

Commit f6847de

Browse files
committed
Track API usage
1 parent 1c96994 commit f6847de

File tree

11 files changed

+98
-8
lines changed

11 files changed

+98
-8
lines changed

torchao/float8/float8_linear_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from typing import Callable, List, Optional, Union
99

10+
import torch
1011
import torch.nn as nn
1112

1213
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
@@ -101,6 +102,7 @@ def convert_to_float8_training(
101102
Returns:
102103
nn.Module: The modified module with swapped linear layers.
103104
"""
105+
torch._C._log_api_usage_once("torchao.quantization.convert_to_float8_training")
104106
if config is None:
105107
config = Float8LinearConfig()
106108

torchao/float8/fsdp_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
3939

4040
from torchao.float8.float8_linear import Float8Linear
4141

42+
torch._C._log_api_usage_once(
43+
"torchao.quantization.precompute_float8_dynamic_scale_for_fsdp"
44+
)
45+
4246
float8_linears: List[Float8Linear] = [
4347
m
4448
for m in module.modules()

torchao/optim/adam.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
bf16_stochastic_round=bf16_stochastic_round,
234234
is_adamw=False,
235235
)
236+
torch._C._log_api_usage_once("torchao.optim.Adam8bit")
236237

237238
@staticmethod
238239
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -263,6 +264,7 @@ def __init__(
263264
bf16_stochastic_round=bf16_stochastic_round,
264265
is_adamw=False,
265266
)
267+
torch._C._log_api_usage_once("torchao.optim.Adam4bit")
266268

267269
@staticmethod
268270
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -293,6 +295,7 @@ def __init__(
293295
bf16_stochastic_round=bf16_stochastic_round,
294296
is_adamw=False,
295297
)
298+
torch._C._log_api_usage_once("torchao.optim.AdamFp8")
296299

297300
@staticmethod
298301
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -323,6 +326,7 @@ def __init__(
323326
bf16_stochastic_round=bf16_stochastic_round,
324327
is_adamw=True,
325328
)
329+
torch._C._log_api_usage_once("torchao.optim.AdamW8bit")
326330

327331
@staticmethod
328332
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -353,6 +357,7 @@ def __init__(
353357
bf16_stochastic_round=bf16_stochastic_round,
354358
is_adamw=True,
355359
)
360+
torch._C._log_api_usage_once("torchao.optim.AdamW4bit")
356361

357362
@staticmethod
358363
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -383,6 +388,7 @@ def __init__(
383388
bf16_stochastic_round=bf16_stochastic_round,
384389
is_adamw=True,
385390
)
391+
torch._C._log_api_usage_once("torchao.optim.AdamWFp8")
386392

387393
@staticmethod
388394
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):

torchao/quantization/pt2e/convert.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,9 +1271,6 @@ def _convert_to_reference_decomposed_fx(
12711271
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
12721272
12731273
"""
1274-
torch._C._log_api_usage_once(
1275-
"quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
1276-
)
12771274
return _convert_fx(
12781275
graph_module,
12791276
is_reference=True,

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def calibrate(model, data_loader):
106106

107107
return torch_prepare_pt2e(model, quantizer)
108108

109-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
109+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_pt2e")
110110
original_graph_meta = model.meta
111111
node_name_to_scope = _get_node_name_to_scope(model)
112112
# TODO: check qconfig_mapping to make sure conv and bn are both configured
@@ -192,7 +192,7 @@ def train_loop(model, train_data):
192192

193193
return torch_prepare_qat_pt2e(model, quantizer)
194194

195-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
195+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_qat_pt2e")
196196
original_graph_meta = model.meta
197197
node_name_to_scope = _get_node_name_to_scope(model)
198198
model = quantizer.transform_for_annotation(model)
@@ -309,7 +309,7 @@ def convert_pt2e(
309309

310310
return torch_convert_pt2e(model, use_reference_representation, fold_quantize)
311311

312-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
312+
torch._C._log_api_usage_once("torchao.quantization.pt2e.convert_pt2e")
313313
if not isinstance(use_reference_representation, bool):
314314
raise ValueError(
315315
"Unexpected argument type for `use_reference_representation`, "

torchao/quantization/qat/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
self.__post_init__()
145145

146146
def __post_init__(self):
147+
torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig")
147148
self.step = self.step.lower()
148149
all_step_values = [s.value for s in QATStep]
149150
if self.step not in all_step_values:
@@ -359,6 +360,7 @@ class ComposableQATQuantizer(TwoStepQuantizer):
359360
"""
360361

361362
def __init__(self, quantizers: List[TwoStepQuantizer]):
363+
torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer")
362364
self.quantizers = quantizers
363365

364366
def prepare(
@@ -385,6 +387,8 @@ def initialize_fake_quantizers(
385387
:class:`~torchao.quantization.qat.fake_quantizer.FakeQuantizer`
386388
in the model based on the provided example inputs.
387389
"""
390+
torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers")
391+
388392
# avoid circular dependencies
389393
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
390394

torchao/quantization/qat/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
*args,
6666
**kwargs,
6767
)
68+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding")
6869
if weight_config is not None:
6970
self.weight_fake_quantizer = FakeQuantizer(weight_config)
7071
else:
@@ -148,6 +149,9 @@ def __init__(
148149
zero_point_precision: torch.dtype = torch.int32,
149150
) -> None:
150151
super().__init__()
152+
torch._C._log_api_usage_once(
153+
"torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer"
154+
)
151155
self.bit_width = 4
152156
self.group_size: int = group_size
153157
self.scale_precision: torch.dtype = scale_precision

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class FakeQuantizer(torch.nn.Module):
4444

4545
def __init__(self, config: FakeQuantizeConfigBase):
4646
super().__init__()
47+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
4748
self.config = config
4849
self.enabled = True
4950
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
*args,
8383
**kwargs,
8484
)
85+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear")
8586
# initialize activation fake quantizer
8687
if activation_config is not None:
8788
self.activation_fake_quantizer = FakeQuantizer(activation_config)
@@ -209,6 +210,9 @@ def __init__(
209210
scales_precision: torch.dtype = torch.float32,
210211
) -> None:
211212
super().__init__()
213+
torch._C._log_api_usage_once(
214+
"torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer"
215+
)
212216
self.groupsize: int = groupsize
213217
self.padding_allowed: bool = padding_allowed
214218
self.precision: torch.dtype = precision
@@ -412,6 +416,9 @@ def __init__(
412416
scales_precision: torch.dtype = torch.bfloat16,
413417
) -> None:
414418
super().__init__()
419+
torch._C._log_api_usage_once(
420+
"torchao.quantization.qat.Int4WeightOnlyQATQuantizer"
421+
)
415422
assert inner_k_tiles in [2, 4, 8]
416423
assert groupsize in [32, 64, 128, 256]
417424
self.inner_k_tiles = inner_k_tiles
@@ -596,6 +603,9 @@ def __init__(
596603
group_size: Optional[int] = 64,
597604
scale_precision: torch.dtype = torch.bfloat16,
598605
):
606+
torch._C._log_api_usage_once(
607+
"torchao.quantization.qat.Float8ActInt4WeightQATQuantizer"
608+
)
599609
if group_size is not None:
600610
weight_granularity = "per_group"
601611
else:

torchao/quantization/quant_api.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132

133133
logger = logging.getLogger(__name__)
134134

135+
# TODO: revisit this list?
135136
__all__ = [
136137
"swap_conv2d_1x1_to_linear",
137138
"Quantizer",
@@ -618,6 +619,8 @@ def quantize_(
618619
quantize_(m, int4_weight_only(group_size=32))
619620
620621
"""
622+
torch._C._log_api_usage_once("torchao.quantization.quantize_")
623+
621624
filter_fn = _is_linear if filter_fn is None else filter_fn
622625

623626
if isinstance(config, ModuleFqnToConfig):
@@ -742,6 +745,11 @@ class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
742745
act_mapping_type: MappingType = MappingType.ASYMMETRIC
743746
set_inductor_config: bool = True
744747

748+
def __post_init__(self):
749+
torch._C._log_api_usage_once(
750+
"torchao.quantization.Int8DynamicActivationInt4WeightConfig"
751+
)
752+
745753

746754
# for BC
747755
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
@@ -853,6 +861,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
853861
layout: Layout = QDQLayout()
854862

855863
def __post_init__(self):
864+
torch._C._log_api_usage_once(
865+
"torchao.quantization.Int8DynamicActivationIntxWeightConfig"
866+
)
856867
assert TORCH_VERSION_AT_LEAST_2_6, (
857868
"Int8DynamicActivationIntxWeightConfig requires torch 2.6+"
858869
)
@@ -1003,6 +1014,11 @@ class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
10031014
act_mapping_type: MappingType = MappingType.SYMMETRIC
10041015
set_inductor_config: bool = True
10051016

1017+
def __post_init__(self):
1018+
torch._C._log_api_usage_once(
1019+
"torchao.quantization.Int4DynamicActivationInt4WeightConfig"
1020+
)
1021+
10061022

10071023
# for bc
10081024
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
@@ -1059,6 +1075,11 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
10591075
mode: Optional[str] = "weight_only"
10601076
set_inductor_config: bool = True
10611077

1078+
def __post_init__(self):
1079+
torch._C._log_api_usage_once(
1080+
"torchao.quantization.GemliteUIntXWeightOnlyConfig"
1081+
)
1082+
10621083

10631084
# for BC
10641085
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
@@ -1128,6 +1149,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11281149
set_inductor_config: bool = True
11291150
preserve_zero: Optional[bool] = None
11301151

1152+
def __post_init__(self):
1153+
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
1154+
11311155

11321156
# for BC
11331157
# TODO maybe change other callsites
@@ -1239,6 +1263,9 @@ class Int8WeightOnlyConfig(AOBaseConfig):
12391263
group_size: Optional[int] = None
12401264
set_inductor_config: bool = True
12411265

1266+
def __post_init__(self):
1267+
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
1268+
12421269

12431270
# for BC
12441271
int8_weight_only = Int8WeightOnlyConfig
@@ -1395,6 +1422,11 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
13951422
weight_only_decode: bool = False
13961423
set_inductor_config: bool = True
13971424

1425+
def __post_init__(self):
1426+
torch._C._log_api_usage_once(
1427+
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
1428+
)
1429+
13981430

13991431
# for BC
14001432
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
@@ -1499,6 +1531,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14991531
set_inductor_config: bool = True
15001532
VERSION: int = 1
15011533

1534+
def __post_init__(self):
1535+
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
1536+
15021537

15031538
# for BC
15041539
float8_weight_only = Float8WeightOnlyConfig
@@ -1644,9 +1679,11 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16441679
VERSION: int = 1
16451680

16461681
def __post_init__(self):
1682+
torch._C._log_api_usage_once(
1683+
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
1684+
)
16471685
if self.mm_config is None:
16481686
self.mm_config = Float8MMConfig(use_fast_accum=True)
1649-
16501687
activation_granularity, weight_granularity = _normalize_granularity(
16511688
self.granularity
16521689
)
@@ -1758,6 +1795,11 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
17581795
activation_dtype: torch.dtype = e5m2_dtype
17591796
weight_dtype: torch.dtype = e4m3_dtype
17601797

1798+
def __post_init__(self):
1799+
torch._C._log_api_usage_once(
1800+
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
1801+
)
1802+
17611803

17621804
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
17631805
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
@@ -1809,6 +1851,11 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
18091851
mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True)
18101852
set_inductor_config: bool = True
18111853

1854+
def __post_init__(self):
1855+
torch._C._log_api_usage_once(
1856+
"torchao.quantization.Float8StaticActivationFloat8WeightConfig"
1857+
)
1858+
18121859

18131860
# for bc
18141861
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
@@ -1889,6 +1936,9 @@ class UIntXWeightOnlyConfig(AOBaseConfig):
18891936
use_hqq: bool = False
18901937
set_inductor_config: bool = True
18911938

1939+
def __post_init__(self):
1940+
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
1941+
18921942

18931943
# for BC
18941944
uintx_weight_only = UIntXWeightOnlyConfig
@@ -1988,6 +2038,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
19882038
layout: Layout = QDQLayout()
19892039

19902040
def __post_init__(self):
2041+
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
19912042
assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+"
19922043
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
19932044
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
@@ -2062,6 +2113,9 @@ class FPXWeightOnlyConfig(AOBaseConfig):
20622113
mbits: int
20632114
set_inductor_config: bool = True
20642115

2116+
def __post_init__(self):
2117+
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
2118+
20652119

20662120
# for BC
20672121
fpx_weight_only = FPXWeightOnlyConfig
@@ -2193,6 +2247,9 @@ class ModuleFqnToConfig(AOBaseConfig):
21932247
default_factory=dict
21942248
)
21952249

2250+
def __post_init__(self):
2251+
torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig")
2252+
21962253

21972254
def _module_fqn_to_config_handler(
21982255
module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig

0 commit comments

Comments
 (0)