Skip to content

Commit da2ca29

Browse files
Parallelize compute component comparisons (#1046)
1 parent 5ce7492 commit da2ca29

File tree

5 files changed

+370
-42
lines changed

5 files changed

+370
-42
lines changed

rollouts/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
)
1515

1616
NEW_TA_TASKS = Feature("new_ta_tasks")
17+
18+
PARALLEL_COMPONENT_COMPARISON = Feature("parallel_component_comparison")

services/comparison_utils.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sentry_sdk
2+
from shared.reports.readonly import ReadOnlyReport
3+
4+
from database.models import CompareCommit
5+
from services.comparison import ComparisonContext, ComparisonProxy
6+
from services.comparison.types import Comparison, FullCommit
7+
from services.report import ReportService
8+
9+
10+
@sentry_sdk.trace
11+
def get_comparison_proxy(
12+
comparison: CompareCommit,
13+
report_service: ReportService,
14+
):
15+
compare_commit = comparison.compare_commit
16+
base_commit = comparison.base_commit
17+
18+
base_report = report_service.get_existing_report_for_commit(
19+
base_commit, report_class=ReadOnlyReport
20+
)
21+
compare_report = report_service.get_existing_report_for_commit(
22+
compare_commit, report_class=ReadOnlyReport
23+
)
24+
# No access to the PR so we have to assume the base commit did not need
25+
# to be adjusted.
26+
patch_coverage_base_commitid = base_commit.commitid
27+
return ComparisonProxy(
28+
Comparison(
29+
head=FullCommit(commit=compare_commit, report=compare_report),
30+
project_coverage_base=FullCommit(commit=base_commit, report=base_report),
31+
patch_coverage_base_commitid=patch_coverage_base_commitid,
32+
enriched_pull=None,
33+
),
34+
context=ComparisonContext(
35+
gh_app_installation_name=report_service.gh_app_installation_name
36+
),
37+
)

tasks/compute_comparison.py

+34-42
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import sentry_sdk
55
from asgiref.sync import async_to_sync
6+
from celery import group
67
from shared.celery_config import compute_comparison_task_name
78
from shared.components import Component
89
from shared.helpers.flag import Flag
9-
from shared.reports.readonly import ReadOnlyReport
1010
from shared.torngit.exceptions import TorngitRateLimitError
1111
from shared.yaml import UserYaml
1212

@@ -16,12 +16,14 @@
1616
from database.models.reports import ReportLevelTotals, RepositoryFlag
1717
from helpers.comparison import minimal_totals
1818
from helpers.github_installation import get_installation_name_for_owner_for_task
19+
from rollouts import PARALLEL_COMPONENT_COMPARISON
1920
from services.archive import ArchiveService
20-
from services.comparison import ComparisonContext, ComparisonProxy, FilteredComparison
21-
from services.comparison.types import Comparison, FullCommit
21+
from services.comparison import ComparisonProxy, FilteredComparison
22+
from services.comparison_utils import get_comparison_proxy
2223
from services.report import ReportService
2324
from services.yaml import get_current_yaml, get_repo_yaml
2425
from tasks.base import BaseCodecovTask
26+
from tasks.compute_component_comparison import compute_component_comparison_task
2527

2628
log = logging.getLogger(__name__)
2729

