Skip to content

Commit 4145be1

Browse files
authored
feat: support MySQL CTAS (#351)
* Enable more CTAS tests * add type mapping for hugeint and varint * Add fallback for simple CTAS
1 parent 052e742 commit 4145be1

File tree

5 files changed

+104
-77
lines changed

5 files changed

+104
-77
lines changed

backend/executor.go

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,13 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
6161
}
6262

6363
n := root
64-
ctx.GetLogger().WithFields(logrus.Fields{
65-
"Query": ctx.Query(),
66-
"NodeType": fmt.Sprintf("%T", n),
67-
}).Traceln("Building node:", n)
64+
65+
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
66+
log.WithFields(logrus.Fields{
67+
"Query": ctx.Query(),
68+
"NodeType": fmt.Sprintf("%T", n),
69+
}).Traceln("Building node:", n)
70+
}
6871

6972
// TODO; find a better way to fallback to the base builder
7073
switch n.(type) {
@@ -114,7 +117,13 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
114117
}
115118

116119
// Fallback to the base builder if the plan contains system/user variables or is not a pure data query.
117-
if containsVariable(n) || !IsPureDataQuery(n) {
120+
tree := n
121+
switch n := n.(type) {
122+
case *plan.TableCopier:
123+
tree = n.Source
124+
}
125+
if containsVariable(tree) || !IsPureDataQuery(tree) {
126+
ctx.GetLogger().Traceln("Falling back to the base builder")
118127
return b.base.Build(ctx, root, r)
119128
}
120129

@@ -133,11 +142,21 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
133142
return nil, err
134143
}
135144
return b.base.Build(ctx, root, r)
136-
// SubqueryAlias is for select * from view
145+
// ResolvedTable is for `SELECT * FROM table` and `TABLE table`
146+
// SubqueryAlias is for `SELECT * FROM view`
137147
case *plan.ResolvedTable, *plan.SubqueryAlias, *plan.TableAlias:
138148
return b.executeQuery(ctx, node, conn)
139149
case *plan.Distinct, *plan.OrderedDistinct:
140150
return b.executeQuery(ctx, node, conn)
151+
case *plan.TableCopier:
152+
// We preserve the table schema in a best-effort manner.
153+
// For simple `CREATE TABLE t AS SELECT * FROM t`,
154+
// we fall back to the framework to create the table and copy the data.
155+
// For more complex cases, we directly execute the CTAS statement in DuckDB.
156+
if _, ok := node.Source.(*plan.ResolvedTable); ok {
157+
return b.base.Build(ctx, root, r)
158+
}
159+
return b.executeDML(ctx, node, conn)
141160
case sql.Expressioner:
142161
return b.executeExpressioner(ctx, node, conn)
143162
case *plan.DeleteFrom:
@@ -174,7 +193,7 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
174193
case *plan.ShowTables:
175194
duckSQL = ctx.Query()
176195
case *plan.ResolvedTable:
177-
// SQLGlot cannot translate MySQL's `TABLE t` into DuckDB's `FROM t` - it produces `"table" AS t` instead.
196+
// SQLGlot cannot translate MySQL's `TABLE t` into DuckDB's `FROM t` - it produces `"table" AS t` instead.
178197
duckSQL = `FROM ` + catalog.ConnectIdentifiersANSI(n.Database().Name(), n.Name())
179198
default:
180199
duckSQL, err = transpiler.TranslateWithSQLGlot(ctx.Query())
@@ -183,10 +202,12 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
183202
return nil, catalog.ErrTranspiler.New(err)
184203
}
185204

186-
ctx.GetLogger().WithFields(logrus.Fields{
187-
"Query": ctx.Query(),
188-
"DuckSQL": duckSQL,
189-
}).Trace("Executing Query...")
205+
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
206+
log.WithFields(logrus.Fields{
207+
"Query": ctx.Query(),
208+
"DuckSQL": duckSQL,
209+
}).Trace("Executing Query...")
210+
}
190211

