-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathmodule.py
109 lines (82 loc) · 3.69 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# This file is adapted from module.py in Megatron-LM
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from domino.arguments import get_args
import domino.parallel_state as mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class DominoModule(torch.nn.Module):
"""extensions of torch Module."""
def __init__(self, config=None, share_embeddings_and_output_weights=True):
super(DominoModule, self).__init__()
self.config = config
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints.
"""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def initialize_word_embeddings(self):
self.share_embeddings_and_output_weights = True
return
def shared_embedding_or_output_weight(self):
if self.pre_process:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_embeddings_and_output_weights:
raise Exception('shared_embedding_or_output_weight() called for last '
'stage, but share_embeddings_and_output_weights is false')
return self.word_embeddings.weight
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
# class Float16Module(torch.nn.Module):
class Float16Module(DominoModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
self.add_module('module', module.half())
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
""" Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)