1+
12# copied from https://github.com/DongyueXie/vebpm/
23
34# '@title Solve Gaussian approximation to Poisson mean problem
@@ -38,17 +39,17 @@ ebpm_normal = function(x,
3839 vga_tol = 1e-5 ,
3940 conv_type = ' sigma2abs' ,
4041 return_sigma2_trace = FALSE ){
41-
42+
4243 # init the posterior mean and variance?
4344 n = length(x )
44-
45+
4546 if (is.null(s )){
4647 s = 1
4748 }
4849 if (length(s )== 1 ){
4950 s = rep(s ,n )
5051 }
51-
52+
5253 if (is.null(q_init )){
5354 m = log(x / s + 1 )
5455 v = rep(1 / n ,n )
@@ -67,19 +68,19 @@ ebpm_normal = function(x,
6768 if (length(v )== 1 ){
6869 v = rep(v ,n )
6970 }
70-
71+
7172 const = sum((x - 1 )* log(s )) - sum(lfactorial(x ))
7273 #
7374 t_start = Sys.time()
74-
75+
7576 if (is.null(g_init )){
7677 prior_mean = NULL
7778 prior_var = NULL
7879 }else {
7980 prior_mean = g_init $ mean
8081 prior_var = g_init $ var
8182 }
82-
83+
8384 if (length(fix_g )== 1 ){
8485 est_prior_mean = ! fix_g
8586 est_prior_var = ! fix_g
@@ -89,10 +90,10 @@ ebpm_normal = function(x,
8990 }else {
9091 stop(' fix_g can be either length 1 or 2' )
9192 }
92-
93+
9394 sigma2_trace = prior_var
9495 if (est_prior_mean | est_prior_var ){
95-
96+
9697 if (is.null(prior_mean )){
9798 est_prior_mean = TRUE
9899 beta = mean(m )
@@ -105,7 +106,7 @@ ebpm_normal = function(x,
105106 }else {
106107 sigma2 = prior_var
107108 }
108-
109+
109110 obj = rep(0 ,maxiter + 1 )
110111 obj [1 ] = - Inf
111112 sigma2_trace = sigma2
@@ -114,7 +115,7 @@ ebpm_normal = function(x,
114115 m = vga_pois_solver(m ,x ,s ,beta ,sigma2 ,tol = vga_tol )
115116 v = m $ v
116117 m = m $ m
117-
118+
118119 if (est_prior_mean ){
119120 beta = mean(m )
120121 }
@@ -124,7 +125,7 @@ ebpm_normal = function(x,
124125 sigma2_trace [iter + 1 ] = sigma2
125126 }
126127 }
127-
128+
128129 if (conv_type == ' elbo' ){
129130 obj [iter + 1 ] = ebpm_normal_obj(x ,s ,beta ,sigma2 ,m ,v ,const )
130131 if ((obj [iter + 1 ] - obj [iter ])/ n < tol ){
@@ -142,20 +143,20 @@ ebpm_normal = function(x,
142143 break
143144 }
144145 }
145-
146+
146147 }
147-
148+
148149 }else {
149150 beta = prior_mean
150151 sigma2 = prior_var
151152 m = vga_pois_solver(m ,x ,s ,beta ,sigma2 ,tol = vga_tol )
152153 v = m $ v
153154 m = m $ m
154155 obj = ebpm_normal_obj(x ,s ,prior_mean ,prior_var ,m ,v ,const )
155-
156+
156157 }
157158 t_end = Sys.time()
158-
159+
159160 return (list (posterior = list (mean_log = m ,
160161 var_log = v ,
161162 mean = exp(m + v / 2 )),
@@ -164,14 +165,14 @@ ebpm_normal = function(x,
164165 obj_trace = obj ,
165166 sigma2_trace = sigma2_trace ,
166167 run_time = difftime(t_end ,t_start ,units = ' secs' )))
167-
168+
168169}
169170
170171
171172ebpm_normal_obj = function (x ,s ,beta ,sigma2 ,m ,v ,const ){
172173 return (sum(x * m - s * exp(m + v / 2 )- log(sigma2 )/ 2 - (m ^ 2 + v - 2 * m * beta + beta ^ 2 )/ 2 / sigma2 + log(v )/ 2 )+ const )
173174}
174-
175+
175176
176177
177178
@@ -185,7 +186,7 @@ ebpm_normal_obj = function(x,s,beta,sigma2,m,v,const){
185186# '@param beta,sigma2 prior mean and variance. Their length should be equal to n=length(x)
186187# '@export
187188vga_pois_solver = function (init_val ,x ,s ,beta ,sigma2 ,maxiter = 1000 ,tol = 1e-5 ,method = ' newton' ){
188-
189+
189190 n = length(x )
190191 if (length(sigma2 )== 1 ){
191192 sigma2 = rep(sigma2 ,n )
@@ -222,8 +223,8 @@ vga_pois_solver = function(init_val,x,s,beta,sigma2,maxiter=1000,tol=1e-5,method
222223 }else {
223224 stop(' Only Newton and bisection are supported.' )
224225 }
225-
226-
226+
227+
227228}
228229
229230# '@title Optimize vga poisson problem 1 iteration.
@@ -291,22 +292,22 @@ vga_pois_solver_bisection = function(x,s,beta,sigma2,maxiter=1000,tol=1e-5){
291292
292293# '@export
293294vga_pois_solver_Newton = function (m ,x ,s ,beta ,sigma2 ,maxiter = 1000 ,tol = 1e-5 ){
294-
295+
295296 const0 = sigma2 * x + beta + 1
296297 const1 = 1 / sigma2
297298 const2 = sigma2 / 2
298299 const3 = beta / sigma2
299-
300+
300301 # make sure m < sigma2*x+beta
301302 m = pmin(m ,const0 - 1 )
302303 # idx = (m>(const0-1))
303304 # if(sum(idx)>0){
304305 # m[idx] =suppressWarnings(vga_pois_solver_bisection(x[idx],s[idx],beta[idx],sigma2[idx],maxiter = 10)$m)
305306 # }
306-
307-
307+
308+
308309 for (i in 1 : maxiter ){
309-
310+
310311 temp = (const0 - m )
311312 sexp = s * exp(m + const2 / temp )
312313 # f = x - sexp - (m-beta)/sigma2
@@ -321,9 +322,9 @@ vga_pois_solver_Newton = function(m,x,s,beta,sigma2,maxiter=1000,tol=1e-5){
321322 warnings(' Newton method not converged yet.' )
322323 }
323324 return (list (m = m ,v = sigma2 / temp ))
324-
325+
325326}
326-
327+
327328
328329
329330# '@title Solve Gaussian approximation to Poisson mean problem
@@ -346,7 +347,7 @@ vga_pois_solver_Newton = function(m,x,s,beta,sigma2,maxiter=1000,tol=1e-5){
346347# ' n = 10000
347348# ' mu = rnorm(n)
348349# ' x = rpois(n,exp(mu))
349- # ' pois_mean_GG (x)
350+ # ' pois_mean_GP (x)
350351# '@details The problem is
351352# '\deqn{x_i\sim Poisson(\exp(\mu_i)),}
352353# '\deqn{\mu_i\sim N(\beta,\sigma^2).}
@@ -358,7 +359,7 @@ pois_mean_GP = function(x,
358359 optim_method = ' L-BFGS-B' ,
359360 maxiter = 1000 ,
360361 tol = 1e-5 ){
361-
362+
362363 # init the posterior mean and variance?
363364 n = length(x )
364365 m = log(x + 0.1 )
@@ -371,7 +372,7 @@ pois_mean_GP = function(x,
371372 }
372373 #
373374 if (is.null(prior_mean ) | is.null(prior_var )){
374-
375+
375376 if (is.null(prior_mean )){
376377 est_beta = TRUE
377378 }else {
@@ -384,7 +385,7 @@ pois_mean_GP = function(x,
384385 est_sigma2 = FALSE
385386 sigma2 = prior_var
386387 }
387-
388+
388389 obj = rep(0 ,maxiter + 1 )
389390 obj [1 ] = - Inf
390391 for (iter in 1 : maxiter ){
@@ -395,7 +396,7 @@ pois_mean_GP = function(x,
395396 sigma2 = mean(m ^ 2 + v - 2 * m * beta + beta ^ 2 )
396397 }
397398 # for(i in 1:n){
398- # temp = pois_mean_GG1 (x[i],s[i],beta,sigma2,optim_method,m[i],v[i])
399+ # temp = pois_mean_GP1 (x[i],s[i],beta,sigma2,optim_method,m[i],v[i])
399400 # m[i] = temp$m
400401 # v[i] = temp$v
401402 # }
@@ -410,18 +411,18 @@ pois_mean_GP = function(x,
410411 method = optim_method )
411412 m = opt $ par [1 : n ]
412413 v = exp(opt $ par [(n + 1 ): (2 * n )])
413- obj [iter + 1 ] = pois_mean_GG_obj (x ,s ,beta ,sigma2 ,m ,v )
414+ obj [iter + 1 ] = pois_mean_GP_obj (x ,s ,beta ,sigma2 ,m ,v )
414415 if ((obj [iter + 1 ] - obj [iter ])< tol ){
415416 obj = obj [1 : (iter + 1 )]
416417 break
417418 }
418419 }
419-
420+
420421 }else {
421422 beta = prior_mean
422423 sigma2 = prior_var
423424 # for(i in 1:n){
424- # temp = pois_mean_GG1 (x[i],s[i],prior_mean,prior_var,optim_method,m[i],v[i])
425+ # temp = pois_mean_GP1 (x[i],s[i],prior_mean,prior_var,optim_method,m[i],v[i])
425426 # m[i] = temp$m
426427 # v[i] = temp$v
427428 # }
@@ -436,18 +437,18 @@ pois_mean_GP = function(x,
436437 method = optim_method )
437438 m = opt $ par [1 : n ]
438439 v = exp(opt $ par [(n + 1 ): (2 * n )])
439- obj = pois_mean_GG_obj (x ,s ,prior_mean ,prior_var ,m ,v )
440-
440+ obj = pois_mean_GP_obj (x ,s ,prior_mean ,prior_var ,m ,v )
441+
441442 }
442-
443+
443444 return (list (posterior = list (posteriorMean_latent = m ,
444445 posteriorVar_latent = v ,
445446 posteriorMean_mean = exp(m + v / 2 )),
446447 fitted_g = list (mean = beta , var = sigma2 ),
447448 obj_value = obj ))
448-
449+
449450 # return(list(posteriorMean=m,priorMean=beta,priorVar=sigma2,posteriorVar=v,obj_value=obj))
450-
451+
451452}
452453# 'calculate objective function
453454pois_mean_GP_opt_obj = function (theta ,x ,s ,beta ,sigma2 ,n ){
@@ -465,6 +466,6 @@ pois_mean_GP_opt_obj_gradient = function(theta,x,s,beta,sigma2,n){
465466}
466467
467468
468- pois_mean_GG_obj = function (x ,s ,beta ,sigma2 ,m ,v ){
469+ pois_mean_GP_obj = function (x ,s ,beta ,sigma2 ,m ,v ){
469470 return (sum(x * m - s * exp(m + v / 2 )- log(sigma2 )/ 2 - (m ^ 2 + v - 2 * m * beta + beta ^ 2 )/ 2 / sigma2 + log(v )/ 2 ))
470471}
0 commit comments