@@ -13,8 +13,9 @@ import (
1313
1414	"cloud.google.com/go/spanner" 
1515	sdb "cloud.google.com/go/spanner/admin/database/apiv1" 
16- 	"cloud.google.com/go/spanner/spansql" 
1716
17+ 	"github.com/cloudspannerecosystem/memefish" 
18+ 	"github.com/cloudspannerecosystem/memefish/token" 
1819	"github.com/golang-migrate/migrate/v4" 
1920	"github.com/golang-migrate/migrate/v4/database" 
2021
@@ -60,11 +61,9 @@ type Config struct {
6061
6162// Spanner implements database.Driver for Google Cloud Spanner 
6263type  Spanner  struct  {
63- 	db  * DB 
64- 
64+ 	db      * DB 
6565	config  * Config 
66- 
67- 	lock  * uatomic.Uint32 
66+ 	lock    * uatomic.Uint32 
6867}
6968
7069type  DB  struct  {
@@ -179,26 +178,65 @@ func (s *Spanner) Run(migration io.Reader) error {
179178		return  err 
180179	}
181180
182- 	stmts  :=  []string {string (migr )}
183- 	if  s .config .CleanStatements  {
184- 		stmts , err  =  cleanStatements (migr )
185- 		if  err  !=  nil  {
186- 			return  err 
181+ 	ctx  :=  context .Background ()
182+ 
183+ 	if  ! s .config .CleanStatements  {
184+ 		return  s .runDdl (ctx , []string {string (migr )})
185+ 	}
186+ 
187+ 	stmtGroups , err  :=  statementGroups (migr )
188+ 	if  err  !=  nil  {
189+ 		return  err 
190+ 	}
191+ 
192+ 	for  _ , group  :=  range  stmtGroups  {
193+ 		switch  group .typ  {
194+ 		case  statementTypeDDL :
195+ 			if  err  :=  s .runDdl (ctx , group .stmts ); err  !=  nil  {
196+ 				return  err 
197+ 			}
198+ 		case  statementTypeDML :
199+ 			if  err  :=  s .runDml (ctx , group .stmts ); err  !=  nil  {
200+ 				return  err 
201+ 			}
202+ 		default :
203+ 			return  fmt .Errorf ("unknown statement type: %s" , group .typ )
187204		}
188205	}
189206
190- 	ctx  :=  context .Background ()
207+ 	return  nil 
208+ }
209+ 
210+ func  (s  * Spanner ) runDdl (ctx  context.Context , stmts  []string ) error  {
191211	op , err  :=  s .db .admin .UpdateDatabaseDdl (ctx , & adminpb.UpdateDatabaseDdlRequest {
192212		Database :   s .config .DatabaseName ,
193213		Statements : stmts ,
194214	})
195215
196216	if  err  !=  nil  {
197- 		return  & database.Error {OrigErr : err , Err : "migration failed" , Query : migr }
217+ 		return  & database.Error {OrigErr : err , Err : "migration failed" , Query : [] byte ( strings . Join ( stmts ,  "; \n " )) }
198218	}
199219
200220	if  err  :=  op .Wait (ctx ); err  !=  nil  {
201- 		return  & database.Error {OrigErr : err , Err : "migration failed" , Query : migr }
221+ 		return  & database.Error {OrigErr : err , Err : "migration failed" , Query : []byte (strings .Join (stmts , ";\n " ))}
222+ 	}
223+ 
224+ 	return  nil 
225+ }
226+ 
227+ func  (s  * Spanner ) runDml (ctx  context.Context , stmts  []string ) error  {
228+ 	_ , err  :=  s .db .data .ReadWriteTransaction (ctx ,
229+ 		func (ctx  context.Context , txn  * spanner.ReadWriteTransaction ) error  {
230+ 			for  _ , s  :=  range  stmts  {
231+ 				_ , err  :=  txn .Update (ctx , spanner.Statement {SQL : s })
232+ 				if  err  !=  nil  {
233+ 					return  err 
234+ 				}
235+ 			}
236+ 			return  nil 
237+ 		})
238+ 	if  err  !=  nil  {
239+ 		return  & database.Error {OrigErr : err , Err : "migration failed" , Query : []byte (strings .Join (stmts , ";\n " ))}
202240	}
203241
204242	return  nil 
@@ -345,17 +383,80 @@ func (s *Spanner) ensureVersionTable() (err error) {
345383	return  nil 
346384}
347385
348- func  cleanStatements (migration  []byte ) ([]string , error ) {
349- 	// The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC 
350- 	// (see https://issuetracker.google.com/issues/159730604) we use 
351- 	// spansql to parse the DDL and output valid stamements without comments 
352- 	ddl , err  :=  spansql .ParseDDL ("" , string (migration ))
353- 	if  err  !=  nil  {
354- 		return  nil , err 
386+ type  statementType  string 
387+ 
388+ const  (
389+ 	statementTypeUnknown  statementType  =  "" 
390+ 	statementTypeDDL      statementType  =  "DDL" 
391+ 	statementTypeDML      statementType  =  "DML" 
392+ )
393+ 
394+ type  statementGroup  struct  {
395+ 	typ    statementType 
396+ 	stmts  []string 
397+ }
398+ 
399+ func  statementGroups (migr  []byte ) (groups  []* statementGroup , err  error ) {
400+ 	lex  :=  & memefish.Lexer {
401+ 		File : & token.File {Buffer : string (migr )},
355402	}
356- 	stmts  :=  make ([]string , 0 , len (ddl .List ))
357- 	for  _ , stmt  :=  range  ddl .List  {
358- 		stmts  =  append (stmts , stmt .SQL ())
403+ 
404+ 	group  :=  & statementGroup {}
405+ 	var  stmtTyp  statementType 
406+ 	var  stmt  strings.Builder 
407+ 	for  {
408+ 		if  err  :=  lex .NextToken (); err  !=  nil  {
409+ 			return  nil , err 
410+ 		}
411+ 
412+ 		if  stmtTyp  ==  statementTypeUnknown  {
413+ 			switch  {
414+ 			case  lex .Token .IsKeywordLike ("INSERT" ) ||  lex .Token .IsKeywordLike ("DELETE" ) ||  lex .Token .IsKeywordLike ("UPDATE" ):
415+ 				stmtTyp  =  statementTypeDML 
416+ 			default :
417+ 				stmtTyp  =  statementTypeDDL 
418+ 			}
419+ 			if  group .typ  !=  stmtTyp  {
420+ 				if  len (group .stmts ) >  0  {
421+ 					groups  =  append (groups , group )
422+ 				}
423+ 				group  =  & statementGroup {typ : stmtTyp }
424+ 			}
425+ 		}
426+ 
427+ 		if  lex .Token .Kind  ==  token .TokenEOF  ||  lex .Token .Kind  ==  ";"  {
428+ 			if  stmt .Len () >  0  {
429+ 				group .stmts  =  append (group .stmts , stmt .String ())
430+ 			}
431+ 			stmtTyp  =  statementTypeUnknown 
432+ 			stmt .Reset ()
433+ 
434+ 			if  lex .Token .Kind  ==  token .TokenEOF  {
435+ 				if  len (group .stmts ) >  0  {
436+ 					groups  =  append (groups , group )
437+ 				}
438+ 
439+ 				break 
440+ 			}
441+ 
442+ 			continue 
443+ 		}
444+ 
445+ 		if  len (lex .Token .Comments ) >  0  {
446+ 			// preserve newline where comments are removed 
447+ 			if  _ , err  :=  stmt .WriteString ("\n " ); err  !=  nil  {
448+ 				return  nil , err 
449+ 			}
450+ 		}
451+ 		if  stmt .Len () >  0  {
452+ 			if  _ , err  :=  stmt .WriteString (lex .Token .Space ); err  !=  nil  {
453+ 				return  nil , err 
454+ 			}
455+ 		}
456+ 		if  _ , err  :=  stmt .WriteString (lex .Token .Raw ); err  !=  nil  {
457+ 			return  nil , err 
458+ 		}
359459	}
360- 	return  stmts , nil 
460+ 
461+ 	return  groups , nil 
361462}
0 commit comments