Skip to content

[not for land] debug logging for float8 training #2701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
e5m2_dtype,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_linear_utils import (
_populate_debug_fqns,
convert_to_float8_training,
)
from torchao.float8.float8_ops import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
Expand Down Expand Up @@ -395,6 +398,88 @@ def test_linear_from_recipe(
config,
)

@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.TENSORWISE,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
def test_debug_logging(self, recipe_name):
x = torch.randn(1, 16, 16, device="cuda", dtype=torch.bfloat16)
m = nn.Sequential(
nn.Linear(16, 32, bias=False, device="cuda", dtype=torch.bfloat16),
nn.Sequential(
nn.ReLU(),
nn.Linear(32, 64, bias=False, device="cuda", dtype=torch.bfloat16),
),
)
config = Float8LinearConfig.from_recipe_name(recipe_name)

@torch.no_grad()
def mean_absolute_percentage_error(x_ref, x):
tmp = torch.abs(x_ref - x) / torch.clamp(torch.abs(x_ref), min=1e-9)
# trim to avoid values close to 0 from
# significantly impacting the results
tmp = torch.clamp(tmp, max=1e3)
return torch.mean(tmp)

iter_counter = 0
iter_fqn_gemm_name_to_data = {}

@torch._dynamo.disable
def debug_logging_fn(fqn, gemm_name, a_hp, b_hp, a_fp8, b_fp8):
"""
Example debugging function - this is user defined, easy to customize
1. captures M, K, N
2. captures MAPE for high precision vs float8 gemm
3. leaves data on GPU, so the user can move it to CPU at their
convenience
"""
M, K = a_hp.shape
K2, N = b_hp.shape
assert K == K2
res_hp = a_hp @ b_hp
res_fp8 = a_fp8 @ b_fp8
mape = mean_absolute_percentage_error(res_hp, res_fp8)
iter_fqn_gemm_name_to_data[(iter_counter, fqn, gemm_name)] = (M, K, N), mape

object.__setattr__(config, "_debug_logging_fn", debug_logging_fn)
m = convert_to_float8_training(m, config=config)
_populate_debug_fqns(m)

# iter 0
m = torch.compile(m)
y = m(x)
y.sum().backward()

# iter 1
iter_counter += 1
m = torch.compile(m)
y = m(x)
y.sum().backward()

if recipe_name == Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
# check length is num_float8_layers * num_gemms_per_layer * num_iters
assert len(iter_fqn_gemm_name_to_data) == 2 * 2 * (iter_counter + 1)
# check that some of the expected debug logs exist
assert (0, "0", "output") in iter_fqn_gemm_name_to_data
assert (1, "1.1", "grad_input") in iter_fqn_gemm_name_to_data
else:
# check length is num_float8_layers * num_gemms_per_layer * num_iters
assert len(iter_fqn_gemm_name_to_data) == 2 * 3 * (iter_counter + 1)
# check that some of the expected debug logs exist
assert (0, "0", "output") in iter_fqn_gemm_name_to_data
assert (0, "1.1", "grad_weight") in iter_fqn_gemm_name_to_data
assert (1, "1.1", "grad_input") in iter_fqn_gemm_name_to_data

# check logged data is what we expect
example_data = iter_fqn_gemm_name_to_data[(1, "1.1", "grad_input")]
assert example_data[0] == (16, 64, 32)
assert type(example_data[1]) == torch.Tensor
assert example_data[1].shape == torch.Size()

@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
Expand Down
13 changes: 12 additions & 1 deletion torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import enum
import logging
from dataclasses import dataclass
from typing import Optional, Union
from typing import Callable, Optional, Union

import torch

Expand Down Expand Up @@ -204,6 +204,17 @@ class Float8LinearConfig:
# same value in the forward pass as the backward passes.
round_scales_to_power_of_2: bool = False

# If specified, the debug fqn, the name of each gemm
# (output/grad_input/grad_weight) and the high_precision and float8 inputs to
# each gemm are passed to this function at each iteration. The intended use
# case is accuracy and performance logging for debugging. This feature is
# prototype and the API may change.
_debug_logging_fn: Optional[
Callable[
[str, str, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None
]
] = None

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down
51 changes: 48 additions & 3 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


@torch._dynamo.allow_in_graph
# TODO(before land): remove two lines of comments below
# note: need to remove torch._dynamo.allow_in_graph for logging to work with torch.compile
# @torch._dynamo.allow_in_graph
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in either high precision or float8.
Expand All @@ -41,10 +43,12 @@ def forward(
weight_hp_t: torch.Tensor,
linear_mm_config: LinearMMConfig,
config: Float8LinearConfig,
debug_fqn: Optional[str],
):
ctx.save_for_backward(input_hp, weight_hp_t)
ctx.linear_mm_config = linear_mm_config
ctx.config = config
ctx.debug_fqn = debug_fqn

c = config

Expand Down Expand Up @@ -87,13 +91,26 @@ def forward(
orig_shape = input_maybe_fp8.shape
input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t)

if config._debug_logging_fn is not None:
input_hp_reshaped = input_hp.reshape(-1, orig_shape[-1])
config._debug_logging_fn(
debug_fqn,
"output",
input_hp_reshaped,
weight_hp_t,
input_maybe_fp8_reshaped,
weight_maybe_fp8_t,
)

res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, grad_output):
input_hp, weight_hp_t = ctx.saved_tensors
c = ctx.config
debug_fqn = ctx.debug_fqn

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
Expand Down Expand Up @@ -144,6 +161,15 @@ def backward(ctx, grad_output):
grad_output_reshaped_maybe_fp8_dim0,
weight_t_maybe_fp8_dim0.t(),
)
if c._debug_logging_fn is not None:
c._debug_logging_fn(
debug_fqn,
"grad_input",
grad_output_reshaped,
weight_hp_t.t(),
grad_output_reshaped_maybe_fp8_dim0,
weight_t_maybe_fp8_dim0.t(),
)
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)
Expand Down Expand Up @@ -198,8 +224,22 @@ def backward(ctx, grad_output):
grad_output_reshaped_maybe_fp8_dim1.t(),
input_reshaped_maybe_fp8_dim1,
)

empty_grads = None, None
if c._debug_logging_fn is not None:
# don't log if this gemm is in high precision
this_gemm_is_hp = (
c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED
)
if not this_gemm_is_hp:
c._debug_logging_fn(
debug_fqn,
"grad_weight",
grad_output_reshaped.t(),
input_hp_reshaped,
grad_output_reshaped_maybe_fp8_dim1.t(),
input_reshaped_maybe_fp8_dim1,
)

empty_grads = None, None, None

return grad_input, grad_weight.t(), *empty_grads

Expand Down Expand Up @@ -252,6 +292,10 @@ def __init__(self, *args, **kwargs):
),
)

# debugging only, API may change at any time. This is expected to be
# set by the user in a separate API call.
self._debug_fqn: Optional[str] = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
Expand All @@ -266,6 +310,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight.t(),
self.linear_mm_config,
self.config,
self._debug_fqn,
)

if self.bias is not None:
Expand Down
10 changes: 10 additions & 0 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,13 @@ def _auto_filter_for_tensorwise(
if K <= 4096 and N <= 1024:
return False
return True


def _populate_debug_fqns(model: nn.Module):
"""Populates the `_debug_fqn` attribute on each `Float8Linear` child of
`model`, useful for debugging. Note that this API is prototype and may
change in the future.
"""
for name, mod in model.named_modules():
if isinstance(mod, Float8Linear):
mod._debug_fqn = name
Loading