Skip to content

Commit 458a0e8

Browse files
committed
fix: run lint
1 parent 7a68e07 commit 458a0e8

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

src/MaxText/layers/deepseek.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def __init__(
6565
self.quant = quant
6666
self.rngs = rngs
6767

68-
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(
69-
self.config, self.model_mode
70-
)
68+
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode)
7169
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)
7270

7371
self.pre_self_attention_layer_norm = RMSNorm(
@@ -119,9 +117,7 @@ def __init__(
119117
rngs=rngs,
120118
)
121119

122-
self.dropout = Dropout(
123-
rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs
124-
)
120+
self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)
125121

126122
def __call__(
127123
self,
@@ -162,9 +158,7 @@ def with_logical_constraint(self, x):
162158
return nn.with_logical_constraint(x, self.logical_axis_names)
163159

164160
def dropout_op(self, x, deterministic):
165-
return self.with_logical_constraint(
166-
self.dropout(x, deterministic=deterministic)
167-
)
161+
return self.with_logical_constraint(self.dropout(x, deterministic=deterministic))
168162

169163
def pre_attention_norm_op(self, x):
170164
return self.with_logical_constraint(self.pre_self_attention_layer_norm(x))
@@ -311,9 +305,7 @@ def __init__(
311305
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
312306
config=self.config,
313307
mesh=mesh,
314-
kernel_init=initializers.nd_dense_init(
315-
1.0, "fan_in", "truncated_normal"
316-
),
308+
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
317309
kernel_axes=("embed", None),
318310
dtype=self.config.dtype,
319311
weight_dtype=self.config.weight_dtype,

tests/pipeline_parallelism_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,9 @@ def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_cla
6969
else:
7070
if issubclass(single_pipeline_stage_class, nnx_wrappers.ToLinen):
7171
rngs = nnx.Rngs(params=0)
72-
single_pipeline_stage = single_pipeline_stage_class(
73-
config=config, mesh=mesh, model_mode=model_mode, rngs=rngs
74-
)
72+
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode, rngs=rngs)
7573
else:
76-
single_pipeline_stage = single_pipeline_stage_class(
77-
config=config, mesh=mesh, model_mode=model_mode
78-
)
74+
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode)
7975

8076
def get_inputs(batch_size, sequence, features):
8177
"""Get random inputs, and random dummy targets

0 commit comments

Comments
 (0)