Skip to content

Commit fbf4f0e

Browse files
Cristian GarciaThe tunix Authors
authored andcommitted
change defaults for Dropout and BatchNorm
Changes `Dropout.deterministic` and `BatchNorm.use_running_average` to be None by default, use now has to explicitely provide them by either: 1. Passing them to the constructor e.g: self.bn = nnx.BatchNorm(..., use_running_average=False) 2. Passing them to __call__: self.dropout(x, deterministic=False) 3. Using `nnx.view` to create a view of the model with specific values: train_model = nnx.view(model, detereministic=False, use_running_average=False) PiperOrigin-RevId: 877557949
1 parent d55b500 commit fbf4f0e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tunix/models/gemma3/vision.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(
170170
param_dtype=dtype_mm,
171171
rngs=rngs,
172172
)
173-
self.dropout = nnx.Dropout(rate=dropout)
173+
self.dropout = nnx.Dropout(rate=dropout, deterministic=False)
174174

175175
self.shd_config = shd_config
176176

@@ -246,7 +246,7 @@ def __init__(
246246
param_dtype=self.dtype_mm,
247247
rngs=rngs,
248248
)
249-
self.dropout = nnx.Dropout(rate=self.dropout_rate)
249+
self.dropout = nnx.Dropout(rate=self.dropout_rate, deterministic=False)
250250
self.fc2 = nnx.Linear(
251251
mlp_dim,
252252
self.width,
@@ -328,7 +328,7 @@ def __init__(
328328
dtype_mm=dtype_mm,
329329
rngs=rngs,
330330
)
331-
self.dropout = nnx.Dropout(rate=dropout)
331+
self.dropout = nnx.Dropout(rate=dropout, deterministic=False)
332332
self.ln2 = nnx.LayerNorm(
333333
num_features=width,
334334
epsilon=1e-6,
@@ -511,7 +511,7 @@ def __init__(
511511
sharding=shd_config.emb_pos_kernel if shd_config else (),
512512
)
513513

514-
self.dropout = nnx.Dropout(rate=dropout)
514+
self.dropout = nnx.Dropout(rate=dropout, deterministic=False)
515515

516516
self.transformer = Encoder(
517517
width=width,

0 commit comments

Comments
 (0)