Skip to content
8 changes: 6 additions & 2 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (association *Association) Count() int {
}

if err := query.Model(fieldValue).Count(&count).Error; err != nil {
association.Error = err
association.Error = NewGormError(err, query.SQL)
}
return count
}
Expand Down Expand Up @@ -371,7 +371,11 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
// setErr set error when the error is not nil. And return Association.
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
if association.scope != nil {
association.Error = NewGormError(err, association.scope.SQL)
} else {
association.Error = err
}
}
return association
}
4 changes: 4 additions & 0 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func createCallback(scope *Scope) {
))
}

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
Expand Down
4 changes: 4 additions & 0 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func queryCallback(scope *Scope) {

scope.prepareQuerySQL()

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

if !scope.HasError() {
scope.db.RowsAffected = 0

Expand Down
12 changes: 11 additions & 1 deletion callback_row_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,24 @@ func rowQueryCallback(scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

if str, ok := scope.Get("gorm:query_hint"); ok {
scope.SQL = fmt.Sprint(str) + scope.SQL
}

if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rowsResult.Rows = rows
if rowsResult.Error != nil {
rowsResult.Error = NewGormError(err, scope.SQL)
} else {
rowsResult.Error = err
}
}
}
}
4 changes: 2 additions & 2 deletions create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ func TestFixFullTableScanWhenInsertIgnore(t *testing.T) {

DB.Callback().Query().Register("gorm:fix_full_table_scan", func(scope *gorm.Scope) {
if strings.Contains(scope.SQL, "SELECT") && strings.Contains(scope.SQL, "pandas") && len(scope.SQLVars) == 0 {
t.Error("Should skip force reload when ignore duplicate panda insert")
t.Error("Should skip force reload when ignore duplicate panda insert")
}
})

if DB.Dialect().GetName() == "mysql" && DB.Set("gorm:insert_modifier", "IGNORE").Create(&pandaYuanYuan).Error != nil {
t.Error("Should ignore duplicate panda insert by insert modifier:IGNORE ")
}
}
}
29 changes: 27 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
)

var (
// ErrNoRecordsInResultSetSQL sql native error on querying with .row() function or similar
ErrNoRecordsInResultSetSQL = errors.New("sql: no rows in result set")
// ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL occurs when you attempt a query with invalid SQL
Expand All @@ -23,14 +25,17 @@ type Errors []error

// IsRecordNotFoundError returns true if error contains a RecordNotFound error
func IsRecordNotFoundError(err error) bool {
if err == nil {
return false
}
if errs, ok := err.(Errors); ok {
for _, err := range errs {
if err == ErrRecordNotFound {
if err.Error() == ErrRecordNotFound.Error() || err.Error() == ErrNoRecordsInResultSetSQL.Error() {
return true
}
}
}
return err == ErrRecordNotFound
return err.Error() == ErrRecordNotFound.Error() || err.Error() == ErrNoRecordsInResultSetSQL.Error()
}

// GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
Expand Down Expand Up @@ -70,3 +75,23 @@ func (errs Errors) Error() string {
}
return strings.Join(errors, "; ")
}

// GormError is a custom error with the error and the SQL executed.
type GormError struct {
Err error
SQL string
}

// New is a construtor of custom error.
func NewGormError(err error, sql string) GormError {
return GormError{err, sql}
}

// Error return the error message.
func (e GormError) Error() string {
if e.Err != nil {
return e.Err.Error()
} else {
return "unexpected error"
}
}
75 changes: 75 additions & 0 deletions formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package gorm

import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)

func FormatSQL(sql string, values ...interface{}) string {
if len(values) > 0 {

formattedValues := []string{}

// duration
for _, value := range values {
indirectValue := reflect.Indirect(reflect.ValueOf(value))
if indirectValue.IsValid() {
value = indirectValue.Interface()
if t, ok := value.(time.Time); ok {
if t.IsZero() {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
} else {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
}
} else if b, ok := value.([]byte); ok {
if str := string(b); isPrintable(str) {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
} else {
formattedValues = append(formattedValues, "'<binary>'")
}
} else if r, ok := value.(driver.Valuer); ok {
if value, err := r.Value(); err == nil && value != nil {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} else {
formattedValues = append(formattedValues, "NULL")
}
} else {
switch value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
default:
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
}
}
} else {
formattedValues = append(formattedValues, "NULL")
}
}

// differentiate between $n placeholders or else treat like ?
if numericPlaceHolderRegexp.MatchString(sql) {
for index, value := range formattedValues {
placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
}
} else {
formattedValuesLength := len(formattedValues)
for index, value := range sqlRegexp.Split(sql, -1) {
sql += value
if index < formattedValuesLength {
sql += formattedValues[index]
}
}
}

}

sql = strings.ReplaceAll(sql, "\n", "")
sql = strings.ReplaceAll(sql, "\t", "")

return sql
}
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA=
github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
12 changes: 7 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (

// DB contains information for current db connection
type DB struct {
SQL string

sync.RWMutex
Value interface{}
Error error
Expand Down Expand Up @@ -83,9 +85,9 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}

db = &DB{
db: dbSQL,
logger: defaultLogger,
db: dbSQL,
logger: defaultLogger,

// Create a clone of the default logger to avoid mutating a shared object when
// multiple gorm connections are created simultaneously.
callbacks: DefaultCallback.clone(defaultLogger),
Expand Down Expand Up @@ -627,7 +629,7 @@ func (s *DB) NewRecord(value interface{}) bool {
// RecordNotFound check if returning ErrRecordNotFound error
func (s *DB) RecordNotFound() bool {
for _, err := range s.GetErrors() {
if err == ErrRecordNotFound {
if err != nil && err.Error() == ErrRecordNotFound.Error() {
return true
}
}
Expand Down Expand Up @@ -825,7 +827,7 @@ func (s *DB) AddError(err error) error {
}
}

s.Error = err
s.Error = NewGormError(err, s.SQL)
}
return err
}
Expand Down
2 changes: 1 addition & 1 deletion main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ func TestRaw(t *testing.T) {
}

DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
if err := DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error("Raw sql to update records")
}
}
Expand Down
4 changes: 2 additions & 2 deletions migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ func getPreparedUser(name string, role string) *User {
}

type Panda struct {
Number int64 `gorm:"unique_index:number"`
Name string `gorm:"column:name;type:varchar(255);default:null"`
Number int64 `gorm:"unique_index:number"`
Name string `gorm:"column:name;type:varchar(255);default:null"`
}

func runMigration() {
Expand Down
6 changes: 3 additions & 3 deletions preload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestNestedPreload1(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}

if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error(err)
}
}
Expand Down Expand Up @@ -1104,7 +1104,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}

if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error(err)
}
}
Expand Down Expand Up @@ -1161,7 +1161,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}

if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != nil && err.Error() != gorm.ErrRecordNotFound.Error() {
t.Error(err)
}
}
Expand Down
4 changes: 4 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ func (scope *Scope) Raw(sql string) *Scope {
func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc())

if scope.DB() != nil {
scope.DB().SQL = FormatSQL(scope.SQL, scope.SQLVars...)
}

if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
Expand Down