Skip to content

Commit 3d57ee0

Browse files
authored
Simplify dtype and device settings (#57)
1 parent 8c49209 commit 3d57ee0

8 files changed

Lines changed: 36 additions & 51 deletions

File tree

gllm/layers/layernorm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@ def __init__(
99
self,
1010
hidden_size: int,
1111
eps: float,
12-
dtype: torch.dtype = None,
1312
) -> None:
1413
super().__init__()
1514
self.variance_epsilon = eps
1615
self.variance_size_override = None
1716
self.hidden_size = hidden_size
18-
self.weight = nn.Parameter(torch.ones(
19-
hidden_size, dtype=dtype))
17+
self.weight = nn.Parameter(torch.ones(hidden_size))
2018
self.has_weight = True
2119

2220
def forward(

gllm/layers/rotary_embedding.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def __init__(
1515
max_position_embeddings: int,
1616
base: float,
1717
is_neox_style: bool,
18-
dtype,
1918
) -> None:
2019
super().__init__()
2120
self.head_size = head_size
@@ -26,7 +25,7 @@ def __init__(
2625

2726
cache = self._compute_cos_sin_cache()
2827

29-
cache = cache.to(dtype=dtype,device='cuda')
28+
cache = cache.to(dtype=torch.get_default_dtype())
3029
self.register_buffer("cos_sin_cache", cache, persistent=False)
3130

3231
def _compute_inv_freq(self, base):
@@ -84,14 +83,13 @@ def __init__(
8483
max_position_embeddings: int,
8584
base: int,
8685
is_neox_style: bool,
87-
scaling_factors: Union[List[float], float],
88-
dtype: torch.dtype,
86+
scaling_factors: Union[List[float], float]
8987
) -> None:
9088
if isinstance(scaling_factors, float):
9189
scaling_factors = [scaling_factors]
9290
self.scaling_factors: List[float] = scaling_factors # noqa
9391
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
94-
is_neox_style, dtype)
92+
is_neox_style)
9593
# Lazy initialized.
9694
self._scaling_factor_to_offset: Dict[float, int]
9795

@@ -142,7 +140,6 @@ def __init__(
142140
max_position_embeddings: int,
143141
base: int,
144142
is_neox_style: bool,
145-
dtype: torch.dtype,
146143
scaling_factor: float,
147144
low_freq_factor: float,
148145
high_freq_factor: float,
@@ -153,7 +150,7 @@ def __init__(
153150
self.high_freq_factor = high_freq_factor
154151
self.orig_max_position = orig_max_position
155152
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
156-
is_neox_style, dtype)
153+
is_neox_style)
157154

158155
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
159156
inv_freqs = super()._compute_inv_freq(base)

gllm/model_loader.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,6 @@ def __init__(self, load_format, model_path):
2020
self.model_path = model_path
2121
self.load_config()
2222
self.load_format = load_format
23-
24-
def get_dtype(self, dtype: str):
25-
if dtype == 'float16':
26-
return torch.float16
27-
elif dtype == 'bfloat16':
28-
return torch.bfloat16
29-
else:
30-
assert 0
3123

3224
def get_finish_tokens(self):
3325
return self.get_model_type().get_finish_tokens(self.config)

gllm/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def init(self, mp_load_progress=None):
3939
memory_manager_cls = PrefixMemoryManager if self.enable_prefix_caching else MemoryManager
4040
self.memory_manager = memory_manager_cls(
4141
gpu_memory_util=self.gpu_memory_util, num_layers=self.model.num_layers,
42-
dtype=self.model.dtype, page_size=self.page_size, kv_head_num=self.model.num_kv_heads,
42+
dtype=self.model_loader.dtype, page_size=self.page_size, kv_head_num=self.model.num_kv_heads,
4343
kv_head_dim=self.model.head_dim, vocab_size=self.model_loader.vocab_size)
4444

4545
def encode(self, content, chat: bool = False):

gllm/models/chatglm.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ def __init__(self, layer_id: int, config):
2525
self.scaling = self.head_dim**-0.5
2626

2727
self.rotary_emb = RotaryEmbedding(
28-
self.head_dim, self.head_dim // 2, config.seq_length, getattr(config,'rope_theta',10000), False, config.torch_dtype)
28+
self.head_dim, self.head_dim // 2, config.seq_length, getattr(config,'rope_theta',10000), False)
2929
self.attn = FlashAttention(
3030
layer_id, self.scaling, self.num_heads, self.num_kv_heads, self.head_dim, self.hidden_size)
3131

3232
self.projection_size = config.kv_channels * self.num_heads
3333
self.qkv_hidden_size = self.projection_size + 2 * \
3434
self.head_dim * config.multi_query_group_num
3535
self.query_key_value = nn.Linear(self.hidden_size, self.qkv_hidden_size,
36-
bias=config.add_bias_linear or config.add_qkv_bias, dtype=config.torch_dtype, device='cuda')
36+
bias=config.add_bias_linear or config.add_qkv_bias)
3737
self.dense = nn.Linear(self.projection_size, self.hidden_size,
38-
bias=config.add_bias_linear, dtype=config.torch_dtype, device='cuda')
38+
bias=config.add_bias_linear)
3939

