forked from rsagroup/rsaModelComparison
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrsa_marglNonlin.m
More file actions
110 lines (97 loc) · 4.64 KB
/
rsa_marglNonlin.m
File metadata and controls
110 lines (97 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
function [nlmlSum,dnlmlSum,wN,VN,nlml] = rsa_marglNonlin(theta,Model, Y, Sigma);
% function [nlmlSum,dnlmlSum,wN,VN,nlml] = rsa_marglNonlin(theta,Model, Y, Sigma);
% returns marginal likelihood (and derivatives) for nonlinear models with 2 levels of parameters:
% 1.level are omega parameters that are subject-specific. These are
% marginalised out
% 2.level are group "hyper"-parameters that
% The first Model.numComp are the prior variance on the omegas
% The next Model.numNonlin are nonlinear components
% INPUT:
% -theta: Hyper parameters consist of liniear and nonlinear parts
%
% -Model: Model structure determing
% Model.numComp : Number of linearly seperable components (at least 1)
% Model.numPrior : Number of prior parameters on the component coefficients
% Model.numNonlin : Number of nonlinear parameters
% Model.nonlinP0 : Starting value of nonlinear(mixing) parameters
% Model.constantParams: Cell array of additional parameters to function
% Model.fcn : Function returning RDM and derivatives
%
% -Y: Data (i.e., distance) for all subjects
% Size: (number of distance pairs)x(number of subjects)
% -Sigma: Covariance matrix of distance for all subjects
% Size: (number of distance pairs)x(number of distance pairs)x(number of subjects)
%
% OUTPUT:
% nlml is the returned value of the negative log marginal likelihood
% dnlml is a (column) vector of partial derivatives of the negative
% log marginal likelihood wrt each log hyperparameter
% wN is the posterior Model of the regression coefficients
% VN is the posterior Variance of the regression coefficients
%
% Joern Diedrichsen
% Prior covariance of the regression coefficients
logtheta = theta(1:Model.numPrior); % Get the number precesion parameter for regressors
nonlinP = theta([1:Model.numNonlin]+Model.numPrior);
% Prior variance-covariance matrix
V0 = diag(exp(logtheta));
% Get the linear design matrix and the derviates of the design matrix
% in respect to the non-linear parameters
M = feval(Model.fcn,nonlinP,Model.constantParams{:});
X = permute(M.RDM,[2 1 3]);
dX = permute(M.dRDMdTheta,[2 1 3 4]); % ay
[N, numSubj] = size(Y);
[N, numReg,depthX] = size(X);
if (depthX ==1)
X=repmat(X,1,1,numSubj); % use the same model for all subjects
dX = repmat(dX,1,1,1,numSubj); % ay
elseif (depthX~=numSubj)
error('X must needs to be a matrix or have a size in the 3rd dimension of numSubj');
end;
% nlml = zeros(numSubj);
% dnlml = zeros(numReg+Model.numNonlin,numSubj);
for s=1:numSubj
% Precompute outer products of the X-vectors
XX = zeros(N,N,numReg);
for i=1:numReg
XX(:,:,i)=X(:,i,s)*X(:,i,s)';
end;
% Precompute V0*X(:,:,s)'
V0XT = V0*X(:,:,s)'; % = X(:,:,s)*V0.
S = Sigma(:,:,s) + X(:,:,s)*V0*X(:,:,s)' ; % Compute training set covariance matrix
L = chol(S)'; % Cholesky factorization of the covariance
alpha = solve_chol(L',Y(:,s)); % Convenience function (=S^-1*Y)
% Negative log-likihood
nlml(s) = 0.5*sum(sum(alpha.*Y(:,s),2)) + sum(log(diag(L))) + 0.5*N*log(2*pi);
% Derivatives
if (nargout>1)
invS = (L'\(L\eye(N)));
W = alpha*alpha'-invS; % This is 2 times derivative of negative log-likelihood wrt S (=2*dl/dS)
% Derivative of negative log-likelihood wrt log prior parameters (=(dl/dS)x(dS/dV0)xV0)
for i=1:numReg
thetai = exp(logtheta(i));
dnlml(i,s) = -1/2*sum(sum(W.*(XX(:,:,i))))*thetai;
end;
% Derivative of negative log-likelihood wrt nonlinear parameters
for i=1:Model.numNonlin
tmp = dX(:,:,i,s)*V0XT;
dSdTheta = tmp+tmp'; % because dX*V0XT = (XV0*dX')'
% dSdTheta = dX(:,:,i,s)*V0*X(:,:,s)' + X(:,:,s)*V0*dX(:,:,i,s)';
dnlml(Model.numPrior+i,s) = -1/2*sum(sum(W.*dSdTheta));
end
end;
% Regression coefficients
if (nargout>2)
wN(:,s) = V0*X(:,:,s)'*alpha; % Regression coefficients by Matrix inversion
end;
if (nargout>3)
VN(:,:,s) = V0 - V0*X(:,:,s)'*invS*X(:,:,s)*V0; % Posterior variance
VN(:,:,s) = (VN(:,:,s)+VN(:,:,s)')/2; % Prevent asymmrety through roundoff.
end;
%keyboard;
end;
% Sum marginal likelihoods for optimization
nlmlSum = sum(nlml);
if (nargout>1)
dnlmlSum = sum(dnlml,2);
end;