forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrecompute.py
More file actions
829 lines (725 loc) · 32.6 KB
/
recompute.py
File metadata and controls
829 lines (725 loc) · 32.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import copy
import inspect
import random
import weakref
from typing import TYPE_CHECKING, Any, TypedDict
import numpy as np
import paddle
from paddle import framework
from paddle.autograd import PyLayer
from paddle.base.framework import EagerParamBase
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
from paddle.framework import core, in_dynamic_mode
from ..utils.log_util import logger
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing_extensions import NotRequired
from paddle.nn import Sequential
class _Ctx(TypedDict):
segments: int = 1
preserve_rng_state: NotRequired[bool]
__all__ = []
def _varbase_help(param):
state = copy.deepcopy(param.__dict__)
new_param = EagerParamBase(
shape=param.shape,
dtype=param.dtype,
trainable=param.trainable,
name=param.name,
**state,
)
param._share_buffer_to(new_param)
return new_param
def detach_variable(inputs):
out = []
for inp in inputs:
if not isinstance(inp, core.eager.Tensor) and (
type(inp) is not tuple or not isinstance(inp[0], core.eager.Tensor)
):
# the inp is not a tensor or not a tuple of tensors
out.append(inp)
continue
if isinstance(inp, EagerParamBase):
out.append(_varbase_help(inp))
continue
if type(inp) is tuple:
detach_inp = []
for i in inp:
# detach all tensors in the tuple
assert isinstance(i, core.eager.Tensor)
if isinstance(i, EagerParamBase):
detach_inp.append(_varbase_help(i))
else:
tmp_i = i.detach()
tmp_i.stop_gradient = i.stop_gradient
detach_inp.append(tmp_i)
out.append(tuple(detach_inp))
continue
x = inp.detach()
x.stop_gradient = inp.stop_gradient
out.append(x)
return tuple(out)
def check_recompute_necessary(inputs):
necessary_for_each_input = []
for input_ in inputs:
if isinstance(input_, paddle.Tensor):
necessary_for_each_input.append(input_.stop_gradient)
elif type(input_) is tuple:
for i in input_:
# traverse all tensors in the tuple
if isinstance(i, paddle.Tensor):
necessary_for_each_input.append(i.stop_gradient)
if all(necessary_for_each_input):
logger.warning(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !"
)
def _protect_tensors(seq):
"""For each element in seq (a list or tuple of forward args), create a new
tensor Python object that shares the same underlying buffer via
_new_shared_tensor(), so that when pipeline-parallel calls
_release_input/_release_output (which clears the data pointer of the
original tensor), the copies held by recompute for backward are not
invalidated. Non-tensor elements are kept as-is.
Returns a list with the same length as seq.
"""
result = list(seq)
for idx, arg in enumerate(result):
if isinstance(arg, core.eager.Tensor):
# _new_shared_tensor() creates a new Python-level tensor object
# that shares the same C++ storage with arg, without cloning data.
shared = arg._new_shared_tensor()
assert shared is not arg, (
"_protect_tensors() must return a new Python object distinct from the original "
"tensor, otherwise the protection against pipeline-parallel tensor "
"release is ineffective."
)
result[idx] = shared
elif isinstance(arg, tuple):
# For tuple args (e.g., pipeline-parallel passes inputs as tuples),
# protect each tensor element inside the tuple individually;
# non-tensor elements (e.g., int, bool) are passed through unchanged.
protected_tuple = []
for t in arg:
if isinstance(t, core.eager.Tensor):
shared = t._new_shared_tensor()
protected_tuple.append(shared)
else:
protected_tuple.append(t)
result[idx] = tuple(protected_tuple)
return result
class CustomStatesManager:
"""CustomStatesManager"""
def __init__(self):
"""__init__"""
self.custom_get_state_func = None
self.custom_set_state_func = None
def set_custom_get_state_func(self, custom_get_state_func):
assert_msg = (
"The custom_state_manager does not support duplicate settings."
)
assert self.custom_get_state_func is None, assert_msg
self.custom_get_state_func = custom_get_state_func
def set_custom_set_state_func(self, custom_set_state_func):
assert_msg = (
"The custom_state_manager does not support duplicate settings."
)
assert self.custom_set_state_func is None, assert_msg
self.custom_set_state_func = custom_set_state_func
custom_state_manager = CustomStatesManager()
@contextlib.contextmanager
def switch_rng_state_tracker(
rng_state,
tracker,
numpy_state,
random_state,
custom_state=None,
custom_get_state_func=None,
custom_set_state_func=None,
):
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
orig_numpy_state = np.random.get_state()
orig_random_state = random.getstate()
np.random.set_state(numpy_state)
random.setstate(random_state)
if custom_state is not None:
assert custom_get_state_func is not None
assert custom_set_state_func is not None
orig_custom_state = custom_get_state_func()
custom_set_state_func(custom_state)
try:
yield
finally:
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
np.random.set_state(orig_numpy_state)
random.setstate(orig_random_state)
if custom_state is not None:
custom_set_state_func(orig_custom_state)
class RecomputeFunction(PyLayer):
@staticmethod
def forward(
ctx,
run_function,
preserve_rng_state,
offload_indices,
custom_get_state_func,
custom_set_state_func,
*args,
**kwargs,
):
# store for recomputing
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
ctx.offload_indices = offload_indices
ctx.kwargs = kwargs
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.
# NOTE recompute with restore RNG only support one scenario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu scenarios are not support
if ctx.preserve_rng_state:
ctx.fw_rng_state = paddle.get_rng_state()
ctx.fwd_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
ctx.fwd_numpy_state = np.random.get_state()
ctx.fwd_random_state = random.getstate()
ctx.fwd_custom_state = custom_get_state_func()
ctx.custom_get_state_func = custom_get_state_func
ctx.custom_set_state_func = custom_set_state_func
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = (
False if tracer._amp_level == core.AmpLevel.O0 else True
)
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError(f"unsupported amp level: {tracer._amp_level}")
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError(f"unsupported amp dtype: {tracer._amp_dtype}")
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args, **kwargs)
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.duplicate_tensor = [False for _ in range(len(args))]
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
if i in ctx.offload_indices:
cpu_arg = (
arg.pin_memory()
if core.is_compiled_with_cuda()
else arg.cpu()
)
cpu_arg._share_buffer_to(arg)
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
elif type(arg) is tuple:
assert i not in ctx.offload_indices, (
f"offload_indices should not contain tensor tuple in position{i}"
)
is_tensors = [paddle.is_tensor(a) for a in arg]
if all(is_tensors):
# the tuple is a tuple of tensors
tensors_stop_gradient = [a.stop_gradient for a in arg]
if not all(tensors_stop_gradient) and any(
tensors_stop_gradient
):
# tensors in the tuple have different stop_gradient value, which pylayer doesn't support
raise ValueError(
"Recompute receive a tuple containing tensor holds different stop gradient."
)
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
# Mark the tuple is a tuple of tensors
ctx.duplicate_tensor[i] = True
ctx.inputs.append(None)
elif any(is_tensors):
# the tuple contains tensors and non-tensor values
raise ValueError(
"Recompute receive a tuple containing tensor and non-tensor at same time."
)
else:
ctx.inputs.append(arg)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
return outputs
@staticmethod
def backward(ctx, *args):
with paddle.base.dygraph.guard():
# TODO need to check the recompute calling is valid or not
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
duplicate_tensor = ctx.duplicate_tensor
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = (
tensors[i].to(
paddle.base.framework._current_expected_place()
)
if i in ctx.offload_indices
else tensors[i]
)
if i in ctx.offload_indices:
# NOTE(zhiqiu): tensor.to(device) will set stop_gradient=True, which may break the gragh
inputs[idx].stop_gradient = tensors[i].stop_gradient
# paddle.enable_grad()
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with (
switch_rng_state_tracker(
ctx.fw_rng_state,
ctx.fwd_rng_state_tracker,
ctx.fwd_numpy_state,
ctx.fwd_random_state,
ctx.fwd_custom_state,
ctx.custom_get_state_func,
ctx.custom_set_state_func,
),
paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype,
),
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
if isinstance(outputs, core.eager.Tensor):
outputs = (outputs,)
assert len(outputs) == len(args)
# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if (
isinstance(outputs[i], core.eager.Tensor)
and not outputs[i].stop_gradient
):
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)
# actually backward
with paddle.amp.auto_cast(enable=False):
paddle.autograd.backward(
forward_outputs_with_grad, backward_inputs_with_grad
)
grads = []
for idx, inp in enumerate(detached_inputs):
if isinstance(inp, core.eager.Tensor):
grads.append(inp._grad_ivar())
elif type(inp) is tuple and duplicate_tensor[idx]:
# input is a tuple and is a tuple of tensors
if all(i.stop_gradient for i in inp):
# all tensors in the tuple doesn't need grad, only return a None for the whole tuple
grads.append(None)
else:
# all tensors in the tuple need grad, should return a tuple of grads
grads.append(tuple(i._grad_ivar() for i in inp))
if in_dynamic_mode():
grads = tuple(grads)
else:
grads = list(grads)
return grads
def _recompute_without_reentrant(
function,
custom_get_state_func,
custom_set_state_func,
preserve_rng_state=True,
*args,
**kwargs,
):
"""
recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd.
"""
if preserve_rng_state:
cur_device = paddle.get_device()
if cur_device.startswith('gpu:'):
fw_cuda_rng_state = paddle.get_cuda_rng_state()
elif 'cpu' in cur_device:
fw_cuda_rng_state = paddle.get_rng_state()
elif 'xpu:' in cur_device:
fw_cuda_rng_state = paddle.get_rng_state()
elif (
cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
):
fw_cuda_rng_state = paddle.get_rng_state(cur_device)
else:
raise RuntimeError(
f"Recompute with RNG preserve is not support current device: {cur_device}."
)
fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
fwd_numpy_state = np.random.get_state()
fwd_random_state = random.getstate()
fwd_custom_state = custom_get_state_func()
tracer = framework._dygraph_tracer()
is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
amp_level = 'O1'
if tracer._amp_dtype == 'float16':
amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
amp_dtype = 'bfloat16'
amp_white_list, amp_black_list = tracer._get_amp_op_list()
class Intermediate_Holder:
pass
storage = weakref.WeakKeyDictionary()
holder_list = []
def pack(x):
res = Intermediate_Holder()
holder_list.append(weakref.ref(res))
return res
def unpack(x):
unpack_counter = 0
if len(storage) == 0:
def inner_pack(inner_x):
nonlocal unpack_counter
unpack_counter += 1
if holder_list[unpack_counter - 1]() is None:
return
if inner_x is None:
storage[holder_list[unpack_counter - 1]()] = None
return
if hasattr(inner_x, "main_grad") or inner_x.grad is not None:
storage[holder_list[unpack_counter - 1]()] = inner_x
else:
if inner_x.is_dist():
tmp_tensor = core.eager.Tensor(inner_x)
else:
tmp_tensor = core.eager.Tensor(
inner_x.dtype,
inner_x.shape,
inner_x.name + "cpy",
core.VarDesc.VarType.DENSE_TENSOR,
inner_x.persistable,
)
inner_x._unsafe_share_buffer_to(tmp_tensor)
storage[holder_list[unpack_counter - 1]()] = tmp_tensor
return
def inner_unpack(inner_x):
raise Exception("An unexpected backward called on a tensor!")
if preserve_rng_state:
with (
switch_rng_state_tracker(
fw_cuda_rng_state,
fwd_cuda_rng_state_tracker,
fwd_numpy_state,
fwd_random_state,
fwd_custom_state,
custom_get_state_func,
custom_set_state_func,
),
paddle.set_grad_enabled(True),
paddle.amp.auto_cast(
enable=is_fw_autocast,
custom_white_list=amp_white_list,
custom_black_list=amp_black_list,
level=amp_level,
dtype=amp_dtype,
),
paddle.autograd.saved_tensors_hooks(
inner_pack, inner_unpack
),
):
function(*args, **kwargs)
else:
with (
paddle.set_grad_enabled(True),
paddle.amp.auto_cast(
enable=is_fw_autocast,
custom_white_list=amp_white_list,
custom_black_list=amp_black_list,
level=amp_level,
dtype=amp_dtype,
),
paddle.autograd.saved_tensors_hooks(
inner_pack, inner_unpack
),
):
function(*args, **kwargs)
if x not in storage:
raise Exception(
"Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute."
)
return storage.pop(x)
with paddle.autograd.saved_tensors_hooks(pack, unpack):
outputs = function(*args, **kwargs)
return outputs
def recompute(function, *args, **kwargs):
"""
recompute intermediate activations to save then memory.
Parameters:
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params,
and the other contains 'preserve_rng_state' and 'use_reentrant'. the key-value pair of preserve_rng_state,
which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value
will be restored when the forward recalculation of backpropagation is performed, its default value is True.
the key-value pair of use_reentrant is used to indicate which implementation of recompute you will be used.
'use_reentrant=True' means to use the PyLayer implementation of recompute, 'use_reentrant=False' means to
use the Hook implementation of recompute, its default value is True.
Returns:
Output of function on args.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED, env:GPU)
>>> import paddle
>>> from paddle.distributed.fleet.utils import recompute
>>> import random
>>> paddle.seed(2023)
>>> def get_fc_block(block_idx, input_size, is_last=False):
... block_name = "block_" + str(block_idx)
... block = paddle.nn.Sequential(
... (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
... (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
... (block_name + "_relu_1", paddle.nn.ReLU()),
... (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
... (block_name + "_relu_2", paddle.nn.ReLU()),
... )
... if is_last:
... block.add_sublayer(
... block_name + "_fc_2",
... paddle.nn.Linear(
... input_size, 1, bias_attr=False
... )
... )
... else:
... block.add_sublayer(
... block_name + "_fc_2",
... paddle.nn.Linear(input_size, input_size, bias_attr=False)
... )
... return block
>>> class Naive_fc_net(paddle.nn.Layer):
... def __init__(self, input_size=10,
... recompute_blocks=[1, 3],
... recompute_kwargs={}):
... super().__init__()
... self.recompute_blocks = recompute_blocks
... self.recompute_kwargs = recompute_kwargs
... self.runfunc0 = get_fc_block(0, input_size, is_last=False)
... self.runfunc1 = get_fc_block(1, input_size, is_last=False)
... self.runfunc2 = get_fc_block(2, input_size, is_last=False)
... self.runfunc3 = get_fc_block(3, input_size, is_last=False)
... self.runfunc4 = get_fc_block(4, input_size, is_last=True)
... self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
... def forward(self, inputs):
... nums = len(self.total_func)
... for i in range(nums):
... if i in self.recompute_blocks:
... inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
... else:
... inputs = self.total_func[i](inputs)
... return inputs
>>> def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
... gen = paddle.seed(10)
... gen.manual_seed(10)
... random.seed(10)
... if cuda_state:
... paddle.set_cuda_rng_state(cuda_state)
... batch_size, input_size = 1, 10
... model = Naive_fc_net(
... input_size,
... recompute_blocks=recompute_block,
... recompute_kwargs=recompute_kwargs)
... optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
... loss_ = []
... param_ = []
... grad_ = []
... for _ in range(5):
... x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
... y_pred = model(x)
... loss = y_pred.mean()
... loss_.append(loss.item())
... loss.backward()
... optimizer.step()
... param_.append(model.parameters()[9])
... grad_.append(model.parameters()[3]._grad_ivar())
... optimizer.clear_grad()
... return loss_, param_, grad_
>>> cuda_state = paddle.get_cuda_rng_state()
>>> # without recompute
>>> loss_ref, param_ref, grad_ref = run_model(
... cuda_state, recompute_block=[]
... )
>>> loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
>>> print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
>>> # The result of the recompute_loss should be the same as the normal_loss.
normal_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0], recompute_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0]
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
# whether to use reentrant method to implement recompute
use_reentrant = kwargs.pop('use_reentrant', True)
if custom_state_manager.custom_get_state_func is None:
assert custom_state_manager.custom_set_state_func is None
custom_get_state_func = lambda x=None: None
custom_set_state_func = lambda x=None: None
else:
custom_get_state_func = custom_state_manager.custom_get_state_func
custom_set_state_func = custom_state_manager.custom_set_state_func
if not in_dynamic_mode():
from paddle.distributed.auto_parallel.interface import (
recompute as static_auto_recompute,
)
return static_auto_recompute(function)(*args, **kwargs)
if framework._dygraph_tracer()._has_grad:
check_args = list(args)
check_args.extend(list(kwargs.values()))
check_recompute_necessary(check_args)
if use_reentrant:
offload_indices = kwargs.pop('offload_indices', [])
# rearrange `position-args + keyword-args` into `position-args`
input_args = []
if isinstance(function, paddle.nn.Layer):
dyfunc_sig = inspect.signature(function.forward)
else:
dyfunc_sig = inspect.signature(function)
bound_args = dyfunc_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for arg, param in zip(
bound_args.arguments.values(), dyfunc_sig.parameters.values()
):
if param.kind == param.VAR_POSITIONAL:
input_args.extend(arg)
elif param.kind in (
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
):
input_args.append(arg)
elif param.kind == param.VAR_KEYWORD:
input_args.extend(arg.values())
elif param.kind == param.KEYWORD_ONLY:
raise ValueError(
"Currently, keyword-only arguments are not supported when you want to send kwargs(dict parameter) to function with use_reentrant=True."
)
else:
raise ValueError("Unknown parameter kind.")
# Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
protected_args = _protect_tensors(input_args)
return RecomputeFunction.apply(
function,
preserve,
offload_indices,
custom_get_state_func,
custom_set_state_func,
*protected_args,
)
else:
return _recompute_without_reentrant(
function,
custom_get_state_func,
custom_set_state_func,
preserve,
*args,
**kwargs,
)
def recompute_sequential(
ctx: _Ctx,
functions: Sequential | Sequence[Callable[..., Any]],
*args: Any,
**kwargs: Any,
) -> Any:
"""
recompute intermediate activations to save the memory for 'Sequential' models. use 'ctx' to transmit some context params, it is similar to 'recompute_hybrid' API.
Parameters:
ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed.
functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs(tuple) to the function.
**kwargs(Dict): inputs(dict) to the function.
Returns:
Output of function on args and kwargs.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> import paddle
>>> from paddle.incubate.distributed.fleet import recompute_sequential
>>> input = paddle.ones(shape=[8, 10])
>>> model = paddle.nn.Sequential(paddle.nn.Linear(10, 10), paddle.nn.Linear(10, 2))
>>> output = recompute_sequential({'segments' : 1}, model, input)
"""
segments = ctx.get('segments', 1)
preserve_rng_state = ctx.get('preserve_rng_state', True)
def _run_func(begin, end, funcs):
def do_run(input):
for i in range(begin, end + 1):
input = funcs[i](input)
return input
return do_run
if isinstance(functions, paddle.nn.Sequential):
functions = list(functions.children())
segment_size = len(functions) // segments
end = -1
for begin in range(0, segment_size * (segments - 1), segment_size):
end = begin + segment_size - 1
args = recompute(
_run_func(begin, end, functions),
*args,
preserve_rng_state=preserve_rng_state,
**kwargs,
)
return _run_func(end + 1, len(functions) - 1, functions)(*args)