forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
54 lines (51 loc) · 1.35 KB
/
__init__.py
File metadata and controls
54 lines (51 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Lets define a few top level things here
# Needed to load Float8TrainingTensor with weights_only = True
from torch.serialization import add_safe_globals
from torchao.float8.config import (
CastConfig,
Float8GemmConfig,
Float8LinearConfig,
Float8LinearRecipeName,
ScalingGranularity,
ScalingType,
)
from torchao.float8.float8_linear_utils import (
_auto_filter_for_recipe,
convert_to_float8_training,
)
from torchao.float8.float8_training_tensor import (
Float8TrainingTensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from torchao.float8.inference import Float8MMConfig
from torchao.float8.types import FP8Granularity
add_safe_globals(
[
Float8TrainingTensor,
ScaledMMConfig,
GemmInputRole,
LinearMMConfig,
Float8MMConfig,
ScalingGranularity,
]
)
__all__ = [
# configuration
"ScalingType",
"ScalingGranularity",
"Float8GemmConfig",
"Float8LinearConfig",
"Float8LinearRecipeName",
"CastConfig",
"ScalingGranularity",
# top level UX
"convert_to_float8_training",
"precompute_float8_dynamic_scale_for_fsdp",
"_auto_filter_for_recipe",
# types
"FP8Granularity",
# note: Float8TrainingTensor and Float8Linear are not public APIs
]