Skip to content

Commit d47e5d0

Browse files
committed
added general_search_v2
1 parent c787b91 commit d47e5d0

File tree

1 file changed

+118
-2
lines changed

1 file changed

+118
-2
lines changed

berliner/utils/machine.py

+118-2
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def general_search(params, sed_mod, lnprior,
373373
test_sed_obs, test_sed_obs_err=None,
374374
test_vpi_obs=None, test_vpi_obs_err=None,
375375
Lvpi=1.0, Lprior=1.0, sed_err_typical=0.1, cost_order=2,
376-
av_llim=0., return_est=False):
376+
av_llim=0., debug=False):
377377
"""
378378
when p = [T, G, Av, DM],
379379
given a set of SED,
@@ -440,7 +440,7 @@ def general_search(params, sed_mod, lnprior,
440440
lnprob[av_est < av_llim] = -np.inf
441441
lnprob -= np.nanmax(lnprob)
442442

443-
if return_est:
443+
if debug:
444444
return params, av_est, dm_est, cost_sed, lnprob
445445

446446
# normalization
@@ -477,3 +477,119 @@ def general_search(params, sed_mod, lnprior,
477477
ind_mle=ind_mle,
478478
n_good=np.sum(ind_good_band)
479479
)
480+
481+
482+
def general_search_v2(params, sed_mod, lnprior,
483+
Alambda,
484+
sed_obs, sed_obs_err=0.1,
485+
vpi_obs=None, vpi_obs_err=None,
486+
Lvpi=1.0, Lprior=1.0, sed_err_typical=0.1, cost_order=2,
487+
av_llim=0., debug=False):
488+
"""
489+
when p = [teff, logg, [M/H], Av, DM], theta = [teff, logg, [M/H]],
490+
given a set of SED,
491+
find the best theta and estimate the corresponding Av and DM
492+
"""
493+
494+
n_band = len(sed_obs)
495+
n_mod = sed_mod.shape[0]
496+
497+
# cope with scalar sed_obs_err
498+
if isinstance(sed_obs_err, np.float):
499+
sed_obs_err = np.ones_like(sed_obs, np.float) * sed_obs_err
500+
501+
# select good bands
502+
ind_good_band = np.isfinite(sed_obs) & (sed_obs_err > 0)
503+
n_good_band = np.sum(ind_good_band)
504+
if n_good_band < 4:
505+
# n_good_band = 3: unique solution
506+
# so n_good_band should be at least 4
507+
return [np.ones((4,), ) * np.nan for i in range(3)]
508+
509+
# use a subset of bands
510+
sed_mod_select = sed_mod[:, ind_good_band]
511+
# observed SED
512+
sed_obs_select = sed_obs[ind_good_band]
513+
sed_obs_err_select = sed_obs_err[ind_good_band]
514+
# extinction coefs
515+
Alambda_select = Alambda[ind_good_band]
516+
517+
# WLS to guess Av and DM
518+
av_est, dm_est = guess_avdm_wls(
519+
sed_mod_select, sed_obs_select, sed_obs_err_select, Alambda_select)
520+
521+
# cost(SED)
522+
res_sed = sed_mod_select + av_est.reshape(-1, 1) * Alambda_select \
523+
+ dm_est.reshape(-1, 1) - sed_obs_select
524+
lnprob_sed = -0.5 * np.nansum(
525+
np.abs(res_sed / sed_obs_err) ** cost_order, axis=1)
526+
527+
# cost(VPI)
528+
if vpi_obs is not None and vpi_obs_err is not None and Lvpi > 0:
529+
vpi_mod = 10 ** (2 - 0.2 * dm_est)
530+
lnprob_vpi = -0.5 * ((vpi_mod - vpi_obs) / vpi_obs_err) ** 2.
531+
else:
532+
lnprob_vpi = np.zeros((n_mod,), np.float)
533+
lnprob_vpi = np.where(np.isfinite(lnprob_vpi), lnprob_vpi, 0) * Lvpi
534+
535+
# lnprob = cost(SED) + cost(VPI) + prior
536+
if Lprior > 0:
537+
lnprob_prior = lnprior * Lprior
538+
539+
# posterior probability
540+
lnpost = lnprob_sed + lnprob_vpi + lnprob_prior
541+
# eliminate neg Av
542+
lnpost[av_est < av_llim] = -np.inf
543+
lnpost -= np.nanmax(lnpost)
544+
545+
# for debugging the code
546+
if debug:
547+
return dict(params=params,
548+
av_est=av_est,
549+
dm_est=dm_est,
550+
lnprob_sed=lnprob_sed,
551+
lnprob_vpi=lnprob_vpi,
552+
lnprior=lnprior)
553+
554+
# normalization
555+
post = np.exp(lnpost)
556+
L0 = np.sum(post)
557+
558+
# weighted mean
559+
# ind_mle = np.argmax(lnpost)
560+
# av_mle = av_est[ind_mle]
561+
# dm_mle = dm_est[ind_mle]
562+
# p_mle = params[ind_mle]
563+
564+
L1_av = np.sum(av_est * post)
565+
L1_dm = np.sum(dm_est * post)
566+
L1_p = np.sum(params * post.reshape(-1, 1), axis=0)
567+
568+
L2_av = np.sum(av_est ** 2 * post)
569+
L2_dm = np.sum(dm_est ** 2 * post)
570+
L2_p = np.sum(params ** 2 * post.reshape(-1, 1), axis=0)
571+
572+
sigma_av = np.sqrt(L2_av / L0 - L1_av ** 2 / L0 ** 2)
573+
sigma_dm = np.sqrt(L2_dm / L0 - L1_dm ** 2 / L0 ** 2)
574+
sigma_p = np.sqrt(L2_p / L0 - L1_p ** 2 / L0 ** 2)
575+
576+
# MLE model
577+
ind_mle = np.argmax(lnprob_sed + lnprob_vpi)
578+
av_mle = av_est[ind_mle]
579+
dm_mle = dm_est[ind_mle]
580+
p_mle = params[ind_mle]
581+
582+
p_mle = np.hstack([p_mle, av_mle, dm_mle])
583+
p_mean = np.hstack([L1_p/L0, L1_av/L0, L1_dm/L0])
584+
p_err = np.hstack([sigma_p, sigma_av, sigma_dm])
585+
586+
rms_sed_mle = np.sqrt(np.nanmean(res_sed[ind_mle] ** 2.))
587+
rms_sed_min = np.min(np.sqrt(np.nanmean(res_sed ** 2., axis=1)))
588+
589+
return dict(p_mle=p_mle,
590+
p_mean=p_mean,
591+
p_err=p_err,
592+
rmsmle=rms_sed_mle,
593+
rmsmin=rms_sed_min,
594+
ind_mle=np.ind,
595+
n_good=np.sum(ind_good_band))

0 commit comments

Comments
 (0)