Skip to content

Track API usage #2706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
from typing import Callable, List, Optional, Union

import torch
import torch.nn as nn

from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
Expand Down Expand Up @@ -101,6 +102,7 @@ def convert_to_float8_training(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
torch._C._log_api_usage_once("torchao.float8.convert_to_float8_training")
if config is None:
config = Float8LinearConfig()

Expand Down
4 changes: 4 additions & 0 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:

from torchao.float8.float8_linear import Float8Linear

torch._C._log_api_usage_once(
"torchao.float8.precompute_float8_dynamic_scale_for_fsdp"
)

float8_linears: List[Float8Linear] = [
m
for m in module.modules()
Expand Down
6 changes: 6 additions & 0 deletions torchao/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=False,
)
torch._C._log_api_usage_once("torchao.optim.Adam8bit")

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
Expand Down Expand Up @@ -263,6 +264,7 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=False,
)
torch._C._log_api_usage_once("torchao.optim.Adam4bit")

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
Expand Down Expand Up @@ -293,6 +295,7 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=False,
)
torch._C._log_api_usage_once("torchao.optim.AdamFp8")

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
Expand Down Expand Up @@ -323,6 +326,7 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=True,
)
torch._C._log_api_usage_once("torchao.optim.AdamW8bit")

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
Expand Down Expand Up @@ -353,6 +357,7 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=True,
)
torch._C._log_api_usage_once("torchao.optim.AdamW4bit")

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
Expand Down Expand Up @@ -383,6 +388,7 @@ def __init__(
bf16_stochastic_round=bf16_stochastic_round,
is_adamw=True,
)
torch._C._log_api_usage_once("torchao.optim.AdamWFp8")

@staticmethod
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
Expand Down
3 changes: 0 additions & 3 deletions torchao/quantization/pt2e/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,9 +1266,6 @@ def _convert_to_reference_decomposed_fx(
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)

"""
torch._C._log_api_usage_once(
"quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
)
return _convert_fx(
graph_module,
is_reference=True,
Expand Down
6 changes: 3 additions & 3 deletions torchao/quantization/pt2e/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def calibrate(model, data_loader):

return torch_prepare_pt2e(model, quantizer)

torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_pt2e")
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
Expand Down Expand Up @@ -192,7 +192,7 @@ def train_loop(model, train_data):

return torch_prepare_qat_pt2e(model, quantizer)

torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_qat_pt2e")
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
model = quantizer.transform_for_annotation(model)
Expand Down Expand Up @@ -304,7 +304,7 @@ def convert_pt2e(

return torch_convert_pt2e(model, use_reference_representation, fold_quantize)

torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
torch._C._log_api_usage_once("torchao.quantization.pt2e.convert_pt2e")
if not isinstance(use_reference_representation, bool):
raise ValueError(
"Unexpected argument type for `use_reference_representation`, "
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
self.__post_init__()

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig")
self.step = self.step.lower()
all_step_values = [s.value for s in QATStep]
if self.step not in all_step_values:
Expand Down Expand Up @@ -377,6 +378,7 @@ class ComposableQATQuantizer(TwoStepQuantizer):
"""

def __init__(self, quantizers: List[TwoStepQuantizer]):
torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer")
self.quantizers = quantizers

