Skip to content

Commit a5dfa29

Browse files
author
Flax Team
committed
Changed hardcoded type "float32" to dtype (models.py)
PiperOrigin-RevId: 321632088
1 parent 0f66ea7 commit a5dfa29

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

examples/imagenet/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def apply(self, x, num_classes, num_filters=64, num_layers=50,
7878
dtype=dtype)
7979
x = jnp.mean(x, axis=(1, 2))
8080
x = nn.Dense(x, num_classes, dtype=dtype)
81-
x = jnp.asarray(x, jnp.float32)
81+
x = jnp.asarray(x, dtype)
8282
x = nn.log_softmax(x)
8383
return x
8484

0 commit comments

Comments
 (0)