Skip to content

Commit c1463df

Browse files
committed
fix(training-operator)
harden HuggingFace training_parameters parsing Signed-off-by: Ayush-kathil <kathilshiva@gmail.com>
1 parent e4705d7 commit c1463df

2 files changed

Lines changed: 243 additions & 0 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2024 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import argparse
18+
import json
19+
import logging
20+
from typing import Any, Optional
21+
22+
try:
23+
from transformers import TrainingArguments
24+
except ImportError: # pragma: no cover - exercised only when transformers is absent.
25+
26+
class TrainingArguments: # type: ignore[no-redef]
27+
def __init__(self, *args: Any, **kwargs: Any) -> None:
28+
raise ImportError(
29+
"transformers is required to construct HuggingFace TrainingArguments."
30+
)
31+
32+
33+
logger = logging.getLogger(__name__)
34+
35+
DEFAULT_OUTPUT_DIR = "./output"
36+
37+
38+
def parse_training_args(raw: Optional[str]) -> dict[str, Any]:
39+
"""Parse a JSON string into a TrainingArguments configuration."""
40+
41+
if raw is None:
42+
return {}
43+
44+
if not isinstance(raw, str):
45+
raise ValueError(
46+
"training_parameters must be a JSON string or None; got "
47+
f"{type(raw).__name__}."
48+
)
49+
50+
normalized = raw.strip()
51+
if not normalized:
52+
return {}
53+
54+
try:
55+
parsed = json.loads(normalized)
56+
except json.JSONDecodeError as exc:
57+
raise ValueError(
58+
"Invalid JSON in training_parameters. Provide a JSON object string, for "
59+
f'example \'{{"output_dir": "./output"}}\'. Received: {raw!r}. '
60+
f"JSON error: {exc.msg} at line {exc.lineno}, column {exc.colno}."
61+
) from exc
62+
63+
if not isinstance(parsed, dict):
64+
raise ValueError(
65+
"training_parameters must decode to a JSON object. Received "
66+
f"{type(parsed).__name__}: {parsed!r}."
67+
)
68+
69+
invalid_keys = [
70+
key for key in parsed.keys() if not isinstance(key, str) or not key.strip()
71+
]
72+
if invalid_keys:
73+
raise ValueError(
74+
"training_parameters contains invalid keys. JSON object keys must be non-empty "
75+
f"strings. Invalid keys: {invalid_keys!r}."
76+
)
77+
78+
return parsed
79+
80+
81+
def build_training_arguments(raw: Optional[str]) -> TrainingArguments:
82+
logger.info("Raw training_parameters payload: %r", raw)
83+
parsed_config = parse_training_args(raw)
84+
85+
if not parsed_config:
86+
logger.info(
87+
"training_parameters is empty or missing; using default "
88+
"TrainingArguments with output_dir=%s",
89+
DEFAULT_OUTPUT_DIR,
90+
)
91+
return TrainingArguments(output_dir=DEFAULT_OUTPUT_DIR)
92+
93+
logger.info(
94+
"Parsed training_parameters config: %s",
95+
json.dumps(parsed_config, sort_keys=True),
96+
)
97+
try:
98+
return TrainingArguments(**parsed_config)
99+
except Exception as exc:
100+
logger.error(
101+
"Failed to create TrainingArguments from parsed training_parameters: %s",
102+
json.dumps(parsed_config, sort_keys=True),
103+
exc_info=True,
104+
)
105+
raise ValueError(
106+
"Failed to initialize TrainingArguments from training_parameters. "
107+
"Check the JSON keys and values, and ensure they match the HuggingFace "
108+
"TrainingArguments signature. Parsed config: "
109+
f"{json.dumps(parsed_config, sort_keys=True)}"
110+
) from exc
111+
112+
113+
def _build_parser() -> argparse.ArgumentParser:
114+
parser = argparse.ArgumentParser(description="Run a HuggingFace training job.")
115+
parser.add_argument(
116+
"--training_parameters",
117+
type=str,
118+
default="{}",
119+
help="JSON object used to initialize HuggingFace TrainingArguments.",
120+
)
121+
return parser
122+
123+
124+
def main() -> None:
125+
logging.basicConfig(
126+
level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s"
127+
)
128+
parser = _build_parser()
129+
args = parser.parse_args()
130+
131+
training_args = build_training_arguments(args.training_parameters)
132+
logger.info("TrainingArguments initialized successfully: %s", training_args)
133+
134+
# Replace this with the actual training workflow used by the example.
135+
logger.info("Trainer entrypoint completed parsing and initialization only.")
136+
137+
138+
if __name__ == "__main__":
139+
main()
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2024 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import importlib.util
18+
from pathlib import Path
19+
20+
import pytest
21+
22+
SCRIPT_PATH = (
23+
Path(__file__).resolve().parents[4]
24+
/ "examples"
25+
/ "v1beta1"
26+
/ "kubeflow-training-operator"
27+
/ "hf_llm_training.py"
28+
)
29+
30+
31+
class DummyTrainingArguments:
32+
def __init__(self, **kwargs):
33+
self.kwargs = kwargs
34+
35+
36+
def load_module():
37+
spec = importlib.util.spec_from_file_location("hf_llm_training", SCRIPT_PATH)
38+
module = importlib.util.module_from_spec(spec)
39+
assert spec.loader is not None
40+
spec.loader.exec_module(module)
41+
return module
42+
43+
44+
def test_parse_training_args_empty_string_returns_empty_dict():
45+
module = load_module()
46+
47+
assert module.parse_training_args("") == {}
48+
49+
50+
def test_parse_training_args_none_returns_empty_dict():
51+
module = load_module()
52+
53+
assert module.parse_training_args(None) == {}
54+
55+
56+
def test_parse_training_args_whitespace_returns_empty_dict():
57+
module = load_module()
58+
59+
assert module.parse_training_args(" \n\t ") == {}
60+
61+
62+
def test_parse_training_args_valid_json_returns_dict():
63+
module = load_module()
64+
65+
assert module.parse_training_args(
66+
'{"output_dir": "./output", "learning_rate": 0.0001}'
67+
) == {
68+
"output_dir": "./output",
69+
"learning_rate": 0.0001,
70+
}
71+
72+
73+
def test_parse_training_args_invalid_json_raises_value_error():
74+
module = load_module()
75+
76+
with pytest.raises(ValueError, match="Invalid JSON in training_parameters"):
77+
module.parse_training_args("{invalid-json")
78+
79+
80+
def test_parse_training_args_malformed_keys_raises_value_error():
81+
module = load_module()
82+
83+
with pytest.raises(ValueError, match="invalid keys"):
84+
module.parse_training_args('{"": 1, " ": 2}')
85+
86+
87+
def test_build_training_arguments_uses_default_when_empty(monkeypatch):
88+
module = load_module()
89+
monkeypatch.setattr(module, "TrainingArguments", DummyTrainingArguments)
90+
91+
training_args = module.build_training_arguments("")
92+
93+
assert training_args.kwargs == {"output_dir": "./output"}
94+
95+
96+
def test_build_training_arguments_passes_valid_config(monkeypatch):
97+
module = load_module()
98+
monkeypatch.setattr(module, "TrainingArguments", DummyTrainingArguments)
99+
100+
training_args = module.build_training_arguments(
101+
'{"output_dir": "./tmp", "num_train_epochs": 3}'
102+
)
103+
104+
assert training_args.kwargs == {"output_dir": "./tmp", "num_train_epochs": 3}

0 commit comments

Comments
 (0)