@@ -40,9 +40,13 @@ func New(config Config) gorm.Dialector {
4040}
4141
4242func (dialector Dialector ) Initialize (db * gorm.DB ) (err error ) {
43-
4443 // register callbacks
45- callbacks .RegisterDefaultCallbacks (db , & callbacks.Config {})
44+ callbacks .RegisterDefaultCallbacks (db , & callbacks.Config {
45+ CreateClauses : []string {"INSERT" , "VALUES" , "ON CONFLICT" },
46+ QueryClauses : []string {"SELECT" , "FROM" , "WHERE" , "GROUP BY" , "ORDER BY" , "LIMIT" , "FOR" },
47+ UpdateClauses : []string {"UPDATE" , "SET" , "RETURNING" , "FROM" , "WHERE" },
48+ DeleteClauses : []string {"DELETE" , "FROM" , "RETURNING" , "WHERE" },
49+ })
4650 db .Callback ().Create ().Replace ("gorm:create" , Create )
4751 db .Callback ().Update ().Replace ("gorm:update" , Update )
4852
@@ -97,6 +101,34 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
97101 }
98102 }
99103 },
104+ "RETURNING" : func (c clause.Clause , builder clause.Builder ) {
105+ if returning , ok := c .Expression .(clause.Returning ); ok {
106+ if stmt , ok := builder .(* gorm.Statement ); ok {
107+ var outputTable string
108+ if _ , ok := stmt .Clauses ["UPDATE" ]; ok {
109+ outputTable = "INSERTED"
110+ } else if _ , ok := stmt .Clauses ["DELETE" ]; ok {
111+ outputTable = "DELETED"
112+ }
113+
114+ if outputTable != "" {
115+ stmt .WriteString ("OUTPUT " )
116+
117+ if len (returning .Columns ) > 0 {
118+ columns := []clause.Column {}
119+ for _ , column := range returning .Columns {
120+ column .Table = outputTable
121+ columns = append (columns , column )
122+ }
123+ returning .Columns = columns
124+ returning .Build (stmt )
125+ } else {
126+ stmt .WriteString (outputTable + ".*" )
127+ }
128+ }
129+ }
130+ }
131+ },
100132 }
101133}
102134
0 commit comments