Skip to content

Commit 795c0c3

Browse files
chore(refactor): Add Config class (#74)
1 parent f92052a commit 795c0c3

File tree

7 files changed

+734
-229
lines changed

7 files changed

+734
-229
lines changed

src/rapids_dependency_file_generator/cli.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
import os
33
import warnings
44

5-
import yaml
6-
75
from ._version import __version__ as version
8-
from .constants import OutputTypes, default_channels, default_dependency_file_path
6+
from .config import Output, load_config_from_file
7+
from .constants import default_dependency_file_path
98
from .rapids_dependency_file_generator import (
109
delete_existing_files,
1110
make_dependency_files,
1211
)
13-
from .rapids_dependency_file_validator import validate_dependencies
1412

1513

1614
def validate_args(argv):
@@ -49,11 +47,11 @@ def validate_args(argv):
4947
"--output",
5048
help="The output file type to generate.",
5149
choices=[
52-
str(x)
50+
x.value
5351
for x in [
54-
OutputTypes.CONDA,
55-
OutputTypes.PYPROJECT,
56-
OutputTypes.REQUIREMENTS,
52+
Output.CONDA,
53+
Output.PYPROJECT,
54+
Output.REQUIREMENTS,
5755
]
5856
],
5957
)
@@ -73,7 +71,7 @@ def validate_args(argv):
7371
help=(
7472
"A string representing a conda channel to prepend to the list of "
7573
"channels. This option is only valid with --output "
76-
f"{OutputTypes.CONDA} or no --output. May be specified multiple times."
74+
f"{Output.CONDA.value} or no --output. May be specified multiple times."
7775
),
7876
)
7977
parser.add_argument(
@@ -83,7 +81,7 @@ def validate_args(argv):
8381
"A string representing a list of conda channels to prepend to the list of "
8482
"channels. Channels should be separated by a semicolon, such as "
8583
'`--prepend-channels "my_channel;my_other_channel"`. This option is '
86-
f"only valid with --output {OutputTypes.CONDA} or no --output. "
84+
f"only valid with --output {Output.CONDA.value} or no --output. "
8785
"DEPRECATED: Use --prepend-channel instead."
8886
),
8987
)
@@ -117,9 +115,9 @@ def validate_args(argv):
117115
"The use of --prepend-channels is deprecated. Use --prepend-channel instead."
118116
)
119117
args.prepend_channels = args.prepend_channels_deprecated.split(";")
120-
if args.prepend_channels and args.output and args.output != str(OutputTypes.CONDA):
118+
if args.prepend_channels and args.output and args.output != Output.CONDA.value:
121119
raise ValueError(
122-
f"--prepend-channel is only valid with --output {OutputTypes.CONDA}"
120+
f"--prepend-channel is only valid with --output {Output.CONDA.value}"
123121
)
124122

125123
# If --clean was passed without arguments, default to cleaning from the root of the
@@ -132,7 +130,7 @@ def validate_args(argv):
132130

133131
def generate_matrix(matrix_arg):
134132
if not matrix_arg:
135-
return {}
133+
return None
136134
matrix = {}
137135
for matrix_column in matrix_arg.split(";"):
138136
key, val = matrix_column.split("=")
@@ -143,29 +141,21 @@ def generate_matrix(matrix_arg):
143141
def main(argv=None):
144142
args = validate_args(argv)
145143

146-
with open(args.config) as f:
147-
parsed_config = yaml.load(f, Loader=yaml.FullLoader)
148-
149-
validate_dependencies(parsed_config)
144+
parsed_config = load_config_from_file(args.config)
150145

151146
matrix = generate_matrix(args.matrix)
152147
to_stdout = all([args.file_key, args.output, args.matrix is not None])
153148

154149
if to_stdout:
155-
parsed_config["files"] = {
156-
args.file_key: {
157-
**parsed_config["files"][args.file_key],
158-
"matrix": matrix,
159-
"output": args.output,
160-
}
161-
}
162-
163-
if args.prepend_channels:
164-
parsed_config["channels"] = args.prepend_channels + parsed_config.get(
165-
"channels", default_channels
166-
)
150+
file_keys = [args.file_key]
151+
output = {Output(args.output)}
152+
else:
153+
file_keys = list(parsed_config.files.keys())
154+
output = {Output.PYPROJECT, Output.CONDA, Output.REQUIREMENTS}
167155

168156
if args.clean:
169157
delete_existing_files(args.clean)
170158

