Skip to content

Commit 665b8bb

Browse files
committed
chore(scripts): add branch restack helper
1 parent 1f06142 commit 665b8bb

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed

scripts/restack_branch_refs.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import argparse
6+
import subprocess
7+
import sys
8+
from collections import defaultdict
9+
from dataclasses import dataclass
10+
11+
12+
def git(*args: str, check: bool = True) -> str:
13+
result = subprocess.run(
14+
["git", *args],
15+
check=False,
16+
capture_output=True,
17+
text=True,
18+
)
19+
if check and result.returncode != 0:
20+
stderr = result.stderr.strip()
21+
raise SystemExit(stderr or f"git {' '.join(args)} failed")
22+
return result.stdout.strip()
23+
24+
25+
def git_ok(*args: str) -> bool:
26+
result = subprocess.run(
27+
["git", *args],
28+
check=False,
29+
capture_output=True,
30+
text=True,
31+
)
32+
return result.returncode == 0
33+
34+
35+
@dataclass(frozen=True)
36+
class BranchMove:
37+
branch: str
38+
old_sha: str
39+
new_sha: str
40+
distance_from_tip: int
41+
42+
43+
def parse_args() -> argparse.Namespace:
44+
parser = argparse.ArgumentParser(
45+
description=(
46+
"Restack local branch refs after rebasing the tip branch of a linear stack. "
47+
"By default this only prints the inferred mapping."
48+
)
49+
)
50+
parser.add_argument(
51+
"--base",
52+
default="main",
53+
help="Base branch name. Defaults to 'main'.",
54+
)
55+
parser.add_argument(
56+
"--tip",
57+
default=None,
58+
help="Tip branch name. Defaults to the currently checked out branch.",
59+
)
60+
parser.add_argument(
61+
"--old-tip",
62+
default=None,
63+
help="Old pre-rebase tip revision. Defaults to '<tip>@{1}'.",
64+
)
65+
parser.add_argument(
66+
"--remote",
67+
default="origin",
68+
help="Remote used by --push. Defaults to 'origin'.",
69+
)
70+
parser.add_argument(
71+
"--apply",
72+
action="store_true",
73+
help="Move local branch refs instead of only printing the plan.",
74+
)
75+
parser.add_argument(
76+
"--push",
77+
action="store_true",
78+
help="Force-push the updated branch refs after applying them.",
79+
)
80+
return parser.parse_args()
81+
82+
83+
def current_branch() -> str:
84+
branch = git("symbolic-ref", "--quiet", "--short", "HEAD")
85+
if not branch:
86+
raise SystemExit("Detached HEAD is not supported for auto-detecting the tip branch.")
87+
return branch
88+
89+
90+
def first_parent_chain(tip_sha: str) -> list[str]:
91+
chain = git("rev-list", "--first-parent", tip_sha).splitlines()
92+
if not chain:
93+
raise SystemExit(f"Could not read first-parent history for {tip_sha}.")
94+
return chain
95+
96+
97+
def branch_refs() -> list[tuple[str, str]]:
98+
rows = git("for-each-ref", "refs/heads", "--format=%(refname:short)\t%(objectname)").splitlines()
99+
refs: list[tuple[str, str]] = []
100+
for row in rows:
101+
branch, sha = row.split("\t", 1)
102+
refs.append((branch, sha))
103+
return refs
104+
105+
106+
def infer_moves(base: str, tip: str, old_tip_rev: str) -> tuple[list[BranchMove], str, str]:
107+
new_tip_sha = git("rev-parse", tip)
108+
old_tip_sha = git("rev-parse", old_tip_rev)
109+
old_chain = first_parent_chain(old_tip_sha)
110+
old_chain_set = set(old_chain)
111+
112+
candidates: list[tuple[str, str, int]] = []
113+
for branch, sha in branch_refs():
114+
if branch in {base, tip}:
115+
continue
116+
if sha not in old_chain_set:
117+
continue
118+
if not git_ok("merge-base", "--is-ancestor", sha, old_tip_sha):
119+
continue
120+
distance = int(git("rev-list", "--first-parent", "--count", f"{sha}..{old_tip_sha}"))
121+
candidates.append((branch, sha, distance))
122+
123+
if not candidates:
124+
raise SystemExit(
125+
f"No stacked branch refs found on the old {tip} first-parent chain ({old_tip_sha[:12]})."
126+
)
127+
128+
groups: dict[str, list[tuple[str, int]]] = defaultdict(list)
129+
for branch, sha, distance in candidates:
130+
groups[sha].append((branch, distance))
131+
132+
ordered_groups = sorted(
133+
groups.items(),
134+
key=lambda item: item[1][0][1],
135+
reverse=True,
136+
)
137+
138+
max_distance = ordered_groups[0][1][0][1]
139+
new_chain = first_parent_chain(new_tip_sha)
140+
if max_distance >= len(new_chain):
141+
raise SystemExit(
142+
f"New tip {new_tip_sha[:12]} does not have enough first-parent history for distance {max_distance}."
143+
)
144+
145+
moves: list[BranchMove] = []
146+
previous_distance: int | None = None
147+
for old_sha, branch_entries in ordered_groups:
148+
distance = branch_entries[0][1]
149+
if previous_distance is not None and distance >= previous_distance:
150+
raise SystemExit("Inferred branch stack is not strictly ordered by ancestry.")
151+
previous_distance = distance
152+
new_sha = new_chain[distance]
153+
for branch, _ in sorted(branch_entries):
154+
if branch == tip and old_sha == old_tip_sha:
155+
continue
156+
if git("rev-parse", branch) == new_sha:
157+
continue
158+
moves.append(
159+
BranchMove(
160+
branch=branch,
161+
old_sha=old_sha,
162+
new_sha=new_sha,
163+
distance_from_tip=distance,
164+
)
165+
)
166+
167+
if not moves:
168+
raise SystemExit("All inferred branch refs already point at the rebased commits.")
169+
170+
return moves, old_tip_sha, new_tip_sha
171+
172+
173+
def print_plan(moves: list[BranchMove], old_tip_sha: str, new_tip_sha: str, tip: str) -> None:
174+
print(f"Old {tip} tip: {old_tip_sha[:12]}")
175+
print(f"New {tip} tip: {new_tip_sha[:12]}")
176+
print("Planned branch updates:")
177+
for move in moves:
178+
print(
179+
f" {move.branch}: {move.old_sha[:12]} -> {move.new_sha[:12]}"
180+
f" (distance {move.distance_from_tip} from {tip})"
181+
)
182+
183+
184+
def apply_moves(moves: list[BranchMove]) -> None:
185+
for move in moves:
186+
git("branch", "-f", move.branch, move.new_sha)
187+
188+
189+
def push_moves(moves: list[BranchMove], remote: str) -> None:
190+
branches = [move.branch for move in moves]
191+
git("push", "--force-with-lease", remote, *branches)
192+
193+
194+
def main() -> int:
195+
args = parse_args()
196+
tip = args.tip or current_branch()
197+
old_tip_rev = args.old_tip or f"{tip}@{{1}}"
198+
moves, old_tip_sha, new_tip_sha = infer_moves(args.base, tip, old_tip_rev)
199+
200+
print_plan(moves, old_tip_sha, new_tip_sha, tip)
201+
if not args.apply and not args.push:
202+
print("\nDry run only. Re-run with --apply to move refs, and add --push to update the remote.")
203+
return 0
204+
205+
apply_moves(moves)
206+
print("\nUpdated local refs.")
207+
208+
if args.push:
209+
push_moves(moves, args.remote)
210+
print(f"Pushed updated refs to {args.remote}.")
211+
212+
return 0
213+
214+
215+
if __name__ == "__main__":
216+
sys.exit(main())

0 commit comments

Comments
 (0)