Skip to content

Commit 601c6d6

Browse files
authored
Fix for AdaDelta (#603)
- state was being read from parameter "s" - but being stored in parameter "u"
1 parent ba8d6bf commit 601c6d6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/mlx/optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def apply_single(
284284
eps = self.eps
285285

286286
v = state.get("v", mx.zeros_like(gradient))
287-
u = state.get("s", mx.zeros_like(gradient))
287+
u = state.get("u", mx.zeros_like(gradient))
288288

289289
v = rho * v + (1 - rho) * mx.square(gradient)
290290
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient

0 commit comments

Comments
 (0)