Skip to content

Commit 37fee6f

Browse files
authored
implement dolt_branch_status() (#9122)
1 parent cea4c36 commit 37fee6f

File tree

7 files changed

+559
-64
lines changed

7 files changed

+559
-64
lines changed

go/libraries/doltcore/sqle/dfunctions/has_ancestor.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,10 @@ func (a *HasAncestor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
6969
if err != nil {
7070
return nil, err
7171
}
72-
73-
cs, err := doltdb.NewCommitSpec(headStr.(string))
72+
headCommit, err = resolveRefSpec(ctx, headRef, ddb, headStr.(string))
7473
if err != nil {
7574
return nil, err
7675
}
77-
optCmt, err := ddb.Resolve(ctx, cs, headRef)
78-
if err != nil {
79-
return nil, fmt.Errorf("error during has_ancestor check: ref not found '%s'", headStr)
80-
}
81-
headCommit, ok = optCmt.ToCommit()
82-
if !ok {
83-
return nil, doltdb.ErrGhostCommitEncountered
84-
}
8576
}
8677

8778
var ancCommit *doltdb.Commit
@@ -94,19 +85,10 @@ func (a *HasAncestor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
9485
if err != nil {
9586
return nil, err
9687
}
97-
cs, err := doltdb.NewCommitSpec(ancStr.(string))
88+
ancCommit, err = resolveRefSpec(ctx, headRef, ddb, ancStr.(string))
9889
if err != nil {
9990
return nil, err
10091
}
101-
optCmt, err := ddb.Resolve(ctx, cs, headRef)
102-
if err != nil {
103-
return nil, fmt.Errorf("error during has_ancestor check: ref not found '%s'", ancStr)
104-
}
105-
ancCommit, ok = optCmt.ToCommit()
106-
if !ok {
107-
return nil, doltdb.ErrGhostCommitEncountered
108-
}
109-
11092
}
11193

11294
headHash, err := headCommit.HashOf()

go/libraries/doltcore/sqle/dfunctions/dolt_merge_base.go renamed to go/libraries/doltcore/sqle/dfunctions/merge_base.go

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
2626
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
27+
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
2728
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
2829
)
2930

