forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
297 lines (251 loc) · 9.81 KB
/
inference.py
File metadata and controls
297 lines (251 loc) · 9.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Defines an nn module designed to be used during inference
"""
import math
from typing import List, NamedTuple, Optional, Tuple, Union
import torch
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torchao.float8.types import FP8Granularity
from torchao.quantization.granularity import (
PerBlock,
PerRow,
PerTensor,
)
from torchao.utils import (
is_MI300,
is_sm_at_least_89,
)
Tensor = torch.Tensor
class Float8MMConfig(NamedTuple):
"""
Configuration for the scaled_mm in the forward and backward pass.
Attributes:
emulate (bool): Whether to emulate the matmuls in fp32.
use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm.
pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s.
This is needed for matmuls not aligned to 16.
"""
emulate: bool = False
use_fast_accum: bool = False
pad_inner_dim: bool = False
def preprocess_data(
a_data: Tensor,
b_data: Tensor,
scaled_mm_config: Float8MMConfig,
) -> Tuple[Tensor, Tensor]:
"""Preprocess the inner fp8 data tensors for admmm
Args:
a_data: Input tensor A.
b_data: Input tensor B.
scaled_mm_config: Configuration for _scaled_mm.
Returns:
Preprocessed tensors A and B in the format for _scaled_mm.
"""
if scaled_mm_config.pad_inner_dim:
assert a_data.size(1) == b_data.size(0), (
f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}"
)
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)
if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
b_data = b_data.t().contiguous().t()
return a_data, b_data
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]):
"""Ensures input tensor is correctly formatted for _scaled_mm"""
# For PerTensor quantization, scale should be a scalar or have shape [1]
if input_scale.numel() == 1:
# Already a scalar, ensure it has the right shape for _scaled_mm
return input_scale.reshape(1, 1)
# For per-row/block quantization, we need to handle the reshaping
input_scale = input_scale.unsqueeze(-1)
# Match: #input_data.reshape(-1, input_data.shape[-1])
if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
return input_scale
def addmm_float8_unwrapped_inference(
a_data: Tensor,
a_scale: Tensor,
b_data: Tensor,
b_scale: Tensor,
output_dtype: torch.dtype,
output_scale: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
use_fast_accum: bool = False,
) -> Tensor:
"""
This is the unwrapped version of addmm_float8, which does not take in Float8TrainingTensors
as inputs. This is used to standardize the logic between subclassed and non subclassed
versions of the linear module.
"""
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
a_data,
b_data,
scale_a=a_scale,
scale_b=b_scale,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return output + bias
return torch._scaled_mm(
a_data,
b_data,
scale_a=a_scale,
scale_b=b_scale,
bias=bias,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
def _slice_scale_for_dimension(
scale: torch.Tensor,
data_shape: List[int],
dim: int,
start: int,
end: int,
step: int,
) -> torch.Tensor:
"""
Slice the scale tensor appropriately based on the data tensor slicing.
This function calculates how the scale should be sliced when the data tensor
is sliced along a given dimension, taking into account the block structure.
"""
aten = torch.ops.aten
# Unsupported case for now, this would be 1 scale per data element
if scale.shape == data_shape:
return aten.slice.Tensor(scale, dim, start, end, step)
# Reconstruct block sizes based on data shape and scale shape
block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape)))
if dim >= len(block_sizes):
# Slicing beyond the dimensions we care about
return scale
block_size_for_dim = block_sizes[dim]
if block_size_for_dim == 1:
# Scale is per-element along this dimension
# Slice away as normal
return aten.slice.Tensor(scale, dim, start, end, step)
else:
# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
scale_end = (
(end + block_size_for_dim - 1) // block_size_for_dim
if end is not None
else None
)
# Error on Step > 1
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
def _is_rowwise_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is rowwise scaled
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
return tuple(x.block_size) == (1,) * (x.dim() - 1) + (x.shape[-1],)
def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is rowwise scaled
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
return all(
x.block_size[i] == -1 or x.block_size[i] == x.shape[i] for i in range(x.ndim)
)
def _is_1_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 1x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128
def _is_128_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 128x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) == 2 and b[0] == 128 and b[1] == 128
def _granularity_is_a_1_128_w_128_128(
g: Union[
FP8Granularity,
Tuple[FP8Granularity, FP8Granularity],
list[FP8Granularity],
],
) -> bool:
return len(g) == 2 and g[0] == PerBlock([1, 128]) and g[1] == PerBlock([128, 128])
def _normalize_granularity(
granularity: Optional[
Union[
FP8Granularity,
Tuple[FP8Granularity, FP8Granularity],
list[FP8Granularity],
]
],
) -> Tuple[FP8Granularity, FP8Granularity]:
processed_granularity = None
if granularity is None:
processed_granularity = (PerTensor(), PerTensor())
elif isinstance(granularity, (PerTensor, PerRow)):
processed_granularity = (granularity, granularity)
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
granularity[1], PerTensor
)
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
granularity[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
raise ValueError(f"Unsupported granularity types: {granularity}.")
if not isinstance(granularity[0], type(granularity[1])):
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}."
)
processed_granularity = tuple(granularity)
else:
raise ValueError(f"Invalid granularity specification: {granularity}.")
return processed_granularity
def _check_hardware_support(
granularities: Tuple[FP8Granularity, FP8Granularity],
) -> None:
"""
Validate that the hardware supports the requested granularities.
Args:
granularities: Tuple of (activation_granularity, weight_granularity)
Raises:
AssertionError: If hardware doesn't support the requested granularity
ValueError: If invalid granularity type is provided
"""
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
granularities[1], PerTensor
)
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
granularities[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
if is_per_tensor or is_per_row:
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
)
elif is_a_1_128_w_128_128:
# TODO(future PR): look into AMD support
assert is_sm_at_least_89(), (
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
)
else:
raise ValueError(f"Invalid granularities {granularities}.")