Skip to content

Commit fd9feeb

Browse files
committed
Update check caching to use hash keys of inputs
1 parent 0ec0992 commit fd9feeb

26 files changed

+694
-216
lines changed

atr/attestable.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,21 @@
1818
from __future__ import annotations
1919

2020
import json
21-
from typing import TYPE_CHECKING, Any, Final
21+
from typing import TYPE_CHECKING, Any
2222

2323
import aiofiles
2424
import aiofiles.os
25-
import blake3
2625
import pydantic
2726

27+
import atr.hashes as hashes
2828
import atr.log as log
2929
import atr.models.attestable as models
3030
import atr.util as util
31+
from atr.models.attestable import AttestableChecksV1
3132

3233
if TYPE_CHECKING:
3334
import pathlib
3435

35-
_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024
36-
3736

3837
def attestable_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path:
3938
return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.json"
@@ -43,18 +42,14 @@ def attestable_paths_path(project_name: str, version_name: str, revision_number:
4342
return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.paths.json"
4443

4544

46-
async def compute_file_hash(path: pathlib.Path) -> str:
47-
hasher = blake3.blake3()
48-
async with aiofiles.open(path, "rb") as f:
49-
while chunk := await f.read(_HASH_CHUNK_SIZE):
50-
hasher.update(chunk)
51-
return f"blake3:{hasher.hexdigest()}"
52-
53-
5445
def github_tp_payload_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path:
5546
return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json"
5647

5748

49+
def attestable_checks_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path:
50+
return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.checks.json"
51+
52+
5853
async def github_tp_payload_write(
5954
project_name: str, version_name: str, revision_number: str, github_payload: dict[str, Any]
6055
) -> None:
@@ -99,6 +94,22 @@ async def load_paths(
9994
return None
10095

10196

97+
async def load_checks(
98+
project_name: str,
99+
version_name: str,
100+
revision_number: str,
101+
) -> list[int] | None:
102+
file_path = attestable_checks_path(project_name, version_name, revision_number)
103+
if await aiofiles.os.path.isfile(file_path):
104+
try:
105+
async with aiofiles.open(file_path, encoding="utf-8") as f:
106+
data = json.loads(await f.read())
107+
return models.AttestableChecksV1.model_validate(data).checks
108+
except (json.JSONDecodeError, pydantic.ValidationError) as e:
109+
log.warning(f"Could not parse {file_path}: {e}")
110+
return []
111+
112+
102113
def migrate_to_paths_files() -> int:
103114
attestable_dir = util.get_attestable_dir()
104115
if not attestable_dir.is_dir():
@@ -140,26 +151,52 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str,
140151
if "\\" in path_key:
141152
# TODO: We should centralise this, and forbid some other characters too
142153
raise ValueError(f"Backslash in path is forbidden: {path_key}")
143-
path_to_hash[path_key] = await compute_file_hash(full_path)
154+
path_to_hash[path_key] = await hashes.compute_file_hash(full_path)
144155
path_to_size[path_key] = (await aiofiles.os.stat(full_path)).st_size
145156
return path_to_hash, path_to_size
146157

147158

148-
async def write(
159+
async def write_files_data(
149160
project_name: str,
150161
version_name: str,
151162
revision_number: str,
163+
release_policy: dict[str, Any] | None,
152164
uploader_uid: str,
153165
previous: models.AttestableV1 | None,
154166
path_to_hash: dict[str, str],
155167
path_to_size: dict[str, int],
156168
) -> None:
157-
result = _generate(path_to_hash, path_to_size, revision_number, uploader_uid, previous)
169+
result = _generate_files_data(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous)
158170
file_path = attestable_path(project_name, version_name, revision_number)
159171
await util.atomic_write_file(file_path, result.model_dump_json(indent=2))
160172
paths_result = models.AttestablePathsV1(paths=result.paths)
161173
paths_file_path = attestable_paths_path(project_name, version_name, revision_number)
162174
await util.atomic_write_file(paths_file_path, paths_result.model_dump_json(indent=2))
175+
checks_file_path = attestable_checks_path(project_name, version_name, revision_number)
176+
if not checks_file_path.exists():
177+
async with aiofiles.open(checks_file_path, "w", encoding="utf-8") as f:
178+
await f.write(models.AttestableChecksV1().model_dump_json(indent=2))
179+
180+
181+
async def write_checks_data(
182+
project_name: str,
183+
version_name: str,
184+
revision_number: str,
185+
checks: list[int],
186+
) -> None:
187+
log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}: {checks}")
188+
189+
def modify(content: str) -> str:
190+
try:
191+
current = AttestableChecksV1.model_validate_json(content).checks
192+
except pydantic.ValidationError:
193+
current = []
194+
new_checks = set(current or [])
195+
new_checks.update(checks)
196+
result = models.AttestableChecksV1(checks=sorted(new_checks))
197+
return result.model_dump_json(indent=2)
198+
199+
await util.atomic_modify_file(attestable_checks_path(project_name, version_name, revision_number), modify)
163200

164201

165202
def _compute_hashes_with_attribution(
@@ -197,10 +234,11 @@ def _compute_hashes_with_attribution(
197234
return new_hashes
198235

199236

200-
def _generate(
237+
def _generate_files_data(
201238
path_to_hash: dict[str, str],
202239
path_to_size: dict[str, int],
203240
revision_number: str,
241+
release_policy: dict[str, Any] | None,
204242
uploader_uid: str,
205243
previous: models.AttestableV1 | None,
206244
) -> models.AttestableV1:
@@ -215,4 +253,5 @@ def _generate(
215253
return models.AttestableV1(
216254
paths=dict(path_to_hash),
217255
hashes=dict(new_hashes),
256+
policy=release_policy or {},
218257
)

atr/db/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ async def begin_immediate(self) -> None:
155155
def check_result(
156156
self,
157157
id: Opt[int] = NOT_SET,
158+
id_in: Opt[list[int]] = NOT_SET,
158159
release_name: Opt[str] = NOT_SET,
159160
revision_number: Opt[str] = NOT_SET,
160161
checker: Opt[str] = NOT_SET,
@@ -164,12 +165,17 @@ def check_result(
164165
status: Opt[sql.CheckResultStatus] = NOT_SET,
165166
message: Opt[str] = NOT_SET,
166167
data: Opt[Any] = NOT_SET,
168+
inputs_hash: Opt[str] = NOT_SET,
167169
_release: bool = False,
168170
) -> Query[sql.CheckResult]:
169171
query = sqlmodel.select(sql.CheckResult)
170172

173+
via = sql.validate_instrumented_attribute
174+
171175
if is_defined(id):
172176
query = query.where(sql.CheckResult.id == id)
177+
if is_defined(id_in):
178+
query = query.where(via(sql.CheckResult.id).in_(id_in))
173179
if is_defined(release_name):
174180
query = query.where(sql.CheckResult.release_name == release_name)
175181
if is_defined(revision_number):
@@ -188,6 +194,8 @@ def check_result(
188194
query = query.where(sql.CheckResult.message == message)
189195
if is_defined(data):
190196
query = query.where(sql.CheckResult.data == data)
197+
if is_defined(inputs_hash):
198+
query = query.where(sql.CheckResult.inputs_hash == inputs_hash)
191199

192200
if _release:
193201
query = query.options(joined_load(sql.CheckResult.release))

atr/docs/checks.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ This check records separate checker keys for errors, warnings, and success. Use
5252

5353
For each `.sha256` or `.sha512` file, ATR computes the hash of the referenced artifact and compares it with the expected value. It supports files that contain just the hash as well as files that include a filename and hash on the same line. If the suffix does not indicate `sha256` or `sha512`, the check fails.
5454

55-
The checker key is `atr.tasks.checks.hashing.check`.
55+
The checker key is `atr.tasks.checks.file_hash.check`.
5656

5757
### Signature verification
5858

atr/docs/tasks.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ In `atr/tasks/checks` you will find several modules that perform these check tas
4141
In `atr/tasks/__init__.py` you will see imports for existing modules where you can add an import for new check task, for example:
4242

4343
```python
44-
import atr.tasks.checks.hashing as hashing
44+
import atr.tasks.checks.hashing as file_hash
4545
import atr.tasks.checks.license as license
4646
```
4747

atr/get/report.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,17 @@ async def selected_path(session: web.Committer, project_name: str, version_name:
4040

4141
# If the draft is not found, we try to get the release candidate
4242
try:
43-
release = await session.release(project_name, version_name, with_committee=True)
43+
release = await session.release(
44+
project_name, version_name, with_committee=True, with_release_policy=True, with_project_release_policy=True
45+
)
4446
except base.ASFQuartException:
4547
release = await session.release(
46-
project_name, version_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE, with_committee=True
48+
project_name,
49+
version_name,
50+
phase=sql.ReleasePhase.RELEASE_CANDIDATE,
51+
with_committee=True,
52+
with_release_policy=True,
53+
with_project_release_policy=True,
4754
)
4855

4956
if release.committee is None:

atr/merge.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import aiofiles.os
2525

2626
import atr.attestable as attestable
27+
import atr.hashes as hashes
2728
import atr.util as util
2829

2930
if TYPE_CHECKING:
@@ -131,7 +132,7 @@ async def _add_from_prior(
131132
if (prior_hashes is not None) and (path in prior_hashes):
132133
n_hashes[path] = prior_hashes[path]
133134
else:
134-
n_hashes[path] = await attestable.compute_file_hash(target)
135+
n_hashes[path] = await hashes.compute_file_hash(target)
135136
stat_result = await aiofiles.os.stat(target)
136137
n_sizes[path] = stat_result.st_size
137138
return prior_hashes
@@ -211,7 +212,7 @@ async def _merge_all_present(
211212
if (prior_hashes is not None) and (path in prior_hashes):
212213
p_hash = prior_hashes[path]
213214
else:
214-
p_hash = await attestable.compute_file_hash(prior_dir / path)
215+
p_hash = await hashes.compute_file_hash(prior_dir / path)
215216
if p_hash != b_hash:
216217
# Case 11 via hash: base and new have the same content but prior differs
217218
return await _replace_with_prior(
@@ -250,7 +251,7 @@ async def _replace_with_prior(
250251
if (prior_hashes is not None) and (path in prior_hashes):
251252
n_hashes[path] = prior_hashes[path]
252253
else:
253-
n_hashes[path] = await attestable.compute_file_hash(file_path)
254+
n_hashes[path] = await hashes.compute_file_hash(file_path)
254255
stat_result = await aiofiles.os.stat(file_path)
255256
n_sizes[path] = stat_result.st_size
256257
return prior_hashes

atr/models/attestable.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Annotated, Literal
18+
from typing import Annotated, Any, Literal
1919

2020
import pydantic
2121

@@ -27,6 +27,11 @@ class HashEntry(schema.Strict):
2727
uploaders: list[Annotated[tuple[str, str], pydantic.BeforeValidator(tuple)]]
2828

2929

30+
class AttestableChecksV1(schema.Strict):
31+
version: Literal[1] = 1
32+
checks: list[int] = schema.factory(list)
33+
34+
3035
class AttestablePathsV1(schema.Strict):
3136
version: Literal[1] = 1
3237
paths: dict[str, str] = schema.factory(dict)
@@ -36,3 +41,4 @@ class AttestableV1(schema.Strict):
3641
version: Literal[1] = 1
3742
paths: dict[str, str] = schema.factory(dict)
3843
hashes: dict[str, HashEntry] = schema.factory(dict)
44+
policy: dict[str, Any] = schema.factory(dict)

atr/models/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ class CheckResult(sqlmodel.SQLModel, table=True):
946946
data: Any = sqlmodel.Field(
947947
sa_column=sqlalchemy.Column(sqlalchemy.JSON), **example({"expected": "...", "found": "..."})
948948
)
949-
input_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc..."))
949+
inputs_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc..."))
950950
cached: bool = sqlmodel.Field(default=False, **example(False))
951951

952952

atr/storage/readers/checks.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,51 @@
1818
# Removing this will cause circular imports
1919
from __future__ import annotations
2020

21+
import importlib
2122
from typing import TYPE_CHECKING
2223

24+
import atr.attestable as attestable
2325
import atr.db as db
26+
import atr.hashes as hashes
2427
import atr.models.sql as sql
2528
import atr.storage as storage
2629
import atr.storage.types as types
30+
import atr.tasks.checks as checks
2731
import atr.util as util
2832

2933
if TYPE_CHECKING:
3034
import pathlib
31-
from collections.abc import Callable
35+
from collections.abc import Callable, Sequence
36+
37+
38+
async def _filter_check_results_by_hash(
39+
all_check_results: Sequence[sql.CheckResult],
40+
rel_path: pathlib.Path,
41+
input_hash_by_module: dict[str, str | None],
42+
release: sql.Release,
43+
) -> Sequence[sql.CheckResult]:
44+
filtered_check_results = []
45+
if release.latest_revision_number is None:
46+
raise ValueError("Release has no revision - Invalid state")
47+
for cr in all_check_results:
48+
if cr.checker not in input_hash_by_module:
49+
module_path = cr.checker.rsplit(".", 1)[0]
50+
try:
51+
module = importlib.import_module(module_path)
52+
policy_keys: list[str] = module.INPUT_POLICY_KEYS
53+
extra_arg_names: list[str] = getattr(module, "INPUT_EXTRA_ARGS", [])
54+
except (ImportError, AttributeError):
55+
policy_keys = []
56+
extra_arg_names = []
57+
extra_args = await checks.resolve_extra_args(extra_arg_names, release)
58+
cache_key = await checks.resolve_cache_key(
59+
cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name
60+
)
61+
input_hash_by_module[cr.checker] = hashes.compute_dict_hash(cache_key) if cache_key else None
62+
63+
if cr.inputs_hash == input_hash_by_module[cr.checker]:
64+
filtered_check_results.append(cr)
65+
return filtered_check_results
3266

3367

3468
class GeneralPublic:
@@ -48,15 +82,20 @@ async def by_release_path(self, release: sql.Release, rel_path: pathlib.Path) ->
4882
if release.latest_revision_number is None:
4983
raise ValueError("Release has no revision - Invalid state")
5084

51-
query = self.__data.check_result(
52-
release_name=release.name,
53-
revision_number=release.latest_revision_number,
54-
primary_rel_path=str(rel_path),
55-
).order_by(
56-
sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(),
57-
sql.validate_instrumented_attribute(sql.CheckResult.created).desc(),
85+
check_ids = await attestable.load_checks(release.project_name, release.version, release.latest_revision_number)
86+
all_check_results = (
87+
[
88+
a
89+
for a in await self.__data.check_result(id_in=check_ids)
90+
.order_by(
91+
sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(),
92+
sql.validate_instrumented_attribute(sql.CheckResult.created).desc(),
93+
)
94+
.all()
95+
]
96+
if check_ids
97+
else []
5898
)
59-
all_check_results = await query.all()
6099

61100
# Filter out any results that are ignored
62101
unignored_checks = []

0 commit comments

Comments
 (0)