Skip to content

[not for land] float8 debug recipe for float8 rowwise fwd and hp bwd #2708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ def test_linear_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
# Float8LinearRecipeName.ROWWISE,
# Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
Float8LinearRecipeName.FWD_ROWWISE_GI_ROWWISE_GW_HP,
],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down
60 changes: 60 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class Float8LinearRecipeName(enum.Enum):
# * the e4m3 dtype is used across the board, including for gradients
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"

# debug only, not for land
FWD_FLOAT8_BWD_HP = "fwd_float8_bwd_hp"

# debug only, not for land
FWD_ROWWISE_GI_ROWWISE_GW_HP = "fwd_rowwise_gi_rowwise_gw_hp"


@dataclass(frozen=True)
class Float8LinearConfig:
Expand Down Expand Up @@ -336,5 +342,59 @@ def from_recipe_name(
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.FWD_FLOAT8_BWD_HP:
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_hp @ weight_hp
cc_go = CastConfig(scaling_type=ScalingType.DISABLED)
cc_w_gi = CastConfig(scaling_type=ScalingType.DISABLED)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
cast_config_input_for_grad_weight=cc_i_gw,
cast_config_weight_for_grad_input=cc_w_gi,
cast_config_grad_output_for_grad_weight=cc_go_gw,
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.FWD_ROWWISE_GI_ROWWISE_GW_HP:
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_axiswise_dim1
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w_gi = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
cast_config_input_for_grad_weight=cc_i_gw,
cast_config_weight_for_grad_input=cc_w_gi,
cast_config_grad_output_for_grad_weight=cc_go_gw,
round_scales_to_power_of_2=True,
)

else:
raise AssertionError(f"unknown recipe_name {recipe_name}")
Loading