Skip to content

Commit 33ade7b

Browse files
committed
[Distributed] fix TypeError of multi learning_rate input.
Signed-off-by: 泊霆 <[email protected]>
1 parent 6dae552 commit 33ade7b

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tensorflow/python/distribute/hvd_strategy.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,9 @@ def wraps_optimizer(cls):
388388
HvdOptimizer
389389
'''
390390
class HvdOptimizer(cls, optimizer.Optimizer):
391-
def __init__(self, *args, **kwargs):
392-
kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\
393-
HvdContext.get().world_size
394-
super(HvdOptimizer, self).__init__(*args, **kwargs)
391+
def __init__(self, learning_rate=0.001, *args, **kwargs):
392+
learning_rate = learning_rate * HvdContext.get().world_size
393+
super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs)
395394

396395
def compute_gradients(self, loss, **kwargs):
397396
loss = hvd.allreduce(loss, op=hvd.Sum)
@@ -1449,4 +1448,4 @@ def export(export_dir_base,
14491448
as_text=as_text,
14501449
clear_devices=clear_devices,
14511450
strip_default_attrs=strip_default_attrs,
1452-
modes=[mode])
1451+
modes=[mode])

0 commit comments

Comments
 (0)