@@ -40,6 +41,21 @@ func NewMergeBase(left, right sql.Expression) sql.Expression {
4041

4142
// Eval implements the sql.Expression interface.
4243
func (d MergeBase) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
44+
sess := dsess.DSessFromSess(ctx.Session)
45+
dbName := ctx.GetCurrentDatabase()
46+
dbData, ok := sess.GetDbData(ctx, dbName)
47+
if !ok {
48+
return nil, sql.ErrDatabaseNotFound.New(dbName)
49+
}
50+
doltDB, ok := sess.GetDoltDB(ctx, dbName)
51+
if !ok {
52+
return nil, sql.ErrDatabaseNotFound.New(dbName)
53+
}
54+
headRef, err := dbData.Rsr.CWBHeadRef(ctx)
55+
if err != nil {
56+
return nil, err
57+
}
58+
4359
leftSpec, err := d.Left().Eval(ctx, row)
4460
if err != nil {
4561
return nil, err
@@ -48,7 +64,6 @@ func (d MergeBase) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
4864
if err != nil {
4965
return nil, err
5066
}
51-
5267
if leftSpec == nil || rightSpec == nil {
5368
return nil, nil
5469
}
@@ -57,13 +72,16 @@ func (d MergeBase) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
5772
if !ok {
5873
return nil, errors.New("left value is not a string")
5974
}
60-
6175
rightStr, ok := rightSpec.(string)
6276
if !ok {
6377
return nil, errors.New("right value is not a string")
6478
}
6579

66-
left, right, err := resolveRefSpecs(ctx, leftStr, rightStr)
80+
left, err := resolveRefSpec(ctx, headRef, doltDB, leftStr)
81+
if err != nil {
82+
return nil, err
83+
}
84+
right, err := resolveRefSpec(ctx, headRef, doltDB, rightStr)
6785
if err != nil {
6886
return nil, err
6987
}
@@ -76,52 +94,20 @@ func (d MergeBase) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
7694
return mergeBase.String(), nil
7795
}
7896

79-
func resolveRefSpecs(ctx *sql.Context, leftSpec, rightSpec string) (left, right *doltdb.Commit, err error) {
80-
lcs, err := doltdb.NewCommitSpec(leftSpec)
81-
if err != nil {
82-
return nil, nil, err
83-
}
84-
rcs, err := doltdb.NewCommitSpec(rightSpec)
97+
func resolveRefSpec(ctx *sql.Context, headRef ref.DoltRef, doltDB *doltdb.DoltDB, spec string) (*doltdb.Commit, error) {
98+
cs, err := doltdb.NewCommitSpec(spec)
8599
if err != nil {
86-
return nil, nil, err
87-
}
88-
89-
sess := dsess.DSessFromSess(ctx.Session)
90-
dbName := ctx.GetCurrentDatabase()
91-
92-
dbData, ok := sess.GetDbData(ctx, dbName)
93-
if !ok {
94-
return nil, nil, sql.ErrDatabaseNotFound.New(dbName)
95-
}
96-
doltDB, ok := sess.GetDoltDB(ctx, dbName)
97-
if !ok {
98-
return nil, nil, sql.ErrDatabaseNotFound.New(dbName)
99-
}
100-
101-
headRef, err := dbData.Rsr.CWBHeadRef(ctx)
102-
if err != nil {
103-
return nil, nil, err
104-
}
105-
106-
optCmt, err := doltDB.Resolve(ctx, lcs, headRef)
107-
if err != nil {
108-
return nil, nil, err
109-
}
110-
left, ok = optCmt.ToCommit()
111-
if !ok {
112-
return nil, nil, doltdb.ErrGhostCommitEncountered
100+
return nil, err
113101
}
114-
115-
optCmt, err = doltDB.Resolve(ctx, rcs, headRef)
102+
optCmt, err := doltDB.Resolve(ctx, cs, headRef)
116103
if err != nil {
117-
return nil, nil, err
104+
return nil, err
118105
}
119-
right, ok = optCmt.ToCommit()
106+
commit, ok := optCmt.ToCommit()
120107
if !ok {
121-
return nil, nil, doltdb.ErrGhostCommitEncountered
108+
return nil, doltdb.ErrGhostCommitEncountered
122109
}
123-
124-
return
110+
return commit, err
125111
}
126112

127113
// String implements the sql.Expression interface.
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dtablefunctions
16+
17+
import (
18+
"fmt"
19+
"strings"
20+
21+
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/types"
23+
24+
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
25+
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
26+
)
27+
28+
var _ sql.TableFunction = (*BranchStatusTableFunction)(nil)
29+
30+
type BranchStatusTableFunction struct {
31+
db sql.Database
32+
exprs []sql.Expression
33+
}
34+
35+
// NewInstance creates a new instance of TableFunction interface
36+
func (b *BranchStatusTableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
37+
if len(args) == 0 {
38+
return nil, sql.ErrInvalidArgumentNumber.New(b.Name(), "at least 1", len(args))
39+
}
40+
return &BranchStatusTableFunction{
41+
db: db,
42+
exprs: args,
43+
}, nil
44+
}
45+
46+
// Name implements the sql.Node interface
47+
func (b *BranchStatusTableFunction) Name() string {
48+
return "DOLT_BRANCH_STATUS"
49+
}
50+
51+
// String implements the Stringer interface
52+
func (b *BranchStatusTableFunction) String() string {
53+
exprStrs := make([]string, len(b.exprs))
54+
for i, expr := range b.exprs {
55+
exprStrs[i] = expr.String()
56+
}
57+
return fmt.Sprintf("%s(%s)", b.Name(), strings.Join(exprStrs, ", "))
58+
}
59+
60+
// Resolved implements the sql.Resolvable interface
61+
func (b *BranchStatusTableFunction) Resolved() bool {
62+
for _, expr := range b.exprs {
63+
if !expr.Resolved() {
64+
return false
65+
}
66+
}
67+
return true
68+
}
69+
70+
// Expressions implements the sql.Expressioner interface
71+
func (b *BranchStatusTableFunction) Expressions() []sql.Expression {
72+
return b.exprs
73+
}
74+
75+
// WithExpressions implements the sql.Expressioner interface
76+
func (b *BranchStatusTableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
77+
nd := *b
78+
nd.exprs = exprs
79+
return &nd, nil
80+
}
81+
82+
// Database implements the sql.Databaser interface
83+
func (b *BranchStatusTableFunction) Database() sql.Database {
84+
return b.db
85+
}
86+
87+
// WithDatabase implements the sql.Databaser interface
88+
func (b *BranchStatusTableFunction) WithDatabase(db sql.Database) (sql.Node, error) {
89+
nd := *b
90+
nd.db = db
91+
return &nd, nil
92+
}
93+
94+
// IsReadOnly implements the sql.Node interface
95+
func (b *BranchStatusTableFunction) IsReadOnly() bool {
96+
return true
97+
}
98+
99+
// Schema implements the sql.Node interface
100+
func (b *BranchStatusTableFunction) Schema() sql.Schema {
101+
return sql.Schema{
102+
&sql.Column{Name: "branch", Type: types.Text, Nullable: false},
103+
&sql.Column{Name: "commits_ahead", Type: types.Uint64, Nullable: false},
104+
&sql.Column{Name: "commits_behind", Type: types.Uint64, Nullable: false},
105+
}
106+
}
107+
108+
// Children implements the sql.Node interface
109+
func (b *BranchStatusTableFunction) Children() []sql.Node {
110+
return nil
111+
}
112+
113+
// WithChildren implements the sql.Node interface
114+
func (b *BranchStatusTableFunction) WithChildren(children ...sql.Node) (sql.Node, error) {
115+
return b, nil
116+
}
117+
118+
// RowIter implements the sql.Node interface
119+
func (b *BranchStatusTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
120+
sqlDb, ok := b.db.(dsess.SqlDatabase)
121+
if !ok {
122+
return nil, fmt.Errorf("unable to get dolt database")
123+
}
124+
ddb := sqlDb.DbData().Ddb
125+
126+
sess := dsess.DSessFromSess(ctx.Session)
127+
dbName := sess.Session.GetCurrentDatabase()
128+
headRef, err := sess.CWBHeadRef(ctx, dbName)
129+
if err != nil {
130+
return nil, err
131+
}
132+
133+
specs, err := mustExpressionsToString(ctx, b.exprs)
134+
if err != nil {
135+
return nil, err
136+
}
137+
if len(specs) == 0 {
138+
return nil, sql.ErrInvalidArgumentNumber.New(b.Name(), "at least 1", 0)
139+
}
140+
if len(specs) == 1 {
141+
return sql.RowsToRowIter(), nil
142+
}
143+
144+
commits := make([]*doltdb.Commit, len(specs))
145+
for i, spec := range specs {
146+
cs, cErr := doltdb.NewCommitSpec(spec)
147+
if cErr != nil {
148+
return nil, cErr
149+
}
150+
optCmt, oErr := ddb.Resolve(ctx, cs, headRef)
151+
if oErr != nil {
152+
return nil, oErr
153+
}
154+
commit, optCommitOk := optCmt.ToCommit()
155+
if !optCommitOk {
156+
return nil, doltdb.ErrGhostCommitEncountered
157+
}
158+
commits[i] = commit
159+
}
160+
161+
baseCommit := commits[0]
162+
branchCommits := commits[1:]
163+
164+
baseHash, err := baseCommit.HashOf()
165+
if err != nil {
166+
return nil, err
167+
}
168+
baseCommitClosure, err := baseCommit.GetCommitClosure(ctx)
169+
if err != nil {
170+
return nil, err
171+
}
172+
baseAncestors, err := baseCommitClosure.AsHashSet(ctx)
173+
if err != nil {
174+
return nil, err
175+
}
176+
baseAncestors.Insert(baseHash)
177+
178+
var rows []sql.Row
179+
for i, branchCommit := range branchCommits {
180+
branchHash, hErr := branchCommit.HashOf()
181+
if hErr != nil {
182+
return nil, hErr
183+
}
184+
185+
// same commit will have no differences
186+
var ahead, behind uint64
187+
if branchHash.Equal(baseHash) {
188+
rows = append(rows, sql.Row{specs[i+1], ahead, behind})
189+
continue
190+
}
191+
192+
branchCommitClosure, bErr := branchCommit.GetCommitClosure(ctx)
193+
if bErr != nil {
194+
return nil, bErr
195+
}
196+
branchAncestors, bErr := branchCommitClosure.AsHashSet(ctx)
197+
if bErr != nil {
198+
return nil, bErr
199+
}
200+
branchAncestors.Insert(branchHash)
201+
for branchAncestor := range branchAncestors {
202+
if !baseAncestors.Has(branchAncestor) {
203+
ahead++
204+
}
205+
}
206+
for baseAncestor := range baseAncestors {
207+
if !branchAncestors.Has(baseAncestor) {
208+
behind++
209+
}
210+
}
211+
rows = append(rows, sql.Row{specs[i+1], ahead, behind})
212+
}
213+
214+
return sql.RowsToRowIter(rows...), nil
215+
}

go/libraries/doltcore/sqle/dtablefunctions/init.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ var DoltTableFunctions = []sql.TableFunction{
2020
&DiffTableFunction{},
2121
&DiffStatTableFunction{},
2222
&DiffSummaryTableFunction{},
23+
&BranchStatusTableFunction{},
2324
&LogTableFunction{},
2425
&PatchTableFunction{},
2526
&SchemaDiffTableFunction{},

0 commit comments

Comments
 (0)