Skip to content
189 changes: 177 additions & 12 deletions .automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
RE_INDIVIDUAL_TEST = re.compile(
r"(?P<test_path>\S+\.py::(?P<cls>\w+)::(?P<method>\w+))"
)
RE_INDIV_PASSED = re.compile(
r"(?:test/)?(?P<file>\S+\.py)::(?P<cls>\w+)::(?P<method>\S+?)\s+PASSED"
)
RE_NEW_PROCESS_SUCCESS = re.compile(r"Test succeeded in new process")

CRASH_PATTERNS = [
(re.compile(r"Segmentation fault", re.IGNORECASE), "SEGFAULT"),
Expand Down Expand Up @@ -78,11 +82,20 @@ def classify_log_file(filename):


def parse_log_file(filepath):
"""Parse a single log file and return test file results and consistent failures."""
"""Parse a single log file and return test file results, consistent failures,
and flaky tests.

A flaky test is one that failed in its normal-process run but PASSED when the
CI harness re-ran it alone in a new subprocess (indicated by a PASSED line
for the specific test::class::method, followed by 'Test succeeded in new
process, continuing with the rest of the tests').
"""
results = {}
current_test = None
last_failed_test = None
consistent_failures = []
flaky_tests = []
last_passed_individual = None

with open(filepath, "r", errors="replace") as f:
for line in f:
Expand Down Expand Up @@ -111,7 +124,9 @@ def parse_log_file(filepath):
and "Aborted (core dumped)" not in line \
and "OutOfMemoryError" not in line \
and "bad_alloc" not in line \
and "stepcurrent" not in line:
and "stepcurrent" not in line \
and "PASSED" not in line \
and "new process" not in line:
continue

stripped = RE_TIMESTAMP.sub("", line).rstrip()
Expand Down Expand Up @@ -195,20 +210,76 @@ def parse_log_file(filepath):

m = RE_FAILED_CONSISTENTLY.search(stripped)
if m:
consistent_failures.append(m.group("test_path"))
shard_str = ""
if active and active in results:
info = results[active]
shard_str = f"{info['shard']}/{info['total']}"
consistent_failures.append((m.group("test_path"), shard_str))

# Detect individual PASSED lines for flaky-rerun tracking.
m = RE_INDIV_PASSED.search(stripped)
if m:
last_passed_individual = {
"file": m.group("file"),
"cls": m.group("cls"),
"method": m.group("method"),
"active": active,
}

# When we see 'Test succeeded in new process' after a PASSED
# individual test, that test was originally failing in the main
# process (CI only falls back to rerun-in-new-process for tests
# that crashed or failed) but passed on retry -> flaky.
if RE_NEW_PROCESS_SUCCESS.search(stripped) and last_passed_individual:
lp = last_passed_individual
lp_active = lp.get("active")
test_shard = ""
if lp_active and lp_active in results:
info = results[lp_active]
test_shard = f"{info['shard']}/{info['total']}"
flaky_tests.append({
"file": lp["file"],
"cls": lp["cls"],
"method": lp["method"],
"test_shard": test_shard,
})
last_passed_individual = None

if active and active in results:
for pattern, label in CRASH_PATTERNS:
if pattern.search(stripped):
if label not in results[active]["crashes"]:
results[active]["crashes"].append(label)

return results, consistent_failures
return results, consistent_failures, flaky_tests


def scan_logs(logs_dir):
"""Scan all log files and return all non-passing test file results."""
"""Scan all log files and return non-passing test file results plus a
test-level shard inventory.

Returns (all_failures, shard_inventory) where shard_inventory is a list
of dicts with one entry per (platform, test_config, job_shard, test_file)
combination seen in the logs, plus a sorted comma-separated list of the
test-level shards observed (e.g. "1/1" or "1/15,2/15,...,15/15"). This
lets downstream consumers look up the test-level shard for any XML-based
failure whose only shard info is the job-level shard."""
all_failures = []
all_flaky = []
shard_map = defaultdict(set)

# Pre-compute job-level shard totals per (platform, test_config) by
# counting how many log files belong to each group. Log files are
# 1-indexed (e.g. rocm1.txt..rocm6.txt for a 6-way sharded job), so
# the count == total shards for that CI job.
shard_totals = defaultdict(int)
for fname in os.listdir(logs_dir):
if not fname.endswith(".txt"):
continue
platform, test_config, shard_num = classify_log_file(fname)
if platform is None:
continue
shard_totals[(platform, test_config)] += 1

for fname in sorted(os.listdir(logs_dir)):
if not fname.endswith(".txt"):
Expand All @@ -218,8 +289,31 @@ def scan_logs(logs_dir):
if platform is None:
continue

job_total = shard_totals.get((platform, test_config), 0)
job_shard_str = f"{shard_num}/{job_total}" if job_total else str(shard_num)

filepath = os.path.join(logs_dir, fname)
results, consistent_failures = parse_log_file(filepath)
results, consistent_failures, flaky_tests = parse_log_file(filepath)

for ft in flaky_tests:
file_part = ft["file"].replace("test/", "").replace(".py", "")
all_flaky.append({
"log_file": fname,
"platform": platform,
"test_config": test_config,
"test_file": file_part,
"test_class": ft["cls"],
"test_name": ft["method"],
"job_shard": job_shard_str,
"test_shard": ft["test_shard"],
})

# Record every (test_file, test_shard) observed in this log file,
# including PASSED ones, so the inventory covers the full run.
for info in results.values():
shard_map[(platform, test_config, job_shard_str, info["test_file"])].add(
f"{info['shard']}/{info['total']}"
)

for key, info in results.items():
if info["status"] == "PASSED":
Expand Down Expand Up @@ -265,14 +359,15 @@ def scan_logs(logs_dir):
"platform": platform,
"test_config": test_config,
"test_file": info["test_file"],
"shard": f"{info['shard']}/{info['total']}",
"job_shard": job_shard_str,
"test_shard": f"{info['shard']}/{info['total']}",
"status": info["status"],
"category": "+".join(categories),
"reason": reason,
"exit_codes": ",".join(str(c) for c in info["exit_codes"]),
})

