Skip to content

Commit d3a2563

Browse files
committed
Read and write checks to/from attestable data
1 parent 133ab83 commit d3a2563

File tree

21 files changed

+254
-136
lines changed

21 files changed

+254
-136
lines changed

atr/attestable.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import atr.log as log
2929
import atr.models.attestable as models
3030
import atr.util as util
31+
from atr.models.attestable import AttestableChecksV1
3132

3233
if TYPE_CHECKING:
3334
import pathlib
@@ -45,6 +46,10 @@ def github_tp_payload_path(project_name: str, version_name: str, revision_number
4546
return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json"
4647

4748

49+
def attestable_checks_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path:
50+
return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.checks.json"
51+
52+
4853
async def github_tp_payload_write(
4954
project_name: str, version_name: str, revision_number: str, github_payload: dict[str, Any]
5055
) -> None:
@@ -89,6 +94,22 @@ async def load_paths(
8994
return None
9095

9196

97+
async def load_checks(
98+
project_name: str,
99+
version_name: str,
100+
revision_number: str,
101+
) -> list[int] | None:
102+
file_path = attestable_checks_path(project_name, version_name, revision_number)
103+
if await aiofiles.os.path.isfile(file_path):
104+
try:
105+
async with aiofiles.open(file_path, encoding="utf-8") as f:
106+
data = json.loads(await f.read())
107+
return models.AttestableChecksV1.model_validate(data).checks
108+
except (json.JSONDecodeError, pydantic.ValidationError) as e:
109+
log.warning(f"Could not parse {file_path}: {e}")
110+
return []
111+
112+
92113
def migrate_to_paths_files() -> int:
93114
attestable_dir = util.get_attestable_dir()
94115
if not attestable_dir.is_dir():
@@ -135,7 +156,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str,
135156
return path_to_hash, path_to_size
136157

137158

138-
async def write(
159+
async def write_files_data(
139160
project_name: str,
140161
version_name: str,
141162
revision_number: str,
@@ -145,12 +166,37 @@ async def write(
145166
path_to_hash: dict[str, str],
146167
path_to_size: dict[str, int],
147168
) -> None:
148-
result = _generate(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous)
169+
result = _generate_files_data(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous)
149170
file_path = attestable_path(project_name, version_name, revision_number)
150171
await util.atomic_write_file(file_path, result.model_dump_json(indent=2))
151172
paths_result = models.AttestablePathsV1(paths=result.paths)
152173
paths_file_path = attestable_paths_path(project_name, version_name, revision_number)
153174
await util.atomic_write_file(paths_file_path, paths_result.model_dump_json(indent=2))
175+
checks_file_path = attestable_checks_path(project_name, version_name, revision_number)
176+
if not checks_file_path.exists():
177+
async with aiofiles.open(checks_file_path, "w", encoding="utf-8") as f:
178+
await f.write(models.AttestableChecksV1().model_dump_json(indent=2))
179+
180+
181+
async def write_checks_data(
182+
project_name: str,
183+
version_name: str,
184+
revision_number: str,
185+
checks: list[int],
186+
) -> None:
187+
log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}: {checks}")
188+
189+
def modify(content: str) -> str:
190+
try:
191+
current = AttestableChecksV1.model_validate_json(content).checks
192+
except pydantic.ValidationError:
193+
current = []
194+
new_checks = set(current or [])
195+
new_checks.update(checks)
196+
result = models.AttestableChecksV1(checks=sorted(new_checks))
197+
return result.model_dump_json(indent=2)
198+
199+
await util.atomic_modify_file(attestable_checks_path(project_name, version_name, revision_number), modify)
154200

155201

156202
def _compute_hashes_with_attribution(
@@ -188,7 +234,7 @@ def _compute_hashes_with_attribution(
188234
return new_hashes
189235

190236

191-
def _generate(
237+
def _generate_files_data(
192238
path_to_hash: dict[str, str],
193239
path_to_size: dict[str, int],
194240
revision_number: str,

atr/db/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ async def begin_immediate(self) -> None:
155155
def check_result(
156156
self,
157157
id: Opt[int] = NOT_SET,
158+
id_in: Opt[list[int]] = NOT_SET,
158159
release_name: Opt[str] = NOT_SET,
159160
revision_number: Opt[str] = NOT_SET,
160161
checker: Opt[str] = NOT_SET,
@@ -169,8 +170,12 @@ def check_result(
169170
) -> Query[sql.CheckResult]:
170171
query = sqlmodel.select(sql.CheckResult)
171172

173+
via = sql.validate_instrumented_attribute
174+
172175
if is_defined(id):
173176
query = query.where(sql.CheckResult.id == id)
177+
if is_defined(id_in):
178+
query = query.where(via(sql.CheckResult.id).in_(id_in))
174179
if is_defined(release_name):
175180
query = query.where(sql.CheckResult.release_name == release_name)
176181
if is_defined(revision_number):

atr/models/attestable.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ class HashEntry(schema.Strict):
2727
uploaders: list[Annotated[tuple[str, str], pydantic.BeforeValidator(tuple)]]
2828

2929

30+
class AttestableChecksV1(schema.Strict):
31+
version: Literal[1] = 1
32+
checks: list[int] = schema.factory(list)
33+
34+
3035
class AttestablePathsV1(schema.Strict):
3136
version: Literal[1] = 1
3237
paths: dict[str, str] = schema.factory(dict)

atr/storage/readers/checks.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import importlib
2222
from typing import TYPE_CHECKING
2323

24+
import atr.attestable as attestable
2425
import atr.db as db
2526
import atr.hashes as hashes
2627
import atr.models.sql as sql
@@ -44,22 +45,22 @@ async def _filter_check_results_by_hash(
4445
if release.latest_revision_number is None:
4546
raise ValueError("Release has no revision - Invalid state")
4647
for cr in all_check_results:
47-
module_path = cr.checker.rsplit(".", 1)[0]
48-
if module_path not in input_hash_by_module:
48+
if cr.checker not in input_hash_by_module:
49+
module_path = cr.checker.rsplit(".", 1)[0]
4950
try:
5051
module = importlib.import_module(module_path)
5152
policy_keys: list[str] = module.INPUT_POLICY_KEYS
5253
extra_arg_names: list[str] = getattr(module, "INPUT_EXTRA_ARGS", [])
5354
except (ImportError, AttributeError):
5455
policy_keys = []
5556
extra_arg_names = []
56-
extra_args = checks.resolve_extra_args(extra_arg_names, release)
57+
extra_args = await checks.resolve_extra_args(extra_arg_names, release)
5758
cache_key = await checks.resolve_cache_key(
5859
cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name
5960
)
60-
input_hash_by_module[module_path] = hashes.compute_dict_hash(cache_key) if cache_key else None
61+
input_hash_by_module[cr.checker] = hashes.compute_dict_hash(cache_key) if cache_key else None
6162

62-
if cr.inputs_hash == input_hash_by_module[module_path]:
63+
if cr.inputs_hash == input_hash_by_module[cr.checker]:
6364
filtered_check_results.append(cr)
6465
return filtered_check_results
6566

@@ -81,31 +82,26 @@ async def by_release_path(self, release: sql.Release, rel_path: pathlib.Path) ->
8182
if release.latest_revision_number is None:
8283
raise ValueError("Release has no revision - Invalid state")
8384

84-
# TODO: Is this potentially too much data? Within a revision I hope it's not too bad?
85-
86-
query = self.__data.check_result(
87-
release_name=release.name,
88-
primary_rel_path=str(rel_path),
89-
).order_by(
90-
sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(),
91-
sql.validate_instrumented_attribute(sql.CheckResult.created).desc(),
92-
)
93-
all_check_results = await query.all()
94-
95-
# Filter to checks for the current file version / policy
96-
# Cache the computed input hash per checker module, since all results here share the same file and release
97-
input_hash_by_module: dict[str, str | None] = {}
98-
# TODO: This has a bug - create an archive, it'll scan with a hash and show missing checksum.
99-
# Then generate a checksum. It'll re-scan the file with the same hash, but now has one. Two checks shown.
100-
filtered_check_results = await _filter_check_results_by_hash(
101-
all_check_results, rel_path, input_hash_by_module, release
85+
check_ids = await attestable.load_checks(release.project_name, release.version, release.latest_revision_number)
86+
all_check_results = (
87+
[
88+
a
89+
for a in await self.__data.check_result(id_in=check_ids)
90+
.order_by(
91+
sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(),
92+
sql.validate_instrumented_attribute(sql.CheckResult.created).desc(),
93+
)
94+
.all()
95+
]
96+
if check_ids
97+
else []
10298
)
10399

104100
# Filter out any results that are ignored
105101
unignored_checks = []
106102
ignored_checks = []
107103
match_ignore = await self.ignores_matcher(release.project_name)
108-
for cr in filtered_check_results:
104+
for cr in all_check_results:
109105
if not match_ignore(cr):
110106
unignored_checks.append(cr)
111107
else:

atr/storage/readers/releases.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import dataclasses
2222
import pathlib
2323

24+
import atr.attestable as attestable
2425
import atr.classify as classify
2526
import atr.db as db
2627
import atr.models.sql as sql
@@ -122,10 +123,11 @@ async def __successes_errors_warnings(
122123
self, release: sql.Release, latest_revision_number: str, info: types.PathInfo
123124
) -> None:
124125
match_ignore = await self.__read_as.checks.ignores_matcher(release.project_name)
126+
check_ids = await attestable.load_checks(release.project_name, release.version, latest_revision_number)
127+
attestable_checks = [a for a in await self.__data.check_result(id_in=check_ids).all()] if check_ids else []
125128

126129
cs = types.ChecksSubset(
127-
release=release,
128-
latest_revision_number=latest_revision_number,
130+
checks=attestable_checks,
129131
info=info,
130132
match_ignore=match_ignore,
131133
)
@@ -137,23 +139,13 @@ async def __successes_errors_warnings(
137139
await self.__blocker(cs)
138140

139141
async def __blocker(self, cs: types.ChecksSubset) -> None:
140-
blocker = await self.__data.check_result(
141-
release_name=cs.release.name,
142-
revision_number=cs.latest_revision_number,
143-
member_rel_path=None,
144-
status=sql.CheckResultStatus.BLOCKER,
145-
).all()
142+
blocker = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.BLOCKER]
146143
for result in blocker:
147144
if primary_rel_path := result.primary_rel_path:
148145
cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(result)
149146

150147
async def __errors(self, cs: types.ChecksSubset) -> None:
151-
errors = await self.__data.check_result(
152-
release_name=cs.release.name,
153-
revision_number=cs.latest_revision_number,
154-
member_rel_path=None,
155-
status=sql.CheckResultStatus.FAILURE,
156-
).all()
148+
errors = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.FAILURE]
157149
for error in errors:
158150
if cs.match_ignore(error):
159151
cs.info.ignored_errors.append(error)
@@ -162,24 +154,14 @@ async def __errors(self, cs: types.ChecksSubset) -> None:
162154
cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(error)
163155

164156
async def __successes(self, cs: types.ChecksSubset) -> None:
165-
successes = await self.__data.check_result(
166-
release_name=cs.release.name,
167-
revision_number=cs.latest_revision_number,
168-
member_rel_path=None,
169-
status=sql.CheckResultStatus.SUCCESS,
170-
).all()
157+
successes = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.SUCCESS]
171158
for success in successes:
172159
# Successes cannot be ignored
173160
if primary_rel_path := success.primary_rel_path:
174161
cs.info.successes.setdefault(pathlib.Path(primary_rel_path), []).append(success)
175162

176163
async def __warnings(self, cs: types.ChecksSubset) -> None:
177-
warnings = await self.__data.check_result(
178-
release_name=cs.release.name,
179-
revision_number=cs.latest_revision_number,
180-
member_rel_path=None,
181-
status=sql.CheckResultStatus.WARNING,
182-
).all()
164+
warnings = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.WARNING]
183165
for warning in warnings:
184166
if cs.match_ignore(warning):
185167
cs.info.ignored_warnings.append(warning)

atr/storage/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ class PathInfo(schema.Strict):
7575

7676
@dataclasses.dataclass
7777
class ChecksSubset:
78-
release: sql.Release
79-
latest_revision_number: str
78+
checks: list[sql.CheckResult]
8079
info: PathInfo
8180
match_ignore: Callable[[sql.CheckResult], bool]
8281

atr/storage/writers/revision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ async def create_and_manage( # noqa: C901
245245

246246
policy = release.release_policy or release.project.release_policy
247247

248-
await attestable.write(
248+
await attestable.write_files_data(
249249
project_name,
250250
version_name,
251251
new_revision.number,

0 commit comments

Comments
 (0)