def prepare(
Expand All @@ -403,6 +405,8 @@ def initialize_fake_quantizers(
:class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase`
in the model based on the provided example inputs.
"""
torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers")

# avoid circular dependencies
from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer

Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
*args,
**kwargs,
)
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding")
if weight_config is not None:
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
else:
Expand Down Expand Up @@ -148,6 +149,9 @@ def __init__(
zero_point_precision: torch.dtype = torch.int32,
) -> None:
super().__init__()
torch._C._log_api_usage_once(
"torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer"
)
self.bit_width = 4
self.group_size: int = group_size
self.scale_precision: torch.dtype = scale_precision
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class IntxFakeQuantizer(FakeQuantizerBase):

def __init__(self, config: IntxFakeQuantizeConfig):
super().__init__()
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
self.config = config
self.enabled = True
self.scale: Optional[torch.Tensor] = None
Expand Down
10 changes: 10 additions & 0 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
*args,
**kwargs,
)
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear")
# initialize activation fake quantizer
if activation_config is not None:
self.activation_fake_quantizer = FakeQuantizerBase.from_config(
Expand Down Expand Up @@ -210,6 +211,9 @@ def __init__(
scales_precision: torch.dtype = torch.float32,
) -> None:
super().__init__()
torch._C._log_api_usage_once(
"torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer"
)
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
Expand Down Expand Up @@ -413,6 +417,9 @@ def __init__(
scales_precision: torch.dtype = torch.bfloat16,
) -> None:
super().__init__()
torch._C._log_api_usage_once(
"torchao.quantization.qat.Int4WeightOnlyQATQuantizer"
)
assert inner_k_tiles in [2, 4, 8]
assert groupsize in [32, 64, 128, 256]
self.inner_k_tiles = inner_k_tiles
Expand Down Expand Up @@ -594,6 +601,9 @@ def __init__(
group_size: Optional[int] = 64,
scale_precision: torch.dtype = torch.bfloat16,
):
torch._C._log_api_usage_once(
"torchao.quantization.qat.Float8ActInt4WeightQATQuantizer"
)
if group_size is not None:
weight_granularity = "per_group"
else:
Expand Down
59 changes: 58 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@

logger = logging.getLogger(__name__)

# TODO: revisit this list?
__all__ = [
"swap_conv2d_1x1_to_linear",
"Quantizer",
Expand Down Expand Up @@ -510,6 +511,8 @@ def quantize_(
quantize_(m, int4_weight_only(group_size=32))

"""
torch._C._log_api_usage_once("torchao.quantization.quantize_")

filter_fn = _is_linear if filter_fn is None else filter_fn

if isinstance(config, ModuleFqnToConfig):
Expand Down Expand Up @@ -619,6 +622,11 @@ class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
act_mapping_type: MappingType = MappingType.ASYMMETRIC
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8DynamicActivationInt4WeightConfig"
)


# for BC
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
Expand Down Expand Up @@ -729,6 +737,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
layout: Layout = QDQLayout()

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8DynamicActivationIntxWeightConfig"
)
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
)
Expand Down Expand Up @@ -876,6 +887,11 @@ class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
act_mapping_type: MappingType = MappingType.SYMMETRIC
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int4DynamicActivationInt4WeightConfig"
)


# for bc
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
Expand Down Expand Up @@ -932,6 +948,11 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
mode: Optional[str] = "weight_only"
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.GemliteUIntXWeightOnlyConfig"
)


# for BC
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
Expand Down Expand Up @@ -1005,6 +1026,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
packing_format: PackingFormat = PackingFormat.PLAIN
VERSION: int = 1

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")


# for BC
# TODO maybe change other callsites
Expand Down Expand Up @@ -1178,6 +1202,9 @@ class Int8WeightOnlyConfig(AOBaseConfig):
group_size: Optional[int] = None
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")


# for BC
int8_weight_only = Int8WeightOnlyConfig
Expand Down Expand Up @@ -1334,6 +1361,11 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
weight_only_decode: bool = False
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
)


# for BC
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
Expand Down Expand Up @@ -1438,6 +1470,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
set_inductor_config: bool = True
version: int = 2

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")


# for BC
float8_weight_only = Float8WeightOnlyConfig
Expand Down Expand Up @@ -1586,9 +1621,11 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
version: int = 2

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
)
if self.mm_config is None:
self.mm_config = Float8MMConfig(use_fast_accum=True)

activation_granularity, weight_granularity = _normalize_granularity(
self.granularity
)
Expand Down Expand Up @@ -1705,6 +1742,11 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
activation_dtype: torch.dtype = e5m2_dtype
weight_dtype: torch.dtype = e4m3_dtype

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
)


@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
Expand Down Expand Up @@ -1756,6 +1798,11 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True)
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8StaticActivationFloat8WeightConfig"
)


# for bc
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
Expand Down Expand Up @@ -1836,6 +1883,9 @@ class UIntXWeightOnlyConfig(AOBaseConfig):
use_hqq: bool = False
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")


# for BC
uintx_weight_only = UIntXWeightOnlyConfig
Expand Down Expand Up @@ -1934,6 +1984,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
layout: Layout = QDQLayout()

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
)
Expand Down Expand Up @@ -2007,6 +2058,9 @@ class FPXWeightOnlyConfig(AOBaseConfig):
mbits: int
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")


# for BC
fpx_weight_only = FPXWeightOnlyConfig
Expand Down Expand Up @@ -2138,6 +2192,9 @@ class ModuleFqnToConfig(AOBaseConfig):
default_factory=dict
)

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig")


def _module_fqn_to_config_handler(
module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig
Expand Down
Loading
Loading