1- import torch
21from torch import nn
32from torch4keras .trainer import *
43
@@ -10,24 +9,5 @@ def __init__(self, *args, **kwargs):
109 nn .Module .__init__ (self )
1110 Trainer .__init__ (self , * args , ** kwargs )
1211
13-
14- class BaseModelDP (nn .DataParallel , BaseModel ):
15- '''DataParallel模式使用多gpu的方法, 父类顺序颠倒也会出问题
16- '''
17- def __init__ (self , * args , ** kwargs ):
18- BaseModel .__init__ (self )
19- nn .DataParallel .__init__ (self , * args , ** kwargs )
20-
21-
22- class BaseModelDDP (nn .parallel .DistributedDataParallel , BaseModel ):
23- '''DistributedDataParallel模式使用多gpu的方法, 父类顺序颠倒也会出问题
24- '''
25- def __init__ (self , * args , master_rank = 0 , ** kwargs ):
26- BaseModel .__init__ (self )
27- nn .parallel .DistributedDataParallel .__init__ (self , * args , ** kwargs )
28-
29- # 默认仅对master_rank=0打印信息
30- assert isinstance (master_rank , (int , list , tuple )), 'Args `master_rank` only supoorts int, list, tuple'
31- if isinstance (master_rank , int ):
32- master_rank = [master_rank ]
33- self .verbose = (torch .distributed .get_rank () in master_rank )
12+ BaseModelDP = TrainerDP
13+ BaseModelDDP = TrainerDDP
0 commit comments