191212
// Execute the DuckDB query
192213
rows, err := conn.QueryContext(ctx.Context, duckSQL)
@@ -204,10 +225,12 @@ func (b *DuckBuilder) executeDML(ctx *sql.Context, n sql.Node, conn *stdsql.Conn
204225
return nil, catalog.ErrTranspiler.New(err)
205226
}
206227

207-
ctx.GetLogger().WithFields(logrus.Fields{
208-
"Query": ctx.Query(),
209-
"DuckSQL": duckSQL,
210-
}).Trace("Executing DML...")
228+
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
229+
log.WithFields(logrus.Fields{
230+
"Query": ctx.Query(),
231+
"DuckSQL": duckSQL,
232+
}).Trace("Executing DML...")
233+
}
211234

212235
// Execute the DuckDB query
213236
result, err := conn.ExecContext(ctx.Context, duckSQL)

catalog/database.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,25 @@ func (d *Database) GetCollation(ctx *sql.Context) sql.CollationID {
446446
func (d *Database) SetCollation(ctx *sql.Context, collation sql.CollationID) error {
447447
return nil
448448
}
449+
450+
// CopyTableData implements sql.TableCopierDatabase interface.
451+
func (d *Database) CopyTableData(ctx *sql.Context, sourceTable string, destinationTable string) (uint64, error) {
452+
d.mu.Lock()
453+
defer d.mu.Unlock()
454+
455+
// Use INSERT INTO ... SELECT to copy data
456+
sql := `INSERT INTO ` + FullTableName(d.catalog, d.name, destinationTable) + ` FROM ` + FullTableName(d.catalog, d.name, sourceTable)
457+
458+
res, err := adapter.Exec(ctx, sql)
459+
if err != nil {
460+
return 0, ErrDuckDB.New(err)
461+
}
462+
463+
// Get count of affected rows
464+
count, err := res.RowsAffected()
465+
if err != nil {
466+
return 0, ErrDuckDB.New(err)
467+
}
468+
469+
return uint64(count), nil
470+
}

catalog/table.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,10 @@ func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) (
754754
}
755755

756756
decodedComment := DecodeComment[MySQLType](comment.String)
757-
dataType := mysqlDataType(AnnotatedDuckType{dataTypes, decodedComment.Meta}, uint8(numericPrecision.Int32), uint8(numericScale.Int32))
757+
dataType, err := mysqlDataType(AnnotatedDuckType{dataTypes, decodedComment.Meta}, uint8(numericPrecision.Int32), uint8(numericScale.Int32))
758+
if err != nil {
759+
return nil, err
760+
}
758761

759762
columnInfo := &ColumnInfo{
760763
ColumnName: columnName,

catalog/type_mapping.go

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func DuckdbDataType(mysqlType sql.Type) (AnnotatedDuckType, error) {
198198
}
199199
}
200200

201-
func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericScale uint8) sql.Type {
201+
func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericScale uint8) (sql.Type, error) {
202202
// TODO: The current type mappings are not lossless. We need to store the original type in the column comments.
203203
duckName := strings.TrimSpace(strings.ToUpper(duckType.name))
204204

@@ -219,7 +219,7 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc
219219
intBaseType = sqltypes.Uint8
220220
case "SMALLINT":
221221
if mysqlName == "YEAR" {
222-
return types.Year
222+
return types.Year, nil
223223
}
224224
intBaseType = sqltypes.Int16
225225
case "USMALLINT":
@@ -240,13 +240,13 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc
240240
intBaseType = sqltypes.Int64
241241
case "UBIGINT":
242242
if mysqlName == "BIT" {
243-
return types.MustCreateBitType(duckType.mysql.Precision)
243+
return types.CreateBitType(duckType.mysql.Precision)
244244
}
245245
intBaseType = sqltypes.Uint64
246246
}
247247