4040
def forward(self, input_data: InputData, hidden_states: torch.Tensor):
4141
qkv = self.query_key_value(hidden_states)
@@ -51,10 +51,10 @@ def __init__(self, config):
5151
super().__init__()
5252
self.add_bias = config.add_bias_linear
5353
self.dense_h_to_4h = nn.Linear(
54-
config.hidden_size, config.ffn_hidden_size*2, bias=self.add_bias, dtype=config.torch_dtype, device='cuda')
54+
config.hidden_size, config.ffn_hidden_size*2, bias=self.add_bias)
5555
self.activation_func = SiluAndMul()
5656
self.dense_4h_to_h = nn.Linear(
57-
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, dtype=config.torch_dtype, device='cuda')
57+
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias)
5858

5959
def forward(self, hidden_states):
6060
# [s, b, 4hp]
@@ -73,13 +73,13 @@ def __init__(self, layer_id, config):
7373

7474
assert config.rmsnorm
7575
self.input_layernorm = RMSNorm(
76-
config.hidden_size, config.layernorm_epsilon, config.torch_dtype)
76+
config.hidden_size, config.layernorm_epsilon)
7777

7878
self.self_attention = GLMAttention(layer_id, config)
7979
self.hidden_dropout = config.hidden_dropout
8080

8181
self.post_attention_layernorm = RMSNorm(
82-
config.hidden_size, config.layernorm_epsilon, config.torch_dtype)
82+
config.hidden_size, config.layernorm_epsilon)
8383

8484
self.mlp = GLMMLP(config)
8585

@@ -127,7 +127,7 @@ def __init__(self, config):
127127
assert config.rmsnorm
128128
layer_norm_func = RMSNorm
129129
self.final_layernorm = layer_norm_func(
130-
config.hidden_size, config.layernorm_epsilon, config.torch_dtype)
130+
config.hidden_size, config.layernorm_epsilon)
131131

132132
def forward(self, input_data: InputData, hidden_states: torch.Tensor):
133133
for layer in self.layers:
@@ -145,14 +145,14 @@ def __init__(self, config):
145145
super().__init__()
146146

147147
self.embedding = nn.Embedding(
148-
config.padded_vocab_size, config.hidden_size, dtype=config.torch_dtype, device='cuda')
148+
config.padded_vocab_size, config.hidden_size)
149149

150150
self.multi_query_group_num = config.multi_query_group_num
151151
self.kv_channels = config.kv_channels
152152

153153
self.encoder = GLMTransformer(config)
154154
self.output_layer = nn.Linear(
155-
config.hidden_size, config.padded_vocab_size, bias=False, dtype=config.torch_dtype, device='cuda')
155+
config.hidden_size, config.padded_vocab_size, bias=False)
156156

157157
def forward(self, input_data: InputData, hidden_states=None):
158158
if get_pp_rank() == 0:
@@ -169,7 +169,6 @@ def __init__(self, config):
169169

170170
self.config = config
171171
self.max_model_len = config.seq_length
172-
self.dtype = config.torch_dtype
173172
self.num_kv_heads = config.multi_query_group_num
174173
self.head_dim = config.hidden_size // config.num_attention_heads
175174
self.transformer = ChatGLMModel(config)

