Skip to content

Commit 9986ffa

Browse files
authored
add adadelta for torch (#534)
Co-authored-by: Haifeng Jin <[email protected]>
1 parent 4fcd567 commit 9986ffa

File tree

5 files changed

+66
-8
lines changed

5 files changed

+66
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
3+
from keras_core import ops
4+
from keras_core import optimizers
5+
from keras_core.backend.torch.optimizers import torch_parallel_optimizer
6+
7+
8+
class Adadelta(
9+
torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adadelta
10+
):
11+
def _parallel_update_step(
12+
self,
13+
grads,
14+
variables,
15+
learning_rate,
16+
):
17+
keras_variables = variables
18+
variables = [v.value for v in variables]
19+
20+
dtype = variables[0].dtype
21+
lr = ops.cast(learning_rate, dtype)
22+
rho = self.rho
23+
24+
accumulated_grads = [
25+
self._accumulated_grads[self._get_variable_index(variable)].value
26+
for variable in keras_variables
27+
]
28+
accumulated_delta_vars = [
29+
self._accumulated_delta_vars[
30+
self._get_variable_index(variable)
31+
].value
32+
for variable in keras_variables
33+
]
34+
torch._foreach_mul_(accumulated_grads, rho)
35+
torch._foreach_add_(
36+
accumulated_grads, torch._foreach_mul(grads, grads), alpha=1 - rho
37+
)
38+
39+
def rms(x):
40+
return torch._foreach_sqrt(torch._foreach_add(x, self.epsilon))
41+
42+
delta_vars = torch._foreach_mul(
43+
torch._foreach_div(
44+
torch._foreach_mul(rms(accumulated_delta_vars), grads),
45+
rms(accumulated_grads),
46+
),
47+
-1,
48+
)
49+
torch._foreach_mul_(accumulated_delta_vars, rho)
50+
torch._foreach_add_(
51+
accumulated_delta_vars,
52+
torch._foreach_mul(delta_vars, delta_vars),
53+
alpha=1 - rho,
54+
)
55+
56+
torch._foreach_add_(variables, delta_vars, alpha=lr)

keras_core/backend/torch/optimizers/torch_optimizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
class TorchOptimizer(BaseOptimizer):
88
def __new__(cls, *args, **kwargs):
99
# Import locally to avoid circular imports.
10+
from keras_core.backend.torch.optimizers import torch_adadelta
1011
from keras_core.backend.torch.optimizers import torch_adam
1112
from keras_core.backend.torch.optimizers import torch_adamw
1213
from keras_core.backend.torch.optimizers import torch_rmsprop
1314
from keras_core.backend.torch.optimizers import torch_sgd
1415

1516
OPTIMIZERS = {
17+
optimizers.Adadelta: torch_adadelta.Adadelta,
1618
optimizers.Adam: torch_adam.Adam,
1719
optimizers.AdamW: torch_adamw.AdamW,
1820
optimizers.RMSprop: torch_rmsprop.RMSprop,

keras_core/backend/torch/optimizers/torch_rmsprop.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ def _parallel_update_step(
5757
self._momentums[self._get_variable_index(variable)].value
5858
for variable in keras_variables
5959
]
60-
momentum_list = torch._foreach_add(
61-
increments, momentum_list, alpha=self.momentum
62-
)
60+
torch._foreach_mul_(momentum_list, self.momentum)
61+
torch._foreach_add_(momentum_list, increments)
6362
torch._foreach_add_(variables, momentum_list, alpha=-1)
6463
else:
6564
torch._foreach_add_(variables, increments, alpha=-1)

keras_core/backend/torch/optimizers/torch_sgd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _parallel_update_step(
1515
variables = [v.value for v in variables]
1616
if self.momentum != 0:
1717
bufs = [
18-
self.momentums[self._get_variable_index(variable.value)].value
18+
self.momentums[self._get_variable_index(variable)].value
1919
for variable in keras_variables
2020
]
2121

keras_core/optimizers/adadelta_test.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from keras_core import backend
4+
from keras_core import ops
45
from keras_core import testing
56
from keras_core.optimizers.adadelta import Adadelta
67

@@ -16,7 +17,7 @@ def test_config(self):
1617

1718
def test_single_step(self):
1819
optimizer = Adadelta(learning_rate=0.5)
19-
grads = np.array([1.0, 6.0, 7.0, 2.0])
20+
grads = ops.array([1.0, 6.0, 7.0, 2.0])
2021
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
2122
optimizer.apply_gradients(zip([grads], [vars]))
2223
self.assertAllClose(
@@ -25,7 +26,7 @@ def test_single_step(self):
2526

2627
def test_weight_decay(self):
2728
grads, var1, var2, var3 = (
28-
np.zeros(()),
29+
ops.zeros(()),
2930
backend.Variable(2.0),
3031
backend.Variable(2.0, name="exclude"),
3132
backend.Variable(2.0),
@@ -49,8 +50,8 @@ def test_correctness_with_golden(self):
4950
optimizer = Adadelta(learning_rate=1.0, rho=0.8, epsilon=1e-6)
5051

5152
x = backend.Variable(np.ones([10]))
52-
grads = np.arange(0.1, 1.1, 0.1)
53-
first_grads = np.full((10,), 0.01)
53+
grads = ops.arange(0.1, 1.1, 0.1)
54+
first_grads = ops.full((10,), 0.01)
5455

5556
golden = np.tile(
5657
[[0.9978], [0.9947], [0.9915], [0.9882], [0.9849]], (1, 10)

0 commit comments

Comments
 (0)