|
| 1 | +// Copyright 2025 Dolthub, Inc. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +package dtablefunctions |
| 16 | + |
| 17 | +import ( |
| 18 | + "fmt" |
| 19 | + "strings" |
| 20 | + |
| 21 | + "github.com/dolthub/go-mysql-server/sql" |
| 22 | + "github.com/dolthub/go-mysql-server/sql/types" |
| 23 | + |
| 24 | + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" |
| 25 | + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" |
| 26 | +) |
| 27 | + |
| 28 | +var _ sql.TableFunction = (*BranchStatusTableFunction)(nil) |
| 29 | + |
| 30 | +type BranchStatusTableFunction struct { |
| 31 | + db sql.Database |
| 32 | + exprs []sql.Expression |
| 33 | +} |
| 34 | + |
| 35 | +// NewInstance creates a new instance of TableFunction interface |
| 36 | +func (b *BranchStatusTableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { |
| 37 | + if len(args) == 0 { |
| 38 | + return nil, sql.ErrInvalidArgumentNumber.New(b.Name(), "at least 1", len(args)) |
| 39 | + } |
| 40 | + return &BranchStatusTableFunction{ |
| 41 | + db: db, |
| 42 | + exprs: args, |
| 43 | + }, nil |
| 44 | +} |
| 45 | + |
| 46 | +// Name implements the sql.Node interface |
| 47 | +func (b *BranchStatusTableFunction) Name() string { |
| 48 | + return "DOLT_BRANCH_STATUS" |
| 49 | +} |
| 50 | + |
| 51 | +// String implements the Stringer interface |
| 52 | +func (b *BranchStatusTableFunction) String() string { |
| 53 | + exprStrs := make([]string, len(b.exprs)) |
| 54 | + for i, expr := range b.exprs { |
| 55 | + exprStrs[i] = expr.String() |
| 56 | + } |
| 57 | + return fmt.Sprintf("%s(%s)", b.Name(), strings.Join(exprStrs, ", ")) |
| 58 | +} |
| 59 | + |
| 60 | +// Resolved implements the sql.Resolvable interface |
| 61 | +func (b *BranchStatusTableFunction) Resolved() bool { |
| 62 | + for _, expr := range b.exprs { |
| 63 | + if !expr.Resolved() { |
| 64 | + return false |
| 65 | + } |
| 66 | + } |
| 67 | + return true |
| 68 | +} |
| 69 | + |
| 70 | +// Expressions implements the sql.Expressioner interface |
| 71 | +func (b *BranchStatusTableFunction) Expressions() []sql.Expression { |
| 72 | + return b.exprs |
| 73 | +} |
| 74 | + |
| 75 | +// WithExpressions implements the sql.Expressioner interface |
| 76 | +func (b *BranchStatusTableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { |
| 77 | + nd := *b |
| 78 | + nd.exprs = exprs |
| 79 | + return &nd, nil |
| 80 | +} |
| 81 | + |
| 82 | +// Database implements the sql.Databaser interface |
| 83 | +func (b *BranchStatusTableFunction) Database() sql.Database { |
| 84 | + return b.db |
| 85 | +} |
| 86 | + |
| 87 | +// WithDatabase implements the sql.Databaser interface |
| 88 | +func (b *BranchStatusTableFunction) WithDatabase(db sql.Database) (sql.Node, error) { |
| 89 | + nd := *b |
| 90 | + nd.db = db |
| 91 | + return &nd, nil |
| 92 | +} |
| 93 | + |
| 94 | +// IsReadOnly implements the sql.Node interface |
| 95 | +func (b *BranchStatusTableFunction) IsReadOnly() bool { |
| 96 | + return true |
| 97 | +} |
| 98 | + |
| 99 | +// Schema implements the sql.Node interface |
| 100 | +func (b *BranchStatusTableFunction) Schema() sql.Schema { |
| 101 | + return sql.Schema{ |
| 102 | + &sql.Column{Name: "branch", Type: types.Text, Nullable: false}, |
| 103 | + &sql.Column{Name: "commits_ahead", Type: types.Uint64, Nullable: false}, |
| 104 | + &sql.Column{Name: "commits_behind", Type: types.Uint64, Nullable: false}, |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +// Children implements the sql.Node interface |
| 109 | +func (b *BranchStatusTableFunction) Children() []sql.Node { |
| 110 | + return nil |
| 111 | +} |
| 112 | + |
| 113 | +// WithChildren implements the sql.Node interface |
| 114 | +func (b *BranchStatusTableFunction) WithChildren(children ...sql.Node) (sql.Node, error) { |
| 115 | + return b, nil |
| 116 | +} |
| 117 | + |
| 118 | +// RowIter implements the sql.Node interface |
| 119 | +func (b *BranchStatusTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { |
| 120 | + sqlDb, ok := b.db.(dsess.SqlDatabase) |
| 121 | + if !ok { |
| 122 | + return nil, fmt.Errorf("unable to get dolt database") |
| 123 | + } |
| 124 | + ddb := sqlDb.DbData().Ddb |
| 125 | + |
| 126 | + sess := dsess.DSessFromSess(ctx.Session) |
| 127 | + dbName := sess.Session.GetCurrentDatabase() |
| 128 | + headRef, err := sess.CWBHeadRef(ctx, dbName) |
| 129 | + if err != nil { |
| 130 | + return nil, err |
| 131 | + } |
| 132 | + |
| 133 | + specs, err := mustExpressionsToString(ctx, b.exprs) |
| 134 | + if err != nil { |
| 135 | + return nil, err |
| 136 | + } |
| 137 | + if len(specs) == 0 { |
| 138 | + return nil, sql.ErrInvalidArgumentNumber.New(b.Name(), "at least 1", 0) |
| 139 | + } |
| 140 | + if len(specs) == 1 { |
| 141 | + return sql.RowsToRowIter(), nil |
| 142 | + } |
| 143 | + |
| 144 | + commits := make([]*doltdb.Commit, len(specs)) |
| 145 | + for i, spec := range specs { |
| 146 | + cs, cErr := doltdb.NewCommitSpec(spec) |
| 147 | + if cErr != nil { |
| 148 | + return nil, cErr |
| 149 | + } |
| 150 | + optCmt, oErr := ddb.Resolve(ctx, cs, headRef) |
| 151 | + if oErr != nil { |
| 152 | + return nil, oErr |
| 153 | + } |
| 154 | + commit, optCommitOk := optCmt.ToCommit() |
| 155 | + if !optCommitOk { |
| 156 | + return nil, doltdb.ErrGhostCommitEncountered |
| 157 | + } |
| 158 | + commits[i] = commit |
| 159 | + } |
| 160 | + |
| 161 | + baseCommit := commits[0] |
| 162 | + branchCommits := commits[1:] |
| 163 | + |
| 164 | + baseHash, err := baseCommit.HashOf() |
| 165 | + if err != nil { |
| 166 | + return nil, err |
| 167 | + } |
| 168 | + baseCommitClosure, err := baseCommit.GetCommitClosure(ctx) |
| 169 | + if err != nil { |
| 170 | + return nil, err |
| 171 | + } |
| 172 | + baseAncestors, err := baseCommitClosure.AsHashSet(ctx) |
| 173 | + if err != nil { |
| 174 | + return nil, err |
| 175 | + } |
| 176 | + baseAncestors.Insert(baseHash) |
| 177 | + |
| 178 | + var rows []sql.Row |
| 179 | + for i, branchCommit := range branchCommits { |
| 180 | + branchHash, hErr := branchCommit.HashOf() |
| 181 | + if hErr != nil { |
| 182 | + return nil, hErr |
| 183 | + } |
| 184 | + |
| 185 | + // same commit will have no differences |
| 186 | + var ahead, behind uint64 |
| 187 | + if branchHash.Equal(baseHash) { |
| 188 | + rows = append(rows, sql.Row{specs[i+1], ahead, behind}) |
| 189 | + continue |
| 190 | + } |
| 191 | + |
| 192 | + branchCommitClosure, bErr := branchCommit.GetCommitClosure(ctx) |
| 193 | + if bErr != nil { |
| 194 | + return nil, bErr |
| 195 | + } |
| 196 | + branchAncestors, bErr := branchCommitClosure.AsHashSet(ctx) |
| 197 | + if bErr != nil { |
| 198 | + return nil, bErr |
| 199 | + } |
| 200 | + branchAncestors.Insert(branchHash) |
| 201 | + for branchAncestor := range branchAncestors { |
| 202 | + if !baseAncestors.Has(branchAncestor) { |
| 203 | + ahead++ |
| 204 | + } |
| 205 | + } |
| 206 | + for baseAncestor := range baseAncestors { |
| 207 | + if !branchAncestors.Has(baseAncestor) { |
| 208 | + behind++ |
| 209 | + } |
| 210 | + } |
| 211 | + rows = append(rows, sql.Row{specs[i+1], ahead, behind}) |
| 212 | + } |
| 213 | + |
| 214 | + return sql.RowsToRowIter(rows...), nil |
| 215 | +} |
0 commit comments