Skip to content

Commit 062442f

Browse files
authored
feat: support check constraints (#335)
1 parent 0283fe8 commit 062442f

File tree

5 files changed

+111
-40
lines changed

5 files changed

+111
-40
lines changed

.github/workflows/mysql-copy-tests.yml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,14 @@ jobs:
3535
with:
3636
python-version: '3.13'
3737

38-
- name: Install system packages
39-
uses: awalsh128/cache-apt-pkgs-action@latest
40-
with:
41-
packages: libnsl2 # required by MySQL Shell
42-
version: 1.1
43-
4438
- name: Install dependencies
4539
run: |
4640
go get .
4741
4842
pip3 install "sqlglot[rs]"
4943
5044
curl -LJO https://dev.mysql.com/get/Downloads/MySQL-Shell/mysql-shell_9.1.0-1debian12_amd64.deb
51-
sudo dpkg -i ./mysql-shell_9.1.0-1debian12_amd64.deb
45+
sudo apt-get install -y ./mysql-shell_9.1.0-1debian12_amd64.deb
5246
5347
- name: Setup test data in source MySQL
5448
run: |
@@ -67,10 +61,11 @@ jobs:
6761
-- A table with non-default starting auto_increment value
6862
CREATE TABLE items (
6963
id INT AUTO_INCREMENT PRIMARY KEY,
64+
v BIGINT check (v > 0),
7065
name VARCHAR(100)
7166
) AUTO_INCREMENT=1000;
7267
73-
INSERT INTO items (name) VALUES ('item1'), ('item2'), ('item3');
68+
INSERT INTO items (v, name) VALUES (1, 'item1'), (2, 'item2'), (3, 'item3');
7469
"
7570
7671
- name: Build and start MyDuck Server

backend/executor.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,12 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
8080
case *plan.InsertInto:
8181
insert := n.(*plan.InsertInto)
8282

83-
// For AUTO_INCREMENT column, we fallback to the framework if the column is specified.
84-
if dst, err := plan.GetInsertable(insert.Destination); err == nil && dst.Schema().HasAutoIncrement() {
85-
if len(insert.ColumnNames) == 0 || len(insert.ColumnNames) == len(dst.Schema()) {
86-
return b.base.Build(ctx, root, r)
87-
}
88-
}
89-
83+
// The handling of auto_increment reset and check constraints is not supported by DuckDB.
84+
// We need to fallback to the framework for these cases.
85+
// But we want to rewrite LOAD DATA to be handled by DuckDB,
86+
// as it is a common way to import data into the database.
87+
// Therefore, we ignoring auto_increment and check constraints for LOAD DATA.
88+
// So rewriting LOAD DATA is done eagerly here.
9089
src := insert.Source
9190
if proj, ok := src.(*plan.Project); ok {
9291
src = proj.Child
@@ -97,6 +96,20 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
9796
}
9897
return b.base.Build(ctx, root, r)
9998
}
99+
100+
if dst, err := plan.GetInsertable(insert.Destination); err == nil {
101+
// For AUTO_INCREMENT column, we fallback to the framework if the column is specified.
102+
// if dst.Schema().HasAutoIncrement() && (0 == len(insert.ColumnNames) || len(insert.ColumnNames) == len(dst.Schema())) {
103+
if dst.Schema().HasAutoIncrement() {
104+
return b.base.Build(ctx, root, r)
105+
}
106+
// For table with check constraints, we fallback to the framework.
107+
if ct, ok := dst.(sql.CheckTable); ok {
108+
if checks, err := ct.GetChecks(ctx); err == nil && len(checks) > 0 {
109+
return b.base.Build(ctx, root, r)
110+
}
111+
}
112+
}
100113
}
101114

102115
// Fallback to the base builder if the plan contains system/user variables or is not a pure data query.

catalog/database.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
229229
b.WriteString(")")
230230

