Skip to content

Commit 7ae88db

Browse files
authored
Merge pull request #754 from EducationalTestingService/fix-yaml
Fix calls to yaml.load
2 parents b10ce39 + 6566dbc commit 7ae88db

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ select = ["D", "E", "F", "I"]
88
ignore = ["D212"]
99
line-length = 100
1010
target-version = "py38"
11+
fix = true

skll/config/__init__.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1919

2020
import numpy as np
21-
import ruamel.yaml as yaml
21+
from ruamel.yaml import YAML
2222

2323
from skll.data.readers import safe_float
2424
from skll.types import ClassMap, FoldMapping, LabelType, PathOrStr
@@ -610,7 +610,9 @@ def parse_config_file(
610610
raise ValueError(
611611
"Configuration file does not contain list of learners " "in [Input] section."
612612
)
613-
learners = yaml.safe_load(fix_json(learners_string))
613+
614+
yaml = YAML(typ="safe", pure=True)
615+
learners = yaml.load(fix_json(learners_string))
614616

615617
if len(learners) == 0:
616618
raise ValueError(
@@ -630,7 +632,7 @@ def parse_config_file(
630632
custom_metric_path = locate_file(config.get("Input", "custom_metric_path"), config_dir)
631633

632634
# get the featuresets
633-
featuresets = yaml.safe_load(config.get("Input", "featuresets"))
635+
featuresets = yaml.load(config.get("Input", "featuresets"))
634636

635637
# ensure that featuresets is either a list of features or a list of lists
636638
# of features
@@ -641,7 +643,7 @@ def parse_config_file(
641643
f"specified: {featuresets}"
642644
)
643645

644-
featureset_names = yaml.safe_load(fix_json(config.get("Input", "featureset_names")))
646+
featureset_names = yaml.load(fix_json(config.get("Input", "featureset_names")))
645647

646648
# ensure that featureset_names is a list of strings, if specified
647649
if featureset_names:
@@ -658,7 +660,7 @@ def parse_config_file(
658660
# learners. If it's not specified, then we just assume
659661
# that we are using 10 folds for each learner.
660662
learning_curve_cv_folds_list_string = config.get("Input", "learning_curve_cv_folds_list")
661-
learning_curve_cv_folds_list = yaml.safe_load(fix_json(learning_curve_cv_folds_list_string))
663+
learning_curve_cv_folds_list = yaml.load(fix_json(learning_curve_cv_folds_list_string))
662664
if len(learning_curve_cv_folds_list) == 0:
663665
learning_curve_cv_folds_list = [10] * len(learners)
664666
else:
@@ -679,7 +681,7 @@ def parse_config_file(
679681
# floats (proportions). If it's not specified, then we just
680682
# assume that we are using np.linspace(0.1, 1.0, 5).
681683
learning_curve_train_sizes_string = config.get("Input", "learning_curve_train_sizes")
682-
learning_curve_train_sizes = yaml.safe_load(fix_json(learning_curve_train_sizes_string))
684+
learning_curve_train_sizes = yaml.load(fix_json(learning_curve_train_sizes_string))
683685
if len(learning_curve_train_sizes) == 0:
684686
learning_curve_train_sizes = np.linspace(0.1, 1.0, 5).tolist()
685687
else:
@@ -698,9 +700,9 @@ def parse_config_file(
698700
# do we need to shuffle the training data
699701
do_shuffle = config.getboolean("Input", "shuffle")
700702

701-
fixed_parameter_list = yaml.safe_load(fix_json(config.get("Input", "fixed_parameters")))
702-
fixed_sampler_parameters = yaml.safe_load(fix_json(config.get("Input", "sampler_parameters")))
703-
param_grid_list = yaml.safe_load(fix_json(config.get("Tuning", "param_grids")))
703+
fixed_parameter_list = yaml.load(fix_json(config.get("Input", "fixed_parameters")))
704+
fixed_sampler_parameters = yaml.load(fix_json(config.get("Input", "sampler_parameters")))
705+
param_grid_list = yaml.load(fix_json(config.get("Tuning", "param_grids")))
704706

705707
# read and normalize the value of `pos_label`
706708
pos_label_string = safe_float(config.get("Tuning", "pos_label"))
@@ -804,7 +806,8 @@ def parse_config_file(
804806

805807
# Get class mapping dictionary if specified
806808
class_map_string = config.get("Input", "class_map")
807-
original_class_map = yaml.safe_load(fix_json(class_map_string))
809+
yaml = YAML(typ="safe", pure=True)
810+
original_class_map = yaml.load(fix_json(class_map_string))
808811
if original_class_map:
809812
# Change class_map to map from originals to replacements instead of
810813
# from replacement to list of originals

skll/config/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pathlib import Path
1414
from typing import Iterable, List, Union
1515

16-
import ruamel.yaml as yaml
16+
from ruamel.yaml import YAML
1717

1818
from skll.types import FoldMapping, PathOrStr
1919

@@ -186,7 +186,8 @@ def _parse_and_validate_metrics(metrics: str, option_name: str, logger=None) ->
186186

187187
# make sure the given metrics data type is a list
188188
# and parse it correctly
189-
metrics = yaml.safe_load(fix_json(metrics))
189+
yaml = YAML(typ="safe", pure=True)
190+
metrics = yaml.load(fix_json(metrics))
190191
if not isinstance(metrics, list):
191192
raise TypeError(f"{option_name} should be a list, not a " f"{type(metrics)}.")
192193

skll/experiments/output.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import matplotlib.pyplot as plt
2323
import numpy as np
2424
import pandas as pd
25-
import ruamel.yaml as yaml
2625
import seaborn as sns
26+
from ruamel.yaml import YAML
2727

2828
from skll.types import FoldMapping, PathOrStr
2929
from skll.utils.logging import get_skll_logger
@@ -638,6 +638,8 @@ def _write_summary_file(result_json_paths: List[str], output_file: IO[str], abla
638638
# Map from feature set names to all features in them
639639
all_features = defaultdict(set)
640640
logger = get_skll_logger("experiment")
641+
yaml = YAML(typ="safe", pure=True)
642+
641643
for json_path_str in result_json_paths:
642644
json_path = Path(json_path_str)
643645
if not json_path.exists():
@@ -654,7 +656,7 @@ def _write_summary_file(result_json_paths: List[str], output_file: IO[str], abla
654656
featureset_name = obj[0]["featureset_name"]
655657
if ablation != 0 and "_minus_" in featureset_name:
656658
parent_set = featureset_name.split("_minus_", 1)[0]
657-
all_features[parent_set].update(yaml.safe_load(obj[0]["featureset"]))
659+
all_features[parent_set].update(yaml.load(obj[0]["featureset"]))
658660
learner_result_dicts.extend(obj)
659661

660662
# Build and write header
@@ -670,9 +672,7 @@ def _write_summary_file(result_json_paths: List[str], output_file: IO[str], abla
670672
featureset_name = lrd["featureset_name"]
671673
if ablation != 0:
672674
parent_set = featureset_name.split("_minus_", 1)[0]
673-
ablated_features = all_features[parent_set].difference(
674-
yaml.safe_load(lrd["featureset"])
675-
)
675+
ablated_features = all_features[parent_set].difference(yaml.load(lrd["featureset"]))
676676
lrd["ablated_features"] = ""
677677
if ablated_features:
678678
lrd["ablated_features"] = json.dumps(sorted(ablated_features))

0 commit comments

Comments
 (0)