-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
303 lines (248 loc) · 10.6 KB
/
main.py
File metadata and controls
303 lines (248 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torchvision.models as models
import torch.nn as nn
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import os
import time
from multiprocessing import Process, log_to_stderr
import csv
from gossip_module.utils import flatten_tensors, flatten_tensors_grad, unflatten_tensors, unflatten_tensors_grad
from fsdp_custom import FullyShardedDataParallel as FSDP
#from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from dp_custom import DataParallel_Custom as DP
from auto_wrap_custom import enable_wrap, auto_wrap, wrap
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
from torch_scheduler import ShardScheduler
from randomScheduleGenerator import randomSchedule
import threading
import argparse
import timeit
import numpy as np
from torchsummary import summary
import copy
def module_check(module):
#if (len(list(module.children())) == 0 ):
#print(module)
#print(module.data.size())
for name, child in module.named_children():
module_check(child)
class Trainer:
def __init__(self, world_size, rank, shard, precision):
torch.backends.cudnn.benchmark = True
#world_size = int(os.environ["WORLD_SIZE"])
self.world_size = world_size
print(f'world_size : {world_size}')
ngpus_per_node = torch.cuda.device_count()
self.shard = shard
#rank = int(os.environ['SLURM_PROCID'])
self.rank = rank
print(f'rank : {rank}')
self.device = torch.device("cuda:" + str(rank%ngpus_per_node))
torch.cuda.set_device(rank%ngpus_per_node)
print("cuda:" + str(rank%ngpus_per_node))
self.process_groups = []
world_list = [x for x in range(world_size) ]
#self.process_groups = []
#world_list = [x for x in range(world_size) ]
#for i in range(self.thread_num):
# ng = dist.new_group(world_list, backend='gloo')
# self.process_groups.append(ng)
self.batch_size = 16
self.image_size = 42
self.classification_num = 1000
#self.model = models.resnet101()
print(f"before init model {torch.cuda.memory_allocated() / 1024 /1024}")
self.model = ResNet(Bottleneck, [3, 8, 36, 3]) #it means "resnet18 model"
self.model.cuda()
print(f"after init model {torch.cuda.memory_allocated() / 1024 /1024}")
self._rs_locks = {}
self._ag_locks = {}
self._ag_fsdp_locks = {}
self._rs_conditions = {}
self._ag_conditions = {}
self._ag_fsdp_conditions = {}
self._forward_conditions = {}
self._backward_conditions = {}
#check lazy init
self._lazy_init_locks = {}
self._lazy_init_conditions = {}
self._partition_counts = {}
self._scheduled_comms = {}
self._done_counts = {}
self.model_parameter_names = {}
module_check(self.model)
self.datasets = []
self.target = None
self.data_index = 0
print(f"before init dataset {torch.cuda.memory_allocated() / 1024 /1024}")
for _ in range(100):
data = torch.rand(self.batch_size, 3, 80, 80)
self.target = torch.LongTensor(self.batch_size).random_() % 1000
data, self.target = data.cuda(), self.target.cuda()
self.datasets.append(data)
print(f"after init dataset {torch.cuda.memory_allocated() / 1024 /1024}")
summary(self.model, ( 3, 80, 80))
self.profiled_memory_utilization = []
wrap_cls = DP
if(self.shard == 0):
wrap_cls = DP
elif(self.shard == 1):
wrap_cls = FSDP
mixed_precision_bool = False
if(precision == 0):
mixed_precision_bool = False
else :
mixed_precision_bool = True
self.comm_stream = torch.cuda.Stream()
self.fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=False, flatten_parameters=True,
done_counts=self._done_counts, partition_counts=self._partition_counts,
rs_locks=self._rs_locks, ag_locks=self._ag_locks, ag_fsdp_locks=self._ag_fsdp_locks,
rs_conditions=self._rs_conditions, ag_conditions=self._ag_conditions, ag_fsdp_conditions=self._ag_fsdp_conditions,
forward_conditions=self._forward_conditions, backward_conditions=self._backward_conditions,
lazy_init_locks=self._lazy_init_locks, lazy_init_conditions=self._lazy_init_conditions,
memory_record=self.profiled_memory_utilization, comm_stream=self.comm_stream,
model_parameter_names=self.model_parameter_names
)
#self.fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=False, flatten_parameters=False, memory_record=self.profiled_memory_utilization)
self.sharded_module = None
self.optimizer = None
self.criterion = None
self.partition_threshold = 20000
with enable_wrap(**self.fsdp_params):
self.sharded_module = auto_wrap(self.model)
self._scheduled_comms = randomSchedule(self.sharded_module)
for n, p in self.sharded_module.named_parameters():
#print(p.numel())
self._partition_counts[p] = (p.numel() // self.partition_threshold) + 1
self._done_counts[p] = 0
self._rs_locks[p] = threading.Lock()
self._ag_locks[p] = threading.Lock()
self._ag_locks[p].acquire()
self._ag_fsdp_locks[p] = threading.Lock()
self._rs_conditions[p] = threading.Condition(threading.Lock())
self._ag_conditions[p] = threading.Condition(threading.Lock())
self._ag_fsdp_conditions[p] = threading.Condition(threading.Lock())
self._forward_conditions[p] = threading.Condition(threading.Lock())
self._backward_conditions[p] = threading.Condition(threading.Lock())
self._lazy_init_locks[p] = threading.Lock()
self._lazy_init_conditions[p] = threading.Condition(threading.Lock())
self.model_parameter_names[p] = n
#model_summary = {}
#for n, p in self.sharded_module.named_parameters():
# layer_info = {}
# layer_info["numel"] = p.numel()
# layer_info["partitions"] = (p.numel() // 20000) + 1
# layer_info["name"] = n
# model_summary[p] = copy.deepcopy(layer_info)
#with open("foo.txt", "w") as f:
# for comm in scheduled_comms :
# f.write("--------------------\n")
# f.write(comm["type"]+ "\n")
# f.write(comm["comm_type"] + "\n")
# for param in comm["params"] :
# param_summary = model_summary[param]
# f.write(param_summary["name"] + "\n")#
#self.sharded_module = FSDP(self.model, memory_record=self.profiled_memory_utilization)
print(f"before init optimizer {torch.cuda.memory_allocated() / 1024 /1024}")
#self.optimizer = torch.optim.SGD(self.sharded_module.parameters() , lr=0.001, momentum=0.9, nesterov=True)
self.optimizer = torch.optim.Adam(self.sharded_module.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
self.optimizer = ShardScheduler(self.sharded_module, self.sharded_module.named_parameters(), self.world_size, self.rank, self.optimizer,
self.partition_threshold, self._done_counts, self._partition_counts,
self._rs_locks, self._ag_locks, self._ag_fsdp_locks,
self._rs_conditions, self._ag_conditions, self._ag_fsdp_conditions,
self._forward_conditions, self._backward_conditions,
self._lazy_init_locks, self._lazy_init_conditions,
10**6, self.comm_stream, self._scheduled_comms)
print(f"after init optimizer {torch.cuda.memory_allocated() / 1024 /1024}")
self.criterion = nn.CrossEntropyLoss()
#if(wftp == True):
# self._register_hooks()
self.scaler = GradScaler()
def benchmark_step(self):
with enable_wrap(**self.fsdp_params):
data = self.datasets[self.data_index%len(self.datasets)]
self.data_index += 1
print(f"before forward {torch.cuda.memory_allocated() / 1024 /1024}")
output = self.sharded_module(data)
print(f"after forward {torch.cuda.memory_allocated() / 1024 /1024}")
#
loss = self.criterion(output,self.target)
print(f"before backward {torch.cuda.memory_allocated() / 1024 /1024}")
#
loss.backward()
#print(len(self.profiled_memory_utilization))
#
#self.optimizer.step()
#self.optimizer.zero_grad()
#torch.cuda.empty_cache()
#self.optimizer.zero_grad()
#data = self.datasets[self.data_index%len(self.datasets)]
##with autocast():
#output = self.model(data)
#loss = self.criterion(output,self.target)
#print(torch.cuda.memory_allocated() / 1024 /1024)
#loss.backward()
#self.optimizer.step()
##self.scaler.scale(loss).backward()
##self.scaler.step(self.optimizer)
#
##self.scaler.update()
#print(torch.cuda.memory_allocated() / 1024 /1024)
#torch.cuda.empty_cache()
#self.profiled_memory_utilization = self.profiled_memory_utilization[:0]
def train(self):
f_times = []
b_times = []
itr_times = []
proc = None
for itr in range(2):
print(itr)
batch = torch.rand(self.batch_size, 3, self.image_size, self.image_size).cuda()
t = torch.randint(0, self.classification_num-1, (self.batch_size,))
#a = torch.zeros((self.batch_size, self.classification_num))
#a[:, t] = 1
target = t.cuda().type(torch.cuda.LongTensor)
with enable_wrap(**self.fsdp_params):
output = self.sharded_module(batch)
loss = self.criterion(output,target)
loss.backward()
#self.optimizer.step()
#self.optimizer.zero_grad()
if __name__ == '__main__':
os.environ['MASTER_ADDR'] = '210.107.197.218'
os.environ['MASTER_PORT'] = '30000'
parser = argparse.ArgumentParser()
parser.add_argument('--rank', dest='rank', default=0, type=int)
parser.add_argument('--shard', dest='shard', default=0, type=int)
parser.add_argument('--mixed_precision', dest='mixed_precision', default=0, type=int)
args = parser.parse_args()
world_size = 2
rank = args.rank
shard = args.shard
mixed_precision = args.mixed_precision
#world_size = int(os.environ["WORLD_SIZE"])
#rank = int(os.environ['SLURM_PROCID'])
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
#case 1
trainer = Trainer(world_size, rank, shard, mixed_precision)
img_secs = []
print(torch.cuda.memory_allocated() / 1024 /1024)
for x in range(1):
time = timeit.timeit(trainer.benchmark_step, number=5)
img_sec = 32 * 10 / time
print('Iter #%d: %.1f img/sec per' % (x, img_sec))
img_secs.append(img_sec)
shard_tag = "DP" if shard == 0 else "FSDP"
mixed_precision_tag = "full_preicision" if mixed_precision == 0 else "mixed_precision"
with open(f'{shard_tag}_{mixed_precision_tag}_memory_utilization.csv', 'w', newline='') as f:
writer = csv.writer(f)
for i in range(len(trainer.profiled_memory_utilization)):
#print(f"{i} : {trainer.profiled_memory_utilization[i]}")
writer.writerow([trainer.profiled_memory_utilization[i]])
img_sec_mean = np.mean(img_secs)
img_sec_conf = 1.96 * np.std(img_secs)
print('Img/sec : %.1f +-%.1f' % ( img_sec_mean, img_sec_conf))
print('Total img/sec on %d (s): %.1f +-%.1f' % (world_size, world_size * img_sec_mean, world_size * img_sec_conf))