gllm/models/llama.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def __init__(self, config):
2020
self.hidden_size = config.hidden_size
2121
self.intermediate_size = config.intermediate_size
2222
self.gate_up_proj = nn.Linear(
23-
self.hidden_size, self.intermediate_size*2, bias=False, dtype=config.torch_dtype, device='cuda')
23+
self.hidden_size, self.intermediate_size*2, bias=False)
2424
self.down_proj = nn.Linear(
25-
self.intermediate_size, self.hidden_size, bias=False, dtype=config.torch_dtype, device='cuda')
25+
self.intermediate_size, self.hidden_size, bias=False)
2626
self.act_fn = SiluAndMul()
2727

2828
def forward(self, x: torch.Tensor):
@@ -39,9 +39,9 @@ def __init__(self, layer_id: int, config):
3939
self.num_key_value_heads = config.num_key_value_heads
4040

4141
self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads+self.num_key_value_heads*2)
42-
* self.head_dim, bias=False, dtype=config.torch_dtype, device='cuda')
42+
* self.head_dim, bias=False)
4343
self.o_proj = nn.Linear(self.num_heads*self.head_dim,
44-
self.hidden_size, bias=False, dtype=config.torch_dtype, device='cuda')
44+
self.hidden_size, bias=False)
4545

4646
self.rope_theta = getattr(config,'rope_theta',10000)
4747

@@ -55,18 +55,18 @@ def __init__(self, layer_id: int, config):
5555
"original_max_position_embeddings"]
5656
self.rotary_emb = Llama3RotaryEmbedding(
5757
self.head_dim, self.head_dim, original_max_position,
58-
self.rope_theta, True, config.torch_dtype,
59-
rope_scaling['factor'], low_freq_factor, high_freq_factor, original_max_position)
58+
self.rope_theta, True, rope_scaling['factor'], low_freq_factor,
59+
high_freq_factor, original_max_position)
6060
elif rope_scaling['type'] == 'linear':
6161
self.rotary_emb = LinearScalingRotaryEmbedding(
6262
self.head_dim, self.head_dim, config.max_position_embeddings,
63-
self.rope_theta, True, rope_scaling['factor'], config.torch_dtype)
63+
self.rope_theta, True, rope_scaling['factor'])
6464
else:
6565
assert 0
6666
else:
6767
self.rotary_emb = RotaryEmbedding(
6868
self.head_dim, self.head_dim, config.max_position_embeddings,
69-
self.rope_theta, True, config.torch_dtype)
69+
self.rope_theta, True)
7070

7171
self.scaling = self.head_dim**-0.5
7272

@@ -90,10 +90,10 @@ class LlamaDecoderLayer(nn.Module):
9090
def __init__(self, layer_id: int, config):
9191
super().__init__()
9292
self.input_layernorm = RMSNorm(
93-
config.hidden_size, config.rms_norm_eps, config.torch_dtype)
93+
config.hidden_size, config.rms_norm_eps)
9494
self.self_attn = LlamaAttention(layer_id, config)
9595
self.post_attention_layernorm = RMSNorm(
96-
config.hidden_size, config.rms_norm_eps, config.torch_dtype)
96+
config.hidden_size, config.rms_norm_eps)
9797
self.mlp = LlamaMLP(config)
9898

9999
def forward(self, input_data: InputData, hidden_states: torch.Tensor, residual: Optional[torch.Tensor]):
@@ -128,10 +128,10 @@ def __init__(self, config):
128128
layer_id-self.start_layer, config) for layer_id in range(self.start_layer, self.end_layer)])
129129
if get_pp_rank() == 0:
130130
self.embed_tokens = nn.Embedding(
131-
config.vocab_size, config.hidden_size, dtype=config.torch_dtype, device='cuda')
131+
config.vocab_size, config.hidden_size)
132132
if get_pp_rank() == get_pp_size() - 1:
133133
self.norm = RMSNorm(
134-
config.hidden_size, config.rms_norm_eps, config.torch_dtype)
134+
config.hidden_size, config.rms_norm_eps)
135135

