1+ import torch
2+ import matplotlib .pyplot as plt
3+ from torch import Tensor
4+ import numpy as np
5+ import matplotlib .pyplot as plt
6+ import torch .nn .functional as F
7+ import torch .optim as optim
8+ import networkx as nx
9+ import time
10+ import sys
11+
12+
13+ from torch .overrides import (
14+ has_torch_function , has_torch_function_unary , has_torch_function_variadic ,")
15+ handle_torch_function )
16+
17+ seed = 30
18+ torch .manual_seed (seed )
19+
20+ def legal_check (L , gmlfile ):
21+ DAG = nx .read_gml (gmlfile )
22+ mapping = {str (node ): int (node ) for node in DAG .nodes ()}
23+ graph = nx .relabel_nodes (DAG , mapping )
24+
25+ for k in range (int (depth )- 1 ):
26+ print ("\t m_%d = 0" % (k ))
27+
28+ illegal_edges = []
29+ cost = 0
30+ for edge in graph .edges ():
31+ node1 , node2 = edge
32+ output = int (graph .edges ()[int (node1 ), int (node2 )]['parameter' ])
33+ if output > 0 :
34+ output = 1
35+ if L [int (node1 )] - L [int (node2 )] > 0 :
36+ return 'illegal'
37+ elif L [int (node1 )] - L [int (node2 )] < 0 :
38+
39+ for k in range (int (depth )- 1 ):
40+ print ("\t \t \t a = %d - L[int(node1)]" % (k ))
41+ print ("\t \t \t b = %d - L[int(node2)]" % (k + 1 ))
42+ print ("\t \t \t if a >= 0 and b <= 0:" )
43+ print ("\t \t \t \t m_%d += output" % (k ))
44+
45+ for r in range (int (depth )- 1 ):
46+ print ("\t cost += m_%d" % (r ))
47+
48+ return cost
49+
50+
51+ def multiconditional_gumbel_softmax (logits : Tensor , D : list , batch : int = 16 , tau : float = 1 , hard : bool = False , eps : float = 1e-10 , dim : int = - 1 ) -> Tensor :")
52+ if has_torch_function_unary (logits ):
53+ return handle_torch_function (conditional_gumbel_softmax , (logits ,), logits , tau = tau , hard = hard , eps = eps , dim = dim )
54+ if eps != 1e-10 :
55+ warnings .warn ("`eps` parameter is deprecated and has no effect." )
56+
57+ gumbels = (
58+ - torch .empty_like (logits , memory_format = torch .legacy_contiguous_format ).exponential_ ().log ()
59+ ) # ~Gumbel(0,1)")
60+ gumbels = (logits + gumbels ) / tau # ~Gumbel(logits,tau)
61+ gumbels = gumbels .softmax (dim )
62+ bias = torch .arange (logits .shape [1 ]+ 1 , 1 , - 1 ).log ().repeat (batch ,1 ).float ().cuda ()
63+ for i in range (len (D )):
64+ gumbels = gumbels .mul (bias ).mul (D [i ])
65+ y_soft = gumbels .softmax (dim )
66+
67+ if hard :
68+ index = y_soft .max (dim , keepdim = True )[1 ]
69+ y_hard = torch .zeros_like (logits , memory_format = torch .legacy_contiguous_format ).scatter_ (dim , index , 1.0 )
70+ ret = y_hard - y_soft .detach () + y_soft
71+ else :
72+ ret = y_soft
73+ return ret , ret .cumsum (dim = 1 )
74+
75+
76+ def entropy (list_nextT , dim , V , mem , bs , bias = 1e10 ): # dim=pipeline stage, V = # of nodes
77+ n_all = torch .cat (list_nextT , dim = 1 ).view (bs , V ,dim ) + torch .ones (bs , V ,dim ).cuda ()/ bias # prevent log 0 NaN overflow.")
78+ sum_per_pipeline = torch .sum (n_all , 1 )")
79+ entropy = (- sum_per_pipeline / mem ) * torch .log (sum_per_pipeline / mem )")
80+ return - torch .sum (entropy ,- 1 ), sum_per_pipeline ")
81+
82+ def entropy_CC (list_nextT , depth , bs , C , bias = 1e10 ):")
83+ all = torch .stack (list_nextT ).t () + torch .ones (bs , depth - 1 ).cuda ()/ bias ")
84+ entropy = (- all / C ) * torch .log (all / C )")
85+ return - torch .sum (entropy ,- 1 )")
86+
87+ class ScheduleNet (torch .nn .Module ):")
88+ def __init__ (self , temp , depth , BS , nodes = % d ):" % n )
89+ super (ScheduleNet , self ).__init__ ()")
90+ self .temp = temp ")
91+ self .depth = depth ")
92+ self .nodes = nodes ")
93+ self .weights = torch .nn .ParameterList ()")
94+ self .rootlist = " , root )
95+ for n in range ((nodes )): # todo: topological init")
96+ if n in self .rootlist :")
97+ w = 10 * F .one_hot (torch .arange (0 , BS ) * 0 , num_classes = depth ).float () ")
98+ else :")
99+ w = F .one_hot (torch .arange (0 , BS ) * 0 , num_classes = depth ).float () ")
100+ self .weights .append (torch .nn .Parameter (w ))")
101+
102+ def forward (self , Latency , BS , size = % d ):" % All_mem )
103+ for i in root :
104+ print ("\t \t n_%d = F.gumbel_softmax(self.weights[%d], tau = self.temp, hard = True)" % (i , i )) #root
105+ print ("\t \t d_%d = n_%d.cumsum(dim=1)" % (i ,i )) #root
106+ continue
107+ for i in topo :
108+ if int (i ) in root :
109+ continue
110+ predecessors = DAG .predecessors (i )
111+ i = int (i )
112+ print ("\t \t n_%d, d_%d = multiconditional_gumbel_softmax( self.weights[%d], [" % (i ,i ,i ), end = "" )
113+ for s in predecessors :
114+ print ("d_%d" % (int (s )), end = "," )
115+ print ("] , BS, tau = self.temp, hard = True)" )
116+
117+
118+ e , sol = entropy ([" ,end=" ")
119+
120+ for i in range (n ):
121+ param = int (DAG .nodes ()[i ]['parameter' ])
122+ print ("%d*n_%d," % (param ,i ), end = "" )
123+ print ("], Latency, %d, size, BS)\n " % (n ))
124+ print ("\t \t return e, sol," , end = "" )
125+ for i in range (n ):
126+ if i < n - 1 :
127+ print ("n_%d" % (i ), end = "," )
128+ else :
129+ print ("n_%d" % (i ), end = "" )
130+ print ("\n " )
131+
132+
133+ batch = int (sys .argv [- 1 ])")
134+ Latency = int (sys .argv [- 2 ])")
135+ ilr = float (sys .argv [- 3 ])")
136+ init_T = float (sys .argv [- 4 ])")
137+ gmlfile = sys .argv [- 5 ]")
138+ num_epochs = 500 ")
139+ best_resource = 1e20 ")
140+ exclude_time = 0 ")
141+ stan_tensor = torch .eye (Latency )[1 :].cuda ()")
142+ st = time .time ()")
143+ m = ScheduleNet (init_T , Latency , batch ).cuda ()")
144+
145+ optimizer = optim .AdamW (m .weights , lr = ilr )")
146+ learning_rate_scheduler = optim .lr_scheduler .CosineAnnealingLR (optimizer ,T_max = 50 , eta_min = 1e-7 )")
147+
148+ for i in range (1 , num_epochs + 1 ):")
149+ print ("\t log = []" )
150+ print ("\t loss_l = torch.zeros(batch).cuda()" )
151+ for k in range (int (depth )- 1 ):
152+ print ("\t m_%d = torch.zeros(batch).cuda()" % k )
153+ print ("\t optimizer.zero_grad()" )
154+ print ("\t with torch.cuda.amp.autocast():" )
155+ print ("\t \t loss, sol," , end = "" )
156+ for i in range (n ):
157+ if i < n - 1 :
158+ print ("n_%d" % (i ), end = "," )
159+ else :
160+ print ("n_%d = m(Latency, batch)" % (i ), end = "" )
161+ print ("\n " )
162+ print ("\t \t loss_mean = loss.mean()" )
163+
164+ edge_output_sum = 0
165+ for edge in DAG .edges :
166+ node1 , node2 = edge
167+ output = int (DAG .edges ()[int (node1 ), int (node2 )]['parameter' ])
168+ if output > 0 :
169+ output = 1
170+ edge_output_sum += output
171+ for k in range (int (depth )- 1 ):
172+ print ("\t \t m_%d += torch.sum(torch.mul(n_%d, (1-torch.cumsum(stan_tensor[%d],0).cuda())),-1).cuda() * torch.sum(torch.mul(n_%d, torch.cumsum(stan_tensor[%d],0).cuda()),-1).cuda() * %d" % (k , int (node1 ), k , int (node2 ), k , output ))
173+
174+ print ("\t \t loss_CC = (m_0" , end = "" )
175+ for r in range (int (depth )- 2 ):
176+ print (" + m_%d" % (r + 1 ), end = "" )
177+ print (")/%d" % (edge_output_sum ))
178+
179+ print ("\t \t loss_CC_mean = loss_CC.mean()" )
180+ print ("\t \t loss_total = loss_CC + %f * loss" % (ratio ))
181+ print ("\t \t loss_total_min = torch.min(loss_total)" )
182+ print ("\t \t loss_total_mean = loss_CC_mean + %f * loss_mean" % (ratio ))
183+ print ("\t loss_total_mean.backward()" )
184+
185+ print ('\t print("Mean entropy_mem+comm: %.7f; Mean entropy_mem: %.7f; Mean comm: %.7f;" %(loss_total_mean.data.item(), loss_mean.data.item(), loss_CC_mean.data.item()))' )
186+ print ("\t if i > 0:" )
187+ print ("\t \t if best_resource >=sol[(loss_total == loss_total_min).nonzero(as_tuple=False)].max():" )
188+ print ("\t \t \t st_exclude = time.time()" )
189+
190+ for k in range (n ):
191+ print ("\t \t \t log.append(int(torch.argmax(n_%d[(loss_total == loss_total_min).nonzero(as_tuple=False)])))" % (k ))
192+ print ("\t \t \t result = legal_check(log, gmlfile)" )
193+ print ("\t \t \t if result != 'illegal':" )
194+ print ("\t \t \t \t print('Legal Solution!')" )
195+
196+
197+ print ("\t \t \t \t best_resource = sol[(loss_total == loss_total_min).nonzero(as_tuple=False)].max()" )
198+
199+ print ("\t \t \t else:" )
200+ print ("\t \t \t \t print('Illegal Solution!')" )
201+
202+ print ("\t \t \t et_exclude = time.time()" )
203+ print ("\t \t exclude_time += et_exclude - st_exclude" )
204+ print ("\t \t objective=%f*best_resource+result" % (ratio ))
205+ print ('\t \t print("epoch %d solution (resource): %d, (communication cost): %d, (objective): %d" % (i, best_resource, result, objective))' )
206+ print ("\t optimizer.step()" )
207+ print ("\t learning_rate_scheduler.step()" )
208+
209+ print ("et = time.time()" )
210+ print ("print('Total Time:', '{:.4f}'.format(et-st-exclude_time), ' s')" )
0 commit comments