for test_path in consistent_failures:
for test_path, shard_str in consistent_failures:
parts = test_path.split("::")
file_part = parts[0].replace("test/", "").replace(".py", "")
test_class = parts[1] if len(parts) > 1 else ""
Expand All @@ -283,19 +378,47 @@ def scan_logs(logs_dir):
"platform": platform,
"test_config": test_config,
"test_file": file_part,
"shard": "",
"job_shard": job_shard_str,
"test_shard": shard_str,
"status": "FAILED_CONSISTENTLY",
"category": "CONSISTENT_FAILURE",
"reason": f"{test_class}::{test_name}" if test_class else "",
"exit_codes": "",
})

return all_failures
def _sort_shards(vals):
def key(v):
try:
a, b = v.split("/", 1)
return (int(b), int(a))
except (ValueError, AttributeError):
return (0, 0)
return sorted(vals, key=key)

shard_inventory = [
{
"platform": platform,
"test_config": test_config,
"job_shard": job_shard_str,
"test_file": test_file,
"test_shards": ",".join(_sort_shards(shards)),
}
for (platform, test_config, job_shard_str, test_file), shards in shard_map.items()
]
shard_inventory.sort(key=lambda r: (r["platform"], r["test_config"],
r["job_shard"], r["test_file"]))

all_flaky.sort(key=lambda r: (r["platform"], r["test_config"],
r["job_shard"], r["test_file"],
r["test_class"], r["test_name"]))

return all_failures, shard_inventory, all_flaky


def write_csv_report(failures, output_path):
fieldnames = [
"log_file", "platform", "test_config", "test_file", "shard",
"log_file", "platform", "test_config", "test_file",
"job_shard", "test_shard",
"status", "category", "reason", "exit_codes",
]
with open(output_path, "w", newline="") as f:
Expand All @@ -305,6 +428,46 @@ def write_csv_report(failures, output_path):
print(f"Log failure report: {output_path} ({len(failures)} entries)")


def write_shards_report(inventory, output_path):
fieldnames = ["platform", "test_config", "job_shard", "test_file", "test_shards"]
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(inventory)
print(f"Log shard inventory: {output_path} ({len(inventory)} entries)")


def write_flaky_report(flaky, output_path):
fieldnames = [
"log_file", "platform", "test_config", "test_file",
"test_class", "test_name", "job_shard", "test_shard",
]
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(flaky)
print(f"Flaky test report: {output_path} ({len(flaky)} entries)")


def _derive_sibling_path(output_path, new_prefix):
"""Given an output path like '.../log_failures_mi355.csv' and
new_prefix='log_shards', return '.../log_shards_mi355.csv'. Falls back to
appending '.{new_prefix}.csv' if the expected prefix isn't present."""
d, base = os.path.split(output_path)
if base.startswith("log_failures"):
return os.path.join(d, new_prefix + base[len("log_failures"):])
stem, ext = os.path.splitext(base)
return os.path.join(d, f"{stem}.{new_prefix}{ext or '.csv'}")


def _derive_shards_path(output_path):
return _derive_sibling_path(output_path, "log_shards")


def _derive_flaky_path(output_path):
return _derive_sibling_path(output_path, "flaky_tests")


def print_summary(failures):
if not failures:
print("No log-based failures detected.")
Expand Down Expand Up @@ -343,9 +506,11 @@ def main():
)
args = parser.parse_args()

failures = scan_logs(args.logs_dir)
failures, shard_inventory, flaky_tests = scan_logs(args.logs_dir)
print_summary(failures)
write_csv_report(failures, args.output)
write_shards_report(shard_inventory, _derive_shards_path(args.output))
write_flaky_report(flaky_tests, _derive_flaky_path(args.output))
return 0 if not failures else 1


Expand Down
Loading