@@ -20,9 +20,9 @@ def __init__(self, config):
2020 self .hidden_size = config .hidden_size
2121 self .intermediate_size = config .intermediate_size
2222 self .gate_up_proj = nn .Linear (
23- self .hidden_size , self .intermediate_size * 2 , bias = False , dtype = config . torch_dtype , device = 'cuda' )
23+ self .hidden_size , self .intermediate_size * 2 , bias = False )
2424 self .down_proj = nn .Linear (
25- self .intermediate_size , self .hidden_size , bias = False , dtype = config . torch_dtype , device = 'cuda' )
25+ self .intermediate_size , self .hidden_size , bias = False )
2626 self .act_fn = SiluAndMul ()
2727
2828 def forward (self , x : torch .Tensor ):
@@ -39,9 +39,9 @@ def __init__(self, layer_id: int, config):
3939 self .num_key_value_heads = config .num_key_value_heads
4040
4141 self .qkv_proj = nn .Linear (self .hidden_size , (self .num_heads + self .num_key_value_heads * 2 )
42- * self .head_dim , bias = False , dtype = config . torch_dtype , device = 'cuda' )
42+ * self .head_dim , bias = False )
4343 self .o_proj = nn .Linear (self .num_heads * self .head_dim ,
44- self .hidden_size , bias = False , dtype = config . torch_dtype , device = 'cuda' )
44+ self .hidden_size , bias = False )
4545
4646 self .rope_theta = getattr (config ,'rope_theta' ,10000 )
4747
@@ -55,18 +55,18 @@ def __init__(self, layer_id: int, config):
5555 "original_max_position_embeddings" ]
5656 self .rotary_emb = Llama3RotaryEmbedding (
5757 self .head_dim , self .head_dim , original_max_position ,
58- self .rope_theta , True , config . torch_dtype ,
59- rope_scaling [ 'factor' ], low_freq_factor , high_freq_factor , original_max_position )
58+ self .rope_theta , True , rope_scaling [ 'factor' ], low_freq_factor ,
59+ high_freq_factor , original_max_position )
6060 elif rope_scaling ['type' ] == 'linear' :
6161 self .rotary_emb = LinearScalingRotaryEmbedding (
6262 self .head_dim , self .head_dim , config .max_position_embeddings ,
63- self .rope_theta , True , rope_scaling ['factor' ], config . torch_dtype )
63+ self .rope_theta , True , rope_scaling ['factor' ])
6464 else :
6565 assert 0
6666 else :
6767 self .rotary_emb = RotaryEmbedding (
6868 self .head_dim , self .head_dim , config .max_position_embeddings ,
69- self .rope_theta , True , config . torch_dtype )
69+ self .rope_theta , True )
7070
7171 self .scaling = self .head_dim ** - 0.5
7272
@@ -90,10 +90,10 @@ class LlamaDecoderLayer(nn.Module):
9090 def __init__ (self , layer_id : int , config ):
9191 super ().__init__ ()
9292 self .input_layernorm = RMSNorm (
93- config .hidden_size , config .rms_norm_eps , config . torch_dtype )
93+ config .hidden_size , config .rms_norm_eps )
9494 self .self_attn = LlamaAttention (layer_id , config )
9595 self .post_attention_layernorm = RMSNorm (
96- config .hidden_size , config .rms_norm_eps , config . torch_dtype )
96+ config .hidden_size , config .rms_norm_eps )
9797 self .mlp = LlamaMLP (config )
9898
9999 def forward (self , input_data : InputData , hidden_states : torch .Tensor , residual : Optional [torch .Tensor ]):
@@ -128,10 +128,10 @@ def __init__(self, config):
128128 layer_id - self .start_layer , config ) for layer_id in range (self .start_layer , self .end_layer )])
129129 if get_pp_rank () == 0 :
130130 self .embed_tokens = nn .Embedding (
131- config .vocab_size , config .hidden_size , dtype = config . torch_dtype , device = 'cuda' )
131+ config .vocab_size , config .hidden_size )
132132 if get_pp_rank () == get_pp_size () - 1 :
133133 self .norm = RMSNorm (
134- config .hidden_size , config .rms_norm_eps , config . torch_dtype )
134+ config .hidden_size , config .rms_norm_eps )
135135
136136 def forward (self , input_data : InputData , hidden_states = None , residual = None ):
137137 if get_pp_rank () == 0 :
@@ -151,7 +151,6 @@ class LlamaForCausalLM(nn.Module):
151151 def __init__ (self , config ):
152152 super ().__init__ ()
153153 self .max_model_len = config .max_position_embeddings
154- self .dtype = config .torch_dtype
155154 self .num_kv_heads = config .num_key_value_heads
156155 self .head_dim = config .hidden_size // config .num_attention_heads
157156 self .model = LlamaModel (config )
@@ -160,7 +159,7 @@ def __init__(self, config):
160159 self .ret_residual = True
161160 if get_pp_rank () == get_pp_size () - 1 :
162161 self .lm_head = nn .Linear (
163- config .hidden_size , config .vocab_size , bias = False , dtype = config . torch_dtype , device = 'cuda' )
162+ config .hidden_size , config .vocab_size , bias = False )
164163 self .sampler = Sampler ()
165164
166165 def forward (self , input_data : InputData , hidden_states = None , residual = None ):
0 commit comments