Skip to content

Commit 9d0bc48

Browse files
committed
trying to reproduce gumbel-softmax implementation
1 parent 4abdc23 commit 9d0bc48

File tree

1 file changed

+210
-0
lines changed

1 file changed

+210
-0
lines changed

test_gs.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import torch
2+
import matplotlib.pyplot as plt
3+
from torch import Tensor
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
import networkx as nx
9+
import time
10+
import sys
11+
12+
13+
from torch.overrides import (
14+
has_torch_function, has_torch_function_unary, has_torch_function_variadic,")
15+
handle_torch_function)
16+
17+
seed = 30
18+
torch.manual_seed(seed)
19+
20+
def legal_check(L, gmlfile):
21+
DAG = nx.read_gml(gmlfile)
22+
mapping = {str(node): int(node) for node in DAG.nodes()}
23+
graph = nx.relabel_nodes(DAG, mapping)
24+
25+
for k in range(int(depth)-1):
26+
print("\tm_%d = 0"%(k))
27+
28+
illegal_edges = []
29+
cost = 0
30+
for edge in graph.edges():
31+
node1, node2 = edge
32+
output = int(graph.edges()[int(node1), int(node2)]['parameter'])
33+
if output > 0:
34+
output = 1
35+
if L[int(node1)] - L[int(node2)] > 0:
36+
return 'illegal'
37+
elif L[int(node1)] - L[int(node2)] < 0:
38+
39+
for k in range(int(depth)-1):
40+
print("\t\t\ta = %d - L[int(node1)]"%(k))
41+
print("\t\t\tb = %d - L[int(node2)]"%(k+1))
42+
print("\t\t\tif a >= 0 and b <= 0:")
43+
print("\t\t\t\tm_%d += output"%(k))
44+
45+
for r in range(int(depth)-1):
46+
print("\tcost += m_%d"%(r))
47+
48+
return cost
49+
50+
51+
def multiconditional_gumbel_softmax(logits: Tensor, D: list, batch: int = 16, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:")
52+
if has_torch_function_unary(logits):
53+
return handle_torch_function(conditional_gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim)
54+
if eps != 1e-10:
55+
warnings.warn("`eps` parameter is deprecated and has no effect.")
56+
57+
gumbels = (
58+
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
59+
) # ~Gumbel(0,1)")
60+
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
61+
gumbels = gumbels.softmax(dim)
62+
bias = torch.arange(logits.shape[1]+1, 1, -1).log().repeat(batch,1).float().cuda()
63+
for i in range(len(D)):
64+
gumbels = gumbels.mul(bias).mul(D[i])
65+
y_soft = gumbels.softmax(dim)
66+
67+
if hard:
68+
index = y_soft.max(dim, keepdim=True)[1]
69+
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
70+
ret = y_hard - y_soft.detach() + y_soft
71+
else:
72+
ret = y_soft
73+
return ret, ret.cumsum(dim=1)
74+
75+
76+
def entropy(list_nextT, dim, V, mem, bs, bias=1e10): # dim=pipeline stage, V = # of nodes
77+
n_all = torch.cat(list_nextT, dim=1).view(bs, V,dim) + torch.ones(bs, V,dim).cuda()/bias # prevent log 0 NaN overflow.")
78+
sum_per_pipeline = torch.sum(n_all, 1)")
79+
entropy = (-sum_per_pipeline/mem) * torch.log(sum_per_pipeline/mem)")
80+
return -torch.sum(entropy,-1), sum_per_pipeline")
81+
82+
def entropy_CC(list_nextT, depth, bs, C, bias=1e10):")
83+
all = torch.stack(list_nextT).t() + torch.ones(bs, depth-1).cuda()/bias")
84+
entropy = (-all/C) * torch.log(all/C)")
85+
return -torch.sum(entropy,-1)")
86+
87+
class ScheduleNet(torch.nn.Module):")
88+
def __init__(self, temp, depth, BS, nodes = %d):" % n)
89+
super(ScheduleNet, self).__init__()")
90+
self.temp = temp")
91+
self.depth = depth")
92+
self.nodes = nodes")
93+
self.weights = torch.nn.ParameterList()")
94+
self.rootlist = ", root)
95+
for n in range((nodes)): # todo: topological init")
96+
if n in self.rootlist:")
97+
w = 10*F.one_hot(torch.arange(0, BS) * 0, num_classes=depth).float() ")
98+
else:")
99+
w = F.one_hot(torch.arange(0, BS) * 0, num_classes=depth).float() ")
100+
self.weights.append(torch.nn.Parameter(w))")
101+
102+
def forward(self, Latency, BS, size=%d):" % All_mem)
103+
for i in root:
104+
print("\t\tn_%d = F.gumbel_softmax(self.weights[%d], tau = self.temp, hard = True)" % (i, i)) #root
105+
print("\t\td_%d = n_%d.cumsum(dim=1)" % (i,i)) #root
106+
continue
107+
for i in topo:
108+
if int(i) in root:
109+
continue
110+
predecessors = DAG.predecessors(i)
111+
i = int(i)
112+
print("\t\tn_%d, d_%d = multiconditional_gumbel_softmax( self.weights[%d], [" % (i,i,i), end ="")
113+
for s in predecessors:
114+
print("d_%d" % (int(s)), end=",")
115+
print("] , BS, tau = self.temp, hard = True)")
116+
117+
118+
e, sol = entropy([" ,end="")
119+
120+
for i in range(n):
121+
param = int(DAG.nodes()[i]['parameter'])
122+
print("%d*n_%d," % (param,i), end="")
123+
print("], Latency, %d, size, BS)\n"%(n))
124+
print("\t\treturn e, sol,", end = "")
125+
for i in range(n):
126+
if i < n-1:
127+
print("n_%d" % (i), end = ",")
128+
else:
129+
print("n_%d" % (i), end = "")
130+
print("\n")
131+
132+
133+
batch = int(sys.argv[-1])")
134+
Latency = int(sys.argv[-2])")
135+
ilr = float(sys.argv[-3])")
136+
init_T = float(sys.argv[-4])")
137+
gmlfile = sys.argv[-5]")
138+
num_epochs = 500")
139+
best_resource = 1e20")
140+
exclude_time = 0")
141+
stan_tensor = torch.eye(Latency)[1:].cuda()")
142+
st = time.time()")
143+
m = ScheduleNet(init_T, Latency, batch).cuda()")
144+
145+
optimizer = optim.AdamW(m.weights, lr=ilr)")
146+
learning_rate_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=50, eta_min=1e-7)")
147+
148+
for i in range(1, num_epochs+1):")
149+
print("\tlog = []")
150+
print("\tloss_l = torch.zeros(batch).cuda()")
151+
for k in range(int(depth)-1):
152+
print("\tm_%d = torch.zeros(batch).cuda()"%k)
153+
print("\toptimizer.zero_grad()")
154+
print("\twith torch.cuda.amp.autocast():")
155+
print("\t\tloss, sol,", end = "")
156+
for i in range(n):
157+
if i < n-1:
158+
print("n_%d" % (i), end = ",")
159+
else:
160+
print("n_%d = m(Latency, batch)" % (i), end = "")
161+
print("\n")
162+
print("\t\tloss_mean = loss.mean()")
163+
164+
edge_output_sum = 0
165+
for edge in DAG.edges:
166+
node1, node2 = edge
167+
output = int(DAG.edges()[int(node1), int(node2)]['parameter'])
168+
if output > 0:
169+
output = 1
170+
edge_output_sum += output
171+
for k in range(int(depth)-1):
172+
print("\t\tm_%d += torch.sum(torch.mul(n_%d, (1-torch.cumsum(stan_tensor[%d],0).cuda())),-1).cuda() * torch.sum(torch.mul(n_%d, torch.cumsum(stan_tensor[%d],0).cuda()),-1).cuda() * %d"%(k, int(node1), k, int(node2), k, output))
173+
174+
print("\t\tloss_CC = (m_0", end="")
175+
for r in range(int(depth)-2):
176+
print(" + m_%d"%(r+1), end="")
177+
print(")/%d"%(edge_output_sum))
178+
179+
print("\t\tloss_CC_mean = loss_CC.mean()")
180+
print("\t\tloss_total = loss_CC + %f * loss"%(ratio))
181+
print("\t\tloss_total_min = torch.min(loss_total)")
182+
print("\t\tloss_total_mean = loss_CC_mean + %f * loss_mean"%(ratio))
183+
print("\tloss_total_mean.backward()")
184+
185+
print('\tprint("Mean entropy_mem+comm: %.7f; Mean entropy_mem: %.7f; Mean comm: %.7f;" %(loss_total_mean.data.item(), loss_mean.data.item(), loss_CC_mean.data.item()))')
186+
print("\tif i > 0:")
187+
print("\t\tif best_resource >=sol[(loss_total == loss_total_min).nonzero(as_tuple=False)].max():")
188+
print("\t\t\tst_exclude = time.time()")
189+
190+
for k in range(n):
191+
print("\t\t\tlog.append(int(torch.argmax(n_%d[(loss_total == loss_total_min).nonzero(as_tuple=False)])))"%(k))
192+
print("\t\t\tresult = legal_check(log, gmlfile)")
193+
print("\t\t\tif result != 'illegal':")
194+
print("\t\t\t\tprint('Legal Solution!')")
195+
196+
197+
print("\t\t\t\tbest_resource = sol[(loss_total == loss_total_min).nonzero(as_tuple=False)].max()")
198+
199+
print("\t\t\telse:")
200+
print("\t\t\t\tprint('Illegal Solution!')")
201+
202+
print("\t\t\tet_exclude = time.time()")
203+
print("\t\texclude_time += et_exclude - st_exclude")
204+
print("\t\tobjective=%f*best_resource+result"%(ratio))
205+
print('\t\tprint("epoch %d solution (resource): %d, (communication cost): %d, (objective): %d" % (i, best_resource, result, objective))')
206+
print("\toptimizer.step()")
207+
print("\tlearning_rate_scheduler.step()")
208+
209+
print("et = time.time()")
210+
print("print('Total Time:', '{:.4f}'.format(et-st-exclude_time), ' s')")

0 commit comments

Comments
 (0)