diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 8fb3921f6..a14981732 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -22,7 +22,12 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import ( + CastConfig, + Float8LinearConfig, + ScalingGranularity, + ScalingType, +) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -30,6 +35,7 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -51,6 +57,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" @@ -58,7 +65,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: return True -class TestFloat8Tensor(unittest.TestCase): +class TestFloat8Tensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None: x1_s = tensor_to_scale(x1_hp, lp_dtype) x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() - self.assertTrue(x3_hp.dtype == hp_dtype) + assert x3_hp.dtype == hp_dtype def test_differentiable_casts(self) -> None: lp_dtypes = (e4m3_dtype, e5m2_dtype) @@ -103,7 +110,7 @@ def test_index_put(self): fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): b[index] = fp8_a fp8_b[index] = a fp8_b_bad[index] = fp8_a @@ -117,7 +124,7 @@ def test_copy_(self): b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work torch.testing.assert_close(b, fp8_a.to_original_precision()) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail fp8_b = Float8Tensor( @@ -129,6 +136,105 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("axiswise_dim", [0, -1]) + def test_axiswise_dynamic_cast(self, shape, axiswise_dim): + a = torch.randn(*shape, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=axiswise_dim, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + + def test_axiswise_reshape(self): + a = torch.randn(3, 5, 7, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + + # if we scale across dim0, we can only reshape to [3, -1] + a_fp8_d0 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + ) + assert list(a_fp8_d0._data.shape) == [3, 5, 7] + assert list(a_fp8_d0._scale.shape) == [1, 5, 7] + + a_fp8_d0_r = a_fp8_d0.reshape(3, -1) + assert list(a_fp8_d0_r.shape) == [3, 5 * 7] + assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7] + # verify numerics did not change + assert torch.allclose( + a_fp8_d0.to_original_precision(), + a_fp8_d0_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(RuntimeError): + a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7) + + # if we scale across dim2, we can only reshape to [-1, 7] + a_fp8_d2 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + ) + assert list(a_fp8_d2._data.shape) == [3, 5, 7] + assert list(a_fp8_d2._scale.shape) == [3, 5, 1] + + a_fp8_d2_r = a_fp8_d2.reshape(-1, 7) + assert list(a_fp8_d2_r.shape) == [3 * 5, 7] + assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1] + # verify numerics did not change + assert torch.allclose( + a_fp8_d2.to_original_precision(), + a_fp8_d2_r.to_original_precision().reshape(3, 5, 7), + atol=0, + rtol=0, + ) + with pytest.raises(RuntimeError): + a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) + + @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") + def test_axiswise_gemm(self, a_shape): + a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + ) + a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, # will be transposed + ) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + a = a.reshape(-1, a_shape[-1]) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + assert sqnr >= 25.0 class TestFloat8Linear: diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 0fa25b9bb..16e638738 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -26,6 +26,18 @@ def short_str(self): return "sta" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ @@ -146,6 +158,8 @@ class Float8LinearConfig: # save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization. # The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings. # For now, we use the checkpointing api to force the recomputation of fp8 weight in backward. + # TODO(future PR): either enable by default or have a warning and set up the + # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index dd9255625..4558695e3 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -9,7 +9,6 @@ import dataclasses import enum -import logging from typing import Optional @@ -50,8 +49,6 @@ WeightWithStaticFloat8CastTensor, ) -logger = logging.getLogger(__name__) - # this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files @torch._dynamo.allow_in_graph @@ -191,15 +188,6 @@ def __init__(self, *args, **kwargs): # would be initialized in every iteration. self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward - # See the comments in config.py for more details of this option. - if ( - self.config.enable_pre_and_post_forward - and not self.config.force_recompute_fp8_weight_in_bwd - ): - logger.warning( - "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." - ) - def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index f8115649b..1bf9faaa4 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -19,6 +19,15 @@ FLOAT8_OPS_TABLE: Dict[Any, Any] = {} +def _assert_tensorwise_scale(aten_op, scale): + assert ( + # TODO(future PR): figure out why tensorwise scaling can have + # both rank 0 and rank 1 + len(scale.shape) + in (0, 1) + ), f"{aten_op} with axiswise scaling is not supported yet" + + def implements(aten_ops): """Register aten ops to the float8 op table""" @@ -45,6 +54,7 @@ def decorator(func): ] ) def float8_desugar_op(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( new_data, @@ -55,10 +65,82 @@ def float8_desugar_op(aten_op, args, kwargs=None): ) +@implements( + [ + aten.t.default, + aten.transpose.int, + ] +) +def float8_desugar_data_and_scale(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + + if aten_op == aten.transpose.int: + _assert_tensorwise_scale(aten_op, args[0]._scale) + + old_axiswise_dim = args[0]._axiswise_dim + new_axiswise_dim = old_axiswise_dim + if old_axiswise_dim is not None: + if old_axiswise_dim == 0: + new_axiswise_dim == -1 + else: + new_axiswise_dim == 0 + + return Float8Tensor( + new_data, + new_scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + new_axiswise_dim, + ) + + +@implements([aten.view.default]) +def float8_view(aten_op, args, kwargs=None): + if len(args[0]._scale.shape) < 2: + # tensorwise scaling + return float8_desugar_op(aten_op, args, kwargs) + + t, new_shape = args[0], args[1] + # for now, only support reshaping to [-1, dim] or [dim, -1] + axiswise_dim = t._axiswise_dim + if len(new_shape) == 2: + + if axiswise_dim == 0: + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale_shape = [1, new_shape[-1]] + new_scale = aten_op(t._scale, new_scale_shape, **kwargs) + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + t._axiswise_dim, + ) + elif axiswise_dim == -1 or axiswise_dim == (len(t.shape) - 1): + new_data = aten_op(t._data, new_shape, **kwargs) + new_scale_shape = [new_shape[0], 1] + new_scale = aten_op(t._scale, new_scale_shape, **kwargs) + new_axiswise_dim = -1 + return Float8Tensor( + new_data, + new_scale, + t._orig_dtype, + t._linear_mm_config, + t._gemm_input_role, + new_axiswise_dim, + ) + raise AssertionError( + f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet." + ) + + @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) - + _assert_tensorwise_scale(aten_op, args[0]._scale) def make_float8(data): return Float8Tensor( data, @@ -102,6 +184,7 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._gemm_input_role is gemm_input_role ), "Expecting all chunks to have the same gemm_input_role as a result of a split" + _assert_tensorwise_scale(aten_op, chunk._scale) chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) @@ -118,6 +201,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" """ + _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): if isinstance(x, Float8Tensor): @@ -230,6 +314,7 @@ def float8_addmm(aten_op, args, kwargs=None): @implements([aten.is_same_size.default]) def float8_is_same_size(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) return args[0].shape == args[1].shape @@ -239,6 +324,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): when the input is a Float8Tensor, presenting as a fp32 tensor. """ + _assert_tensorwise_scale(aten_op, args[0]._scale) assert isinstance(args[0], Float8Tensor) assert ( len(kwargs) == 1 and "dtype" in kwargs @@ -266,6 +352,7 @@ def allgather_fp8(aten_op, args, kwargs=None): """ override funcol with FP8 handling """ + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance( fp8_input, Float8Tensor @@ -285,6 +372,7 @@ def allgather_fp8(aten_op, args, kwargs=None): @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) def wait_tensor_fp8(aten_op, args, kwargs=None): + _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] assert isinstance(fp8_input, Float8Tensor) @@ -305,6 +393,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype assert fp8_self._orig_dtype == fp8_values._orig_dtype @@ -335,8 +424,10 @@ def copy_fp8(aten_op, args, kwargs=None): if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): src_hp = src.to_original_precision() + _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + _assert_tensorwise_scale(aten_op, src._scale) assert ( self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index e9e195176..3207c0c9f 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -12,6 +12,8 @@ import torch +from torchao.float8.config import ScalingGranularity + from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -37,6 +39,8 @@ def hp_tensor_to_float8_dynamic( reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, device_mesh = None, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -50,16 +54,26 @@ def hp_tensor_to_float8_dynamic( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh) + scale = tensor_to_scale( + hp_tensor, + float8_dtype, + reduce_amax, + device_mesh, + scaling_granularity, + axiswise_dim, + ) return hp_tensor_and_scale_to_float8( hp_tensor, scale, float8_dtype, linear_mm_config, gemm_input_role, + axiswise_dim, ) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 8927cf4e7..eb1030df8 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -152,6 +152,7 @@ def forward( float8_dtype=e4m3_dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): """ This function will apply the scaling, and then convert to a Float8Tensor @@ -183,6 +184,7 @@ def forward( tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, + axiswise_dim=axiswise_dim, ) return DTensor.from_local( inner_float8_tensor, @@ -199,6 +201,7 @@ def forward( tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, + axiswise_dim=axiswise_dim, ) @staticmethod @@ -229,6 +232,7 @@ def hp_tensor_and_scale_to_float8( float8_dtype=e4m3_dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, @@ -245,9 +249,10 @@ def hp_tensor_and_scale_to_float8( the 3 fwd/bwd gemms of linear gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear + axiswise_dim: for rowwise scaling, contains the axis scaled across """ return _ToFloat8ConstrFunc.apply( - hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role + hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim ) @@ -261,11 +266,19 @@ class Float8Tensor(torch.Tensor): * `_data`: the underlying e4m3 or e5m2 data * `_scale`: the scale used to scale the original fp32 tensor. We multiply by scale to go from fp32 range to fp8 range, and divide by scale to go - from fp8 range to fp32 range. + from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible + with `_data`. For example: + - if scaling is tensorwise, `_scale` is a scalar tensor + - if scaling is axiswise and _data.shape is [3, 5], `_scale` could have + shape [1, 5] or [3, 1]. `axiswise_dim` defines the scaling axis. + - if scaling is axiswise and _data.shape is [2, 3, 5], `_scale` could have + shape [1, 1, 5] or [2, 1, 1]. `axiswise_dim` defines the scaling + axis. Non-one entries which are not the first or last element are not + supported. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. - * `_emulate`: if true using fp32 emulation for the matmuls, helpful - if you don't have access to h100 hardware. + * `_axiswise_dim`: for axiswise scaling only, contains the axis scales + across. Only values of 0 or -1 are supported. Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -280,6 +293,7 @@ class Float8Tensor(torch.Tensor): _scale: torch.Tensor _orig_dtype: torch.dtype _linear_mm_config: LinearMMConfig + _axiswise_dim: Optional[int] __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] def __new__( @@ -289,13 +303,8 @@ def __new__( orig_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, + axiswise_dim: Optional[int] = None, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), @@ -313,17 +322,20 @@ def __new__( linear_mm_config if linear_mm_config is not None else LinearMMConfig() ) self._gemm_input_role = gemm_input_role + assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}" + self._axiswise_dim = axiswise_dim return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, "_linear_mm_config": self._linear_mm_config, "_gemm_input_role": self._gemm_input_role, + "_axiswise_dim": self._axiswise_dim, } return ["_data", "_scale"], ctx @@ -336,6 +348,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride metadata["_orig_dtype"], metadata["_linear_mm_config"], metadata["_gemm_input_role"], + metadata["_axiswise_dim"], ) def to_original_precision(self): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 49a2a1152..b6f42c508 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import torch import torch.distributed as dist import torchao.float8.config as config +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -102,9 +103,18 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax( - x: torch.Tensor, reduce_amax: bool = False, device_mesh=None + x: torch.Tensor, + reduce_amax: bool = False, + device_mesh=None, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = torch.max(torch.abs(x)) + if scaling_granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + assert axiswise_dim is not None, "unsupported" + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -122,8 +132,16 @@ def tensor_to_scale( float8_dtype: torch.dtype, reduce_amax: bool = False, device_mesh=None, + scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + axiswise_dim: Optional[int] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) + amax = tensor_to_amax( + x, + reduce_amax, + device_mesh, + scaling_granularity, + axiswise_dim, + ) return amax_to_scale(amax, float8_dtype, x.dtype)