-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.m
More file actions
executable file
·150 lines (122 loc) · 3.61 KB
/
main.m
File metadata and controls
executable file
·150 lines (122 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
% ----------------------- TIME AND PRECISION ----------------------------
sizes = [50, 100, 150, 200, 250, 300];
eps = 1e-8;
global Ua Ub Sa Sb Va Vb;
for ii = 1:length(sizes)
% create matrices A and B
M = sizes(ii);
N = sizes(ii);
x = 0.1:0.1:M;
y = 0.1:0.1:N;
[X, Y] = meshgrid(x, y);
A = 1 ./ (X + Y);
B = 1 ./ sqrt(X.^2 + Y.^2);
C = A .* B;
% calculate svd of A and B
[Ua, Sa, Va] = truncated_svd(A, eps);
[Ub, Sb, Vb] = truncated_svd(B, eps);
% calculate svd of C directly
t = tic;
[Uc, Sc, Vc] = truncated_svd(C, eps);
time_svd(ii) = toc(t);
svd_errors(ii) = norm(C - Uc * Sc * Vc');
% calculate approximation of C by Lanczos + fast mv
% initialize random vector
x = rand(10*N,1);
% || C*C' - Q*H*Q' || < tol
t = tic;
[Q, H] = lanczos(@mvMult_times, x, 1e-8, 30);
% || H - V*L*V' || < tol
[V, L] = eig(H);
% || C*C' - Uc*L*Uc' || < 2*tol
Uc = Q * V;
% get sigmas, s_i = sqrt(l_i)
sigmas = sqrt(diag(L));
% v_i = C'*u_i / s_i
Vc = [];
m = size(Uc, 2);
for j = 1:m
u = Uc(:, j);
s = sigmas(j);
%v = (C' * u) ./ s;
v = (mvMult_transpose(u)) ./ s;
Vc = [Vc, v];
end
% Sc = sqrt(L)
Sc = diag(sigmas);
time_approx(ii) = toc(t);
% error of our approximation
approx_errors(ii) = norm(C - Uc * Sc * Vc', 'fro');
end
% plot times
figure()
plot(sizes, time_svd, 'bx-', sizes, time_approx, 'gx-')
set(gca,'fontsize',10)
xlabel('matrix size')
ylabel('time')
legend({'truncated SVD', 'lanczos + fast mvMult'}, 'Location', 'NorthWest');
grid on;
% plot error of approximations
figure()
semilogy(sizes, svd_errors, 'bx-', sizes, approx_errors, 'gx-')
set(gca,'fontsize',10)
xlabel('matrix size')
ylabel('error')
legend({'truncated SVD', 'lanczos + fast mvMult'}, 'Location', 'NorthWest');
grid on;
% ------------------------ RANK COMPARISON -----------------------------
eps = 1e-4;
% create matrices A and B
M = 1;
N = 2;
x = 0.1:0.1:M;
y = 0.1:0.1:N;
[X, Y] = meshgrid(x, y);
A = 1 ./ (X + Y);
B = 1 ./ sqrt(X.^2 + Y.^2);
C = A .* B;
% calculate svd of A and B
[Ua, Sa, Va] = truncated_svd(A, eps);
[Ub, Sb, Vb] = truncated_svd(B, eps);
% calculate svd of C directly
[Uc, Sc, Vc] = truncated_svd(C, eps);
% rank and accuracy of SVD and Hadamard representation
rankC = rank(Uc * Sc * Vc');
errorC = norm(C - Uc * Sc * Vc','fro');
rankC2 = rank(kr(Ua', Ub')' * kron(Sa, Sb) * kr(Va', Vb'));
errorC2 = norm(C - kr(Ua', Ub')' * kron(Sa, Sb) * kr(Va', Vb'),'fro');
% calculate approximation of C by Lanczos + fast mv
% initialize random vector
x = rand(10*N,1);
% || C*C' - Q*H*Q' || < tol
[Q, H] = lanczos(@mvMult_times, x, 1e-8, 30);
% || H - V*L*V' || < tol
[V, L] = eig(H);
% || C*C' - Uc*L*Uc' || < 2*tol
Uc = Q * V;
% get sigmas, s_i = sqrt(l_i)
sigmas = sqrt(diag(L));
% v_i = C'*u_i / s_i
Vc = [];
m = size(Uc, 2);
for j = 1:m
u = Uc(:, j);
s = sigmas(j);
%v = (C' * u) ./ s;
v = (mvMult_transpose(u)) ./ s;
Vc = [Vc, v];
end
% Sc = sqrt(L)
Sc = diag(sigmas);
time_approx(ii) = toc(t);
% error of our approximation
approx_errors(ii) = norm(C - Uc * Sc * Vc', 'fro');
% calculate rank and accuracy of our approximation
rankC3 = rank(Uc * Sc * Vc');
errorC3 = norm(C - Uc * Sc * Vc','fro');
fprintf("C truncSVD rank: %d\n", rankC);
fprintf("C approximation rank: %d\n", rankC2);
fprintf("C our approx rank: %d\n", rankC3)
fprintf("Error of truncSVD: %f\n", errorC)
fprintf("Error of approximation: %f\n", errorC2)
fprintf("Error of our approximation: %f\n", errorC3)