Skip to content

Commit 4f5fcb1

Browse files
authored
Merge pull request #6 from entireio/soph/branch-optimizations
Trunk-aware batched bootstrap planning
2 parents 396abb4 + dd85566 commit 4f5fcb1

9 files changed

Lines changed: 536 additions & 60 deletions

File tree

internal/gitproto/fetch.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ func (s *RefService) FetchPack(
8888
}
8989

9090
// FetchCommitGraph fetches only the commit graph (tree:0 filter) for a ref.
91-
// Requires v2 with filter support.
91+
// Requires v2 with filter support. Optional haves let the source skip commits
92+
// already reachable from those hashes, which is valuable when planning later
93+
// branches that share history with an already-planned trunk.
9294
func (s *RefService) FetchCommitGraph(
9395
ctx context.Context,
9496
store storer.Storer,
9597
conn *Conn,
9698
ref DesiredRef,
99+
haves []plumbing.Hash,
97100
) error {
98101
if s.Protocol != "v2" {
99102
return errors.New("commit graph fetch requires protocol v2")
@@ -102,13 +105,19 @@ func (s *RefService) FetchCommitGraph(
102105
return errors.New("source does not advertise fetch filter support")
103106
}
104107

105-
cmdArgs := []string{
108+
sortedHaves := SortedUniqueHashes(haves)
109+
cmdArgs := make([]string, 0, 4+len(sortedHaves))
110+
cmdArgs = append(cmdArgs,
106111
"ofs-delta",
107112
"no-progress",
108113
"filter tree:0",
109-
"want " + ref.SourceHash.String(),
110-
"done",
114+
"want "+ref.SourceHash.String(),
115+
)
116+
for _, h := range sortedHaves {
117+
cmdArgs = append(cmdArgs, "have "+h.String())
111118
}
119+
cmdArgs = append(cmdArgs, "done")
120+
112121
body, err := EncodeCommand("fetch", s.V2Caps.RequestCapabilities(), cmdArgs)
113122
if err != nil {
114123
return err

internal/gitproto/fetch_test.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"github.com/stretchr/testify/require"
2121
)
2222

23+
const refsHeadsMain = "refs/heads/main"
24+
2325
func TestCapabilities(t *testing.T) {
2426
// v2 protocol
2527
v2Caps := &V2Capabilities{
@@ -221,28 +223,52 @@ func TestDecodeV2LSRefs(t *testing.T) {
221223
FormatPktLine("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb refs/heads/dev\n") +
222224
"0000" // flush
223225

224-
refs, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
226+
refs, head, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
225227
if err != nil {
226228
t.Fatalf("decodeV2LSRefs: %v", err)
227229
}
228230
if len(refs) != 2 {
229231
t.Fatalf("expected 2 refs, got %d", len(refs))
230232
}
231-
if refs[0].Name().String() != "refs/heads/main" {
232-
t.Errorf("refs[0].Name() = %q, want %q", refs[0].Name(), "refs/heads/main")
233+
if refs[0].Name().String() != refsHeadsMain {
234+
t.Errorf("refs[0].Name() = %q, want %q", refs[0].Name(), refsHeadsMain)
233235
}
234236
if refs[0].Hash().String() != "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" {
235237
t.Errorf("refs[0].Hash() = %q", refs[0].Hash())
236238
}
237239
if refs[1].Name().String() != "refs/heads/dev" {
238240
t.Errorf("refs[1].Name() = %q, want %q", refs[1].Name(), "refs/heads/dev")
239241
}
242+
if head != "" {
243+
t.Errorf("head target = %q, want empty (no HEAD advertised)", head)
244+
}
245+
}
246+
247+
func TestDecodeV2LSRefsHeadSymref(t *testing.T) {
248+
wire := "" +
249+
FormatPktLine("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa HEAD symref-target:refs/heads/main\n") +
250+
FormatPktLine("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa refs/heads/main\n") +
251+
"0000"
252+
refs, head, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
253+
if err != nil {
254+
t.Fatalf("decodeV2LSRefs: %v", err)
255+
}
256+
// HEAD is consumed for its symref target and not emitted as a ref.
257+
if len(refs) != 1 {
258+
t.Fatalf("expected 1 ref (HEAD filtered), got %d", len(refs))
259+
}
260+
if refs[0].Name().String() != refsHeadsMain {
261+
t.Errorf("refs[0].Name() = %q, want refs/heads/main", refs[0].Name())
262+
}
263+
if head.String() != refsHeadsMain {
264+
t.Errorf("head target = %q, want refs/heads/main", head)
265+
}
240266
}
241267

242268
func TestDecodeV2LSRefsMalformed(t *testing.T) {
243269
// Line with only one field (no refname).
244270
wire := FormatPktLine("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n") + "0000"
245-
_, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
271+
_, _, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
246272
if err == nil {
247273
t.Fatal("expected error for malformed ls-refs line, got nil")
248274
}
@@ -251,7 +277,7 @@ func TestDecodeV2LSRefsMalformed(t *testing.T) {
251277
func TestDecodeV2LSRefsEmpty(t *testing.T) {
252278
// Empty response (just flush).
253279
wire := "0000"
254-
refs, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
280+
refs, _, err := decodeV2LSRefs(bytes.NewReader([]byte(wire)))
255281
if err != nil {
256282
t.Fatalf("decodeV2LSRefs: %v", err)
257283
}
@@ -287,7 +313,7 @@ func TestFetchPackUnsupportedProtocol(t *testing.T) {
287313

288314
func TestFetchCommitGraphRequiresV2(t *testing.T) {
289315
rs := &RefService{Protocol: "v1"}
290-
err := rs.FetchCommitGraph(t.Context(), nil, nil, DesiredRef{})
316+
err := rs.FetchCommitGraph(t.Context(), nil, nil, DesiredRef{}, nil)
291317
if err == nil {
292318
t.Fatal("expected error for non-v2 protocol")
293319
}
@@ -300,7 +326,7 @@ func TestFetchCommitGraphRequiresFilter(t *testing.T) {
300326
},
301327
}
302328
rs := &RefService{Protocol: "v2", V2Caps: caps}
303-
err := rs.FetchCommitGraph(t.Context(), nil, nil, DesiredRef{})
329+
err := rs.FetchCommitGraph(t.Context(), nil, nil, DesiredRef{}, nil)
304330
if err == nil {
305331
t.Fatal("expected error when filter not supported")
306332
}

internal/gitproto/refs.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/go-git/go-git/v6/plumbing"
1313
"github.com/go-git/go-git/v6/plumbing/format/pktline"
1414
"github.com/go-git/go-git/v6/plumbing/protocol/packp"
15+
"github.com/go-git/go-git/v6/plumbing/protocol/packp/capability"
1516
"github.com/go-git/go-git/v6/plumbing/transport"
1617
)
1718

@@ -21,6 +22,10 @@ type RefService struct {
2122
Protocol string // "v1" or "v2"
2223
V1Adv *packp.AdvRefs
2324
V2Caps *V2Capabilities
25+
// HeadTarget is the branch that HEAD points to on the source, when
26+
// advertised as a symref. Empty for detached HEAD or when the source
27+
// does not advertise symref information.
28+
HeadTarget plumbing.ReferenceName
2429
// Verbose, when true, streams source-side sideband progress ("Counting
2530
// objects", "Compressing objects", ...) to stderr and asks the source
2631
// upload-pack to emit progress by not sending the no-progress option.
@@ -36,7 +41,7 @@ func ListSourceRefs(ctx context.Context, conn *Conn, protocolMode string, refPre
3641
if err != nil {
3742
return nil, nil, err
3843
}
39-
return refs, &RefService{Protocol: "v1", V1Adv: adv}, nil
44+
return refs, &RefService{Protocol: "v1", V1Adv: adv, HeadTarget: headTargetFromAdv(adv)}, nil
4045

4146
case "auto", "v2":
4247
data, err := RequestInfoRefs(ctx, conn, transport.UploadPackService, "version=2")
@@ -47,11 +52,11 @@ func ListSourceRefs(ctx context.Context, conn *Conn, protocolMode string, refPre
4752
if !caps.Supports("ls-refs") || !caps.Supports("fetch") {
4853
return nil, nil, errors.New("source does not advertise required protocol v2 commands")
4954
}
50-
refs, err := listSourceRefsV2(ctx, conn, caps, refPrefixes)
55+
refs, headTarget, err := listSourceRefsV2(ctx, conn, caps, refPrefixes)
5156
if err != nil {
5257
return nil, nil, err
5358
}
54-
return refs, &RefService{Protocol: "v2", V2Caps: caps}, nil
59+
return refs, &RefService{Protocol: "v2", V2Caps: caps, HeadTarget: headTarget}, nil
5560
}
5661
if protocolMode == "v2" {
5762
return nil, nil, errors.New("source did not negotiate protocol v2")
@@ -65,7 +70,7 @@ func ListSourceRefs(ctx context.Context, conn *Conn, protocolMode string, refPre
6570
if err != nil {
6671
return nil, nil, err
6772
}
68-
return refs, &RefService{Protocol: "v1", V1Adv: adv}, nil
73+
return refs, &RefService{Protocol: "v1", V1Adv: adv, HeadTarget: headTargetFromAdv(adv)}, nil
6974

7075
default:
7176
return nil, nil, fmt.Errorf("unsupported protocol mode %q", protocolMode)
@@ -136,46 +141,81 @@ func listSourceRefsV1(ctx context.Context, conn *Conn) (*packp.AdvRefs, []*plumb
136141
return adv, refs, nil
137142
}
138143

139-
func listSourceRefsV2(ctx context.Context, conn *Conn, caps *V2Capabilities, prefixes []string) ([]*plumbing.Reference, error) {
140-
args := []string{"peel"}
144+
func listSourceRefsV2(ctx context.Context, conn *Conn, caps *V2Capabilities, prefixes []string) ([]*plumbing.Reference, plumbing.ReferenceName, error) {
145+
// Always include "HEAD" so the server returns the symref-target attribute
146+
// for HEAD. Without this, callers that pass only "refs/heads/" or
147+
// "refs/tags/" prefixes filter HEAD out of the response and lose the
148+
// default-branch hint that bootstrap planning uses as a trunk cutoff.
149+
args := []string{"peel", "symrefs", "ref-prefix HEAD"}
141150
for _, prefix := range prefixes {
142151
args = append(args, "ref-prefix "+prefix)
143152
}
144153
body, err := EncodeCommand("ls-refs", caps.RequestCapabilities(), args)
145154
if err != nil {
146-
return nil, err
155+
return nil, "", err
147156
}
148157
data, err := PostRPC(ctx, conn, transport.UploadPackService, body, true, "upload-pack ls-refs")
149158
if err != nil {
150-
return nil, err
159+
return nil, "", err
151160
}
152161
return decodeV2LSRefs(bytes.NewReader(data))
153162
}
154163

155-
func decodeV2LSRefs(r *bytes.Reader) ([]*plumbing.Reference, error) {
164+
func decodeV2LSRefs(r *bytes.Reader) ([]*plumbing.Reference, plumbing.ReferenceName, error) {
156165
reader := NewPacketReader(r)
157166
var refs []*plumbing.Reference
167+
var headTarget plumbing.ReferenceName
158168
for {
159169
kind, payload, err := reader.ReadPacket()
160170
if err != nil {
161-
return nil, err
171+
return nil, "", err
162172
}
163173
if kind == PacketFlush {
164-
return refs, nil
174+
return refs, headTarget, nil
165175
}
166176
if kind != PacketData {
167-
return nil, fmt.Errorf("unexpected packet type %v in ls-refs response", kind)
177+
return nil, "", fmt.Errorf("unexpected packet type %v in ls-refs response", kind)
168178
}
169179
fields := strings.Fields(strings.TrimSpace(string(payload)))
170180
if len(fields) < 2 {
171-
return nil, fmt.Errorf("malformed ls-refs response line %q", payload)
181+
return nil, "", fmt.Errorf("malformed ls-refs response line %q", payload)
172182
}
173183
hash := plumbing.NewHash(fields[0])
174184
name := plumbing.ReferenceName(fields[1])
185+
if name == plumbing.HEAD {
186+
// HEAD is surfaced via headTarget only; not appended to the ref
187+
// slice because it is a symbolic ref, matching v1 behavior where
188+
// symrefs are filtered out by downstream RefHashMap.
189+
for _, attr := range fields[2:] {
190+
if target, ok := strings.CutPrefix(attr, "symref-target:"); ok {
191+
headTarget = plumbing.ReferenceName(target)
192+
break
193+
}
194+
}
195+
continue
196+
}
175197
refs = append(refs, plumbing.NewHashReference(name, hash))
176198
}
177199
}
178200

201+
// headTargetFromAdv extracts the branch HEAD points to from v1 advertised
202+
// capabilities. Returns empty when HEAD is detached or no symref is advertised.
203+
func headTargetFromAdv(adv *packp.AdvRefs) plumbing.ReferenceName {
204+
if adv == nil || adv.Capabilities == nil {
205+
return ""
206+
}
207+
for _, value := range adv.Capabilities.Get(capability.SymRef) {
208+
parts := strings.SplitN(value, ":", 2)
209+
if len(parts) != 2 {
210+
continue
211+
}
212+
if plumbing.ReferenceName(parts[0]) == plumbing.HEAD {
213+
return plumbing.ReferenceName(parts[1])
214+
}
215+
}
216+
return ""
217+
}
218+
179219
func decodeV1AdvRefs(data []byte) (*packp.AdvRefs, error) {
180220
rd := bufio.NewReader(bytes.NewReader(data))
181221
consumedSmartHeader, err := consumeSmartInfoRefsHeader(rd)

internal/gitproto/refs_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,30 @@ func TestRefHashMap(t *testing.T) {
4343
}
4444
}
4545

46+
func TestHeadTargetFromAdv(t *testing.T) {
47+
// nil returns empty.
48+
if got := headTargetFromAdv(nil); got != "" {
49+
t.Errorf("headTargetFromAdv(nil) = %q, want empty", got)
50+
}
51+
52+
adv := packp.NewAdvRefs()
53+
if err := adv.Capabilities.Add(capability.SymRef, "HEAD:refs/heads/main"); err != nil {
54+
t.Fatalf("Capabilities.Add: %v", err)
55+
}
56+
if got := headTargetFromAdv(adv); got.String() != "refs/heads/main" {
57+
t.Errorf("headTargetFromAdv = %q, want refs/heads/main", got)
58+
}
59+
60+
// Symref pointing at something other than HEAD is ignored.
61+
adv = packp.NewAdvRefs()
62+
if err := adv.Capabilities.Add(capability.SymRef, "refs/remotes/origin/HEAD:refs/heads/main"); err != nil {
63+
t.Fatalf("Capabilities.Add: %v", err)
64+
}
65+
if got := headTargetFromAdv(adv); got != "" {
66+
t.Errorf("headTargetFromAdv ignored non-HEAD symref = %q, want empty", got)
67+
}
68+
}
69+
4670
func TestAdvRefsCaps(t *testing.T) {
4771
// nil AdvRefs should return nil.
4872
if got := AdvRefsCaps(nil); got != nil {

internal/planner/checkpoint.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ type BootstrapBatch struct {
2020
// FirstParentChain walks the first-parent chain from tip back to root,
2121
// returning the chain in root-to-tip order.
2222
func FirstParentChain(store storer.EncodedObjectStorer, tip plumbing.Hash) ([]plumbing.Hash, error) {
23+
return FirstParentChainStoppingAt(store, tip, nil)
24+
}
25+
26+
// FirstParentChainStoppingAt walks the first-parent chain from tip back to
27+
// root, stopping early when a commit's hash is in stopAt. Returns the chain
28+
// in root-to-tip order, excluding any stopAt commits. When tip itself is in
29+
// stopAt, the returned chain is empty. A nil stopAt behaves like the plain
30+
// FirstParentChain walk.
31+
//
32+
// This supports trunk-aware planning: once trunk's ancestry is known, other
33+
// branches only need their divergence chain.
34+
func FirstParentChainStoppingAt(store storer.EncodedObjectStorer, tip plumbing.Hash, stopAt map[plumbing.Hash]struct{}) ([]plumbing.Hash, error) {
35+
if _, stop := stopAt[tip]; stop {
36+
return nil, nil
37+
}
2338
commit, err := object.GetCommit(store, tip)
2439
if err != nil {
2540
return nil, fmt.Errorf("load tip commit %s: %w", tip, err)
@@ -30,7 +45,11 @@ func FirstParentChain(store storer.EncodedObjectStorer, tip plumbing.Hash) ([]pl
3045
if len(commit.ParentHashes) == 0 {
3146
break
3247
}
33-
commit, err = object.GetCommit(store, commit.ParentHashes[0])
48+
parent := commit.ParentHashes[0]
49+
if _, stop := stopAt[parent]; stop {
50+
break
51+
}
52+
commit, err = object.GetCommit(store, parent)
3453
if err != nil {
3554
return nil, fmt.Errorf("load parent commit %s: %w", commit.ParentHashes[0], err)
3655
}

0 commit comments

Comments
 (0)