@@ -100,6 +100,10 @@ def norm_type(self) -> NormTypeEnum:
100
100
def positional_embedding_type (self ) -> PositionalEmbeddingType :
101
101
return PositionalEmbeddingType .rotate_half
102
102
103
+ @property
104
+ def positional_embedding_config (self ) -> Optional [RotateHalfConfig ]:
105
+ return RotateHalfConfig (theta_base = self ._config .rotary_emb_base )
106
+
103
107
def make_norm_layer (self ) -> None :
104
108
"""
105
109
Instantiates the normalization layer for the model. This sets the `self.norm` attribute.
@@ -119,27 +123,6 @@ def make_norm_layer(self) -> None:
119
123
120
124
self .norm = heuristics .instantiate_pre_norm (norm_config , self ._engine_config )
121
125
122
- def make_attn_layer (self ) -> None :
123
- """
124
- Builds the attention layer for the model. This sets the `self.attn` attribute.
125
- """
126
- softmax_scale = 1.0 / (self .head_size ** 0.5 )
127
-
128
- rotary_config = RotateHalfConfig (theta_base = self ._config .rotary_emb_base )
129
-
130
- attn_config = DSSelfAttentionConfig (max_tokens = self ._engine_config .state_manager .max_ragged_batch_size ,
131
- n_heads_q = self .n_heads_q_local ,
132
- n_heads_kv = self .n_heads_kv_local ,
133
- head_size = self .head_size ,
134
- max_sequences = self ._engine_config .state_manager .max_ragged_sequence_count ,
135
- scale_factor = softmax_scale ,
136
- input_dtype = self .activation_dtype ,
137
- output_dtype = self .activation_dtype ,
138
- positional_embedding_type = self .positional_embedding_type ,
139
- positional_embedding_config = rotary_config )
140
-
141
- self .attn = heuristics .instantiate_attention (attn_config , self ._engine_config )
142
-
143
126
"""
144
127
Forward implementations
145
128
"""
@@ -210,8 +193,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
210
193
Performs unembedding of the hidden states to logits. This will only sample the final
211
194
token of each sequence.
212
195
"""
213
- logits = self .unembed (hidden_states , self ._non_transformer .word_unembed , ragged_batch_info ,
214
- self ._non_transformer .final_norm )
196
+ logits = self .unembed (hidden_states ,
197
+ self ._non_transformer .word_unembed ,
198
+ ragged_batch_info ,
199
+ gamma = self ._non_transformer .final_norm )
215
200
216
201
if self .tp_size > 1 :
217
202
comm_buffer = empty_from (self ._comm_logits , (self .tp_size , logits .shape [0 ], logits .shape [1 ]))
0 commit comments