-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstrong_wolfe.py
More file actions
62 lines (54 loc) · 2.36 KB
/
strong_wolfe.py
File metadata and controls
62 lines (54 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# 至关重要
import numpy as np
floatX = np.longdouble
def strong_wolfe(fhandle, x, p, *, c1=1e-4, c2=0.9, alpha_max=1e6,
max_iter=20, max_zoom=20):
"""
Strong Wolfe 线搜索
"""
x = np.asarray(x, dtype=floatX).ravel()
p = np.asarray(p, dtype=floatX).ravel()
f0 = fhandle(x, 'f')
g0 = fhandle(x, 'grad')
gp0 = float(np.dot(g0, p))
alpha0 = 1e-8 # 探索区间左端
alpha1 = 1.0 # 当前尝试步长(也是探索区间右端)
for _ in range(max_iter):
f1 = fhandle(x + alpha1 * p, 'f')
#Armijo 不满足 → 步子太大,进入 zoom 收缩
if f1 > f0 + c1 * alpha1 * gp0:
alpha0_ = _zoom(fhandle, x, p, alpha0, alpha1, f0, gp0, c1, c2, max_zoom)#左端前移
if alpha0==alpha0_:#左端点没动,说明再前进一点点就上升,停机即可
return alpha0
alpha0 = alpha0_
#Armijo 满足了,再检查曲率
g1 = fhandle(x + alpha1 * p, 'grad')
gp1 = float(np.dot(g1, p))
if abs(gp1) <= -c2 * gp0: # 下降的势不太大,可以结束
return alpha1
# 还有下降的势气
if alpha1 == alpha_max:# 已经到头了,但是目标函数可以下降,勉强接受
return alpha1
alpha1 = min(alpha1 * 2.0, alpha_max)#右端前移
# 要么是梯度几乎平了,计算精度问题导致无法下降
# 要么是始终有下降的势
return alpha0
def _zoom(fhandle, x, p, alpha0, alpha1, f0, gp0, c1, c2, max_zoom=20):
x = x.ravel()
p = p.ravel()
for _ in range(max_zoom):
alpha_t = 0.5 * (alpha0 + alpha1) # 中点试探
f_test = fhandle(x + alpha_t * p, 'f')
# 步子还是太大,右端往后收
if f_test > f0 + c1 * alpha_t * gp0:
alpha1 = alpha_t
continue
# Armijo 满足了,检查曲率
gp_test = float(np.dot(fhandle(x + alpha_t * p, 'grad'), p))
if abs(gp_test) <= -c2 * gp0:
return alpha_t
# 还有下降的势,左端迁移
alpha0 = alpha_t
# 要么始终有下降的势,无法满足曲率条件,退出,再扩大搜索范围
# 要么一直不下降,不停continue,左端点没动
return alpha0