231231
// Add comment to the table
232-
info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName}
232+
info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName, nil}
233233
b.WriteString(fmt.Sprintf(
234234
"; COMMENT ON TABLE %s IS '%s'",
235235
fullTableName,

catalog/table.go

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type ExtraTableInfo struct {
2727
PkOrdinals []int
2828
Replicated bool
2929
Sequence string
30+
Checks []sql.CheckDefinition
3031
}
3132

3233
type ColumnInfo struct {
@@ -37,6 +38,7 @@ type ColumnInfo struct {
3738
ColumnDefault stdsql.NullString
3839
Comment stdsql.NullString
3940
}
41+
4042
type IndexedTable struct {
4143
*Table
4244
Lookup sql.IndexLookup
@@ -54,6 +56,8 @@ var _ sql.TruncateableTable = (*Table)(nil)
5456
var _ sql.ReplaceableTable = (*Table)(nil)
5557
var _ sql.CommentedTable = (*Table)(nil)
5658
var _ sql.AutoIncrementTable = (*Table)(nil)
59+
var _ sql.CheckTable = (*Table)(nil)
60+
var _ sql.CheckAlterableTable = (*Table)(nil)
5761

5862
func NewTable(name string, db *Database) *Table {
5963
return &Table{
@@ -707,6 +711,9 @@ func (t *Table) PreciseMatch() bool {
707711

708712
// Comment implements sql.CommentedTable.
709713
func (t *Table) Comment() string {
714+
t.mu.RLock()
715+
defer t.mu.RUnlock()
716+
710717
return t.comment.Text
711718
}
712719

@@ -761,10 +768,16 @@ func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup
761768

762769
// PeekNextAutoIncrementValue implements sql.AutoIncrementTable.
763770
func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
771+
t.mu.RLock()
772+
defer t.mu.RUnlock()
773+
764774
if t.comment.Meta.Sequence == "" {
765775
return 0, sql.ErrNoAutoIncrementCol
766776
}
777+
return t.getNextAutoIncrementValue(ctx)
778+
}
767779

780+
func (t *Table) getNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
768781
// For PeekNextAutoIncrementValue, we want to see what the next value would be
769782
// without actually incrementing. We can do this by getting currval + 1.
770783
var val uint64
@@ -788,12 +801,20 @@ func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
788801
}
789802

790803
// GetNextAutoIncrementValue implements sql.AutoIncrementTable.
791-
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
804+
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal any) (uint64, error) {
805+
t.mu.Lock()
806+
defer t.mu.Unlock()
807+
792808
if t.comment.Meta.Sequence == "" {
793809
return 0, sql.ErrNoAutoIncrementCol
794810
}
795811

796-
// If insertVal is provided and greater than current sequence value, update sequence
812+
nextVal, err := t.getNextAutoIncrementValue(ctx)
813+
if err != nil {
814+
return 0, err
815+
}
816+
817+
// If insertVal is provided and greater than the next sequence value, update sequence
797818
if insertVal != nil {
798819
var start uint64
799820
switch v := insertVal.(type) {
@@ -804,7 +825,7 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{
804825
start = uint64(v)
805826
}
806827
}
807-
if start > 0 {
828+
if start > 0 && start > nextVal {
808829
err := t.setAutoIncrementValue(ctx, start)
809830
if err != nil {
810831
return 0, err
@@ -815,7 +836,7 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{
815836

816837
// Get next value from sequence
817838
var val uint64
818-
err := adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val)
839+
err = adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val)
819840
if err != nil {
820841
return 0, ErrDuckDB.New(err)
821842
}
@@ -885,14 +906,12 @@ func (t *Table) setAutoIncrementValue(ctx *sql.Context, value uint64) error {
885906
// }
886907

887908
// Update the table comment with the new sequence name
888-
tableInfo := t.comment.Meta
889-
tableInfo.Sequence = fullSequenceName
890-
comment := NewCommentWithMeta(t.comment.Text, tableInfo)
891-
if _, err = adapter.Exec(ctx, `COMMENT ON TABLE `+FullTableName(t.db.catalog, t.db.name, t.name)+` IS '`+comment.Encode()+`'`); err != nil {
892-
return ErrDuckDB.New(err)
909+
if err = t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) {
910+
info.Sequence = fullSequenceName
911+
}); err != nil {
912+
return err
893913
}
894914

895-
t.comment.Meta.Sequence = fullSequenceName
896915
return t.withSchema(ctx)
897916
}
898917

@@ -910,6 +929,62 @@ func (s *autoIncrementSetter) Close(ctx *sql.Context) error {
910929
}
911930

912931
func (s *autoIncrementSetter) AcquireAutoIncrementLock(ctx *sql.Context) (func(), error) {
913-
// DuckDB handles sequence synchronization internally
914-
return func() {}, nil
932+
s.t.mu.Lock()
933+
return s.t.mu.Unlock, nil
934+
}
935+
936+
func (t *Table) updateExtraTableInfo(ctx *sql.Context, updater func(*ExtraTableInfo)) error {
937+
tableInfo := t.comment.Meta
938+
updater(&tableInfo)
939+
comment := NewCommentWithMeta(t.comment.Text, tableInfo)
940+
_, err := adapter.Exec(ctx, `COMMENT ON TABLE `+FullTableName(t.db.catalog, t.db.name, t.name)+` IS '`+comment.Encode()+`'`)
941+
if err != nil {
942+
return ErrDuckDB.New(err)
943+
}
944+
t.comment.Meta = tableInfo // Update the in-memory metadata
945+
return nil
946+
}
947+
948+
// CheckConstraints implements sql.CheckTable.
949+
func (t *Table) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
950+
t.mu.RLock()
951+
defer t.mu.RUnlock()
952+
953+
return t.comment.Meta.Checks, nil
954+
}
955+
956+
// AddCheck implements sql.CheckAlterableTable.
957+
func (t *Table) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
958+
t.mu.Lock()
959+
defer t.mu.Unlock()
960+
961+
// TODO(fan): Implement this once DuckDB supports modifying check constraints.
962+
// https://duckdb.org/docs/sql/statements/alter_table.html#add--drop-constraint
963+
// https://github.com/duckdb/duckdb/issues/57
964+
// Just record the check constraint for now.
965+
return t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) {
966+
info.Checks = append(info.Checks, *check)
967+
})
968+
}
969+
970+
// DropCheck implements sql.CheckAlterableTable.
971+
func (t *Table) DropCheck(ctx *sql.Context, checkName string) error {
972+
t.mu.Lock()
973+
defer t.mu.Unlock()
974+
975+
checks := make([]sql.CheckDefinition, 0, max(len(t.comment.Meta.Checks)-1, 0))
976+
found := false
977+
for i, check := range t.comment.Meta.Checks {
978+
if check.Name == checkName {
979+
found = true
980+
continue
981+
}
982+
checks = append(checks, t.comment.Meta.Checks[i])
983+
}
984+
if !found {
985+
return sql.ErrUnknownConstraint.New(checkName)
986+
}
987+
return t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) {
988+
info.Checks = checks
989+
})
915990
}

main_test.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,15 +1130,10 @@ func TestCreateTable(t *testing.T) {
11301130
"CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable",
11311131
"display_width_for_numeric_types",
11321132
"SHOW_FULL_FIELDS_FROM_numericDisplayWidthTest;",
1133-
"Validate_that_CREATE_LIKE_preserves_checks",
11341133
"datetime_precision",
11351134
"CREATE_TABLE_tt_(pk_int_primary_key,_d_datetime(6)_default_current_timestamp(6))",
11361135
"Identifier_lengths",
1137-
"create_table_b_(a_int_primary_key,_constraint_abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijkl_check_(a_>_0))",
1138-
"create_table_d_(a_int_primary_key,_constraint_abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijkl_foreign_key_(a)_references_parent(a))",
11391136
"table_charset_options",
1140-
"show_create_table_t1",
1141-
"show_create_table_t2",
11421137
"show_create_table_t3",
11431138
"show_create_table_t4",
11441139
"create_table_with_select_preserves_default",
@@ -1158,17 +1153,10 @@ func TestCreateTable(t *testing.T) {
11581153
"CREATE_EVENT_foo_ON_SCHEDULE_EVERY_1_YEAR_DO_CREATE_TABLE_bar_AS_SELECT_1;",
11591154
"trigger_contains_CREATE_TABLE_AS",
11601155
"CREATE_TRIGGER_foo_AFTER_UPDATE_ON_t_FOR_EACH_ROW_BEGIN_CREATE_TABLE_bar_AS_SELECT_1;_END;",
1161-
"insert_into_t1_(b)_values_(1),_(2)",
1162-
"show_create_table_t1",
1163-
"select_*_from_t1_order_by_b",
1164-
"insert_into_t1_(b)_values_(1),_(2)",
1165-
"show_create_table_t1",
1166-
"select_*_from_t1_order_by_b",
11671156
}
11681157

11691158
// Patch auto-generated queries that are known to fail
11701159
waitForFixQueries = append(waitForFixQueries, []string{
1171-
"CREATE TABLE t1 (pk int primary key, test_score int, height int CHECK (height < 10) , CONSTRAINT mycheck CHECK (test_score >= 50))",
11721160
"create table a (i int primary key, j int default 100);", // skip the case "create table with select preserves default" since there is no support for CREATE TABLE SELECT
11731161
}...)
11741162

0 commit comments

Comments
 (0)