Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def minimize(
no_grad_set: set[Tensor],
) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]: ...

def step(self) -> None: ...
def step(
self, closure: Callable[[], Tensor] | None
) -> Tensor | None: ...

def set_state_dict(self, state_dict: dict[str, Tensor]) -> None: ...

Expand Down
43 changes: 37 additions & 6 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .optimizer import Optimizer

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence

from typing_extensions import NotRequired

Expand Down Expand Up @@ -470,33 +470,63 @@ def _append_optimize_op(self, block, param_and_grad):

@imperative_base.no_grad
@framework.non_static_only
def step(self) -> None:
def step(
self, closure: Callable[[], Tensor] | None = None
) -> Tensor | None:
"""
Execute the optimizer and update parameters once.

Args:
closure (Callable|None, optional): A closure that reevaluates the model
and returns the loss. It should be a callable that takes no arguments
and returns a Tensor. This is useful for optimizers that need to
evaluate the loss multiple times (e.g., line search). Default is None.

Returns:
None
Tensor|None: If closure is provided, returns the loss value computed by
the closure. Otherwise returns None.

Examples:
.. code-block:: pycon

>>> import paddle

>>> a = paddle.rand([2, 13], dtype="float32")
>>> x = paddle.rand([2, 13], dtype="float32")
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> adam = paddle.optimizer.Adam(
... learning_rate=0.01,
... parameters=linear.parameters(),
... )
>>> out = linear(a)
>>> out = linear(x)
>>> out.backward()
>>> adam.step()
>>> adam.clear_grad()

>>> # usage 1: not use closure
>>> adam.zero_grad()
>>> output = linear(x)
>>> loss = paddle.mean(output)
>>> loss.backward()
>>> adam.step()

>>> # usage 2: use closure
>>> def closure():
... adam.zero_grad()
... output = linear(x)
... loss = paddle.mean(output)
... loss.backward()
... return loss
>>> step_loss = adam.step(closure)
"""
loss = None
if closure is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss = None
if closure is not None:
    with imperative_base.enable_grad():
        loss = closure()

...

return loss

with imperative_base.enable_grad():
loss = closure()

if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
return loss

if not isinstance(self._parameter_list[0], dict):
params_grads = []
Expand Down Expand Up @@ -550,6 +580,7 @@ def step(self) -> None:
params_grads=params_grads,
param_group_idx=idx,
)
return loss

def _multi_tensor_init(self, target_block, parameters, param_group_idx):
"""
Expand Down
41 changes: 36 additions & 5 deletions python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,33 +633,63 @@ def __str__(self):

@imperative_base.no_grad
@framework.non_static_only
def step(self) -> None:
def step(
self, closure: Callable[[], Tensor] | None = None
) -> Tensor | None:
"""
Execute the optimizer and update parameters once.

Args:
closure (Callable|None, optional): A closure that reevaluates the model
and returns the loss. It should be a callable that takes no arguments
and returns a Tensor. This is useful for optimizers that need to
evaluate the loss multiple times (e.g., line search). Default is None.

Returns:
None
Tensor|None: If closure is provided, returns the loss value computed by
the closure. Otherwise returns None.

Examples:
.. code-block:: pycon

>>> import paddle

>>> a = paddle.rand([2, 13], dtype="float32")
>>> x = paddle.rand([2, 13], dtype="float32")
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> opt = paddle.optimizer.AdamW(
... learning_rate=0.01,
... parameters=linear.parameters(),
... )
>>> out = linear(a)
>>> out = linear(x)
>>> out.backward()
>>> opt.step()
>>> opt.clear_grad()

>>> # usage 1: not use closure
>>> opt.zero_grad()
>>> output = linear(x)
>>> loss = paddle.mean(output)
>>> loss.backward()
>>> opt.step()

