18
18
from typing import Any , Dict , List , Optional , Set , Tuple , Union
19
19
20
20
import numpy as np
21
- import ruamel .yaml as yaml
21
+ from ruamel .yaml import YAML
22
22
23
23
from skll .data .readers import safe_float
24
24
from skll .types import ClassMap , FoldMapping , LabelType , PathOrStr
@@ -610,7 +610,9 @@ def parse_config_file(
610
610
raise ValueError (
611
611
"Configuration file does not contain list of learners " "in [Input] section."
612
612
)
613
- learners = yaml .safe_load (fix_json (learners_string ))
613
+
614
+ yaml = YAML (typ = "safe" , pure = True )
615
+ learners = yaml .load (fix_json (learners_string ))
614
616
615
617
if len (learners ) == 0 :
616
618
raise ValueError (
@@ -630,7 +632,7 @@ def parse_config_file(
630
632
custom_metric_path = locate_file (config .get ("Input" , "custom_metric_path" ), config_dir )
631
633
632
634
# get the featuresets
633
- featuresets = yaml .safe_load (config .get ("Input" , "featuresets" ))
635
+ featuresets = yaml .load (config .get ("Input" , "featuresets" ))
634
636
635
637
# ensure that featuresets is either a list of features or a list of lists
636
638
# of features
@@ -641,7 +643,7 @@ def parse_config_file(
641
643
f"specified: { featuresets } "
642
644
)
643
645
644
- featureset_names = yaml .safe_load (fix_json (config .get ("Input" , "featureset_names" )))
646
+ featureset_names = yaml .load (fix_json (config .get ("Input" , "featureset_names" )))
645
647
646
648
# ensure that featureset_names is a list of strings, if specified
647
649
if featureset_names :
@@ -658,7 +660,7 @@ def parse_config_file(
658
660
# learners. If it's not specified, then we just assume
659
661
# that we are using 10 folds for each learner.
660
662
learning_curve_cv_folds_list_string = config .get ("Input" , "learning_curve_cv_folds_list" )
661
- learning_curve_cv_folds_list = yaml .safe_load (fix_json (learning_curve_cv_folds_list_string ))
663
+ learning_curve_cv_folds_list = yaml .load (fix_json (learning_curve_cv_folds_list_string ))
662
664
if len (learning_curve_cv_folds_list ) == 0 :
663
665
learning_curve_cv_folds_list = [10 ] * len (learners )
664
666
else :
@@ -679,7 +681,7 @@ def parse_config_file(
679
681
# floats (proportions). If it's not specified, then we just
680
682
# assume that we are using np.linspace(0.1, 1.0, 5).
681
683
learning_curve_train_sizes_string = config .get ("Input" , "learning_curve_train_sizes" )
682
- learning_curve_train_sizes = yaml .safe_load (fix_json (learning_curve_train_sizes_string ))
684
+ learning_curve_train_sizes = yaml .load (fix_json (learning_curve_train_sizes_string ))
683
685
if len (learning_curve_train_sizes ) == 0 :
684
686
learning_curve_train_sizes = np .linspace (0.1 , 1.0 , 5 ).tolist ()
685
687
else :
@@ -698,9 +700,9 @@ def parse_config_file(
698
700
# do we need to shuffle the training data
699
701
do_shuffle = config .getboolean ("Input" , "shuffle" )
700
702
701
- fixed_parameter_list = yaml .safe_load (fix_json (config .get ("Input" , "fixed_parameters" )))
702
- fixed_sampler_parameters = yaml .safe_load (fix_json (config .get ("Input" , "sampler_parameters" )))
703
- param_grid_list = yaml .safe_load (fix_json (config .get ("Tuning" , "param_grids" )))
703
+ fixed_parameter_list = yaml .load (fix_json (config .get ("Input" , "fixed_parameters" )))
704
+ fixed_sampler_parameters = yaml .load (fix_json (config .get ("Input" , "sampler_parameters" )))
705
+ param_grid_list = yaml .load (fix_json (config .get ("Tuning" , "param_grids" )))
704
706
705
707
# read and normalize the value of `pos_label`
706
708
pos_label_string = safe_float (config .get ("Tuning" , "pos_label" ))
@@ -804,7 +806,8 @@ def parse_config_file(
804
806
805
807
# Get class mapping dictionary if specified
806
808
class_map_string = config .get ("Input" , "class_map" )
807
- original_class_map = yaml .safe_load (fix_json (class_map_string ))
809
+ yaml = YAML (typ = "safe" , pure = True )
810
+ original_class_map = yaml .load (fix_json (class_map_string ))
808
811
if original_class_map :
809
812
# Change class_map to map from originals to replacements instead of
810
813
# from replacement to list of originals
0 commit comments