-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
68 lines (54 loc) · 1.92 KB
/
run.py
File metadata and controls
68 lines (54 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# run.py
import numpy as np
import pandas as pd
import time
from rosenbrock import rosenbrock
from sd import sd
from bfgs import bfgs
from lbfgs import lbfgs
# ============================================
ALGORITHM = 'SD' # 'SD', 'BFGS', 'L_BFGS'
DIM = 200 # 维度
N_POINTS = 100 # 测试点数
m = 10 # L-BFGS 用的 m
success_t=1e-8#收敛的判断标准
# =====================================================
np.random.seed(0)
floatX = np.longdouble#至关重要
# 选择优化器
if ALGORITHM == 'SD':
opt = sd
elif ALGORITHM == 'BFGS':
opt = bfgs
elif ALGORITHM == 'L_BFGS':
opt = lambda f, x0: lbfgs(f, x0, m=m)
# 生成随机初始点
range_val = 3
X0 = 1 + 2 * range_val * (np.random.rand(N_POINTS, DIM) - 0.5)
X0 = np.asarray(X0, dtype=floatX)
records = []
print(f"开始测试 {ALGORITHM} | 维度 {DIM} | 测试点 {N_POINTS} 个")
print("-" * 70)
for i in range(N_POINTS):
x0 = X0[i]
t0 = time.time()
x, fval, it = opt(rosenbrock, x0)[:3]
t = time.time() - t0
dist = np.linalg.norm(x - np.ones(DIM, dtype=floatX))
err = abs(fval)
success = dist < success_t and err < success_t
records.append([i + 1, DIM, ALGORITHM, it, round(t, 4), dist, err, success])
status = "成功" if success else "失败"
print(f"点 {i + 1:3d} → {it:5d} 次迭代 {t:6.3f} 秒 误差 {err:.2e} {status}")
# 保存结果
df = pd.DataFrame(records, columns=[
'id', 'dim', 'algo', 'iter', 'time_s', 'dist', 'error', 'success'])
filename = f"result\{ALGORITHM}_dim{DIM}_{N_POINTS}pts.csv"
df.to_csv(filename, index=False, float_format='%.16e')
success_rate = df.success.mean()
print("-" * 70)
print(f"全部完成!")
print(f"成功率:{success_rate:.1%} ({df.success.sum()}/{N_POINTS})")
print(f"平均迭代次数:{df.iter.mean():.1f}")
print(f"平均时间:{df.time_s.mean():.3f} 秒")
print(f"结果已保存 → {filename}")