@@ -65,9 +65,7 @@ def __init__(
6565 self .quant = quant
6666 self .rngs = rngs
6767
68- batch_size , sequence_length = max_utils .get_batch_seq_len_for_mode (
69- self .config , self .model_mode
70- )
68+ batch_size , sequence_length = max_utils .get_batch_seq_len_for_mode (self .config , self .model_mode )
7169 self .dummy_inputs_shape = (batch_size , sequence_length , self .config .emb_dim )
7270
7371 self .pre_self_attention_layer_norm = RMSNorm (
@@ -119,9 +117,7 @@ def __init__(
119117 rngs = rngs ,
120118 )
121119
122- self .dropout = Dropout (
123- rate = self .config .dropout_rate , broadcast_dims = (- 2 ,), rngs = self .rngs
124- )
120+ self .dropout = Dropout (rate = self .config .dropout_rate , broadcast_dims = (- 2 ,), rngs = self .rngs )
125121
126122 def __call__ (
127123 self ,
@@ -162,9 +158,7 @@ def with_logical_constraint(self, x):
162158 return nn .with_logical_constraint (x , self .logical_axis_names )
163159
164160 def dropout_op (self , x , deterministic ):
165- return self .with_logical_constraint (
166- self .dropout (x , deterministic = deterministic )
167- )
161+ return self .with_logical_constraint (self .dropout (x , deterministic = deterministic ))
168162
169163 def pre_attention_norm_op (self , x ):
170164 return self .with_logical_constraint (self .pre_self_attention_layer_norm (x ))
@@ -311,9 +305,7 @@ def __init__(
311305 self .DeepSeekMoeBlock_0 = moe .RoutedAndSharedMoE (
312306 config = self .config ,
313307 mesh = mesh ,
314- kernel_init = initializers .nd_dense_init (
315- 1.0 , "fan_in" , "truncated_normal"
316- ),
308+ kernel_init = initializers .nd_dense_init (1.0 , "fan_in" , "truncated_normal" ),
317309 kernel_axes = ("embed" , None ),
318310 dtype = self .config .dtype ,
319311 weight_dtype = self .config .weight_dtype ,
0 commit comments