Skip to content

Commit 38c46f2

Browse files
authored
[autorevert] Add support for job and test filtering in workflow restarts (#7595)
Adds support for more granular dispatches (test and job level filters, see pytorch/pytorch#168201) to autorevert. - Job/test filtering: When restarting workflows, only re-run specific failed jobs and tests instead of the entire workflow (uses workflow_dispatch inputs when supported) - Workflow resolver: Parses workflow YAML files to detect which inputs are available for filtering - New CLI subcommand: restart-workflow for manually triggering filtered workflow restarts - Fallback behavior: Workflows without input support (e.g., inductor) fall back to full workflow restart ---- ## Testing (links lead to runs per commit, see issued from my account as results of local testing) --- 0. manual dispatch testing: ``` python -m pytorch_auto_revert restart-workflow pull 4816fd912210162bea4cdf34f7a39d2909477549 --jobs "linux-jammy-py3.10-gcc11" --tests "distributed/test_functional_differentials" ``` runs: https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2F4816fd912210162bea4cdf34f7a39d2909477549 ---- 1. granular restart on trunk: ``` python -m pytorch_auto_revert autorevert-checker pull --hours 12 --as-of "2025-12-18 06:25" --hud-html ``` log: P2090410466 runs (filters by job and test): https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2F9fe21ba6d0583790c1857485ede8e17c89ab9afd https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2F3fc6a055e09174135cd839e723c4f0bdab9589b3 ---- 2. many restarts ``` python -m pytorch_auto_revert --dry-run autorevert-checker pull --hours 18 --hud-html --as-of "2025-12-19 22:00" ``` log P2090444414: runs: https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2Feafa4f67d2afdca606eebbca50571b0ba1ab922b https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2F96b3e7d78914f5db043e8b9ae3b3f72498abca4e https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2F7d49bd5060925055724d8976794cc1fd328066aa --- 3. workflow without input support (inductor) ``` python -m pytorch_auto_revert autorevert-checker inductor --hours 64 --hud-html --as-of "2025-12-18 17:31" ``` log: P2090462639 run: https://github.com/pytorch/pytorch/actions/workflows/inductor.yml?query=branch%3Atrunk%2Fa79fbc97065538f756418e6e3bde02a708e893b5
1 parent a01969d commit 38c46f2

File tree

11 files changed

+546
-13
lines changed

11 files changed

+546
-13
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Workflow Dispatch Filters
2+
3+
## Overview
4+
5+
PyTorch CI workflows (`trunk.yml`, `pull.yml`) support optional filtering inputs for `workflow_dispatch` events. This allows autorevert to re-run only specific failed jobs and tests instead of the full CI suite.
6+
7+
## Workflow Dispatch Inputs
8+
9+
| Input | Type | Description |
10+
|-------|------|-------------|
11+
| `jobs-to-include` | string | Space-separated list of job display names to run (empty = all jobs) |
12+
| `tests-to-include` | string | Space-separated list of test modules to run (empty = all tests) |
13+
14+
## Filter Value Derivation
15+
16+
Filter values are derived from Signal metadata during signal extraction.
17+
18+
### Job Names (`jobs-to-include`)
19+
20+
Derived from `Signal.job_base_name`. Job names follow two patterns:
21+
22+
| Pattern | Example | Filter Value |
23+
|---------|---------|--------------|
24+
| With ` / ` separator | `linux-jammy-cuda12.8-py3.10-gcc11 / test` | `linux-jammy-cuda12.8-py3.10-gcc11` |
25+
| Without separator | `inductor-build` | `inductor-build` |
26+
27+
**More examples:**
28+
- `linux-jammy-cuda12.8-py3.10-gcc11 / build``linux-jammy-cuda12.8-py3.10-gcc11`
29+
- `linux-jammy-py3.10-gcc11``linux-jammy-py3.10-gcc11`
30+
- `job-filter``job-filter`
31+
- `get-label-type``get-label-type`
32+
33+
### Test Modules (`tests-to-include`)
34+
35+
Derived from `Signal.test_module` (set during signal extraction from test file path, without `.py` extension).
36+
37+
**Examples:**
38+
- `test_torch`
39+
- `test_nn`
40+
- `distributed/elastic/multiprocessing/api_test`
41+
- `distributed/test_c10d`
42+
43+
## Input Format Rules
44+
45+
### `jobs-to-include`
46+
- Space-separated exact job **display names**
47+
- Case-sensitive, must match exactly
48+
- Examples:
49+
- Build/test jobs: `"linux-jammy-cuda12.8-py3.10-gcc11 linux-jammy-py3.10-gcc11"`
50+
- Standalone jobs: `"inductor-build job-filter get-label-type"`
51+
52+
### `tests-to-include`
53+
- Space-separated test module paths (no `.py` extension)
54+
- Module-level only (no `::TestClass::test_method`)
55+
- Example: `"test_torch test_nn distributed/elastic/multiprocessing/api_test"`
56+
57+
## Behavior Notes
58+
59+
1. **Empty inputs** = run all jobs/tests (normal CI behavior)
60+
2. **Filtered dispatch** = only matching jobs run; within those jobs, only matching tests run
61+
3. **Test sharding** preserved - distributed tests still run on distributed shards
62+
4. **TD compatibility** - TD is disabled for filtered test runs; only specified tests run
63+
5. **Workflow support detection** - autorevert parses workflow YAML to check if inputs are supported before dispatch

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .github_client_helper import GHClientFactory
3030
from .testers.autorevert_v2 import autorevert_v2
3131
from .testers.hud import render_hud_html_from_clickhouse, write_hud_html_from_cli
32-
from .testers.restart_checker import workflow_restart_checker
32+
from .testers.restart_checker import dispatch_workflow_restart, workflow_restart_checker
3333
from .utils import parse_datetime, RestartAction, RetryWithBackoff, RevertAction
3434

3535

@@ -356,6 +356,35 @@ def get_opts(default_config: DefaultConfig) -> argparse.Namespace:
356356
help="If no `--commit` specified, look back days for bulk query (default: 7)",
357357
)
358358

359+
# restart-workflow subcommand: dispatch a workflow restart with optional filters
360+
restart_workflow_parser = subparsers.add_parser(
361+
"restart-workflow",
362+
help="Dispatch a workflow restart with optional job/test filters",
363+
)
364+
restart_workflow_parser.add_argument(
365+
"workflow",
366+
help="Workflow name (e.g., trunk or trunk.yml)",
367+
)
368+
restart_workflow_parser.add_argument(
369+
"commit",
370+
help="Commit SHA to restart",
371+
)
372+
restart_workflow_parser.add_argument(
373+
"--jobs",
374+
default=None,
375+
help="Space-separated job display names to filter (e.g., 'linux-jammy-cuda12.8-py3.10-gcc11')",
376+
)
377+
restart_workflow_parser.add_argument(
378+
"--tests",
379+
default=None,
380+
help="Space-separated test module paths to filter (e.g., 'test_torch distributed/test_c10d')",
381+
)
382+
restart_workflow_parser.add_argument(
383+
"--repo-full-name",
384+
default=default_config.repo_full_name,
385+
help="Repository in owner/repo format (default: pytorch/pytorch)",
386+
)
387+
359388
# hud subcommand: generate local HTML report for signals/detections
360389
hud_parser = subparsers.add_parser(
361390
"hud", help="Render HUD HTML from a logged autorevert run state"
@@ -437,10 +466,13 @@ def _get(attr: str, default=None):
437466
log_level=_get("log_level", DEFAULT_LOG_LEVEL),
438467
dry_run=_get("dry_run", False),
439468
subcommand=_get("subcommand", "autorevert-checker"),
440-
# Subcommand: workflow-restart-checker
469+
# Subcommand: workflow-restart-checker and restart-workflow
441470
workflow=_get("workflow", None),
442471
commit=_get("commit", None),
443472
days=_get("days", DEFAULT_WORKFLOW_RESTART_DAYS),
473+
# Subcommand: restart-workflow (filter inputs)
474+
jobs=_get("jobs", None),
475+
tests=_get("tests", None),
444476
# Subcommand: hud
445477
timestamp=_get("timestamp", None),
446478
hud_html=_get("hud_html", None),
@@ -693,6 +725,15 @@ def main_run(
693725
workflow_restart_checker(
694726
config.workflow, commit=config.commit, days=config.days
695727
)
728+
elif config.subcommand == "restart-workflow":
729+
dispatch_workflow_restart(
730+
workflow=config.workflow,
731+
commit=config.commit,
732+
jobs=config.jobs,
733+
tests=config.tests,
734+
repo=config.repo_full_name,
735+
dry_run=config.dry_run,
736+
)
696737
elif config.subcommand == "hud":
697738
out_path: Optional[str] = (
698739
None if config.hud_html is HUD_HTML_NO_VALUE_FLAG else config.hud_html

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,18 @@ class AutorevertConfig:
8282
subcommand: str = "autorevert-checker"
8383

8484
# -------------------------------------------------------------------------
85-
# Subcommand: workflow-restart-checker
85+
# Subcommand: workflow-restart-checker and restart-workflow
8686
# -------------------------------------------------------------------------
8787
workflow: Optional[str] = None
8888
commit: Optional[str] = None
8989
days: int = DEFAULT_WORKFLOW_RESTART_DAYS
9090

91+
# -------------------------------------------------------------------------
92+
# Subcommand: restart-workflow (filter inputs)
93+
# -------------------------------------------------------------------------
94+
jobs: Optional[str] = None # Space-separated job display names
95+
tests: Optional[str] = None # Space-separated test module paths
96+
9197
# -------------------------------------------------------------------------
9298
# Subcommand: hud
9399
# -------------------------------------------------------------------------

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class Signal:
288288
- workflow_name: source workflow this signal is derived from
289289
- commits: newest → older list of SignalCommit objects for this signal
290290
- job_base_name: optional job base name for job-level signals (recorded when signal is created)
291+
- test_module: optional test module path for test-level signals (e.g., "test_torch" or "distributed/test_c10d")
291292
"""
292293

293294
def __init__(
@@ -296,13 +297,16 @@ def __init__(
296297
workflow_name: str,
297298
commits: List[SignalCommit],
298299
job_base_name: Optional[str] = None,
300+
test_module: Optional[str] = None,
299301
source: SignalSource = SignalSource.TEST,
300302
):
301303
self.key = key
302304
self.workflow_name = workflow_name
303305
# commits are ordered from newest to oldest
304306
self.commits = commits
305307
self.job_base_name = job_base_name
308+
# Test module path without .py extension (e.g., "test_torch", "distributed/test_c10d")
309+
self.test_module = test_module
306310
# Track the origin of the signal (test-track or job-track).
307311
self.source = source
308312

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from datetime import datetime, timedelta
88
from enum import Enum
9-
from typing import Dict, Iterable, List, Optional, Tuple, Union
9+
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple, Union
1010

1111
import github
1212

@@ -40,10 +40,33 @@ class SignalMetadata:
4040
workflow_name: str
4141
key: str
4242
job_base_name: Optional[str] = None
43+
test_module: Optional[str] = None
4344
wf_run_id: Optional[int] = None
4445
job_id: Optional[int] = None
4546

4647

48+
def _derive_job_filter(job_base_name: Optional[str]) -> Optional[str]:
49+
"""Extract job display name for jobs-to-include filter.
50+
51+
For jobs with " / " separator (e.g., "linux-jammy-cuda12.8 / test"),
52+
returns the prefix before the separator.
53+
54+
For jobs without separator (e.g., "linux-jammy-py3.10-gcc11", "inductor-build"),
55+
returns the full job_base_name as the display name.
56+
57+
Examples:
58+
"linux-jammy-cuda12.8 / test" -> "linux-jammy-cuda12.8"
59+
"linux-jammy-py3.10-gcc11" -> "linux-jammy-py3.10-gcc11"
60+
"inductor-build" -> "inductor-build"
61+
"job-filter" -> "job-filter"
62+
"""
63+
if not job_base_name:
64+
return None
65+
if " / " in job_base_name:
66+
return job_base_name.split(" / ")[0].strip()
67+
return job_base_name.strip()
68+
69+
4770
@dataclass(frozen=True)
4871
class ActionGroup:
4972
"""A coalesced action candidate built from one or more signals.
@@ -52,12 +75,16 @@ class ActionGroup:
5275
- commit_sha: target commit
5376
- workflow_target: workflow to restart (restart only); None/'' for revert
5477
- sources: contributing signals (workflow_name, key, outcome)
78+
- jobs_to_include: job display names to filter for restart (empty = all jobs)
79+
- tests_to_include: test module paths to filter for restart (empty = all tests)
5580
"""
5681

5782
type: str # 'revert' | 'restart'
5883
commit_sha: str
5984
workflow_target: str | None # restart-only; None/'' for revert
6085
sources: List[SignalMetadata]
86+
jobs_to_include: FrozenSet[str] = frozenset()
87+
tests_to_include: FrozenSet[str] = frozenset()
6188

6289

6390
class ActionLogger:
@@ -229,6 +256,7 @@ def group_actions(
229256
workflow_name=sig.workflow_name,
230257
key=sig.key,
231258
job_base_name=sig.job_base_name,
259+
test_module=sig.test_module,
232260
wf_run_id=wf_run_id,
233261
job_id=job_id,
234262
)
@@ -251,12 +279,18 @@ def group_actions(
251279
)
252280
)
253281
for (wf, sha), sources in restart_map.items():
282+
jobs = [_derive_job_filter(src.job_base_name) for src in sources]
283+
254284
groups.append(
255285
ActionGroup(
256286
type="restart",
257287
commit_sha=sha,
258288
workflow_target=wf,
259289
sources=sources,
290+
jobs_to_include=frozenset(j for j in jobs if j is not None),
291+
tests_to_include=frozenset(
292+
src.test_module for src in sources if src.test_module
293+
),
260294
)
261295
)
262296
return groups
@@ -279,6 +313,8 @@ def execute(self, group: ActionGroup, ctx: RunContext) -> bool:
279313
commit_sha=group.commit_sha,
280314
sources=group.sources,
281315
ctx=ctx,
316+
jobs_to_include=group.jobs_to_include,
317+
tests_to_include=group.tests_to_include,
282318
)
283319
return False
284320

@@ -330,6 +366,8 @@ def execute_restart(
330366
commit_sha: str,
331367
sources: List[SignalMetadata],
332368
ctx: RunContext,
369+
jobs_to_include: FrozenSet[str] = frozenset(),
370+
tests_to_include: FrozenSet[str] = frozenset(),
333371
) -> bool:
334372
"""Dispatch a workflow restart subject to pacing, cap, and backoff; always logs the event."""
335373
if ctx.restart_action == RestartAction.SKIP:
@@ -374,18 +412,32 @@ def execute_restart(
374412
)
375413
return False
376414

377-
notes = ""
415+
# Build notes incrementally
416+
notes_parts: list[str] = []
417+
if jobs_to_include:
418+
notes_parts.append(f"jobs_filter={','.join(jobs_to_include)}")
419+
if tests_to_include:
420+
notes_parts.append(f"tests_filter={','.join(tests_to_include)}")
421+
378422
ok = True
379423
if not dry_run:
380424
try:
381-
self._restart.restart_workflow(workflow_target, commit_sha)
425+
self._restart.restart_workflow(
426+
workflow_target,
427+
commit_sha,
428+
jobs_to_include=jobs_to_include,
429+
tests_to_include=tests_to_include,
430+
)
382431
except Exception as exc:
383432
ok = False
384-
notes = str(exc) or repr(exc)
433+
notes_parts.append(str(exc) or repr(exc))
385434
logging.exception(
386435
"[v2][action] restart for sha %s: exception while dispatching",
387436
commit_sha[:8],
388437
)
438+
439+
notes = "; ".join(notes_parts)
440+
389441
self._logger.insert_event(
390442
repo=ctx.repo_full_name,
391443
ts=ctx.ts,

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
133133
workflow_name=s.workflow_name,
134134
commits=new_commits,
135135
job_base_name=s.job_base_name,
136+
test_module=s.test_module,
136137
source=s.source,
137138
)
138139
)
@@ -218,6 +219,7 @@ def _inject_pending_workflow_events(
218219
workflow_name=s.workflow_name,
219220
commits=new_commits,
220221
job_base_name=s.job_base_name,
222+
test_module=s.test_module,
221223
source=s.source,
222224
)
223225
)
@@ -434,12 +436,17 @@ def _build_test_signals(
434436
)
435437

436438
if has_any_events:
439+
# Extract test module from test_id (format: "file.py::test_name")
440+
# Result: "file" or "path/to/file" without .py extension
441+
test_module = test_id.split("::")[0].replace(".py", "")
442+
437443
signals.append(
438444
Signal(
439445
key=test_id,
440446
workflow_name=wf_name,
441447
commits=commit_objs,
442448
job_base_name=str(job_base_name),
449+
test_module=test_module,
443450
source=SignalSource.TEST,
444451
)
445452
)

0 commit comments

Comments
 (0)