248248
if intBaseType != sqltypes.Null {
249-
return types.MustCreateNumberTypeWithDisplayWidth(intBaseType, int(duckType.mysql.Display))
249+
return types.CreateNumberTypeWithDisplayWidth(intBaseType, int(duckType.mysql.Display))
250250
}
251251

252252
length := int64(duckType.mysql.Length)
@@ -255,70 +255,79 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc
255255

256256
switch duckName {
257257
case "FLOAT":
258-
return types.Float32
258+
return types.Float32, nil
259259
case "DOUBLE":
260-
return types.Float64
260+
return types.Float64, nil
261261

262262
case "TIMESTAMP", "TIMESTAMP_S", "TIMESTAMP_MS":
263263
if mysqlName == "DATETIME" {
264-
return types.MustCreateDatetimeType(sqltypes.Datetime, precision)
264+
return types.CreateDatetimeType(sqltypes.Datetime, precision)
265265
}
266-
return types.MustCreateDatetimeType(sqltypes.Timestamp, precision)
266+
return types.CreateDatetimeType(sqltypes.Timestamp, precision)
267267

268268
case "DATE":
269-
return types.Date
269+
return types.Date, nil
270270
case "INTERVAL", "TIME":
271-
return types.Time
271+
return types.Time, nil
272272

273273
case "DECIMAL":
274-
return types.MustCreateDecimalType(numericPrecision, numericScale)
274+
return types.CreateDecimalType(numericPrecision, numericScale)
275+
276+
case "UHUGEINT", "HUGEINT":
277+
// MySQL does not have these types. We store them as DECIMAL.
278+
return types.CreateDecimalType(39, 0)
279+
280+
case "VARINT":
281+
// MySQL does not have this type. We store it as DECIMAL.
282+
// Here we use the maximum supported precision for DECIMAL in MySQL.
283+
return types.CreateDecimalType(65, 0)
275284

276285
case "VARCHAR":
277286
if mysqlName == "TEXT" {
278287
if length <= types.TinyTextBlobMax {
279-
return types.TinyText
288+
return types.TinyText, nil
280289
} else if length <= types.TextBlobMax {
281-
return types.Text
290+
return types.Text, nil
282291
} else if length <= types.MediumTextBlobMax {
283-
return types.MediumText
292+
return types.MediumText, nil
284293
} else {
285-
return types.LongText
294+
return types.LongText, nil
286295
}
287296
} else if mysqlName == "VARCHAR" {
288-
return types.MustCreateString(sqltypes.VarChar, length, collation)
297+
return types.CreateString(sqltypes.VarChar, length, collation)
289298
} else if mysqlName == "CHAR" {
290-
return types.MustCreateString(sqltypes.Char, length, collation)
299+
return types.CreateString(sqltypes.Char, length, collation)
291300
} else if mysqlName == "SET" {
292-
return types.MustCreateSetType(duckType.mysql.Values, collation)
301+
return types.CreateSetType(duckType.mysql.Values, collation)
293302
}
294-
return types.Text
303+
return types.Text, nil
295304

296305
case "BLOB":
297306
if mysqlName == "BLOB" {
298307
if length <= types.TinyTextBlobMax {
299-
return types.TinyBlob
308+
return types.TinyBlob, nil
300309
} else if length <= types.TextBlobMax {
301-
return types.Blob
310+
return types.Blob, nil
302311
} else if length <= types.MediumTextBlobMax {
303-
return types.MediumBlob
312+
return types.MediumBlob, nil
304313
} else {
305-
return types.LongBlob
314+
return types.LongBlob, nil
306315
}
307316
} else if mysqlName == "VARBINARY" {
308-
return types.MustCreateBinary(sqltypes.VarBinary, length)
317+
return types.CreateBinary(sqltypes.VarBinary, length)
309318
} else if mysqlName == "BINARY" {
310-
return types.MustCreateBinary(sqltypes.Binary, length)
319+
return types.CreateBinary(sqltypes.Binary, length)
311320
}
312-
return types.Blob
321+
return types.Blob, nil
313322

314323
case "JSON":
315-
return types.JSON
324+
return types.JSON, nil
316325
case "ENUM":
317-
return types.MustCreateEnumType(duckType.mysql.Values, collation)
326+
return types.CreateEnumType(duckType.mysql.Values, collation)
318327
case "SET":
319-
return types.MustCreateSetType(duckType.mysql.Values, collation)
328+
return types.CreateSetType(duckType.mysql.Values, collation)
320329
default:
321-
panic(fmt.Sprintf("encountered unknown DuckDB type(%v). This is likely a bug - please check the duckdbDataType function for missing type mappings", duckType))
330+
return nil, fmt.Errorf("encountered unknown DuckDB type(%v)", duckType)
322331
}
323332
}
324333

