Skip to content

Commit 9775f84

Browse files
enable bf16 on eval (#46)
1 parent bc89134 commit 9775f84

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

training/train_single_task.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ def _update_minbatch(train_state, batch_info):
260260
env,
261261
env_params,
262262
train_state,
263-
# TODO: make this as a static method mb?
264-
jnp.zeros((1, config.rnn_num_layers, config.rnn_hidden_dim)),
263+
jnp.zeros(
264+
(1, config.rnn_num_layers, config.rnn_hidden_dim),
265+
dtype=jnp.bfloat16 if config.enable_bf16 else None,
266+
),
265267
1,
266268
)
267269
eval_stats = jax.lax.pmean(eval_stats, axis_name="devices")

0 commit comments

Comments
 (0)