@@ -54,10 +56,11 @@ def run_impl(
5456
installation_name_to_use = get_installation_name_for_owner_for_task(
5557
self.name, repo.owner
5658
)
57-
58-
comparison_proxy = self.get_comparison_proxy(
59-
comparison, current_yaml, installation_name_to_use
59+
report_service = ReportService(
60+
current_yaml, gh_app_installation_name=installation_name_to_use
6061
)
62+
63+
comparison_proxy = get_comparison_proxy(comparison, report_service)
6164
if not comparison_proxy.has_head_report():
6265
comparison.error = CompareCommitError.missing_head_report.value
6366
comparison.state = CompareCommitState.error.value
@@ -241,10 +244,31 @@ def compute_component_comparisons(
241244
component_count=len(components),
242245
),
243246
)
244-
for component in components:
245-
self.compute_component_comparison(
246-
db_session, comparison, comparison_proxy, component
247-
)
247+
if PARALLEL_COMPONENT_COMPARISON.check_value(
248+
comparison.compare_commit.repoid, default=False
249+
):
250+
self.parallel_compute_component_comparison(comparison.id, components)
251+
else:
252+
for component in components:
253+
self.compute_component_comparison(
254+
db_session, comparison, comparison_proxy, component
255+
)
256+
257+
@sentry_sdk.trace
258+
def parallel_compute_component_comparison(
259+
self,
260+
comparison_id: int,
261+
components: list[Component],
262+
):
263+
task_group = group(
264+
[
265+
compute_component_comparison_task.s(
266+
comparison_id, component.component_id
267+
)
268+
for component in components
269+
]
270+
)
271+
task_group.apply_async()
248272

249273
def compute_component_comparison(
250274
self,
@@ -288,38 +312,6 @@ def compute_component_comparison(
288312
db_session.add(component_comparison)
289313
db_session.flush()
290314

291-
@sentry_sdk.trace
292-
def get_comparison_proxy(
293-
self, comparison, current_yaml, installation_name_to_use: str | None = None
294-
):
295-
compare_commit = comparison.compare_commit
296-
base_commit = comparison.base_commit
297-
report_service = ReportService(
298-
current_yaml, gh_app_installation_name=installation_name_to_use
299-
)
300-
base_report = report_service.get_existing_report_for_commit(
301-
base_commit, report_class=ReadOnlyReport
302-
)
303-
compare_report = report_service.get_existing_report_for_commit(
304-
compare_commit, report_class=ReadOnlyReport
305-
)
306-
# No access to the PR so we have to assume the base commit did not need
307-
# to be adjusted.
308-
patch_coverage_base_commitid = base_commit.commitid
309-
return ComparisonProxy(
310-
Comparison(
311-
head=FullCommit(commit=compare_commit, report=compare_report),
312-
project_coverage_base=FullCommit(
313-
commit=base_commit, report=base_report
314-
),
315-
patch_coverage_base_commitid=patch_coverage_base_commitid,
316-
enriched_pull=None,
317-
),
318-
context=ComparisonContext(
319-
gh_app_installation_name=installation_name_to_use
320-
),
321-
)
322-
323315
@sentry_sdk.trace
324316
def store_results(self, comparison, impacted_files):
325317
repository = comparison.compare_commit.repository

tasks/compute_component_comparison.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from asgiref.sync import async_to_sync
2+
from shared.components import Component
3+
from shared.yaml import UserYaml
4+
from sqlalchemy.orm import Session
5+
6+
from app import celery_app
7+
from database.models import CompareCommit, CompareComponent
8+
from helpers.github_installation import get_installation_name_for_owner_for_task
9+
from services.comparison import ComparisonProxy, FilteredComparison
10+
from services.comparison_utils import get_comparison_proxy
11+
from services.report import ReportService
12+
from services.yaml import get_current_yaml, get_repo_yaml
13+
from tasks.base import BaseCodecovTask
14+
15+
16+
def compute_component_comparison(
17+
db_session: Session,
18+
comparison: CompareCommit,
19+
comparison_proxy: ComparisonProxy,
20+
component: Component,
21+
):
22+
component_comparison = (
23+
db_session.query(CompareComponent)
24+
.filter_by(
25+
commit_comparison_id=comparison.id,
26+
component_id=component.component_id,
27+
)
28+
.first()
29+
)
30+
if not component_comparison:
31+
component_comparison = CompareComponent(
32+
commit_comparison=comparison,
33+
component_id=component.component_id,
34+
)
35+
36+
# filter comparison by component
37+
head_report = comparison_proxy.comparison.head.report
38+
flags = component.get_matching_flags(head_report.flags.keys())
39+
filtered: FilteredComparison = comparison_proxy.get_filtered_comparison(
40+
flags=flags, path_patterns=component.paths
41+
)
42+
43+
# component comparison totals
44+
component_comparison.base_totals = (
45+
filtered.project_coverage_base.report.totals.asdict()
46+
)
47+
component_comparison.head_totals = filtered.head.report.totals.asdict()
48+
diff = comparison_proxy.get_diff()
49+
if diff:
50+
patch_totals = filtered.head.report.apply_diff(diff)
51+
if patch_totals:
52+
component_comparison.patch_totals = patch_totals.asdict()
53+
54+
db_session.add(component_comparison)
55+
db_session.flush()
56+
57+
58+
class ComputeComponentComparisonTask(BaseCodecovTask):
59+
def run_impl(
60+
self,
61+
db_session: Session,
62+
comparison_id: int,
63+
component_id: str,
64+
*args,
65+
**kwargs,
66+
):
67+
comparison: CompareCommit = db_session.query(CompareCommit).get(comparison_id)
68+
repo = comparison.compare_commit.repository
69+
current_yaml = get_repo_yaml(repo)
70+
installation_name_to_use = get_installation_name_for_owner_for_task(
71+
self.name, repo.owner
72+
)
73+
report_service = ReportService(
74+
current_yaml, gh_app_installation_name=installation_name_to_use
75+
)
76+
comparison_proxy = get_comparison_proxy(comparison, report_service)
77+
head_commit = comparison_proxy.comparison.head.commit
78+
79+
yaml: UserYaml = async_to_sync(get_current_yaml)(
80+
head_commit, comparison_proxy.repository_service
81+
)
82+
83+
components = yaml.get_components()
84+
85+
component_dict = {c.component_id: c for c in components}
86+
compute_component_comparison(
87+
db_session, comparison, comparison_proxy, component_dict[component_id]
88+
)
89+
90+
91+
RegisteredComputeComponentComparisonTask = celery_app.register_task(
92+
ComputeComponentComparisonTask()
93+
)
94+
compute_component_comparison_task = celery_app.tasks[
95+
RegisteredComputeComponentComparisonTask.name
96+
]

0 commit comments

Comments
 (0)