Skip to content

Commit 945eaf5

Browse files
bringing correction
1 parent a1f4a05 commit 945eaf5

4 files changed

Lines changed: 238 additions & 181 deletions

File tree

R/EB_poisson_mean_routines.R

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#copied from https://github.com/DongyueXie/vebpm/
23

34
#'@title Solve Gaussian approximation to Poisson mean problem
@@ -38,17 +39,17 @@ ebpm_normal = function(x,
3839
vga_tol=1e-5,
3940
conv_type='sigma2abs',
4041
return_sigma2_trace=FALSE){
41-
42+
4243
# init the posterior mean and variance?
4344
n = length(x)
44-
45+
4546
if(is.null(s)){
4647
s = 1
4748
}
4849
if(length(s)==1){
4950
s = rep(s,n)
5051
}
51-
52+
5253
if(is.null(q_init)){
5354
m = log(x/s+1)
5455
v = rep(1/n,n)
@@ -67,19 +68,19 @@ ebpm_normal = function(x,
6768
if(length(v)==1){
6869
v = rep(v,n)
6970
}
70-
71+
7172
const = sum((x-1)*log(s)) - sum(lfactorial(x))
7273
#
7374
t_start = Sys.time()
74-
75+
7576
if(is.null(g_init)){
7677
prior_mean = NULL
7778
prior_var = NULL
7879
}else{
7980
prior_mean = g_init$mean
8081
prior_var = g_init$var
8182
}
82-
83+
8384
if(length(fix_g)==1){
8485
est_prior_mean = !fix_g
8586
est_prior_var = !fix_g
@@ -89,10 +90,10 @@ ebpm_normal = function(x,
8990
}else{
9091
stop('fix_g can be either length 1 or 2')
9192
}
92-
93+
9394
sigma2_trace = prior_var
9495
if(est_prior_mean | est_prior_var){
95-
96+
9697
if(is.null(prior_mean)){
9798
est_prior_mean = TRUE
9899
beta = mean(m)
@@ -105,7 +106,7 @@ ebpm_normal = function(x,
105106
}else{
106107
sigma2=prior_var
107108
}
108-
109+
109110
obj = rep(0,maxiter+1)
110111
obj[1] = -Inf
111112
sigma2_trace = sigma2
@@ -114,7 +115,7 @@ ebpm_normal = function(x,
114115
m = vga_pois_solver(m,x,s,beta,sigma2,tol=vga_tol)
115116
v = m$v
116117
m = m$m
117-
118+
118119
if(est_prior_mean){
119120
beta = mean(m)
120121
}
@@ -124,7 +125,7 @@ ebpm_normal = function(x,
124125
sigma2_trace[iter+1] = sigma2
125126
}
126127
}
127-
128+
128129
if(conv_type=='elbo'){
129130
obj[iter+1] = ebpm_normal_obj(x,s,beta,sigma2,m,v,const)
130131
if((obj[iter+1] - obj[iter])/n <tol){
@@ -142,20 +143,20 @@ ebpm_normal = function(x,
142143
break
143144
}
144145
}
145-
146+
146147
}
147-
148+
148149
}else{
149150
beta = prior_mean
150151
sigma2 = prior_var
151152
m = vga_pois_solver(m,x,s,beta,sigma2,tol=vga_tol)
152153
v = m$v
153154
m = m$m
154155
obj = ebpm_normal_obj(x,s,prior_mean,prior_var,m,v,const)
155-
156+
156157
}
157158
t_end = Sys.time()
158-
159+
159160
return(list(posterior = list(mean_log = m,
160161
var_log = v,
161162
mean = exp(m + v/2)),
@@ -164,14 +165,14 @@ ebpm_normal = function(x,
164165
obj_trace = obj,
165166
sigma2_trace=sigma2_trace,
166167
run_time = difftime(t_end,t_start,units='secs')))
167-
168+
168169
}
169170

170171

171172
ebpm_normal_obj = function(x,s,beta,sigma2,m,v,const){
172173
return(sum(x*m-s*exp(m+v/2)-log(sigma2)/2-(m^2+v-2*m*beta+beta^2)/2/sigma2+log(v)/2)+const)
173174
}
174-
175+
175176

176177

177178

@@ -185,7 +186,7 @@ ebpm_normal_obj = function(x,s,beta,sigma2,m,v,const){
185186
#'@param beta,sigma2 prior mean and variance. Their length should be equal to n=length(x)
186187
#'@export
187188
vga_pois_solver = function(init_val,x,s,beta,sigma2,maxiter=1000,tol=1e-5,method = 'newton'){
188-
189+
189190
n = length(x)
190191
if(length(sigma2)==1){
191192
sigma2 = rep(sigma2,n)
@@ -222,8 +223,8 @@ vga_pois_solver = function(init_val,x,s,beta,sigma2,maxiter=1000,tol=1e-5,method
222223
}else{
223224
stop('Only Newton and bisection are supported.')
224225
}
225-
226-
226+
227+
227228
}
228229

229230
#'@title Optimize vga poisson problem 1 iteration.
@@ -291,22 +292,22 @@ vga_pois_solver_bisection = function(x,s,beta,sigma2,maxiter=1000,tol=1e-5){
291292

292293
#'@export
293294
vga_pois_solver_Newton = function(m,x,s,beta,sigma2,maxiter=1000,tol=1e-5){
294-
295+
295296
const0 = sigma2*x+beta + 1
296297
const1 = 1/sigma2
297298
const2 = sigma2/2
298299
const3 = beta/sigma2
299-
300+
300301
# make sure m < sigma2*x+beta
301302
m = pmin(m,const0-1)
302303
# idx = (m>(const0-1))
303304
# if(sum(idx)>0){
304305
# m[idx] =suppressWarnings(vga_pois_solver_bisection(x[idx],s[idx],beta[idx],sigma2[idx],maxiter = 10)$m)
305306
# }
306-
307-
307+
308+
308309
for(i in 1:maxiter){
309-
310+
310311
temp = (const0-m)
311312
sexp = s*exp(m+const2/temp)
312313
# f = x - sexp - (m-beta)/sigma2
@@ -321,9 +322,9 @@ vga_pois_solver_Newton = function(m,x,s,beta,sigma2,maxiter=1000,tol=1e-5){
321322
warnings('Newton method not converged yet.')
322323
}
323324
return(list(m=m,v=sigma2/temp))
324-
325+
325326
}
326-
327+
327328

328329

329330
#'@title Solve Gaussian approximation to Poisson mean problem
@@ -346,7 +347,7 @@ vga_pois_solver_Newton = function(m,x,s,beta,sigma2,maxiter=1000,tol=1e-5){
346347
#' n = 10000
347348
#' mu = rnorm(n)
348349
#' x = rpois(n,exp(mu))
349-
#' pois_mean_GG(x)
350+
#' pois_mean_GP(x)
350351
#'@details The problem is
351352
#'\deqn{x_i\sim Poisson(\exp(\mu_i)),}
352353
#'\deqn{\mu_i\sim N(\beta,\sigma^2).}
@@ -358,7 +359,7 @@ pois_mean_GP = function(x,
358359
optim_method = 'L-BFGS-B',
359360
maxiter = 1000,
360361
tol = 1e-5){
361-
362+
362363
# init the posterior mean and variance?
363364
n = length(x)
364365
m = log(x+0.1)
@@ -371,7 +372,7 @@ pois_mean_GP = function(x,
371372
}
372373
#
373374
if(is.null(prior_mean) | is.null(prior_var)){
374-
375+
375376
if(is.null(prior_mean)){
376377
est_beta = TRUE
377378
}else{
@@ -384,7 +385,7 @@ pois_mean_GP = function(x,
384385
est_sigma2 = FALSE
385386
sigma2=prior_var
386387
}
387-
388+
388389
obj = rep(0,maxiter+1)
389390
obj[1] = -Inf
390391
for(iter in 1:maxiter){
@@ -395,7 +396,7 @@ pois_mean_GP = function(x,
395396
sigma2 = mean(m^2+v-2*m*beta+beta^2)
396397
}
397398
# for(i in 1:n){
398-
# temp = pois_mean_GG1(x[i],s[i],beta,sigma2,optim_method,m[i],v[i])
399+
# temp = pois_mean_GP1(x[i],s[i],beta,sigma2,optim_method,m[i],v[i])
399400
# m[i] = temp$m
400401
# v[i] = temp$v
401402
# }
@@ -410,18 +411,18 @@ pois_mean_GP = function(x,
410411
method = optim_method)
411412
m = opt$par[1:n]
412413
v = exp(opt$par[(n+1):(2*n)])
413-
obj[iter+1] = pois_mean_GG_obj(x,s,beta,sigma2,m,v)
414+
obj[iter+1] = pois_mean_GP_obj(x,s,beta,sigma2,m,v)
414415
if((obj[iter+1] - obj[iter])<tol){
415416
obj = obj[1:(iter+1)]
416417
break
417418
}
418419
}
419-
420+
420421
}else{
421422
beta = prior_mean
422423
sigma2 = prior_var
423424
# for(i in 1:n){
424-
# temp = pois_mean_GG1(x[i],s[i],prior_mean,prior_var,optim_method,m[i],v[i])
425+
# temp = pois_mean_GP1(x[i],s[i],prior_mean,prior_var,optim_method,m[i],v[i])
425426
# m[i] = temp$m
426427
# v[i] = temp$v
427428
# }
@@ -436,18 +437,18 @@ pois_mean_GP = function(x,
436437
method = optim_method)
437438
m = opt$par[1:n]
438439
v = exp(opt$par[(n+1):(2*n)])
439-
obj = pois_mean_GG_obj(x,s,prior_mean,prior_var,m,v)
440-
440+
obj = pois_mean_GP_obj(x,s,prior_mean,prior_var,m,v)
441+
441442
}
442-
443+
443444
return(list(posterior = list(posteriorMean_latent = m,
444445
posteriorVar_latent = v,
445446
posteriorMean_mean = exp(m + v/2)),
446447
fitted_g = list(mean = beta, var=sigma2),
447448
obj_value=obj))
448-
449+
449450
#return(list(posteriorMean=m,priorMean=beta,priorVar=sigma2,posteriorVar=v,obj_value=obj))
450-
451+
451452
}
452453
#'calculate objective function
453454
pois_mean_GP_opt_obj = function(theta,x,s,beta,sigma2,n){
@@ -465,6 +466,6 @@ pois_mean_GP_opt_obj_gradient = function(theta,x,s,beta,sigma2,n){
465466
}
466467

467468

468-
pois_mean_GG_obj = function(x,s,beta,sigma2,m,v){
469+
pois_mean_GP_obj = function(x,s,beta,sigma2,m,v){
469470
return(sum(x*m-s*exp(m+v/2)-log(sigma2)/2-(m^2+v-2*m*beta+beta^2)/2/sigma2+log(v)/2))
470471
}

0 commit comments

Comments
 (0)