Skip to content

Commit d121db8

Browse files
author
Adam Bosák
committed
pbm now allows for a second variant of the pbm - more steps of the primal update
1 parent c7c68c3 commit d121db8

30 files changed

Lines changed: 1449 additions & 16 deletions

benchmark/demo_balls.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from matplotlib.patches import Circle
2+
import torch
3+
from humancompatible.train.optim.PBM import PBM
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from mpl_toolkits.axes_grid1 import make_axes_locatable
7+
8+
torch.manual_seed(1)
9+
np.random.seed(1)
10+
11+
def plot_balls_trajectory(trajectories, names):
12+
"""
13+
trajectory: array-like of shape (N, 2), where each row is [x, y]
14+
"""
15+
16+
fig, ax = plt.subplots(figsize=(24, 16))
17+
18+
# Feasible regions: unit balls
19+
ball_centers = [(-2, 0), (2, 0)]
20+
radius = np.sqrt(0.99)
21+
labels = [r"$\mathbb{E}[g_1(x,y,\xi)] \leq 0 $", r"$\mathbb{E}[g_2(x,y,\xi)] \leq 0$"]
22+
23+
# Heatmap for x^2 + y^2
24+
x = np.linspace(-4, 4, 100)
25+
y = np.linspace(-2.5, 2.5, 100)
26+
X, Y = np.meshgrid(x, y)
27+
Z = X**2 + Y**2
28+
29+
30+
for center, label in zip(ball_centers, labels):
31+
ball = Circle(
32+
center,
33+
radius=radius,
34+
facecolor="lightgray",
35+
edgecolor="black",
36+
linewidth=1.5,
37+
alpha=0.6,
38+
zorder=1,
39+
)
40+
ax.add_patch(ball)
41+
# Label inside the ball
42+
ax.text(
43+
center[0], center[1],
44+
label,
45+
fontsize=18,
46+
fontweight='bold',
47+
color='black',
48+
ha='center',
49+
va='center',
50+
zorder=5
51+
)
52+
53+
for i, traj in enumerate(trajectories):
54+
traj = np.asarray(traj)
55+
56+
# Trajectory
57+
ax.plot(
58+
traj[:, 0],
59+
traj[:, 1],
60+
linewidth=2.0,
61+
zorder=3,
62+
alpha=1.0
63+
)
64+
65+
# Emphasize x_0 and x_n
66+
x0 = traj[0]
67+
xn = traj[-1]
68+
69+
ax.scatter(
70+
x0[0], x0[1],
71+
s=80,
72+
marker="o",
73+
facecolor="white",
74+
edgecolor="black",
75+
linewidth=2,
76+
zorder=4,
77+
)
78+
ax.scatter(
79+
xn[0], xn[1],
80+
s=60,
81+
marker="s",
82+
facecolor="black",
83+
edgecolor="black",
84+
zorder=4,
85+
)
86+
x_n = [
87+
r"$x_{\rho=0}^n$",
88+
r"$x_{\rho=1}^n$",
89+
r"$x_{\rho=2.5}^n$",
90+
r"$x_{SPBM}^n$",
91+
]
92+
# Labels for x_0 and x_n
93+
ax.annotate(
94+
r"$x^0$",
95+
xy=(x0[0], x0[1]),
96+
xytext=(6, 8),
97+
textcoords="offset points",
98+
fontsize=22,
99+
zorder=5,
100+
)
101+
ax.annotate(
102+
x_n[i],
103+
xy=(xn[0], xn[1]),
104+
# xytext=(8 if i==1 or i == 3 else -36, -12),
105+
xytext=(8 if i==1 or i == 3 else -45, -15),
106+
textcoords="offset points",
107+
fontsize=22,
108+
zorder=5,
109+
)
110+
111+
# Formatting
112+
ax.set_aspect("equal", adjustable="box")
113+
ax.set_xlabel(r"$x$", fontsize=12)
114+
ax.set_ylabel(r"$y$", fontsize=12)
115+
ax.grid(True, linestyle=":", linewidth=0.8, alpha=0.7)
116+
ax.set_xlim(-3.2, 3.2)
117+
ax.set_ylim(-1.8, 1.8)
118+
119+
120+
contour = ax.contourf(X, Y, Z, levels=100, cmap='viridis', alpha=0.5, zorder=0)
121+
divider = make_axes_locatable(ax)
122+
cax = divider.append_axes("right", size="3%", pad=0.08)
123+
124+
cbar = fig.colorbar(contour, cax=cax)
125+
cbar.set_label(r"$x^2 + y^2$", fontsize=14)
126+
127+
fig.savefig(
128+
"./demo_balls_pbm.pdf",
129+
bbox_inches="tight",
130+
pad_inches=0.05
131+
)
132+
133+
134+
135+
def balls(x, sample):
136+
g1 = ((x[0] - 2 + sample)**2 + x[1]**2 - 1)
137+
g2 = ((x[0] + 2 + sample)**2 + x[1]**2 - 1)
138+
# g1 = ((x[0] - 2)**2 + x[1]**2 - 1) + sample
139+
# g2 = ((x[0] + 2)**2 + x[1]**2 - 1) + sample
140+
if g1 <= g2:
141+
return g1
142+
else:
143+
return g2
144+
145+
def parabola(x):
146+
return x[0]**2 + x[1]**2
147+
148+
149+
samples = [
150+
# torch.tensor([0]),
151+
# torch.tensor([0])
152+
torch.tensor([-0.1]),
153+
torch.tensor([0.1])
154+
]
155+
156+
157+
################## SGD #########################
158+
159+
def run_sgd(rho: float):
160+
161+
xy = torch.nn.Parameter(data=torch.ones(2, requires_grad=True))
162+
with torch.no_grad():
163+
xy[0] = 0
164+
xy[1] = 1
165+
166+
sgd = torch.optim.SGD([xy], lr=0.05, dampening=0.1)
167+
168+
scheduler = torch.optim.lr_scheduler.LambdaLR(
169+
sgd,
170+
lr_lambda=lambda step: 0.99 ** step
171+
)
172+
173+
param_log_sgd = []
174+
con_log_sgd = []
175+
176+
for i in range(200):
177+
param_log_sgd.append(
178+
xy.detach().numpy().copy()
179+
)
180+
181+
r = np.random.uniform()
182+
minibatch = samples[0 if r > 0.5 else 1]
183+
184+
c = balls(xy, minibatch)
185+
186+
obj = parabola(xy) + rho * torch.square(torch.norm(c, p=2))
187+
188+
obj.backward()
189+
sgd.step()
190+
scheduler.step()
191+
sgd.zero_grad()
192+
for gr in sgd.param_groups:
193+
gr['lr'] *= 0.97
194+
195+
con_log_sgd.append(c.detach().numpy().copy().item())
196+
197+
return param_log_sgd, con_log_sgd
198+
199+
200+
sgd_param_logs, sgd_con_logs = [], []
201+
202+
rhos = [0.1,1,2]
203+
204+
for rho in rhos:
205+
param_log_sgd, con_log_sgd = run_sgd(rho)
206+
sgd_param_logs.append(param_log_sgd)
207+
sgd_con_logs.append(con_log_sgd)
208+
209+
210+
################## PBM #########################
211+
212+
xy = torch.nn.Parameter(data=torch.ones(2, requires_grad=True))
213+
with torch.no_grad():
214+
xy[0] = 0
215+
xy[1] = 1
216+
217+
pbm = PBM([xy], m=1, lr=0.01, dual_bounds=(1e-3, 1e3), penalty_update_m='CONST', epoch_len=2, mu=0, opt_method="Adam")
218+
219+
iters = 200
220+
221+
param_log = []
222+
con_log = []
223+
dual_log = []
224+
c_grad_log = []
225+
226+
for i in range(iters):
227+
param_log.append(
228+
xy.detach().numpy().copy()
229+
)
230+
# print(xy)
231+
232+
r = np.random.uniform()
233+
minibatch = samples[
234+
0 if r > 0.5 else 1
235+
]
236+
237+
c = balls(xy, minibatch)
238+
239+
pbm.dual_step(0, c)
240+
dual_log.append(pbm._dual_vars.detach().numpy().copy().item())
241+
242+
obj = parabola(xy)
243+
244+
pbm.step(obj)
245+
for gr in pbm.param_groups:
246+
gr['lr'] *= 0.99
247+
248+
con_log.append(c.detach().numpy().copy().item())
249+
250+
251+
# print(param_log)
252+
# print(con_log)
253+
254+
trajectories = sgd_param_logs
255+
trajectories.append(param_log)
256+
257+
plot_balls_trajectory(
258+
trajectories,
259+
[r"$\rho= $" + str(rho) for rho in rhos] + ['spbm']
260+
)
261+
262+
# print(param_log)
263+
# print(np.array(dual_log))
264+
# print(np.array(con_log))
265+
# print(np.array(con_log))

0 commit comments

Comments
 (0)