-
Notifications
You must be signed in to change notification settings - Fork 234
/
Copy pathfloat8_linear_utils.py
115 lines (95 loc) · 3.61 KB
/
float8_linear_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Callable, Optional
import torch.nn as nn
from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import Float8Linear
log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())
def swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> nn.Module:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance, and the FQN.
Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, nn.Linear) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return from_float_func(
module,
)
root_module = module
def post_order_traversal(
module: nn.Module,
cur_fqn: Optional[str] = None,
parent_module: Optional[nn.Module] = None,
):
if cur_fqn is None:
cur_fqn = ""
for child_module_name, child_module in module.named_children():
if cur_fqn == "":
new_fqn = child_module_name
else:
new_fqn = f"{cur_fqn}.{child_module_name}"
post_order_traversal(child_module, new_fqn, module)
if isinstance(module, nn.Linear) and (
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
new_linear_module = from_float_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)
post_order_traversal(root_module)
return root_module
def convert_to_float8_training(
module: nn.Module,
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
config: Optional[Float8LinearConfig] = None,
) -> nn.Module:
"""
Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
Args:
module: Module to modify.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance and the FQN.
config (Float8LinearConfig): configuration for conversion to float8
Returns:
nn.Module: The modified module with swapped linear layers.
"""
if config is None:
config = Float8LinearConfig()
from_float = lambda m: Float8Linear.from_float(
m,
config=config,
)
return swap_linear_layers(
module,
from_float,
module_filter_fn=module_filter_fn,
)