main_test.go

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,45 +1110,15 @@ func TestCreateTable(t *testing.T) {
11101110
"create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_unique_index(b1(123),_b2(456)))",
11111111
"create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_index(b1(10)),_index(b2(20)),_index(b1(123),_b2(456)))",
11121112
"create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_index(b1(10)),_index(b2(20)),_index(b1(123),_b2(456)))",
1113-
"CREATE_TABLE_t1_as_select_*_from_mytable",
1114-
"CREATE_TABLE_t1_as_select_*_from_mytable",
1115-
"CREATE_TABLE_t1_as_select_*_from_mytable#01",
1116-
"CREATE_TABLE_t1_as_select_*_from_mytable",
1117-
"CREATE_TABLE_t1_as_select_s,_i_from_mytable",
1118-
"CREATE_TABLE_t1_as_select_s,_i_from_mytable",
1119-
"CREATE_TABLE_t1_as_select_distinct_s,_i_from_mytable",
1120-
"CREATE_TABLE_t1_as_select_distinct_s,_i_from_mytable",
1121-
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s",
1122-
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s",
1113+
// SUM(VARCHAR) is not supported by DuckDB
11231114
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s",
1124-
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s",
1125-
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s_having_sum(i)_>_2",
11261115
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s_having_sum(i)_>_2",
1127-
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s_limit_1",
1128-
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s_limit_1",
1129-
"CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable",
1130-
"CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable",
11311116
"display_width_for_numeric_types",
11321117
"SHOW_FULL_FIELDS_FROM_numericDisplayWidthTest;",
11331118
"datetime_precision",
11341119
"CREATE_TABLE_tt_(pk_int_primary_key,_d_datetime(6)_default_current_timestamp(6))",
11351120
"Identifier_lengths",
11361121
"table_charset_options",
1137-
"show_create_table_t3",
1138-
"show_create_table_t4",
1139-
"create_table_with_select_preserves_default",
1140-
"create_table_t1_select_*_from_a;",
1141-
"create_table_t2_select_j_from_a;",
1142-
"create_table_t3_select_j_as_i_from_a;",
1143-
"create_table_t4_select_j_+_1_from_a;",
1144-
"create_table_t5_select_a.j_from_a;",
1145-
"create_table_t6_select_sqa.j_from_(select_i,_j_from_a)_sqa;",
1146-
"show_create_table_t7;",
1147-
"create_table_t8_select_*_from_(select_*_from_a)_a_join_(select_*_from_b)_b;",
1148-
"show_create_table_t9;",
1149-
"create_table_t11_select_sum(j)_over()_as_jj_from_a;",
1150-
"create_table_t12_select_j_from_a_group_by_j;",
1151-
"create_table_t13_select_*_from_c;",
11521122
"event_contains_CREATE_TABLE_AS",
11531123
"CREATE_EVENT_foo_ON_SCHEDULE_EVERY_1_YEAR_DO_CREATE_TABLE_bar_AS_SELECT_1;",
11541124
"trigger_contains_CREATE_TABLE_AS",

0 commit comments

Comments
 (0)