Skip to content

feat: add BeforeFind hook #7370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
createCallback.Clauses = config.CreateClauses

queryCallback := db.Callback().Query()
queryCallback.Register("gorm:before_query", BeforeQuery)

Check failure on line 51 in callbacks/callbacks.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 Error return value of `queryCallback.Register` is not checked (errcheck) Raw Output: callbacks/callbacks.go:51:24: Error return value of `queryCallback.Register` is not checked (errcheck) queryCallback.Register("gorm:before_query", BeforeQuery) ^
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
Expand Down
4 changes: 4 additions & 0 deletions callbacks/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}

type BeforeFindInterface interface {
BeforeFind(*gorm.DB) error
}

type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}
12 changes: 12 additions & 0 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ import (
"gorm.io/gorm/utils"
)

func BeforeQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeFindInterface); ok {
db.AddError(i.BeforeFind(tx))
return true
}
return false
})
}
}

func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
Expand Down
2 changes: 1 addition & 1 deletion schema/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestCallback(t *testing.T) {
}
}

for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "BeforeFind", "AfterFind"} {
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be false", str)
}
Expand Down
7 changes: 5 additions & 2 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeBeforeFind callbackType = "BeforeFind"
callbackTypeAfterFind callbackType = "AfterFind"
)

Expand Down Expand Up @@ -52,7 +53,7 @@ type Schema struct {
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
BeforeFind, AfterFind bool
err error
initialized chan struct{}
namer Namer
Expand Down Expand Up @@ -308,7 +309,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
callbackTypeBeforeFind, callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
Expand Down Expand Up @@ -396,6 +397,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeBeforeFind:
return modelType.MethodByName(string(callbackTypeBeforeFind))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
Expand Down
96 changes: 96 additions & 0 deletions tests/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,99 @@ func TestPropagateUnscoped(t *testing.T) {
t.Fatalf("unscoped did not propagate")
}
}

type Product7 struct {
gorm.Model
Code string
Price float64
BeforeFindCallTimes int64 `gorm:"-"`
}

func (s *Product7) BeforeFind(tx *gorm.DB) error {
s.BeforeFindCallTimes++
return nil
}

// Modifies transient field
func TestBeforeFindHookCallCount(t *testing.T) {
DB.Migrator().DropTable(&Product7{})
DB.AutoMigrate(&Product7{})

p := Product7{Code: "before_find_count", Price: 100}
DB.Save(&p)

var result Product7

DB.First(&result, "code = ?", "before_find_count")
if result.BeforeFindCallTimes != 1 {
t.Errorf("Expected 1, got %d", result.BeforeFindCallTimes)
}

DB.First(&result, "code = ?", "before_find_count")
if result.BeforeFindCallTimes != 2 {
t.Errorf("Expected 2, got %d", result.BeforeFindCallTimes)
}
}

type Product8 struct {
gorm.Model
Code string
Price float64
}

func (s *Product8) BeforeFind(tx *gorm.DB) error {
tx.Statement.Where("price > ?", 50)

return nil
}

func TestBeforeFindModifiesQuery(t *testing.T) {
DB.Migrator().DropTable(&Product8{})
DB.AutoMigrate(&Product8{})

p1 := Product8{Code: "A", Price: 30}
DB.Create(&p1)

var result Product8

DB.Find(&result)

if (result != Product8{}) {
t.Errorf("BeforeFind should filter results, got %v", result)
}

p2 := Product8{Code: "B", Price: 100}
DB.Create(&p2)

DB.Find(&result)

if result.Code != "B" {
t.Errorf("BeforeFind should filter results, got %v", result)
}
}

type Product9 struct {
gorm.Model
Code string
Price float64
}

func (s *Product9) BeforeFind(tx *gorm.DB) error {
s.Price = 200
return nil
}

func TestDatabaseOverwritesBeforeFindChanges(t *testing.T) {
DB.Migrator().DropTable(&Product9{})
DB.AutoMigrate(&Product9{})

p := Product9{Code: "price_overwrite", Price: 100}
DB.Save(&p)

var result Product9
DB.First(&result, "code = ?", "price_overwrite")

if result.Price != 100 {
t.Errorf("Price should be loaded from database, got %f", result.Price)
}
}
Loading