diff --git a/notify.go b/notify.go index ff0b04c0..5c421fdb 100644 --- a/notify.go +++ b/notify.go @@ -4,6 +4,7 @@ package pq // This module contains support for Postgres LISTEN/NOTIFY. import ( + "context" "database/sql/driver" "errors" "fmt" @@ -40,6 +41,51 @@ func SetNotificationHandler(c driver.Conn, handler func(*Notification)) { c.(*conn).notificationHandler = handler } +// NotificationHandlerConnector wraps a regular connector and sets a notification handler +// on it. +type NotificationHandlerConnector struct { + driver.Connector + notificationHandler func(*Notification) +} + +// Connect calls the underlying connector's connect method and then sets the +// notification handler. +func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { + c, err := n.Connector.Connect(ctx) + if err == nil { + SetNotificationHandler(c, n.notificationHandler) + } + return c, err +} + +// ConnectorNotificationHandler returns the currently set notification handler, if any. If +// the given connector is not a result of ConnectorWithNotificationHandler, nil is +// returned. +func ConnectorNotificationHandler(c driver.Connector) func(*Notification) { + if c, ok := c.(*NotificationHandlerConnector); ok { + return c.notificationHandler + } + return nil +} + +// ConnectorWithNotificationHandler creates or sets the given handler for the given +// connector. If the given connector is a result of calling this function +// previously, it is simply set on the given connector and returned. Otherwise, +// this returns a new connector wrapping the given one and setting the notification +// handler. A nil notification handler may be used to unset it. +// +// The returned connector is intended to be used with database/sql.OpenDB. +// +// Note: Notification handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func ConnectorWithNotificationHandler(c driver.Connector, handler func(*Notification)) *NotificationHandlerConnector { + if c, ok := c.(*NotificationHandlerConnector); ok { + c.notificationHandler = handler + return c + } + return &NotificationHandlerConnector{Connector: c, notificationHandler: handler} +} + const ( connStateIdle int32 = iota connStateExpectResponse diff --git a/notify_test.go b/notify_test.go index 075666dd..ed980710 100644 --- a/notify_test.go +++ b/notify_test.go @@ -1,6 +1,8 @@ package pq import ( + "database/sql" + "database/sql/driver" "errors" "fmt" "io" @@ -568,3 +570,43 @@ func TestListenerPing(t *testing.T) { t.Fatalf("expected errListenerClosed; got %v", err) } } + +func TestConnectorWithNotificationHandler_Simple(t *testing.T) { + b, err := NewConnector("") + if err != nil { + t.Fatal(err) + } + var notification *Notification + // Make connector w/ handler to set the local var + c := ConnectorWithNotificationHandler(b, func(n *Notification) { notification = n }) + sendNotification(c, t, "Test notification #1") + if notification == nil || notification.Extra != "Test notification #1" { + t.Fatalf("Expected notification w/ message, got %v", notification) + } + // Unset the handler on the same connector + prevC := c + if c = ConnectorWithNotificationHandler(c, nil); c != prevC { + t.Fatalf("Expected to not create new connector but did") + } + sendNotification(c, t, "Test notification #2") + if notification == nil || notification.Extra != "Test notification #1" { + t.Fatalf("Expected notification to not change, got %v", notification) + } + // Set it back on the same connector + if c = ConnectorWithNotificationHandler(c, func(n *Notification) { notification = n }); c != prevC { + t.Fatal("Expected to not create new connector but did") + } + sendNotification(c, t, "Test notification #3") + if notification == nil || notification.Extra != "Test notification #3" { + t.Fatalf("Expected notification w/ message, got %v", notification) + } +} + +func sendNotification(c driver.Connector, t *testing.T, escapedNotification string) { + db := sql.OpenDB(c) + defer db.Close() + sql := fmt.Sprintf("LISTEN foo; NOTIFY foo, '%s';", escapedNotification) + if _, err := db.Exec(sql); err != nil { + t.Fatal(err) + } +}