-
Notifications
You must be signed in to change notification settings - Fork 6k
Expand file tree
/
Copy pathauto_cast.py
More file actions
1443 lines (1236 loc) · 55.4 KB
/
auto_cast.py
File metadata and controls
1443 lines (1236 loc) · 55.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import os
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Protocol,
TypeVar,
Union,
overload,
)
import paddle
from paddle.base import core
from paddle.base.framework import (
_current_expected_place,
_dygraph_tracer,
dygraph_only,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from paddle.base.wrapped_decorator import signature_safe_contextmanager
from paddle.static.amp.decorator import OptimizerWithMixedPrecision
from .amp_lists import black_list, white_list
if TYPE_CHECKING:
from collections.abc import Generator
from contextlib import AbstractContextManager
from typing_extensions import TypeAlias, TypeGuard
from paddle import Tensor
from paddle._typing import PlaceLike
from paddle._typing.dtype_like import _DTypeLiteral
from paddle.nn import Layer
from paddle.nn.layer.layers import _StateDict
from paddle.static import Operator, Program
_AmpLevelLiteral = Literal["O0", "OD", "O1", "O2"]
_CustomList: TypeAlias = Union[list[str], tuple[str, ...], set[str]]
class _OptimizerLike(Protocol):
def minimize(
self,
loss: Tensor,
startup_program: Program,
parameters: list[Tensor],
no_grad_set: set[Tensor],
) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]: ...
def step(
self, closure: Callable[[], Tensor] | None
) -> Tensor | None: ...
def set_state_dict(self, state_dict: dict[str, Tensor]) -> None: ...
def clear_grad(self, set_to_zero: bool) -> None: ...
_ModelsT = TypeVar("_ModelsT", "Layer", list["Layer"])
_OptimizersT = TypeVar("_OptimizersT", "_OptimizerLike", list["_OptimizerLike"])
AMP_RELATED_FLAGS = [
'FLAGS_cudnn_exhaustive_search',
'FLAGS_conv_workspace_size_limit',
'FLAGS_cudnn_batchnorm_spatial_persistent',
]
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1000,
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
AMP_LEVEL = core.AmpLevel
_g_amp_state_ = None
def amp_state():
global _g_amp_state_
return _g_amp_state_
class AMPGlobalState:
model_parameters: list[Tensor]
use_master_grad: bool
already_register_final_backward_hook: bool
already_classify_params_meshes: bool
mesh2params: dict[paddle.distributed.ProcessMesh | None, list[Tensor]]
amp_dtype: _DTypeLiteral
def __init__(self) -> None:
self.model_parameters = []
self.use_master_grad = False
self.already_register_final_backward_hook = False
self.already_classify_params_meshes = False # For dist
self.mesh2params = {} # For dist
self.amp_dtype = 'float32'
def __setattr__(self, name: str, val: Any) -> None:
self.__dict__[name] = val
_amp_global_state = AMPGlobalState()
def amp_global_state() -> AMPGlobalState:
return _amp_global_state
# NOTE(zhiqiu): similar as paddle.static.amp.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list(
custom_white_list: _CustomList,
custom_black_list: _CustomList,
level: _AmpLevelLiteral = 'O1',
dtype: _DTypeLiteral = 'float16',
) -> tuple[set[str], set[str]]:
"""
Update black and white list according to users' custom list.
"""
if level == 'O0':
_white_list = set()
_black_list = set()
return _white_list, _black_list
_white_list = copy.copy(white_list()[dtype][level])
_black_list = copy.copy(black_list()[dtype][level])
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
raise ValueError("Custom white list overlap custom black list")
if custom_white_list:
for op_name in custom_white_list:
if op_name in _black_list:
_black_list.remove(op_name)
_white_list.add(op_name)
if custom_black_list:
for op_name in custom_black_list:
if op_name in _white_list:
_white_list.remove(op_name)
_black_list.add(op_name)
return _white_list, _black_list
def _in_amp_guard() -> bool:
"""
Judge whether current code block is in `amp_guard` context.
"""
tracer = _dygraph_tracer()
if tracer:
if tracer._amp_level == core.AmpLevel.O1:
return True
else:
return False
else:
return False
def _in_pure_fp16_guard() -> bool:
tracer = _dygraph_tracer()
return tracer and tracer._amp_level == core.AmpLevel.O2
def _is_gpu_float16_supported() -> bool:
"""
Judge whether current gpu support float16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
return prop[0] >= 7 or paddle.is_compiled_with_rocm()
def _is_gpu_bfloat16_supported() -> bool:
"""
Judge whether current gpu support bfloat16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None and cuda_version != 'False':
cuda_version_check = int(cuda_version.split('.')[0]) >= 11
else:
cuda_version_check = False
return prop[0] >= 8 and cuda_version_check or paddle.is_compiled_with_rocm()
def _is_xpu_float16_supported() -> bool:
"""
Judge whether current xpu device support float16 amp.
Only XPU2 and XPU3 support float16 amp.
"""
place = _current_expected_place()
return (
core.get_xpu_device_version(place.get_device_id())
>= core.XPUVersion.XPU2
)
def _is_xpu_bfloat16_supported() -> bool:
"""
Judge whether current xpu device support bfloat16 amp.
Only XPU3 support bfloat16 amp.
Although XPU2 supports bfloat16 computing, but XPU2's bfloat16 operators haven't been widely covered.
"""
place = _current_expected_place()
return (
core.get_xpu_device_version(place.get_device_id())
>= core.XPUVersion.XPU3
)
def _is_custom_device_bfloat16_supported() -> bool:
"""
Judge whether current custom device support bfloat16 amp.
"""
place = _current_expected_place()
return (
place.get_device_type() == 'npu'
or place.get_device_type() == 'intel_hpu'
or place.get_device_type() == 'iluvatar_gpu'
or place.get_device_type() == 'metax_gpu'
)
def need_keep_fp32(layer: Layer, dtype: str) -> bool:
need_keep_fp32 = False
# Highest priority. Because all the layers except BN will use bfloat16 params in bfloat16 training,
# here we provide a option to keep fp32 param.
if not layer._cast_to_low_precision:
need_keep_fp32 = True
# The BN layers will keep fp32
elif isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.SyncBatchNorm,
),
):
need_keep_fp32 = True
# layer._dtype is used to set params dtype. BF16 will use bf16 params.
elif (layer._dtype == 'float16') or (
(dtype == 'float16')
and isinstance(
layer,
(
paddle.nn.LayerNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
)
):
need_keep_fp32 = True
return need_keep_fp32
def set_excluded_layers(
models: list[Layer],
excluded_layers: Layer | list[Layer | type[Layer]] | type[Layer],
) -> None:
excluded_layers_instances = []
excluded_layers_types = []
error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types."
if excluded_layers is None:
excluded_layers = []
elif isinstance(excluded_layers, paddle.nn.Layer):
excluded_layers_instances = [excluded_layers]
elif isinstance(excluded_layers, type) and issubclass(
excluded_layers, paddle.nn.Layer
):
excluded_layers_types = [excluded_layers]
elif isinstance(excluded_layers, list):
for item in excluded_layers:
if isinstance(item, paddle.nn.Layer):
excluded_layers_instances.append(item)
elif issubclass(item, paddle.nn.Layer):
excluded_layers_types.append(item)
else:
raise TypeError(error_message)
else:
raise TypeError(error_message)
for idx in range(len(excluded_layers_instances)):
for layer in excluded_layers_instances[idx].sublayers(
include_self=True
):
layer._cast_to_low_precision = False
excluded_layers_types = tuple(excluded_layers_types)
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
if isinstance(layer, excluded_layers_types):
layer._cast_to_low_precision = False
def _pir_apply(
self: Layer,
func: Callable[[Tensor, _DTypeLiteral], Tensor | None],
dtype: _DTypeLiteral,
include_sublayers: bool = True,
) -> None:
if include_sublayers:
for layer in self.children():
_pir_apply(layer, func, dtype, include_sublayers)
for key, param in self._parameters.items():
if param is not None:
param_applied = func(param, dtype)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = func(buf, dtype)
self._dtype = dtype
def _pir_transform(t: Tensor, dtype: str) -> None:
main = paddle.static.default_main_program()
startup = paddle.static.default_startup_program()
with paddle.static.program_guard(startup):
block = startup.global_block()
for op in block.ops:
if (
op.name() == 'builtin.set_parameter'
and op.attrs()['parameter_name'] == t.name
):
param = op.operand(0).source()
cast_param = paddle.cast(param, dtype)
cast_param.persistable = True
paddle._pir_ops.update_parameter(cast_param, t.name)
block.remove_op(op)
break
main.set_parameters_from(startup)
with paddle.static.program_guard(main):
paddle.pir.reset_insertion_point_to_start()
block = main.global_block()
cast_param = paddle._pir_ops.parameter(t.name)
cast_param.trainable = t.trainable
cast_param.stop_gradient = t.stop_gradient
cast_param.persistable = t.persistable
cast_param.optimize_attr = t.optimize_attr
cast_param.regularizer = t.regularizer
cast_param.do_model_average = t.do_model_average
cast_param.need_clip = t.need_clip
cast_param.is_distributed = t.is_distributed
cast_param.is_parameter = t.is_parameter
op = t.get_defining_op()
t.replace_all_uses_with(cast_param)
block.remove_op(op)
t.value_assign(cast_param)
def _pir_to_impl(
self: Layer,
dtype: _DTypeLiteral,
include_sublayers: bool,
floating_only: bool,
) -> Layer:
def transform(t: Tensor, dtype: _DTypeLiteral) -> Tensor | None:
if floating_only and (not paddle.is_floating_point(t)):
return t
return _pir_transform(t, dtype)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
_pir_apply(self, transform, dtype, include_sublayers)
self._dtype = dtype
return self
def amp_initialize(
models: list[Layer],
dtype: _DTypeLiteral,
excluded_layers: Layer | list[Layer | type[Layer]] | type[Layer],
) -> list[Layer]:
set_excluded_layers(models, excluded_layers)
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
if need_keep_fp32(layer, dtype):
continue
if dtype == "float16" and isinstance(
layer,
(
paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention,
),
):
layer._amp_decorate(dtype=dtype)
continue
if in_pir_mode():
_pir_to_impl(
layer,
dtype=dtype,
include_sublayers=False,
floating_only=True,
)
else:
layer._to_impl(
dtype=dtype, include_sublayers=False, floating_only=True
)
return models
def check_models(models: list[Layer]) -> None:
for model in models:
if not isinstance(model, paddle.nn.Layer):
raise RuntimeError(
f"Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {type(model)}."
)
if isinstance(model, paddle.DataParallel):
raise RuntimeError(
"For distributed AMP training, you should first use paddle.amp.decorate() to decorate origin model, and then call paddle.DataParallel get distributed model."
)
def _is_valid_optimizer(optimizer: Any) -> TypeGuard[_OptimizerLike]:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
)
return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
),
)
def check_optimizers(optimizers: list[Any]) -> None:
for optimizer in optimizers:
if not _is_valid_optimizer(optimizer):
raise RuntimeError(
f"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or DygraphShardingOptimizer, but receive {type(optimizer)}."
)
@signature_safe_contextmanager
def amp_guard(
enable: bool = True,
custom_white_list: _CustomList | None = None,
custom_black_list: _CustomList | None = None,
level: _AmpLevelLiteral = 'O1',
dtype: _DTypeLiteral = 'float16',
use_promote: bool = True,
) -> Generator[None, None, None]:
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
by autocast algorithm for better performance.
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list|tuple|None, optional): The custom white_list. It's the set of ops that support
fp16 calculation and are considered numerically-safe and performance-critical. These ops
will be converted to fp16.
custom_black_list(set|list|tuple|None, optional): The custom black_list. The set of ops that support fp16
calculation and are considered numerically-dangerous and whose effects may also be
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp).
dtype(str|core.DataType, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
use_promote(bool, optional): Whether op's dtype is 'float32', accord 'Promote to the Widest' principle, use 'float32' to calculate.
Only active on 'AMP-02'. Default is True.
Examples:
.. code-block:: pycon
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> data = paddle.uniform([10, 3, 32, 32], paddle.float32, -1, 1)
>>> conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> conv2d = paddle.amp.amp_decorate(models=conv2d, level='O2')
>>> with paddle.amp.amp_guard():
... conv = conv2d(data)
... print(conv.dtype)
>>> # doctest: +SKIP("This has diff in xdoctest env")
paddle.float16
>>> # doctest: -SKIP
>>> with paddle.amp.amp_guard(enable=False):
... conv = conv2d(data)
... print(conv.dtype)
>>> # doctest: +SKIP("This has diff in xdoctest env")
paddle.float32
>>> # doctest: -SKIP
"""
assert in_dynamic_or_pir_mode(), (
"We only support 'amp_guard' in dynamic or pir mode."
)
amp_state = locals()
global _g_amp_state_
original_state = _g_amp_state_
_g_amp_state_ = amp_state
# check amp_level: O0-O2
level = level.upper()
if level not in ['O0', 'OD', 'O1', 'O2']:
raise ValueError("level should be O0, OD, O1 or O2.")
# check amp_dtype: float16 or bfloat16
if isinstance(dtype, paddle.base.core.DataType):
dtype = dtype.name
dtype = dtype.lower()
if enable:
if dtype not in ['float16', 'bfloat16']:
raise ValueError(
"If enable amp, dtype should be 'float16' or 'bfloat16'."
)
amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype
if level == 'OD':
amp_level = AMP_LEVEL.OD
elif level == 'O1':
amp_level = AMP_LEVEL.O1
elif level == 'O2':
amp_level = AMP_LEVEL.O2
elif level == 'O0':
amp_level = AMP_LEVEL.O0
_white_list, _black_list = _update_list(
custom_white_list, custom_black_list, level, dtype
)
if in_pir_mode():
if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
amp_attrs = core._get_amp_attrs()
# set amp level
original_amp_level = amp_attrs._amp_level
amp_attrs._amp_level = amp_level
# set amp op list
original_white_list, original_black_list = core._get_amp_op_list()
core._set_amp_op_list(_white_list, _black_list)
# set amp dtype
original_amp_dtype = amp_attrs._amp_dtype
amp_attrs._amp_dtype = amp_dtype
# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = amp_attrs._use_promote
amp_attrs._use_promote = use_promote
try:
yield
finally:
_g_amp_state_ = original_state
amp_attrs._amp_level = original_amp_level
core._set_amp_op_list(original_white_list, original_black_list)
amp_attrs._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
amp_attrs._use_promote = original_use_promote
else:
# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16 and bfloat16.
# Maybe we will support cpu for bfloat16.
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
f'amp_guard can only be enabled on CUDAPlace, XPUPlace, and CustomPlace, current place is {tracer._expected_place}, so it makes no effect.'
)
enable = False
if enable:
# For xpu:
if tracer._expected_place.is_xpu_place():
if (dtype == 'float16') and not _is_xpu_float16_supported():
xpu_version = core.get_xpu_device_version(
_current_expected_place().get_device_id()
)
warnings.warn(
f'{core.XPUVersion(xpu_version)} does not support float16 amp.'
)
enable = False
elif (dtype == 'bfloat16') and not _is_xpu_bfloat16_supported():
xpu_version = core.get_xpu_device_version(
_current_expected_place().get_device_id()
)
warnings.warn(
f'{core.XPUVersion(xpu_version)} does not support bfloat16 amp.'
)
enable = False
# For custom device:
if (
tracer._expected_place.is_custom_place()
and not _is_custom_device_bfloat16_supported()
and (dtype == 'bfloat16')
):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
f"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: {paddle.device.cuda.get_device_name()}, with Compute Capability: {prop[0]}.{prop[1]}."
)
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
f"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: {paddle.device.cuda.get_device_name()}, with Compute Capability: {prop[0]}.{prop[1]}, current CUDA Version is: {cuda_version}."
)
enable = False
if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
# master_grad_hook will run at the end of backward.
# Since backward_final_hook will be cleared once they have been
# done, we should register the hook every step.
if (
amp_global_state().use_master_grad
and not amp_global_state().already_register_final_backward_hook
):
def _dtensor_from_local(local_tensor, mesh, placements):
global_dims = list(local_tensor.shape)
for idx, placement in enumerate(placements):
if placement.is_shard():
global_dims[placement.get_dim()] = (
global_dims[placement.get_dim()] * mesh.shape[idx]
)
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
return paddle.Tensor(
local_tensor,
dims=global_dims,
process_mesh=mesh,
placements=placements,
place=place,
)
def master_grad_hook():
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
# classify the params of model into different classes according to their process_mesh.
# Otherwise, fault will occur.
if not amp_global_state().already_classify_params_meshes:
for param in amp_global_state().model_parameters:
if param is not None and param.process_mesh is not None:
if (
param.process_mesh
not in amp_global_state().mesh2params
):
amp_global_state().mesh2params[
param.process_mesh
] = [param]
else:
amp_global_state().mesh2params[
param.process_mesh
].append(param)
amp_global_state().already_classify_params_meshes = True
if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(
amp_global_state().model_parameters
)
amp_global_state().already_register_final_backward_hook = False
def _update_main_grad_hook(param):
@paddle.autograd.no_grad()
def param_hook(tmp_grad):
if tmp_grad is not None and tmp_grad._is_initialized():
if param.main_grad is None:
tmp = core.eager.Tensor(
value=tmp_grad._local_value()
.cast(paddle.float32)
.value(),
place=tmp_grad.place,
name="main_grad@" + param.name,
)
param.main_grad = _dtensor_from_local(
tmp,
tmp_grad.process_mesh,
tmp_grad.placements,
)
else:
param.main_grad._local_value().add_(
tmp_grad._local_value()
)
tmp_grad._clear_data()
return param_hook
if os.getenv("FLAGS_enable_tensor_fusion") in [
"True",
"true",
"1",
]:
for param in amp_global_state().model_parameters:
if not hasattr(param, "main_grad"):
param.main_grad = None
param._register_grad_hook(_update_main_grad_hook(param))
os.environ["FLAGS_enable_tensor_fusion"] = "0"
else:
core.eager._add_backward_final_hook(master_grad_hook)
amp_global_state().already_register_final_backward_hook = True
if tracer:
# enable auto_cast
original_amp_level = tracer._amp_level
tracer._amp_level = amp_level
# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()
tracer._set_amp_op_list(_white_list, _black_list)
# TODO(zhiqiu) set amp related flags automatically in this guard
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed inside amp_guard.
# So, users need to set related flags manually.
# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)
# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = tracer._use_promote
tracer._use_promote = use_promote
# restore status
try:
yield
finally:
if tracer:
_g_amp_state_ = original_state
tracer._amp_level = original_amp_level
tracer._set_amp_op_list(
original_white_list, original_black_list
)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
tracer._use_promote = original_use_promote
class StateDictHook:
def __init__(self, save_dtype: str) -> None:
self._save_dtype = save_dtype
def __call__(self, state_dict: _StateDict) -> None:
with paddle.base.framework._dygraph_guard(paddle.base.dygraph.Tracer()):
for key in state_dict:
param = state_dict[key]
if paddle.is_floating_point(param):
param_applied = paddle.cast(param, self._save_dtype)
param_applied.name = param.name
state_dict[key] = param_applied
def _set_multi_precision(
optimizer: _OptimizerLike, multi_precision: bool
) -> None:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
DygraphShardingOptimizerV2,
)
optimizer = (
optimizer._inner_opt
if isinstance(
optimizer, (DygraphShardingOptimizer, DygraphShardingOptimizerV2)
)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
@overload
def amp_decorate(
models: _ModelsT,
optimizers: _OptimizersT = ...,
level: _AmpLevelLiteral = ...,
dtype: _DTypeLiteral = ...,
master_weight: bool | None = ...,
save_dtype: _DTypeLiteral | None = ...,
master_grad: bool = ...,
excluded_layers: (
Layer | list[Layer | type[Layer]] | type[Layer] | None
) = ...,
) -> tuple[_ModelsT, _OptimizersT]: ...
@overload
def amp_decorate(
models: _ModelsT,
optimizers: None = ...,
level: _AmpLevelLiteral = ...,
dtype: _DTypeLiteral = ...,
master_weight: bool | None = ...,
save_dtype: _DTypeLiteral | None = ...,
master_grad: bool = ...,
excluded_layers: (
Layer | list[Layer | type[Layer]] | type[Layer] | None
) = ...,
) -> _ModelsT: ...
@dygraph_only
def amp_decorate(
models: _ModelsT,
optimizers: _OptimizersT | None = None,
level: _AmpLevelLiteral = 'O1',
dtype: _DTypeLiteral = 'float16',
master_weight: bool | None = None,
save_dtype: _DTypeLiteral | None = None,
master_grad: bool = False,
excluded_layers: (
Layer | list[Layer | type[Layer]] | type[Layer] | None
) = None,
) -> tuple[_ModelsT, _OptimizersT] | _ModelsT:
"""
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm, InstanceNorm and LayerNorm.
Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.
Args:
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer|None, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm, InstanceNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool|None, optional): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(str|None, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight
gradients will be float32 dtype after the back propagation. Default is False, there is only float16 weight gradients.
excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as
an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16.
Examples:
.. code-block:: pycon
>>> # doctest: +REQUIRES(env:GPU)
>>> # Demo1: single model and optimizer:
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> optimizer = paddle.optimizer.SGD(parameters=model.parameters())
>>> model, optimizer = paddle.amp.amp_decorate(models=model, optimizers=optimizer, level='O2')
>>> data = paddle.rand([10, 3, 32, 32])
>>> with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
... output = model(data)
... print(output.dtype)
paddle.float16
>>> # Demo2: multi models and optimizers:
>>> model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
>>> models, optimizers = paddle.amp.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
>>> data = paddle.rand([10, 3, 32, 32])
>>> with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
... output = models[0](data)
... output2 = models[1](data)
... print(output.dtype)
... print(output2.dtype)
paddle.float16
paddle.float16
>>> # Demo3: optimizers is None:
>>> model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
>>> optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters())
>>> model = paddle.amp.amp_decorate(models=model3, level='O2')
>>> data = paddle.rand([10, 3, 32, 32])
>>> with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
... output = model(data)
... print(output.dtype)
paddle.float16
"""
if level not in ['O1', 'O2']:
raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
)
if dtype not in ['float16', 'bfloat16']:
raise ValueError("dtype only support float16 or bfloat16.")
if level == 'O1':
if optimizers is None:
return models
else:
return models, optimizers
# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
# check device_type:
if not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_custom_place()
):
if optimizers is None:
return models
else:
return models, optimizers
# For xpu:
if tracer._expected_place.is_xpu_place():
if (dtype == 'float16' and not _is_xpu_float16_supported()) or (
dtype == 'bfloat16' and not _is_xpu_bfloat16_supported()
):
if optimizers is None:
return models
else:
return models, optimizers
# For custom device:
if (
tracer._expected_place.is_custom_place()
and not _is_custom_device_bfloat16_supported()
and (dtype == 'bfloat16')
):
if optimizers is None:
return models
else:
return models, optimizers
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16' and not _is_gpu_float16_supported()) or (
dtype == 'bfloat16' and not _is_gpu_bfloat16_supported()
):
if optimizers is None:
return models
else:
return models, optimizers
models_is_list = False
if isinstance(models, paddle.nn.Layer):
models_is_list = False
models = [models]
check_models(models)
elif isinstance(models, list):