Skip to content

Commit 5191dc8

Browse files
Move model legacy import to inside function.
1 parent 86e4f1f commit 5191dc8

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

harmonic/utils.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import numpy as np
33
import harmonic as hm
44
import getdist
5-
from harmonic import model_legacy
6-
from getdist import plots
75
import matplotlib as plt
86

97

@@ -263,7 +261,7 @@ def cross_validation(
263261
domains: List,
264262
hyper_parameters: List,
265263
nfold=2,
266-
modelClass=model_legacy.KernelDensityEstimate,
264+
modelClass=None,
267265
seed: int = -1,
268266
) -> List:
269267
"""Perform n-fold validation for given model using chains to be split into
@@ -285,8 +283,8 @@ def cross_validation(
285283
hyper_parameters (List): List of hyper_parameters where each entry is a
286284
hyper_parameter list to be considered.
287285
288-
modelClass (Model): Model that is being cross validated (default =
289-
KernelDensityEstimate).
286+
modelClass (Model): Model that is being cross validated (defaults to
287+
KernelDensityEstimate inside function).
290288
291289
seed (int): Seed for random number generator when drawing the chains
292290
(if this is negative the seed is not set).
@@ -301,6 +299,9 @@ def cross_validation(
301299
302300
"""
303301

302+
if modelClass is None:
303+
modelClass = hm.model_legacy.KernelDensityEstimate
304+
304305
ln_validation_variances = np.zeros((nfold, len(hyper_parameters)))
305306

306307
if seed > 0:

0 commit comments

Comments
 (0)