171-
make_dependency_files(parsed_config, args.config, to_stdout)
159+
make_dependency_files(
160+
parsed_config, file_keys, output, matrix, args.prepend_channels, to_stdout
161+
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from dataclasses import dataclass, field
2+
from enum import Enum
3+
from os import PathLike
4+
from pathlib import Path
5+
6+
import yaml
7+
8+
from . import constants
9+
from .rapids_dependency_file_validator import validate_dependencies
10+
11+
12+
class Output(Enum):
13+
PYPROJECT = "pyproject"
14+
REQUIREMENTS = "requirements"
15+
CONDA = "conda"
16+
17+
18+
@dataclass
19+
class FileExtras:
20+
table: str
21+
key: str | None = None
22+
23+
24+
@dataclass
25+
class File:
26+
output: set[Output]
27+
includes: list[str]
28+
extras: FileExtras | None = None
29+
matrix: dict[str, list[str]] = field(default_factory=dict)
30+
requirements_dir: Path = Path(constants.default_requirements_dir)
31+
conda_dir: Path = Path(constants.default_conda_dir)
32+
pyproject_dir: Path = Path(constants.default_pyproject_dir)
33+
34+
35+
@dataclass
36+
class PipRequirements:
37+
pip: list[str]
38+
39+
40+
@dataclass
41+
class CommonDependencies:
42+
output_types: set[Output]
43+
packages: list[str | PipRequirements]
44+
45+
46+
@dataclass
47+
class MatrixMatcher:
48+
matrix: dict[str, str]
49+
packages: list[str | PipRequirements]
50+
51+
52+
@dataclass
53+
class SpecificDependencies:
54+
output_types: set[Output]
55+
matrices: list[MatrixMatcher]
56+
57+
58+
@dataclass
59+
class Dependencies:
60+
common: list[CommonDependencies] = field(default_factory=list)
61+
specific: list[SpecificDependencies] = field(default_factory=list)
62+
63+
64+
@dataclass
65+
class Config:
66+
path: Path
67+
files: dict[str, File] = field(default_factory=dict)
68+
channels: list[str] = field(
69+
default_factory=lambda: list(constants.default_channels)
70+
)
71+
dependencies: dict[str, Dependencies] = field(default_factory=dict)
72+
73+
74+
def _parse_outputs(outputs: str | list[str]) -> set[Output]:
75+
if isinstance(outputs, str):
76+
outputs = [outputs]
77+
if outputs == ["none"]:
78+
outputs = []
79+
return {Output(o) for o in outputs}
80+
81+
82+
def _parse_extras(extras: dict[str, str]) -> FileExtras:
83+
return FileExtras(
84+
table=extras["table"],
85+
key=extras.get("key", None),
86+
)
87+
88+
89+
def _parse_file(file_config: dict[str, object]) -> File:
90+
def get_extras():
91+
try:
92+
extras = file_config["extras"]
93+
except KeyError:
94+
return None
95+
96+
return _parse_extras(extras)
97+
98+
return File(
99+
output=_parse_outputs(file_config["output"]),
100+
extras=get_extras(),
101+
includes=list(file_config["includes"]),
102+
matrix={
103+
key: list(value) for key, value in file_config.get("matrix", {}).items()
104+
},
105+
requirements_dir=Path(
106+
file_config.get("requirements_dir", constants.default_requirements_dir)
107+
),
108+
conda_dir=Path(file_config.get("conda_dir", constants.default_conda_dir)),
109+
pyproject_dir=Path(
110+
file_config.get("pyproject_dir", constants.default_pyproject_dir)
111+
),
112+
)
113+
114+
115+
def _parse_requirement(requirement: str | dict[str, str]) -> str | PipRequirements:
116+
if isinstance(requirement, str):
117+
return requirement
118+
119+
return PipRequirements(pip=requirement["pip"])
120+
121+
122+
def _parse_dependencies(dependencies: dict[str, object]) -> Dependencies:
123+
return Dependencies(
124+
common=[
125+
CommonDependencies(
126+
output_types=_parse_outputs(d["output_types"]),
127+
packages=[_parse_requirement(p) for p in d["packages"]],
128+
)
129+
for d in dependencies.get("common", [])
130+
],
131+
specific=[
132+
SpecificDependencies(
133+
output_types=_parse_outputs(d["output_types"]),
134+
matrices=[
135+
MatrixMatcher(
136+
matrix=dict(m.get("matrix", {}) or {}),
137+
packages=[
138+
_parse_requirement(p) for p in m.get("packages", []) or []
139+
],
140+
)
141+
for m in d["matrices"]
142+
],
143+
)
144+
for d in dependencies.get("specific", [])
145+
],
146+
)
147+
148+
149+
def _parse_channels(channels) -> list[str]:
150+
if isinstance(channels, str):
151+
return [channels]
152+
153+
return list(channels)
154+
155+
156+
def parse_config(config: dict[str, object], path: PathLike) -> Config:
157+
validate_dependencies(config)
158+
return Config(
159+
path=Path(path),
160+
files={key: _parse_file(value) for key, value in config["files"].items()},
161+
channels=_parse_channels(config.get("channels", [])),
162+
dependencies={
163+
key: _parse_dependencies(value)
164+
for key, value in config["dependencies"].items()
165+
},
166+
)
167+
168+
169+
def load_config_from_file(path: PathLike) -> Config:
170+
with open(path) as f:
171+
return parse_config(yaml.safe_load(f), path)

src/rapids_dependency_file_generator/constants.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
from enum import Enum
2-
3-
4-
class OutputTypes(Enum):
5-
CONDA = "conda"
6-
REQUIREMENTS = "requirements"
7-
PYPROJECT = "pyproject"
8-
NONE = "none"
9-
10-
def __str__(self):
11-
return self.value
12-
13-
141
cli_name = "rapids-dependency-file-generator"
152

163
default_channels = [

0 commit comments

Comments
 (0)