@@ -97,11 +97,23 @@ def __call__(
9797 )
9898
9999 if self .model_mode == MODEL_MODE_PREFILL :
100- logical_axis_names = ("activation_batch" , "prefill_activation_length" , "activation_embed" )
100+ logical_axis_names = (
101+ "activation_batch" ,
102+ "prefill_activation_length" ,
103+ "activation_embed" ,
104+ )
101105 elif self .config .expert_shard_attention_option == EP_AS_CONTEXT and self .model_mode == MODEL_MODE_TRAIN :
102- logical_axis_names = ("activation_batch_no_exp" , "activation_length" , "activation_embed" )
106+ logical_axis_names = (
107+ "activation_batch_no_exp" ,
108+ "activation_length" ,
109+ "activation_embed" ,
110+ )
103111 else :
104- logical_axis_names = ("activation_batch" , "activation_length_no_exp" , "activation_embed" )
112+ logical_axis_names = (
113+ "activation_batch" ,
114+ "activation_length_no_exp" ,
115+ "activation_embed" ,
116+ )
105117
106118 if model_mode == MODEL_MODE_PREFILL :
107119 inputs = _maybe_shard_with_logical (inputs , logical_axis_names )
@@ -235,7 +247,11 @@ def __call__(
235247 ) -> jnp .ndarray :
236248 for lyr in range (self .num_decoder_layers ):
237249 inputs = self .decoder_layer (
238- config = self .config , mesh = self .mesh , name = f"layers_{ lyr } " , quant = self .quant , model_mode = model_mode
250+ config = self .config ,
251+ mesh = self .mesh ,
252+ name = f"layers_{ lyr } " ,
253+ quant = self .quant ,
254+ model_mode = model_mode ,
239255 )(
240256 inputs ,
241257 decoder_segment_ids ,
@@ -269,7 +285,10 @@ def setup(self):
269285 pipeline_stage_module = self .get_pipeline_stage_module (self .decoder_layer )
270286 remat_policy = self .get_remat_policy ()
271287 self .pipeline_module = pipeline .Pipeline (
272- config = self .config , mesh = self .mesh , layers = pipeline_stage_module , remat_policy = remat_policy
288+ config = self .config ,
289+ mesh = self .mesh ,
290+ layers = pipeline_stage_module ,
291+ remat_policy = remat_policy ,
273292 )
274293
275294 def minimal_policy (self , with_context = False ):
@@ -339,7 +358,11 @@ def get_remat_policy(self):
339358 elif cfg .remat_policy == "qkv_proj_offloaded" :
340359 policy = jax .checkpoint_policies .save_and_offload_only_these_names (
341360 names_which_can_be_saved = [],
342- names_which_can_be_offloaded = ["query_proj" , "value_proj" , "key_proj" ],
361+ names_which_can_be_offloaded = [
362+ "query_proj" ,
363+ "value_proj" ,
364+ "key_proj" ,
365+ ],
343366 offload_src = "device" ,
344367 offload_dst = "pinned_host" ,
345368 )
@@ -395,7 +418,10 @@ def get_decoder_layers(self):
395418 return [mixtral .MixtralDecoderLayerToLinen ]
396419 case DecoderBlockType .DEEPSEEK :
397420 if self .config .use_batch_split_schedule :
398- return [deepseek_batchsplit .DeepSeekDenseLayer , deepseek_batchsplit .DeepSeekMoELayer ]
421+ return [
422+ deepseek_batchsplit .DeepSeekDenseLayer ,
423+ deepseek_batchsplit .DeepSeekMoELayer ,
424+ ]
399425 else :
400426 return [deepseek .DeepSeekDenseLayer , deepseek .DeepSeekMoELayer ]
401427 case DecoderBlockType .GEMMA :
@@ -447,7 +473,10 @@ def map_fn(path, value):
447473 block_layer ,
448474 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config ),
449475 policy = policy ,
450- static_argnums = (4 , 5 ), # Deterministic and model mode are static arguments.
476+ static_argnums = (
477+ 4 ,
478+ 5 ,
479+ ), # Deterministic and model mode are static arguments.
451480 )
452481 RemattedBlockLayers .append (layer )
453482 return RemattedBlockLayers
@@ -473,11 +502,25 @@ def get_norm_layer(self, num_features: int):
473502 ):
474503 return functools .partial (rms_norm , num_features = num_features , shard_mode = self .config .shard_mode )
475504 elif self .config .decoder_block == DecoderBlockType .GPT3 :
476- return functools .partial (gpt3 .gpt3_layer_norm , num_features = num_features , reductions_in_fp32 = False , use_bias = True )
505+ return functools .partial (
506+ gpt3 .gpt3_layer_norm ,
507+ num_features = num_features ,
508+ reductions_in_fp32 = False ,
509+ use_bias = True ,
510+ )
477511 else :
478512 raise ValueError (f"Incorrect decoder_block name { self .config .decoder_block .value = } " )
479513
480- def scan_decoder_layers (self , cfg , decoder_layer , length , metadata_axis_name , mesh , in_axes_tuple , ** kwargs ):
514+ def scan_decoder_layers (
515+ self ,
516+ cfg ,
517+ decoder_layer ,
518+ length ,
519+ metadata_axis_name ,
520+ mesh ,
521+ in_axes_tuple ,
522+ ** kwargs ,
523+ ):
481524 """scan decoder layers, calls `flax.linen.transforms.scan`"""
482525 initializing = self .is_mutable_collection ("params" )
483526 params_spec = cfg .param_scan_axis if initializing else ScanIn (cfg .param_scan_axis )
@@ -500,7 +543,11 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
500543 metadata_params = {nn .PARTITION_NAME : metadata_axis_name },
501544 )
502545 return scan_fn (
503- config = cfg , mesh = mesh , name = metadata_axis_name , quant = self .quant , ** kwargs # pytype: disable=wrong-keyword-args
546+ config = cfg ,
547+ mesh = mesh ,
548+ name = metadata_axis_name ,
549+ quant = self .quant ,
550+ ** kwargs , # pytype: disable=wrong-keyword-args
504551 )
505552
506553 def get_pipeline_stage_module (self , decoder_blocks ):
@@ -558,7 +605,13 @@ def _apply_embedding(
558605
559606 # Merge the image embeddings with the text embeddings for multimodal models
560607 if image_embeddings is not None and cfg .use_multimodal :
561- if cfg .model_name in ["gemma3-4b" , "gemma3-12b" , "gemma3-27b" , "llama4-17b-16e" , "llama4-17b-128e" ]:
608+ if cfg .model_name in [
609+ "gemma3-4b" ,
610+ "gemma3-12b" ,
611+ "gemma3-27b" ,
612+ "llama4-17b-16e" ,
613+ "llama4-17b-128e" ,
614+ ]:
562615 y = multimodal_utils .merge_mm_embeddings (
563616 text_embeddings = y ,
564617 vision_embeddings = image_embeddings ,
@@ -751,7 +804,10 @@ def __call__(
751804 remaining_layers = self .config .num_decoder_layers - self .config .pipeline_parallel_layers
752805 if remaining_layers > 0 :
753806 logical_axis_rules_pp_as_dp = sharding .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
754- with self .mesh , nn .partitioning .axis_rules (logical_axis_rules_pp_as_dp ):
807+ with (
808+ self .mesh ,
809+ nn .partitioning .axis_rules (logical_axis_rules_pp_as_dp ),
810+ ):
755811 y , _ = self .scan_decoder_layers (
756812 cfg ,
757813 RemattedBlockLayers [0 ],
@@ -838,7 +894,11 @@ def __call__(
838894 for layer , num_layers , layer_prefix in zip (layers , num_layers_list , layer_prefixes ):
839895 for index in range (num_layers ):
840896 y = layer (
841- config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode
897+ config = cfg ,
898+ mesh = mesh ,
899+ name = f"{ layer_prefix } _{ index } " ,
900+ quant = self .quant ,
901+ model_mode = self .model_mode ,
842902 )(
843903 y ,
844904 decoder_segment_ids ,
@@ -868,7 +928,12 @@ def __call__(
868928 if cfg .decoder_block == DecoderBlockType .GPT_OSS :
869929 layer_kwargs = {"attention_type" : gpt_oss .get_attention_type (layer_id = lyr )}
870930 layer = RemattedBlockLayer (
871- config = cfg , mesh = mesh , name = f"layers_{ lyr } " , quant = self .quant , model_mode = self .model_mode , ** layer_kwargs
931+ config = cfg ,
932+ mesh = mesh ,
933+ name = f"layers_{ lyr } " ,
934+ quant = self .quant ,
935+ model_mode = self .model_mode ,
936+ ** layer_kwargs ,
872937 )
873938 y = layer (
874939 y ,
@@ -952,7 +1017,12 @@ def _apply_gemma3_scanned_blocks(
9521017 rem_layer_kwargs = {"num_of_layers" : num_remaining_layers }
9531018 # pytype: disable=wrong-keyword-args
9541019 layer = RemattedGemma3Block (
955- config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode , name = "layers_remainder" , ** rem_layer_kwargs
1020+ config = cfg ,
1021+ mesh = mesh ,
1022+ quant = self .quant ,
1023+ model_mode = self .model_mode ,
1024+ name = "layers_remainder" ,
1025+ ** rem_layer_kwargs ,
9561026 )
9571027 y , _ = layer (
9581028 y ,
0 commit comments