3232from MaxText import max_logging
3333from MaxText .common_types import Config , DType , AxisNames , BATCH , LENGTH , EMBED , HEAD , D_KV , Array , MODEL_MODE_TRAIN
3434from MaxText .layers import initializers , nnx_wrappers
35- from MaxText .layers .linears import mlp_block
35+ from MaxText .layers .linears import DenseGeneral , MlpBlock
3636from MaxText .layers import models
3737from MaxText .layers import quantizations
38- from MaxText .layers .attentions import KVQuant , attention_op_as_linen , dense_general
38+ from MaxText .layers .attentions import AttentionOp , KVQuant
3939from MaxText .layers .initializers import Initializer , NdInitializer , nd_dense_init
4040from MaxText .layers .quantizations import AqtQuantization as Quant
4141
@@ -166,7 +166,7 @@ class Gpt3MultiHeadAttention(nnx.Module):
166166 head_dim: dimension of each head.
167167 max_target_length: maximum length of output
168168 max_prefill_predict_length: size of the maximum prefill
169- self. mesh: device self.mesh
169+ mesh: device self.mesh
170170 dtype: the dtype of the computation.
171171 dropout_rate: dropout rate
172172 kernel_init: initializer for the kernel of the Dense layers.
@@ -180,30 +180,32 @@ class Gpt3MultiHeadAttention(nnx.Module):
180180 """
181181
182182 def __init__ (
183- self ,
184- config : Config ,
185- num_heads : int ,
186- head_dim : int ,
187- max_target_length : int ,
188- max_prefill_predict_length : int ,
189- mesh : Mesh ,
190- attention_kernel : str ,
191- dtype : DType = jnp .float32 ,
192- weight_dtype : DType = jnp .float32 ,
193- dropout_rate : float = 0.0 ,
194- kernel_init : NdInitializer = nd_dense_init (1.0 , "fan_in" , "normal" ),
195- float32_qk_product : bool = False , # computes logits in float32 for stability.
196- float32_logits : bool = True , # cast logits in float32 for stability.
197- fused_qkv : bool = True ,
198- quant : Optional [Quant ] = None ,
199- kv_quant : Optional [KVQuant ] = None ,
200- use_bias : bool = True ,
201- input_axis_names : AxisNames = (BATCH , LENGTH , EMBED ),
202- query_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
203- key_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
204- value_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
205- out_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
206- ** kwargs : Any
183+ self ,
184+ config : Config ,
185+ num_heads : int ,
186+ feature_dim : Array ,
187+ head_dim : int ,
188+ max_target_length : int ,
189+ max_prefill_predict_length : int ,
190+ mesh : Mesh ,
191+ attention_kernel : str ,
192+ dtype : DType = jnp .float32 ,
193+ weight_dtype : DType = jnp .float32 ,
194+ dropout_rate : float = 0.0 ,
195+ kernel_init : NdInitializer = nd_dense_init (1.0 , "fan_in" , "normal" ),
196+ float32_qk_product : bool = False , # computes logits in float32 for stability.
197+ float32_logits : bool = True , # cast logits in float32 for stability.
198+ fused_qkv : bool = True ,
199+ quant : Optional [Quant ] = None ,
200+ kv_quant : Optional [KVQuant ] = None ,
201+ use_bias : bool = True ,
202+ input_axis_names : AxisNames = (BATCH , LENGTH , EMBED ),
203+ query_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
204+ key_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
205+ value_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
206+ out_axis_names : AxisNames = (BATCH , LENGTH , HEAD , D_KV ),
207+ rngs : Optional [nnx .Rngs ] = None ,
208+ ** kwargs : Any ,
207209 ):
208210 self .config = config
209211 self .num_heads = num_heads
@@ -227,59 +229,48 @@ def __init__(
227229 self .key_axis_names = key_axis_names
228230 self .value_axis_names = value_axis_names
229231 self .out_axis_names = out_axis_names
230-
231- def qkv_projection (self , inputs : Array , proj_name : str ):
232- """Fused QKV projection"""
233-
234- qkv_proj = dense_general (
235- inputs_shape = inputs .shape ,
232+ self .rngs = rngs if rngs is not None else kwargs .get ("rngs" , nnx .Rngs (0 ))
233+ print (f'feature_dim: { feature_dim } ' )
234+ self .qkv_projection_layer = self .create_projection_layer (feature_dim , ("embed" , "qkv" , "heads" , "kv" ))
235+ self .q_projection_layer = self .create_projection_layer (feature_dim , ("embed" , "heads" , "kv" ))
236+ self .k_projection_layer = self .create_projection_layer (feature_dim , ("embed" , "heads" , "kv" ))
237+ self .v_projection_layer = self .create_projection_layer (feature_dim , ("embed" , "heads" , "kv" ))
238+ self .out_projection_layer = self .create_projection_layer (feature_dim , ("heads" , "kv" , "embed" ))
239+
240+ def create_projection_layer (self , input_shape : Array , kernel_axes : str ):
241+ return DenseGeneral (
242+ in_features_shape = input_shape ,
236243 out_features_shape = (3 , self .num_heads , self .head_dim ),
237244 axis = - 1 ,
238245 kernel_init = self .kernel_init ,
239- kernel_axes = ( "embed" , "qkv" , "heads" , "kv" ) ,
246+ kernel_axes = kernel_axes ,
240247 dtype = self .dtype ,
241248 weight_dtype = self .weight_dtype ,
242- name = proj_name ,
243249 quant = self .quant ,
244250 use_bias = self .use_bias ,
245251 matmul_precision = self .config .matmul_precision ,
246- )(inputs )
252+ rngs = self .rngs ,
253+ )
254+
255+ def qkv_projection (self , projection_layer : Any , inputs : Array , proj_name : str ):
256+ """Fused QKV projection"""
257+ print (f'qkv_projection in_features_shape: { inputs .shape } ' )
258+
259+ qkv_proj = self .qkv_projection_layer (inputs )
260+
247261 qkv_proj = checkpoint_name (qkv_proj , "qkv_proj" )
248262 query , key , value = qkv_proj [:, :, 0 , ...], qkv_proj [:, :, 1 , ...], qkv_proj [:, :, 2 , ...]
249263 return query , key , value
250264
251- def projection (self , inputs : Array , proj_name : str ) -> Array :
265+ def projection (self , projection_layer : Any , inputs : Array , proj_name : str ) -> Array :
252266 """individual projection for one of q, k and v."""
253- proj = dense_general (
254- inputs_shape = inputs .shape ,
255- out_features_shape = (self .num_heads , self .head_dim ),
256- axis = - 1 ,
257- kernel_init = self .kernel_init ,
258- kernel_axes = ("embed" , "heads" , "kv" ),
259- dtype = self .dtype ,
260- weight_dtype = self .weight_dtype ,
261- name = proj_name ,
262- quant = self .quant ,
263- use_bias = self .use_bias ,
264- matmul_precision = self .config .matmul_precision ,
265- )(inputs )
267+ proj = projection_layer (inputs )
266268 return proj
267269
268- def out_projection (self , output_dim : int , out : Array ) -> Array :
270+ def out_projection (self , projection_layer : Any , output_dim : int , out : Array ) -> Array :
269271 """output projection"""
270- out_proj = dense_general (
271- inputs_shape = out .shape ,
272- out_features_shape = output_dim ,
273- axis = (- 2 , - 1 ),
274- kernel_init = self .kernel_init ,
275- kernel_axes = ("heads" , "kv" , "embed" ),
276- dtype = self .dtype ,
277- weight_dtype = self .weight_dtype ,
278- name = "out" ,
279- quant = self .quant ,
280- use_bias = self .use_bias ,
281- matmul_precision = self .config .matmul_precision ,
282- )(out )
272+
273+ out_proj = projection_layer (out )
283274 return out_proj
284275
285276 def __call__ (
@@ -292,11 +283,12 @@ def __call__(
292283 ):
293284 inputs_q = nn .with_logical_constraint (inputs_q , self .input_axis_names )
294285 if self .fused_qkv :
295- query , key , value = self .qkv_projection (inputs_q , proj_name = "qkv_proj" )
286+ print (f'inputs_q size: { inputs_q .shape } ' )
287+ query , key , value = self .qkv_projection (self .qkv_projection_layer , inputs_q , proj_name = "qkv_proj" )
296288 else :
297- query = self .projection (inputs_q , proj_name = "query" )
298- key = self .projection (inputs_q , proj_name = "key" )
299- value = self .projection (inputs_q , proj_name = "value" )
289+ query = self .projection (self . q_projection_layer , inputs_q , proj_name = "query" )
290+ key = self .projection (self . k_projection_layer , inputs_q , proj_name = "key" )
291+ value = self .projection (self . v_projection_layer , inputs_q , proj_name = "value" )
300292
301293 depth_scaling = jnp .sqrt (self .head_dim ).astype (self .dtype )
302294 query /= depth_scaling
@@ -309,7 +301,7 @@ def __call__(
309301 value = nn .with_logical_constraint (value , self .value_axis_names )
310302 value = checkpoint_name (value , "value_proj" )
311303
312- attention_op = attention_op_as_linen (
304+ attention_op = AttentionOp (
313305 config = self .config ,
314306 mesh = self .mesh ,
315307 attention_kernel = self .attention_kernel ,
@@ -328,7 +320,7 @@ def __call__(
328320 out = nn .with_logical_constraint (out , self .out_axis_names )
329321
330322 # apply output projection, output dim is set to the input dim.
331- out = self .out_projection (inputs_q .shape [- 1 ], out )
323+ out = self .out_projection (self . out_projection_layer , inputs_q .shape [- 1 ], out )
332324 out = checkpoint_name (out , "out_proj" )
333325 return out
334326
@@ -340,16 +332,14 @@ def __call__(
340332
341333class Gpt3DecoderLayer (nnx .Module ):
342334 """Transformer decoder layer that attends to the encoder."""
335+
343336 def __init__ (
344- self ,
345- config : models .Config ,
346- mesh : Mesh ,
347- quant : Optional [Quant ] = None ,
348- ** kwargs : Any
337+ self , config : models .Config , mesh : Mesh , quant : Optional [Quant ] = None , rngs : Optional [nnx .Rngs ] = None , ** kwargs : Any
349338 ):
350339 self .config = config
351340 self .mesh = mesh
352341 self .quant = quant
342+ self .rngs = rngs if rngs is not None else kwargs .get ("rngs" , nnx .Rngs (0 ))
353343
354344 def __call__ (
355345 self ,
@@ -364,14 +354,14 @@ def __call__(
364354 ):
365355 inputs = nn .with_logical_constraint (inputs , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
366356 inputs = checkpoint_name (inputs , "decoder_layer_input" )
367- lnx = gpt3_layer_norm (
357+ lnx = Gpt3LayerNorm (
368358 num_features = inputs .shape [- 1 ],
369359 dtype = self .config .dtype ,
370- name = "pre_self_attention_norm" ,
371360 kernel_axes = ("norm" ,),
372361 epsilon = self .config .normalization_layer_epsilon ,
373362 reductions_in_fp32 = False ,
374363 use_bias = True ,
364+ rngs = self .rngs ,
375365 )(inputs )
376366
377367 lnx = nn .with_logical_constraint (lnx , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
@@ -380,10 +370,15 @@ def __call__(
380370 assert (
381371 self .config .num_query_heads == self .config .num_kv_heads
382372 ), f"{ self .config .num_query_heads = } should be the same as { self .config .num_kv_heads = } in gpt3"
373+
374+ # Todo: make the axis names parameters
375+ lnx_sharding = nn .with_logical_constraint (lnx , (BATCH , LENGTH , EMBED ))
376+
383377 attention_layer = Gpt3MultiHeadAttention (
384378 config = self .config ,
385379 num_heads = self .config .num_query_heads ,
386380 dtype = self .config .dtype ,
381+ feature_dim = lnx_sharding .shape ,
387382 weight_dtype = self .config .weight_dtype ,
388383 head_dim = self .config .head_dim ,
389384 max_target_length = self .config .max_target_length ,
@@ -410,18 +405,18 @@ def __call__(
410405 attention_lnx += inputs
411406
412407 # MLP block.
413- mlp_lnx = mlp_block (
408+ mlp_lnx = MlpBlock (
414409 in_features = attention_lnx .shape [- 1 ],
415410 intermediate_dim = self .config .mlp_dim ,
416411 activations = self .config .mlp_activations ,
417412 intermediate_dropout_rate = self .config .dropout_rate ,
418413 dtype = self .config .dtype ,
419414 weight_dtype = self .config .weight_dtype ,
420- name = "mlp" ,
421415 use_bias = True ,
422416 use_pre_norm = True ,
423417 config = self .config ,
424418 quant = self .quant ,
419+ rngs = self .rngs ,
425420 )(attention_lnx , deterministic = deterministic )
426421 mlp_lnx = nn .with_logical_constraint (mlp_lnx , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
427422
@@ -448,9 +443,10 @@ def __call__(
448443 else :
449444 return layer_output
450445
446+
451447def gpt3_decoder_layer_class () -> nn .Module :
452448 """Creates a Gemma3DecoderLayer Linen module."""
453449 return nnx_wrappers .to_linen_class (
454450 Gpt3DecoderLayer ,
455451 metadata_fn = initializers .variable_to_logically_partitioned ,
456- )
452+ )
0 commit comments