136136
def forward(self, input_data: InputData, hidden_states=None, residual=None):
137137
if get_pp_rank() == 0:
@@ -151,7 +151,6 @@ class LlamaForCausalLM(nn.Module):
151151
def __init__(self, config):
152152
super().__init__()
153153
self.max_model_len = config.max_position_embeddings
154-
self.dtype = config.torch_dtype
155154
self.num_kv_heads = config.num_key_value_heads
156155
self.head_dim = config.hidden_size // config.num_attention_heads
157156
self.model = LlamaModel(config)
@@ -160,7 +159,7 @@ def __init__(self, config):
160159
self.ret_residual = True
161160
if get_pp_rank() == get_pp_size() - 1:
162161
self.lm_head = nn.Linear(
163-
config.hidden_size, config.vocab_size, bias=False, dtype=config.torch_dtype, device='cuda')
162+
config.hidden_size, config.vocab_size, bias=False)
164163
self.sampler = Sampler()
165164

166165
def forward(self, input_data: InputData, hidden_states=None, residual=None):

gllm/models/qwen2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, layer_id: int, config):
4848
self.hidden_size, (self.num_heads+self.num_kv_heads*2)*self.head_dim, bias=True)
4949
self.o_proj = nn.Linear(self.num_heads*self.head_dim, self.hidden_size, bias=False)
5050
self.rotary_emb = RotaryEmbedding(
51-
self.head_dim, self.head_dim, self.max_position_embeddings, self.rope_theta, True, config.torch_dtype)
51+
self.head_dim, self.head_dim, self.max_position_embeddings, self.rope_theta, True)
5252
self.attn = FlashAttention(
5353
layer_id, self.scaling, self.num_heads, self.num_kv_heads, self.head_dim, self.hidden_size)
5454

@@ -67,9 +67,9 @@ def __init__(self, layer_id: int, config, attention_type=Qwen2Attention, mlp_typ
6767
self.self_attn = attention_type(layer_id, config)
6868
self.mlp = mlp_type(config)
6969
self.input_layernorm = RMSNorm(
70-
config.hidden_size, config.rms_norm_eps, config.torch_dtype)
70+
config.hidden_size, config.rms_norm_eps)
7171
self.post_attention_layernorm = RMSNorm(
72-
config.hidden_size, config.rms_norm_eps, config.torch_dtype)
72+
config.hidden_size, config.rms_norm_eps)
7373

7474
def forward(self, input_data: InputData, hidden_states: torch.Tensor, residual: Optional[torch.Tensor]):
7575
if residual is None:
@@ -102,7 +102,7 @@ def __init__(self, config, decoder_layer_type=Qwen2DecoderLayer):
102102
])
103103
if get_pp_rank() == get_pp_size() - 1:
104104
self.norm = RMSNorm(
105-
config.hidden_size, config.rms_norm_eps, config.torch_dtype)
105+
config.hidden_size, config.rms_norm_eps)
106106

107107
def forward(self, input_data: InputData, hidden_states=None, residual=None):
108108
if get_pp_rank() == 0:

gllm/models/qwen3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def __init__(self, layer_id, config):
2929
self.hidden_size, (self.num_heads+self.num_kv_heads*2)*self.head_dim, bias=self.qkv_bias)
3030
self.o_proj = nn.Linear(self.num_heads*self.head_dim, self.hidden_size, bias=False)
3131
self.rotary_emb = RotaryEmbedding(
32-
self.head_dim, self.head_dim, config.max_position_embeddings, self.rope_theta, True, config.torch_dtype)
32+
self.head_dim, self.head_dim, config.max_position_embeddings, self.rope_theta, True)
3333
self.attn = FlashAttention(
3434
layer_id, self.scaling, self.num_heads, self.num_kv_heads, self.head_dim, self.hidden_size)
35-
self.q_norm = RMSNorm(self.head_dim, config.rms_norm_eps, config.torch_dtype)
36-
self.k_norm = RMSNorm(self.head_dim, config.rms_norm_eps, config.torch_dtype)
35+
self.q_norm = RMSNorm(self.head_dim, config.rms_norm_eps)
36+
self.k_norm = RMSNorm(self.head_dim, config.rms_norm_eps)
3737

3838
def forward(self, input_data: InputData, hidden_states: torch.Tensor):
3939
qkv = self.qkv_proj(hidden_states)

0 commit comments

Comments
 (0)