@@ -502,9 +502,13 @@ <h1>Source code for botorch.acquisition.logei</h1><div class="highlight"><pre>
502
502
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _cache_root</ span > < span class ="p "> :</ span >
503
503
< span class ="n "> samples_full</ span > < span class ="o "> =</ span > < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="n "> get_posterior_samples</ span > < span class ="p "> (</ span > < span class ="n "> posterior</ span > < span class ="p "> )</ span >
504
504
< span class ="n "> obj_full</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> objective</ span > < span class ="p "> (</ span > < span class ="n "> samples_full</ span > < span class ="p "> ,</ span > < span class ="n "> X</ span > < span class ="o "> =</ span > < span class ="n "> X_full</ span > < span class ="p "> )</ span >
505
+ < span class ="c1 "> # Calculate the positive index for splitting the samples & objective values.</ span >
506
+ < span class ="n "> split_dim</ span > < span class ="o "> =</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> obj_full</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> )</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span >
505
507
< span class ="c1 "> # assigning baseline buffers so `best_f` can be computed in _sample_forward</ span >
506
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> baseline_samples</ span > < span class ="p "> ,</ span > < span class ="n "> samples</ span > < span class ="o "> =</ span > < span class ="n "> samples_full</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> ([</ span > < span class ="n "> n_baseline</ span > < span class ="p "> ,</ span > < span class ="n "> q</ span > < span class ="p "> ],</ span > < span class ="n "> dim</ span > < span class ="o "> =-</ span > < span class ="mi "> 2</ span > < span class ="p "> )</ span >
507
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> baseline_obj</ span > < span class ="p "> ,</ span > < span class ="n "> obj</ span > < span class ="o "> =</ span > < span class ="n "> obj_full</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> ([</ span > < span class ="n "> n_baseline</ span > < span class ="p "> ,</ span > < span class ="n "> q</ span > < span class ="p "> ],</ span > < span class ="n "> dim</ span > < span class ="o "> =-</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
508
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> baseline_samples</ span > < span class ="p "> ,</ span > < span class ="n "> samples</ span > < span class ="o "> =</ span > < span class ="n "> samples_full</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> (</ span >
509
+ < span class ="p "> [</ span > < span class ="n "> n_baseline</ span > < span class ="p "> ,</ span > < span class ="n "> q</ span > < span class ="p "> ],</ span > < span class ="n "> dim</ span > < span class ="o "> =</ span > < span class ="n "> split_dim</ span >
510
+ < span class ="p "> )</ span >
511
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> baseline_obj</ span > < span class ="p "> ,</ span > < span class ="n "> obj</ span > < span class ="o "> =</ span > < span class ="n "> obj_full</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> ([</ span > < span class ="n "> n_baseline</ span > < span class ="p "> ,</ span > < span class ="n "> q</ span > < span class ="p "> ],</ span > < span class ="n "> dim</ span > < span class ="o "> =</ span > < span class ="n "> split_dim</ span > < span class ="p "> )</ span >
508
512
< span class ="k "> return</ span > < span class ="n "> samples</ span > < span class ="p "> ,</ span > < span class ="n "> obj</ span >
509
513
510
514
< span class ="c1 "> # handle one-to-many input transforms</ span >
0 commit comments