1+ from matplotlib .patches import Circle
2+ import torch
3+ from humancompatible .train .optim .PBM import PBM
4+ import numpy as np
5+ import matplotlib .pyplot as plt
6+ from mpl_toolkits .axes_grid1 import make_axes_locatable
7+
8+ torch .manual_seed (1 )
9+ np .random .seed (1 )
10+
11+ def plot_balls_trajectory (trajectories , names ):
12+ """
13+ trajectory: array-like of shape (N, 2), where each row is [x, y]
14+ """
15+
16+ fig , ax = plt .subplots (figsize = (24 , 16 ))
17+
18+ # Feasible regions: unit balls
19+ ball_centers = [(- 2 , 0 ), (2 , 0 )]
20+ radius = np .sqrt (0.99 )
21+ labels = [r"$\mathbb{E}[g_1(x,y,\xi)] \leq 0 $" , r"$\mathbb{E}[g_2(x,y,\xi)] \leq 0$" ]
22+
23+ # Heatmap for x^2 + y^2
24+ x = np .linspace (- 4 , 4 , 100 )
25+ y = np .linspace (- 2.5 , 2.5 , 100 )
26+ X , Y = np .meshgrid (x , y )
27+ Z = X ** 2 + Y ** 2
28+
29+
30+ for center , label in zip (ball_centers , labels ):
31+ ball = Circle (
32+ center ,
33+ radius = radius ,
34+ facecolor = "lightgray" ,
35+ edgecolor = "black" ,
36+ linewidth = 1.5 ,
37+ alpha = 0.6 ,
38+ zorder = 1 ,
39+ )
40+ ax .add_patch (ball )
41+ # Label inside the ball
42+ ax .text (
43+ center [0 ], center [1 ],
44+ label ,
45+ fontsize = 18 ,
46+ fontweight = 'bold' ,
47+ color = 'black' ,
48+ ha = 'center' ,
49+ va = 'center' ,
50+ zorder = 5
51+ )
52+
53+ for i , traj in enumerate (trajectories ):
54+ traj = np .asarray (traj )
55+
56+ # Trajectory
57+ ax .plot (
58+ traj [:, 0 ],
59+ traj [:, 1 ],
60+ linewidth = 2.0 ,
61+ zorder = 3 ,
62+ alpha = 1.0
63+ )
64+
65+ # Emphasize x_0 and x_n
66+ x0 = traj [0 ]
67+ xn = traj [- 1 ]
68+
69+ ax .scatter (
70+ x0 [0 ], x0 [1 ],
71+ s = 80 ,
72+ marker = "o" ,
73+ facecolor = "white" ,
74+ edgecolor = "black" ,
75+ linewidth = 2 ,
76+ zorder = 4 ,
77+ )
78+ ax .scatter (
79+ xn [0 ], xn [1 ],
80+ s = 60 ,
81+ marker = "s" ,
82+ facecolor = "black" ,
83+ edgecolor = "black" ,
84+ zorder = 4 ,
85+ )
86+ x_n = [
87+ r"$x_{\rho=0}^n$" ,
88+ r"$x_{\rho=1}^n$" ,
89+ r"$x_{\rho=2.5}^n$" ,
90+ r"$x_{SPBM}^n$" ,
91+ ]
92+ # Labels for x_0 and x_n
93+ ax .annotate (
94+ r"$x^0$" ,
95+ xy = (x0 [0 ], x0 [1 ]),
96+ xytext = (6 , 8 ),
97+ textcoords = "offset points" ,
98+ fontsize = 22 ,
99+ zorder = 5 ,
100+ )
101+ ax .annotate (
102+ x_n [i ],
103+ xy = (xn [0 ], xn [1 ]),
104+ # xytext=(8 if i==1 or i == 3 else -36, -12),
105+ xytext = (8 if i == 1 or i == 3 else - 45 , - 15 ),
106+ textcoords = "offset points" ,
107+ fontsize = 22 ,
108+ zorder = 5 ,
109+ )
110+
111+ # Formatting
112+ ax .set_aspect ("equal" , adjustable = "box" )
113+ ax .set_xlabel (r"$x$" , fontsize = 12 )
114+ ax .set_ylabel (r"$y$" , fontsize = 12 )
115+ ax .grid (True , linestyle = ":" , linewidth = 0.8 , alpha = 0.7 )
116+ ax .set_xlim (- 3.2 , 3.2 )
117+ ax .set_ylim (- 1.8 , 1.8 )
118+
119+
120+ contour = ax .contourf (X , Y , Z , levels = 100 , cmap = 'viridis' , alpha = 0.5 , zorder = 0 )
121+ divider = make_axes_locatable (ax )
122+ cax = divider .append_axes ("right" , size = "3%" , pad = 0.08 )
123+
124+ cbar = fig .colorbar (contour , cax = cax )
125+ cbar .set_label (r"$x^2 + y^2$" , fontsize = 14 )
126+
127+ fig .savefig (
128+ "./demo_balls_pbm.pdf" ,
129+ bbox_inches = "tight" ,
130+ pad_inches = 0.05
131+ )
132+
133+
134+
135+ def balls (x , sample ):
136+ g1 = ((x [0 ] - 2 + sample )** 2 + x [1 ]** 2 - 1 )
137+ g2 = ((x [0 ] + 2 + sample )** 2 + x [1 ]** 2 - 1 )
138+ # g1 = ((x[0] - 2)**2 + x[1]**2 - 1) + sample
139+ # g2 = ((x[0] + 2)**2 + x[1]**2 - 1) + sample
140+ if g1 <= g2 :
141+ return g1
142+ else :
143+ return g2
144+
145+ def parabola (x ):
146+ return x [0 ]** 2 + x [1 ]** 2
147+
148+
149+ samples = [
150+ # torch.tensor([0]),
151+ # torch.tensor([0])
152+ torch .tensor ([- 0.1 ]),
153+ torch .tensor ([0.1 ])
154+ ]
155+
156+
157+ ################## SGD #########################
158+
159+ def run_sgd (rho : float ):
160+
161+ xy = torch .nn .Parameter (data = torch .ones (2 , requires_grad = True ))
162+ with torch .no_grad ():
163+ xy [0 ] = 0
164+ xy [1 ] = 1
165+
166+ sgd = torch .optim .SGD ([xy ], lr = 0.05 , dampening = 0.1 )
167+
168+ scheduler = torch .optim .lr_scheduler .LambdaLR (
169+ sgd ,
170+ lr_lambda = lambda step : 0.99 ** step
171+ )
172+
173+ param_log_sgd = []
174+ con_log_sgd = []
175+
176+ for i in range (200 ):
177+ param_log_sgd .append (
178+ xy .detach ().numpy ().copy ()
179+ )
180+
181+ r = np .random .uniform ()
182+ minibatch = samples [0 if r > 0.5 else 1 ]
183+
184+ c = balls (xy , minibatch )
185+
186+ obj = parabola (xy ) + rho * torch .square (torch .norm (c , p = 2 ))
187+
188+ obj .backward ()
189+ sgd .step ()
190+ scheduler .step ()
191+ sgd .zero_grad ()
192+ for gr in sgd .param_groups :
193+ gr ['lr' ] *= 0.97
194+
195+ con_log_sgd .append (c .detach ().numpy ().copy ().item ())
196+
197+ return param_log_sgd , con_log_sgd
198+
199+
200+ sgd_param_logs , sgd_con_logs = [], []
201+
202+ rhos = [0.1 ,1 ,2 ]
203+
204+ for rho in rhos :
205+ param_log_sgd , con_log_sgd = run_sgd (rho )
206+ sgd_param_logs .append (param_log_sgd )
207+ sgd_con_logs .append (con_log_sgd )
208+
209+
210+ ################## PBM #########################
211+
212+ xy = torch .nn .Parameter (data = torch .ones (2 , requires_grad = True ))
213+ with torch .no_grad ():
214+ xy [0 ] = 0
215+ xy [1 ] = 1
216+
217+ pbm = PBM ([xy ], m = 1 , lr = 0.01 , dual_bounds = (1e-3 , 1e3 ), penalty_update_m = 'CONST' , epoch_len = 2 , mu = 0 , opt_method = "Adam" )
218+
219+ iters = 200
220+
221+ param_log = []
222+ con_log = []
223+ dual_log = []
224+ c_grad_log = []
225+
226+ for i in range (iters ):
227+ param_log .append (
228+ xy .detach ().numpy ().copy ()
229+ )
230+ # print(xy)
231+
232+ r = np .random .uniform ()
233+ minibatch = samples [
234+ 0 if r > 0.5 else 1
235+ ]
236+
237+ c = balls (xy , minibatch )
238+
239+ pbm .dual_step (0 , c )
240+ dual_log .append (pbm ._dual_vars .detach ().numpy ().copy ().item ())
241+
242+ obj = parabola (xy )
243+
244+ pbm .step (obj )
245+ for gr in pbm .param_groups :
246+ gr ['lr' ] *= 0.99
247+
248+ con_log .append (c .detach ().numpy ().copy ().item ())
249+
250+
251+ # print(param_log)
252+ # print(con_log)
253+
254+ trajectories = sgd_param_logs
255+ trajectories .append (param_log )
256+
257+ plot_balls_trajectory (
258+ trajectories ,
259+ [r"$\rho= $" + str (rho ) for rho in rhos ] + ['spbm' ]
260+ )
261+
262+ # print(param_log)
263+ # print(np.array(dual_log))
264+ # print(np.array(con_log))
265+ # print(np.array(con_log))
0 commit comments