Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cleanup][1/x] make hp_tensor_to_float8_dynamic only work with hp inputs #1458

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,16 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
Expand Down
2 changes: 0 additions & 2 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def hp_tensor_to_float8_dynamic(
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(
hp_tensor,
float8_dtype,
Expand Down
27 changes: 15 additions & 12 deletions torchao/float8/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

from torchao.float8.config import ScalingType, e4m3_dtype
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8BwDynamic,
hp_tensor_to_float8_dynamic,
Expand Down Expand Up @@ -46,12 +47,13 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if not tensor_already_casted_to_fp8(input_tensor):
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
Expand Down Expand Up @@ -104,12 +106,13 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if not tensor_already_casted_to_fp8(input_tensor):
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
mod.config.cast_config_input.target_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
Expand Down
4 changes: 3 additions & 1 deletion torchao/float8/stateful_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if self.scaling_type_input is ScalingType.DELAYED:
if tensor_already_casted_to_fp8(input):
input_fp8 = input
elif self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
Expand Down
Loading