-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlflPredictor.m
44 lines (38 loc) · 1.57 KB
/
lflPredictor.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
% Predictions from the LFL model using the given weight vector
% This is computed over all users U and users V (which are implicit in the weight)
% Output consists of the real-valued predictions (expected value under the
% probability model); discrete valued 'argmax predictions', viz the most
% likely label under the probability model; and the actual probabilities
% themselves
function [predictions, argmaxPredictions, probabilities] = lflPredictor(w)
Y = size(w.userUW, 2);
U = size(w.userUW, 3);
n = length(w.usersU);
probabilities = zeros(U, U, Y);
for index = 1 : n;
u = w.usersU(index);
v = w.usersV(index);
uW = w.userUW(:,:,u);
vW = w.userVW(:,:,v);
lW = w.lambdaW;
if w.withSideInfo
s = [w.sideInfo(u,:)'; w.sideInfo(v,:)';];
sW = w.sideInfoW;
% Vector whose ith element is Pr[y = i | u, v; w]
p = exp(diag(uW' * lW * uW + sW' * s));
else
% p = exp(diag(uW' * lW * uW));
p = exp(diag(uW' * lW * vW));
end
p = p/sum(p);
probabilities(u, v, :) = p;
end
% for y = 1:Y
% uW = squeeze(w.userW(:,y,:));
% % lW = squeeze(w.lambdaW(:,:,y));
% lW = squeeze(w.lambdaW);
% probabilities(:,:,y) = exp(uW' * lW * uW);
% end
probabilities = bsxfun(@rdivide, probabilities, sum(probabilities, 3));
predictions = sum(bsxfun(@times, reshape(1:Y, [1 1 Y]), probabilities), 3);
[values, argmaxPredictions] = max(probabilities, [], 3);