Skip to content

Commit 2dc8fd7

Browse files
committed
add axiswise scaling to Float8Linear
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and optimize performance, add docs Test Plan: ``` // tests pass ./test/float8/test_everything.sh // sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8: // 1. verify performance does not regress with tensorwise scaling // 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though // logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 10553fb ghstack-comment-id: 2368837904 Pull Request resolved: #920
1 parent 0224576 commit 2dc8fd7

File tree

9 files changed

+462
-55
lines changed

9 files changed

+462
-55
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
17+
from torchao.float8.config import (
18+
CastConfig,
19+
Float8LinearConfig,
20+
ScalingType,
21+
ScalingGranularity,
22+
)
1823
from torchao.float8.float8_linear import Float8Linear
1924
from torchao.float8.float8_linear_utils import (
2025
linear_requires_sync,
@@ -107,35 +112,49 @@ def main(
107112
scaling_type_input: str = "dynamic",
108113
scaling_type_weight: str = "dynamic",
109114
scaling_type_grad_output: str = "dynamic",
115+
scaling_granularity: str = "tensorwise",
110116
):
111117
device = "cuda"
112118
print(f"Compile is set to | {compile}")
113119

114120
scaling_type_input = ScalingType(scaling_type_input)
115121
scaling_type_weight = ScalingType(scaling_type_weight)
116122
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
123+
scaling_granularity = ScalingGranularity(scaling_granularity)
117124

118125
if scaling_type_input is ScalingType.STATIC:
119126
cast_config_input=CastConfig(
120127
scaling_type=scaling_type_input,
121128
static_scale=torch.tensor([1.0], device="cuda"),
129+
scaling_granularity=scaling_granularity,
122130
)
123131
else:
124-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
132+
cast_config_input=CastConfig(
133+
scaling_type=scaling_type_input,
134+
scaling_granularity=scaling_granularity,
135+
)
125136
if scaling_type_weight is ScalingType.STATIC:
126137
cast_config_weight=CastConfig(
127138
scaling_type=scaling_type_weight,
128139
static_scale=torch.tensor([1.0], device="cuda"),
140+
scaling_granularity=scaling_granularity,
129141
)
130142
else:
131-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
143+
cast_config_weight=CastConfig(
144+
scaling_type=scaling_type_weight,
145+
scaling_granularity=scaling_granularity,
146+
)
132147
if scaling_type_grad_output is ScalingType.STATIC:
133148
cast_config_grad_output=CastConfig(
134149
scaling_type=scaling_type_grad_output,
135150
static_scale=torch.tensor([1.0], device="cuda"),
151+
scaling_granularity=scaling_granularity,
136152
)
137153
else:
138-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
154+
cast_config_grad_output=CastConfig(
155+
scaling_type=scaling_type_grad_output,
156+
scaling_granularity=scaling_granularity,
157+
)
139158

140159
config = Float8LinearConfig(
141160
cast_config_input=cast_config_input,
@@ -167,7 +186,7 @@ def main(
167186
copy.deepcopy(linear_ref),
168187
config=config,
169188
)
170-
scaling_repr = linear_float8.scaling_repr()
189+
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
171190

172191
if fast_accum:
173192
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -310,6 +329,7 @@ def invoke_main() -> None:
310329
parser.add_argument("--scaling_type_input", type=str, required=False)
311330
parser.add_argument("--scaling_type_weight", type=str, required=False)
312331
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
332+
parser.add_argument("--scaling_granularity", type=str, required=False)
313333
args = parser.parse_args()
314334
output_path = Path(args.output_path) if args.output_path is not None else None
315335
kwargs = {}
@@ -327,6 +347,8 @@ def invoke_main() -> None:
327347
kwargs["scaling_type_weight"] = args.scaling_type_weight
328348
if args.scaling_type_grad_output is not None:
329349
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
350+
if args.scaling_granularity is not None:
351+
kwargs["scaling_granularity"] = args.scaling_granularity
330352
main(
331353
output_path,
332354
not args.disable_compile,

benchmarks/float8/bench_matmul.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.nn as nn
1414
import torch.utils.benchmark as benchmark
1515

16+
from torchao.float8.config import ScalingGranularity
17+
1618
from utils import (
1719
get_name_to_shapes_iter,
1820
profiler_output_to_filtered_time_by_kernel_name,
@@ -75,6 +77,7 @@ def run(
7577
K: Optional[int] = None,
7678
N: Optional[int] = None,
7779
use_gpu_kernel_time: bool = False,
80+
scaling_granularity: str = "tensorwise",
7881
):
7982
device = "cuda"
8083

@@ -84,6 +87,7 @@ def run(
8487
dtype = torch.bfloat16
8588
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
8689
fast_accum_vals = [True, False]
90+
scaling_granularity = ScalingGranularity(scaling_granularity)
8791

8892
for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
8993
if n_limit is not None and idx >= n_limit:
@@ -109,8 +113,13 @@ def run(
109113
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
110114
A = torch.zeros(M, K, device=device, dtype=d1)
111115
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
112-
scale_a = torch.tensor([1.0], device=device)
113-
scale_b = torch.tensor([1.0], device=device)
116+
if scaling_granularity == ScalingGranularity.TENSORWISE:
117+
scale_a = torch.tensor([1.0], device=device)
118+
scale_b = torch.tensor([1.0], device=device)
119+
else:
120+
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
121+
scale_a = torch.ones(M, 1, device=device)
122+
scale_b = torch.ones(1, N, device=device)
114123

115124
def do_matmul(A, B):
116125
nonlocal scale_a

benchmarks/float8/profile_linear_float8.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
import torch.nn as nn
2424
import torch.nn.functional as F
25-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
25+
from torchao.float8.config import (
26+
CastConfig,
27+
Float8LinearConfig,
28+
ScalingType,
29+
ScalingGranularity,
30+
)
2631
from torchao.float8.float8_linear_utils import (
2732
convert_to_float8_training,
2833
linear_requires_sync,
@@ -252,6 +257,7 @@ def main(
252257
scaling_type_input: str = "dynamic",
253258
scaling_type_weight: str = "dynamic",
254259
scaling_type_grad_output: str = "dynamic",
260+
scaling_granularity: str = "tensorwise",
255261
model_type: str = "linear",
256262
dtype_filter: str = "both",
257263
add_inductor_metadata_to_trace: bool = True,
@@ -263,28 +269,41 @@ def main(
263269
scaling_type_input = ScalingType(scaling_type_input)
264270
scaling_type_weight = ScalingType(scaling_type_weight)
265271
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272+
scaling_granularity = ScalingGranularity(scaling_granularity)
266273

267274
if scaling_type_input is ScalingType.STATIC:
268275
cast_config_input=CastConfig(
269276
scaling_type=scaling_type_input,
270277
static_scale=torch.tensor([1.0], device="cuda"),
278+
scaling_granularity=scaling_granularity,
271279
)
272280
else:
273-
cast_config_input=CastConfig(scaling_type=scaling_type_input)
281+
cast_config_input=CastConfig(
282+
scaling_type=scaling_type_input,
283+
scaling_granularity=scaling_granularity,
284+
)
274285
if scaling_type_weight is ScalingType.STATIC:
275286
cast_config_weight=CastConfig(
276287
scaling_type=scaling_type_weight,
277288
static_scale=torch.tensor([1.0], device="cuda"),
289+
scaling_granularity=scaling_granularity,
278290
)
279291
else:
280-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
292+
cast_config_weight=CastConfig(
293+
scaling_type=scaling_type_weight,
294+
scaling_granularity=scaling_granularity,
295+
)
281296
if scaling_type_grad_output is ScalingType.STATIC:
282297
cast_config_grad_output=CastConfig(
283298
scaling_type=scaling_type_grad_output,
284299
static_scale=torch.tensor([1.0], device="cuda"),
300+
scaling_granularity=scaling_granularity,
285301
)
286302
else:
287-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
303+
cast_config_grad_output=CastConfig(
304+
scaling_type=scaling_type_grad_output,
305+
scaling_granularity=scaling_granularity,
306+
)
288307

289308
config = Float8LinearConfig(
290309
cast_config_input=cast_config_input,

test/float8/test_base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ def _test_linear_impl(
324324
"scaling_type_grad_output",
325325
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
326326
)
327+
@pytest.mark.parametrize(
328+
"scaling_granularity",
329+
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
330+
)
327331
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
328332
@pytest.mark.parametrize("linear_bias", [False, True])
329333
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@@ -334,33 +338,56 @@ def test_linear(
334338
scaling_type_input: ScalingType,
335339
scaling_type_weight: ScalingType,
336340
scaling_type_grad_output: ScalingType,
341+
scaling_granularity: ScalingGranularity,
337342
linear_dtype: torch.dtype,
338343
linear_bias: bool,
339344
):
345+
if scaling_granularity is ScalingGranularity.AXISWISE:
346+
if (
347+
scaling_type_input != ScalingType.DYNAMIC or
348+
scaling_type_weight != ScalingType.DYNAMIC or
349+
scaling_type_grad_output != ScalingType.DYNAMIC or
350+
linear_dtype != torch.bfloat16 or
351+
(not is_cuda_9_0)
352+
):
353+
pytest.skip()
354+
340355
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
341356
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
342357

343358
if scaling_type_input is ScalingType.STATIC:
344359
cast_config_input = CastConfig(
345360
scaling_type=scaling_type_input,
361+
scaling_granularity=scaling_granularity,
346362
static_scale=torch.tensor([1.0], device="cuda"),
347363
)
348364
else:
349-
cast_config_input = CastConfig(scaling_type=scaling_type_input)
365+
cast_config_input = CastConfig(
366+
scaling_type=scaling_type_input,
367+
scaling_granularity=scaling_granularity,
368+
)
350369
if scaling_type_weight is ScalingType.STATIC:
351370
cast_config_weight = CastConfig(
352371
scaling_type=scaling_type_weight,
372+
scaling_granularity=scaling_granularity,
353373
static_scale=torch.tensor([1.0], device="cuda"),
354374
)
355375
else:
356-
cast_config_weight = CastConfig(scaling_type=scaling_type_weight)
376+
cast_config_weight = CastConfig(
377+
scaling_type=scaling_type_weight,
378+
scaling_granularity=scaling_granularity,
379+
)
357380
if scaling_type_grad_output is ScalingType.STATIC:
358381
cast_config_grad_output = CastConfig(
359382
scaling_type=scaling_type_grad_output,
383+
scaling_granularity=scaling_granularity,
360384
static_scale=torch.tensor([1.0], device="cuda"),
361385
)
362386
else:
363-
cast_config_grad_output = CastConfig(scaling_type=scaling_type_grad_output)
387+
cast_config_grad_output = CastConfig(
388+
scaling_type=scaling_type_grad_output,
389+
scaling_granularity=scaling_granularity,
390+
)
364391

365392
config = Float8LinearConfig(
366393
cast_config_input=cast_config_input,

0 commit comments

Comments
 (0)