-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathAMP_singlecell.m
More file actions
108 lines (105 loc) · 3.85 KB
/
Copy pathAMP_singlecell.m
File metadata and controls
108 lines (105 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
function [X,Pa] = AMP_singlecell(Y,S,gamma_w,lsfc,AMP_option)
%% System Size Extraction
[L,M] = size(Y);
[~,N] = size(S);
%% Hyper-parameters Initialization
p = 0.5*ones(N,1);
V = lsfc * ones(1,M);
%% Variable Initialization
Pi = zeros(N,M);
Mu = zeros(N,M);
Sigma = zeros(N,M);
X_hat = zeros(N,M);
X_var = zeros(N,M);
R = S'*Y;
Vk = zeros(L,M);
Gamma = ones(M,1)./(L+L*diag(Y'*Y)/norm(S,'fro')^2);
aclist = 1:N;
%% Algorithm Parameter
MAXITER = 200;
Damp = 0.03;
Threshold = 1e-4;
relative_change = zeros(1,MAXITER);
%% SVD
[Bar_U,Bar_S,Bar_V] = svd(S,'econ');
Rank_S = rank(Bar_S);
Y_tilde = Bar_S \ Bar_U' * Y;
%% Iteration Process
for t=1:MAXITER
%% E-step
X_pre = X_hat; % X record for damp
Gamma_pre = Gamma; % Gamma record for damp
% Pre-computation for AMP
prod_temp = ones(N,1);
for n=1:N
for m=1:M
prod_temp(n) = prod_temp(n) * (V(n,m) * Gamma(m) + 1)...
* exp( - Gamma(m)^2 * V(n,m) * norm(R(n,m),2)^2/(V(n,m)...
* Gamma(m) + 1));
end
end
prod_temp(prod_temp<1e-6) = 1e-6;
% AMP iterations across M antennas
for m=1:M
alpha_m = 0;
for i=1:length(aclist)
n = aclist(i);
% Run several standard AMP iterations before the proposed
% method since the mild condition is not satisfied. If M and L
% are large, this step could be short.
% Example: M=L=200, t<2; M=L=120, t<3; M=L=100, t<4; M=L=80,
% t<5; M=40, L=30, t<14.
if t<(14)
Pi(n,m) = (1 + ((1-p(n))/p(n)) * (V(n,m) * Gamma(m) + 1)...
* exp( - Gamma(m)^2 * V(n,m) * norm(R(n,m),2)^2/(V(n,m)...
* Gamma(m) + 1)))^(-1);
else
Pi(n,m) = (1 + ((1-p(n))/p(n)) * prod_temp(n))^(-1);
if Pi(n,m)<1e-8
Pi(n,m)=1e-8;
end
end
Mu(n,m) = V(n,m) * Gamma(m)/(V(n,m) * Gamma(m) + 1) * R(n,m); % Mean of channel
Sigma(n,m) = real(V(n,m)/(V(n,m) * Gamma(m) + 1)); % Variance of channel
X_hat(n,m) = Damp * X_pre(n,m) + (1-Damp) * Pi(n,m) * Mu(n,m); % Posterior mean
X_var(n,m) = Pi(n,m) * Sigma(n,m); % Posterior variance
phi_temp = 1/Pi(n,m);
omega_temp = 1 + (Gamma(m)^2 * V(n,m) * (phi_temp-1) ...
* norm(R(n,m),2)^2)/((Gamma(m)*V(n,m) + 1) * phi_temp);
alpha_temp = (Gamma(m)*V(n,m))/(Gamma(m)*V(n,m)+1) ...
* (omega_temp/phi_temp);
alpha_m = alpha_m + alpha_temp/N;
end
if strcmp(AMP_option, 'vector AMP') == 1
R_tilde = (X_hat(:,m) - alpha_m * R(:,m))/(1 - alpha_m);
gamma_tilde = real(Gamma(m) * (1 - alpha_m)/alpha_m);
d = gamma_w * (gamma_w * Bar_S .* Bar_S + gamma_tilde ...
* eye(Rank_S))^(-1)...
* diag(Bar_S .* Bar_S);
Gamma(m) = Damp * Gamma_pre(m) + (1-Damp) * real(gamma_tilde...
* mean(d) / (N/Rank_S - mean(d))); % Update Gamma(m)
R(:,m) = R_tilde + (N/Rank_S) * Bar_V * diag(d/mean(d))...
* (Y_tilde(:,m) - Bar_V' * R_tilde);
elseif strcmp(AMP_option, 'AMP') == 1
Vk(:,m) = Y(:,m) - S*X_hat(:,m) + (N/L)*alpha_m*Vk(:,m);
R(:,m) = X_hat(:,m) + S'*Vk(:,m);
Gamma_temp = (norm(Vk(:,m),2)^2/L)^(-1);
Gamma(m) = Damp * Gamma_pre(m) + (1-Damp) * Gamma_temp;
end
end
%% M-step
% Update $p_n$
p = real(mean(Pi,2));
p=real(p);
p(p<1e-8) = 1e-8;
%% Stop criteria
relative_change(t) = norm(X_hat-X_pre,'fro')^2/norm(X_hat,'fro')^2;
if t>5 && relative_change(t) < Threshold
break;
end
end
fprintf('Method: CVAMP, it %d: relative_change = %g\n', t, ...
relative_change(t));
%% Generate output
X = X_hat;
Pa = p;