@@ -312,46 +312,42 @@ def static_pylayer(forward_fn, inputs, backward_fn=None, name=None):
312
312
Examples:
313
313
.. code-block:: python
314
314
315
- >>> import paddle
316
- >>> import numpy as np
317
-
318
- >>> paddle.enable_static()
319
-
320
- >>> def forward_fn(x):
321
- ... return paddle.exp(x)
322
-
323
- >>> def backward_fn(dy):
324
- ... return 2 * paddle.exp(dy)
325
-
326
- >>> main_program = paddle.static.Program()
327
- >>> start_program = paddle.static.Program()
328
-
329
- >>> place = paddle.CPUPlace()
330
- >>> exe = paddle.static.Executor(place)
331
- >>> with paddle.static.program_guard(main_program, start_program):
332
- ... data = paddle.static.data(name="X", shape=[None, 5], dtype="float32")
333
- ... data.stop_gradient = False
334
- ... ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn)
335
- ... data_grad = paddle.static.gradients([ret], data)[0]
336
-
337
- >>> exe.run(start_program)
338
- >>> x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
339
- >>> x, x_grad, y = exe.run(
340
- ... main_program,
341
- ... feed={"X": x},
342
- ... fetch_list=[
343
- ... data.name,
344
- ... data_grad.name,
345
- ... ret.name
346
- ... ],
347
- ... )
348
-
349
- >>> print(x)
350
- [[1. 2. 3. 4. 5.]]
351
- >>> print(x_grad)
352
- [[5.4365635 5.4365635 5.4365635 5.4365635 5.4365635]]
353
- >>> print(y)
354
- [[ 2.7182817 7.389056 20.085537 54.59815 148.41316 ]]
315
+ >>> import paddle
316
+ >>> import numpy as np
317
+
318
+ >>> paddle.enable_static()
319
+
320
+ >>> def forward_fn(x):
321
+ ... return paddle.exp(x)
322
+
323
+ >>> def backward_fn(dy):
324
+ ... return 2 * paddle.exp(dy)
325
+
326
+ >>> main_program = paddle.static.Program()
327
+ >>> start_program = paddle.static.Program()
328
+
329
+ >>> place = paddle.CPUPlace()
330
+ >>> exe = paddle.static.Executor(place)
331
+ >>> with paddle.static.program_guard(main_program, start_program):
332
+ ... data = paddle.static.data(name="X", shape=[None, 5], dtype="float32")
333
+ ... data.stop_gradient = False
334
+ ... ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn)
335
+ ... data_grad = paddle.static.gradients([ret], data)[0]
336
+
337
+ >>> exe.run(start_program)
338
+ >>> x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
339
+ >>> x, x_grad, y = exe.run(
340
+ ... main_program,
341
+ ... feed={"X": x},
342
+ ... fetch_list=[data, data_grad, ret],
343
+ ... )
344
+
345
+ >>> print(x)
346
+ [[1. 2. 3. 4. 5.]]
347
+ >>> print(x_grad)
348
+ [[5.4365635 5.4365635 5.4365635 5.4365635 5.4365635]]
349
+ >>> print(y)
350
+ [[ 2.7182817 7.389056 20.085537 54.59815 148.41316 ]]
355
351
"""
356
352
assert (
357
353
in_dygraph_mode () is False
0 commit comments