Skip to content

Commit d16bb8d

Browse files
committed
Add direct_url model and validator
1 parent 2df7bdd commit d16bb8d

File tree

2 files changed

+467
-0
lines changed

2 files changed

+467
-0
lines changed

src/packaging/direct_url.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
import sys
5+
from collections.abc import Mapping
6+
from dataclasses import dataclass
7+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
if sys.version_info >= (3, 11):
11+
from typing import Self
12+
else:
13+
from typing_extensions import Self
14+
15+
__all__ = [
16+
"ArchiveInfo",
17+
"DirInfo",
18+
"DirectUrl",
19+
"DirectUrlValidationError",
20+
"VcsInfo",
21+
]
22+
23+
_T = TypeVar("_T")
24+
25+
26+
class _FromMappingProtocol(Protocol): # pragma: no cover
27+
@classmethod
28+
def _from_dict(cls, d: Mapping[str, Any]) -> Self: ...
29+
30+
31+
_FromMappingProtocolT = TypeVar("_FromMappingProtocolT", bound=_FromMappingProtocol)
32+
33+
34+
def _json_dict_factory(data: list[tuple[str, Any]]) -> dict[str, Any]:
35+
return {key: value for key, value in data if value is not None}
36+
37+
38+
def _get(d: Mapping[str, Any], expected_type: type[_T], key: str) -> _T | None:
39+
"""Get a value from the dictionary and verify it's the expected type."""
40+
if (value := d.get(key)) is None:
41+
return None
42+
if not isinstance(value, expected_type):
43+
raise DirectUrlValidationError(
44+
f"Unexpected type {type(value).__name__} "
45+
f"(expected {expected_type.__name__})",
46+
context=key,
47+
)
48+
return value
49+
50+
51+
def _get_required(d: Mapping[str, Any], expected_type: type[_T], key: str) -> _T:
52+
"""Get a required value from the dictionary and verify it's the expected type."""
53+
if (value := _get(d, expected_type, key)) is None:
54+
raise _DirectUrlRequiredKeyError(key)
55+
return value
56+
57+
58+
def _get_object(
59+
d: Mapping[str, Any], target_type: type[_FromMappingProtocolT], key: str
60+
) -> _FromMappingProtocolT | None:
61+
"""Get a dictionary value from the dictionary and convert it to a dataclass."""
62+
if (value := _get(d, Mapping, key)) is None: # type: ignore[type-abstract]
63+
return None
64+
try:
65+
return target_type._from_dict(value)
66+
except Exception as e:
67+
raise DirectUrlValidationError(e, context=key) from e
68+
69+
70+
class DirectUrlValidationError(Exception):
71+
"""Raised when when input data is not spec-compliant."""
72+
73+
context: str | None = None
74+
message: str
75+
76+
def __init__(
77+
self,
78+
cause: str | Exception,
79+
*,
80+
context: str | None = None,
81+
) -> None:
82+
if isinstance(cause, DirectUrlValidationError):
83+
if cause.context:
84+
self.context = (
85+
f"{context}.{cause.context}" if context else cause.context
86+
)
87+
else:
88+
self.context = context # pragma: no cover
89+
self.message = cause.message
90+
else:
91+
self.context = context
92+
self.message = str(cause)
93+
94+
def __str__(self) -> str:
95+
if self.context:
96+
return f"{self.message} in {self.context!r}"
97+
return self.message
98+
99+
100+
class _DirectUrlRequiredKeyError(DirectUrlValidationError):
101+
def __init__(self, key: str) -> None:
102+
super().__init__("Missing required value", context=key)
103+
104+
105+
@dataclass(frozen=True, init=False)
106+
class VcsInfo:
107+
vcs: str
108+
commit_id: str
109+
requested_revision: str | None = None
110+
111+
def __init__(
112+
self,
113+
*,
114+
vcs: str,
115+
commit_id: str,
116+
requested_revision: str | None = None,
117+
) -> None:
118+
object.__setattr__(self, "vcs", vcs)
119+
object.__setattr__(self, "commit_id", commit_id)
120+
object.__setattr__(self, "requested_revision", requested_revision)
121+
122+
@classmethod
123+
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
124+
# We can't validate vcs value because is not closed.
125+
return cls(
126+
vcs=_get_required(d, str, "vcs"),
127+
requested_revision=_get(d, str, "requested_revision"),
128+
commit_id=_get_required(d, str, "commit_id"),
129+
)
130+
131+
132+
@dataclass(frozen=True, init=False)
133+
class ArchiveInfo:
134+
hashes: Mapping[str, str] | None = None
135+
hash: str | None = None # Deprecated, use `hashes` instead
136+
137+
def __init__(
138+
self,
139+
*,
140+
hashes: Mapping[str, str] | None = None,
141+
hash: str | None = None,
142+
) -> None:
143+
object.__setattr__(self, "hashes", hashes)
144+
object.__setattr__(self, "hash", hash)
145+
146+
@classmethod
147+
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
148+
archive_info = cls(
149+
hashes=_get(d, Mapping, "hashes"), # type: ignore[type-abstract]
150+
hash=_get(d, str, "hash"),
151+
)
152+
hashes = archive_info.hashes or {}
153+
if not all(isinstance(hash, str) for hash in hashes.values()):
154+
raise DirectUrlValidationError(
155+
"Hash values must be strings", context="hashes"
156+
)
157+
if archive_info.hash is not None:
158+
if "=" not in archive_info.hash:
159+
raise DirectUrlValidationError(
160+
"Invalid hash format (expected '<algorithm>=<hash>')",
161+
context="hash",
162+
)
163+
if archive_info.hashes is not None:
164+
# if `hashes` are present, the legacy `hash` must match one of them
165+
hash_algorithm, hash_value = archive_info.hash.split("=", 1)
166+
if hash_algorithm not in hashes:
167+
raise DirectUrlValidationError(
168+
f"Algorithm {hash_algorithm!r} used in hash field "
169+
f"is not present in hashes field",
170+
context="hashes",
171+
)
172+
if hashes[hash_algorithm] != hash_value:
173+
raise DirectUrlValidationError(
174+
f"Algorithm {hash_algorithm!r} used in hash field "
175+
f"has different value in hashes field",
176+
context="hash",
177+
)
178+
return archive_info
179+
180+
181+
@dataclass(frozen=True, init=False)
182+
class DirInfo:
183+
editable: bool | None = None
184+
185+
def __init__(
186+
self,
187+
*,
188+
editable: bool | None = None,
189+
) -> None:
190+
object.__setattr__(self, "editable", editable)
191+
192+
@classmethod
193+
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
194+
return cls(
195+
editable=_get(d, bool, "editable"),
196+
)
197+
198+
199+
@dataclass(frozen=True, init=False)
200+
class DirectUrl:
201+
url: str
202+
archive_info: ArchiveInfo | None = None
203+
vcs_info: VcsInfo | None = None
204+
dir_info: DirInfo | None = None
205+
subdirectory: str | None = None # XXX Path or str?
206+
207+
def __init__(
208+
self,
209+
*,
210+
url: str,
211+
archive_info: ArchiveInfo | None = None,
212+
vcs_info: VcsInfo | None = None,
213+
dir_info: DirInfo | None = None,
214+
subdirectory: str | None = None,
215+
) -> None:
216+
object.__setattr__(self, "url", url)
217+
object.__setattr__(self, "archive_info", archive_info)
218+
object.__setattr__(self, "vcs_info", vcs_info)
219+
object.__setattr__(self, "dir_info", dir_info)
220+
object.__setattr__(self, "subdirectory", subdirectory)
221+
222+
@classmethod
223+
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
224+
direct_url = cls(
225+
url=_get_required(d, str, "url"),
226+
archive_info=_get_object(d, ArchiveInfo, "archive_info"),
227+
vcs_info=_get_object(d, VcsInfo, "vcs_info"),
228+
dir_info=_get_object(d, DirInfo, "dir_info"),
229+
subdirectory=_get(d, str, "subdirectory"),
230+
)
231+
if (
232+
bool(direct_url.vcs_info)
233+
+ bool(direct_url.archive_info)
234+
+ bool(direct_url.dir_info)
235+
) != 1:
236+
raise DirectUrlValidationError(
237+
"Exactly one of vcs_info, archive_info, dir_info must be present"
238+
)
239+
if direct_url.dir_info is not None and not direct_url.url.startswith("file://"):
240+
raise DirectUrlValidationError(
241+
"URL scheme must be file:// when dir_info is present",
242+
context="url",
243+
)
244+
# XXX subdirectory must be relative, can we, should we validate that here?
245+
# XXX url MUST be stripped of any sensitive authentication information.
246+
# We can't validate it here because it MAY contain git or other non security
247+
# sensitive auth strings.
248+
return direct_url
249+
250+
@classmethod
251+
def from_dict(cls, d: Mapping[str, Any], /) -> Self:
252+
return cls._from_dict(d)
253+
254+
def to_dict(self) -> Mapping[str, Any]:
255+
return dataclasses.asdict(self, dict_factory=_json_dict_factory)
256+
257+
def validate(self) -> None:
258+
"""Validate the DirectUrl instance against the specification.
259+
260+
Raises :class:`DirectUrlValidationError` otherwise.
261+
"""
262+
self.from_dict(self.to_dict())

0 commit comments

Comments
 (0)