@@ -373,7 +373,7 @@ def general_search(params, sed_mod, lnprior,
373
373
test_sed_obs , test_sed_obs_err = None ,
374
374
test_vpi_obs = None , test_vpi_obs_err = None ,
375
375
Lvpi = 1.0 , Lprior = 1.0 , sed_err_typical = 0.1 , cost_order = 2 ,
376
- av_llim = 0. , return_est = False ):
376
+ av_llim = 0. , debug = False ):
377
377
"""
378
378
when p = [T, G, Av, DM],
379
379
given a set of SED,
@@ -440,7 +440,7 @@ def general_search(params, sed_mod, lnprior,
440
440
lnprob [av_est < av_llim ] = - np .inf
441
441
lnprob -= np .nanmax (lnprob )
442
442
443
- if return_est :
443
+ if debug :
444
444
return params , av_est , dm_est , cost_sed , lnprob
445
445
446
446
# normalization
@@ -477,3 +477,119 @@ def general_search(params, sed_mod, lnprior,
477
477
ind_mle = ind_mle ,
478
478
n_good = np .sum (ind_good_band )
479
479
)
480
+
481
+
482
+ def general_search_v2 (params , sed_mod , lnprior ,
483
+ Alambda ,
484
+ sed_obs , sed_obs_err = 0.1 ,
485
+ vpi_obs = None , vpi_obs_err = None ,
486
+ Lvpi = 1.0 , Lprior = 1.0 , sed_err_typical = 0.1 , cost_order = 2 ,
487
+ av_llim = 0. , debug = False ):
488
+ """
489
+ when p = [teff, logg, [M/H], Av, DM], theta = [teff, logg, [M/H]],
490
+ given a set of SED,
491
+ find the best theta and estimate the corresponding Av and DM
492
+ """
493
+
494
+ n_band = len (sed_obs )
495
+ n_mod = sed_mod .shape [0 ]
496
+
497
+ # cope with scalar sed_obs_err
498
+ if isinstance (sed_obs_err , np .float ):
499
+ sed_obs_err = np .ones_like (sed_obs , np .float ) * sed_obs_err
500
+
501
+ # select good bands
502
+ ind_good_band = np .isfinite (sed_obs ) & (sed_obs_err > 0 )
503
+ n_good_band = np .sum (ind_good_band )
504
+ if n_good_band < 4 :
505
+ # n_good_band = 3: unique solution
506
+ # so n_good_band should be at least 4
507
+ return [np .ones ((4 ,), ) * np .nan for i in range (3 )]
508
+
509
+ # use a subset of bands
510
+ sed_mod_select = sed_mod [:, ind_good_band ]
511
+ # observed SED
512
+ sed_obs_select = sed_obs [ind_good_band ]
513
+ sed_obs_err_select = sed_obs_err [ind_good_band ]
514
+ # extinction coefs
515
+ Alambda_select = Alambda [ind_good_band ]
516
+
517
+ # WLS to guess Av and DM
518
+ av_est , dm_est = guess_avdm_wls (
519
+ sed_mod_select , sed_obs_select , sed_obs_err_select , Alambda_select )
520
+
521
+ # cost(SED)
522
+ res_sed = sed_mod_select + av_est .reshape (- 1 , 1 ) * Alambda_select \
523
+ + dm_est .reshape (- 1 , 1 ) - sed_obs_select
524
+ lnprob_sed = - 0.5 * np .nansum (
525
+ np .abs (res_sed / sed_obs_err ) ** cost_order , axis = 1 )
526
+
527
+ # cost(VPI)
528
+ if vpi_obs is not None and vpi_obs_err is not None and Lvpi > 0 :
529
+ vpi_mod = 10 ** (2 - 0.2 * dm_est )
530
+ lnprob_vpi = - 0.5 * ((vpi_mod - vpi_obs ) / vpi_obs_err ) ** 2.
531
+ else :
532
+ lnprob_vpi = np .zeros ((n_mod ,), np .float )
533
+ lnprob_vpi = np .where (np .isfinite (lnprob_vpi ), lnprob_vpi , 0 ) * Lvpi
534
+
535
+ # lnprob = cost(SED) + cost(VPI) + prior
536
+ if Lprior > 0 :
537
+ lnprob_prior = lnprior * Lprior
538
+
539
+ # posterior probability
540
+ lnpost = lnprob_sed + lnprob_vpi + lnprob_prior
541
+ # eliminate neg Av
542
+ lnpost [av_est < av_llim ] = - np .inf
543
+ lnpost -= np .nanmax (lnpost )
544
+
545
+ # for debugging the code
546
+ if debug :
547
+ return dict (params = params ,
548
+ av_est = av_est ,
549
+ dm_est = dm_est ,
550
+ lnprob_sed = lnprob_sed ,
551
+ lnprob_vpi = lnprob_vpi ,
552
+ lnprior = lnprior )
553
+
554
+ # normalization
555
+ post = np .exp (lnpost )
556
+ L0 = np .sum (post )
557
+
558
+ # weighted mean
559
+ # ind_mle = np.argmax(lnpost)
560
+ # av_mle = av_est[ind_mle]
561
+ # dm_mle = dm_est[ind_mle]
562
+ # p_mle = params[ind_mle]
563
+
564
+ L1_av = np .sum (av_est * post )
565
+ L1_dm = np .sum (dm_est * post )
566
+ L1_p = np .sum (params * post .reshape (- 1 , 1 ), axis = 0 )
567
+
568
+ L2_av = np .sum (av_est ** 2 * post )
569
+ L2_dm = np .sum (dm_est ** 2 * post )
570
+ L2_p = np .sum (params ** 2 * post .reshape (- 1 , 1 ), axis = 0 )
571
+
572
+ sigma_av = np .sqrt (L2_av / L0 - L1_av ** 2 / L0 ** 2 )
573
+ sigma_dm = np .sqrt (L2_dm / L0 - L1_dm ** 2 / L0 ** 2 )
574
+ sigma_p = np .sqrt (L2_p / L0 - L1_p ** 2 / L0 ** 2 )
575
+
576
+ # MLE model
577
+ ind_mle = np .argmax (lnprob_sed + lnprob_vpi )
578
+ av_mle = av_est [ind_mle ]
579
+ dm_mle = dm_est [ind_mle ]
580
+ p_mle = params [ind_mle ]
581
+
582
+ p_mle = np .hstack ([p_mle , av_mle , dm_mle ])
583
+ p_mean = np .hstack ([L1_p / L0 , L1_av / L0 , L1_dm / L0 ])
584
+ p_err = np .hstack ([sigma_p , sigma_av , sigma_dm ])
585
+
586
+ rms_sed_mle = np .sqrt (np .nanmean (res_sed [ind_mle ] ** 2. ))
587
+ rms_sed_min = np .min (np .sqrt (np .nanmean (res_sed ** 2. , axis = 1 )))
588
+
589
+ return dict (p_mle = p_mle ,
590
+ p_mean = p_mean ,
591
+ p_err = p_err ,
592
+ rmsmle = rms_sed_mle ,
593
+ rmsmin = rms_sed_min ,
594
+ ind_mle = np .ind ,
595
+ n_good = np .sum (ind_good_band ))
0 commit comments