File tree 1 file changed +6
-5
lines changed
1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change 2
2
import numpy as np
3
3
import harmonic as hm
4
4
import getdist
5
- from harmonic import model_legacy
6
- from getdist import plots
7
5
import matplotlib as plt
8
6
9
7
@@ -263,7 +261,7 @@ def cross_validation(
263
261
domains : List ,
264
262
hyper_parameters : List ,
265
263
nfold = 2 ,
266
- modelClass = model_legacy . KernelDensityEstimate ,
264
+ modelClass = None ,
267
265
seed : int = - 1 ,
268
266
) -> List :
269
267
"""Perform n-fold validation for given model using chains to be split into
@@ -285,8 +283,8 @@ def cross_validation(
285
283
hyper_parameters (List): List of hyper_parameters where each entry is a
286
284
hyper_parameter list to be considered.
287
285
288
- modelClass (Model): Model that is being cross validated (default =
289
- KernelDensityEstimate).
286
+ modelClass (Model): Model that is being cross validated (defaults to
287
+ KernelDensityEstimate inside function ).
290
288
291
289
seed (int): Seed for random number generator when drawing the chains
292
290
(if this is negative the seed is not set).
@@ -301,6 +299,9 @@ def cross_validation(
301
299
302
300
"""
303
301
302
+ if modelClass is None :
303
+ modelClass = hm .model_legacy .KernelDensityEstimate
304
+
304
305
ln_validation_variances = np .zeros ((nfold , len (hyper_parameters )))
305
306
306
307
if seed > 0 :
You can’t perform that action at this time.
0 commit comments