>>> # usage 2: use closure
>>> def closure():
... opt.zero_grad()
... output = linear(x)
... loss = paddle.mean(output)
... loss.backward()
... return loss
>>> step_loss = opt.step(closure)
"""
loss = None
if closure is not None:
with imperative_base.enable_grad():
loss = closure()

if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
return loss

if not isinstance(self._parameter_list[0], dict):
params_grads = []
Expand Down Expand Up @@ -725,6 +755,7 @@ def step(self) -> None:
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
return loss

def _update_param_group(self, parameters):
self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
Expand Down
41 changes: 36 additions & 5 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,33 +2037,63 @@ def _declarative_step(self):

@imperative_base.no_grad()
@framework.non_static_only
def step(self) -> None:
def step(
self, closure: Callable[[], Tensor] | None = None
) -> Tensor | None:
"""
Execute the optimizer and update parameters once.

Args:
closure (Callable|None, optional): A closure that reevaluates the model
and returns the loss. It should be a callable that takes no arguments
and returns a Tensor. This is useful for optimizers that need to
evaluate the loss multiple times (e.g., line search). Default is None.

Returns:
None
Tensor|None: If closure is provided, returns the loss value computed by
the closure. Otherwise returns None.

Examples:
.. code-block:: pycon

>>> import paddle

>>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
>>> x = paddle.arange(26, dtype="float32").reshape([2, 13])
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> adam = paddle.optimizer.Adam(
... learning_rate=0.01,
... parameters=linear.parameters(),
... )
>>> out = linear(a)
>>> out = linear(x)
>>> out.backward()
>>> adam.step()
>>> adam.clear_grad()

>>> # usage 1: not use closure
>>> adam.zero_grad()
>>> output = linear(x)
>>> loss = paddle.mean(output)
>>> loss.backward()
>>> adam.step()

>>> # usage 2: use closure
>>> def closure():
... adam.zero_grad()
... output = linear(x)
... loss = paddle.mean(output)
... loss.backward()
... return loss
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个加了closure的都改下示例代码

>>> step_loss = adam.step(closure)
"""
loss = None
if closure is not None:
with imperative_base.enable_grad():
loss = closure()

if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
return loss

if not isinstance(self._param_groups[0], dict):
params_grads = []
Expand Down Expand Up @@ -2111,6 +2141,7 @@ def step(self) -> None:
params_grads=params_grads,
param_group_idx=idx,
)
return loss

def _add_param_group(self, param_group):
"""
Expand Down
58 changes: 58 additions & 0 deletions test/legacy_test/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,64 @@ def test_weight_decay_int(self):
adam.step()
adam.zero_grad(False)

def test_step_without_closure(self):
paddle.seed(100)
numpy.random.seed(100)
paddle.disable_static()
x = paddle.arange(26, dtype="float32").reshape([2, 13])
linear = paddle.nn.Linear(13, 5)
optimizers = [
paddle.optimizer.Adam(
learning_rate=0.01,
parameters=linear.parameters(),
),
paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=linear.parameters(),
),
paddle.optimizer.ASGD(
learning_rate=0.01,
parameters=linear.parameters(),
),
]
for optimizer in optimizers:
optimizer.zero_grad()
output = linear(x)
loss = paddle.mean(output)
loss.backward()
optimizer.step()

def test_step_with_closure(self):
paddle.seed(100)
numpy.random.seed(100)
paddle.disable_static()
x = paddle.arange(26, dtype="float32").reshape([2, 13])
linear = paddle.nn.Linear(13, 5)
optimizers = [
paddle.optimizer.Adam(
learning_rate=0.01,
parameters=linear.parameters(),
),
paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=linear.parameters(),
),
paddle.optimizer.ASGD(
learning_rate=0.01,
parameters=linear.parameters(),
),
]
for optimizer in optimizers:

def closure():
optimizer.zero_grad()
output = linear(x)
loss = paddle.mean(output)
loss.backward()
return loss

loss = optimizer.step(closure)


if __name__ == '__main__':
paddle.enable_static()
Expand Down
Loading