diff --git a/hooks.go b/hooks.go index 4da709f7..4ff75777 100644 --- a/hooks.go +++ b/hooks.go @@ -55,6 +55,7 @@ const ( StoredInflightMessages StoredRetainedMessages StoredSysInfo + StoredClientByID ) var ( @@ -114,6 +115,7 @@ type Hook interface { StoredInflightMessages() ([]storage.Message, error) StoredRetainedMessages() ([]storage.Message, error) StoredSysInfo() (storage.SystemInfo, error) + StoredClientByID(id string, username []byte) (string, []storage.Subscription, []storage.Message, error) } // HookOptions contains values which are inherited from the server on initialisation. @@ -679,6 +681,25 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { return false } +// StoredClientByID returns the state of the stored client with the given session ID, if any. +func (h *Hooks) StoredClientByID(id string, username []byte) (oldRemote string, subs []storage.Subscription, msgs []storage.Message, err error) { + for _, hook := range h.GetAll() { + if hook.Provides(StoredClientByID) { + oldRemote, subs, msgs, err = hook.StoredClientByID(id, username) + if err != nil { + h.Log.Error("failed to load client by ID", "error", err, "hook", hook.ID()) + return + } + + if oldRemote != "" && err == nil { + return + } + } + } + + return +} + // HookBase provides a set of default methods for each hook. It should be embedded in // all hooks. type HookBase struct { @@ -859,3 +880,8 @@ func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) { func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) { return } + +// StoredClientByID returns the state of the stored client with the given session ID, if any. +func (h *HookBase) StoredClientByID(id string, username []byte) (oldRemote string, subs []storage.Subscription, msgs []storage.Message, err error) { + return +} diff --git a/server.go b/server.go index 4ad91822..0430b700 100644 --- a/server.go +++ b/server.go @@ -485,6 +485,11 @@ func (s *Server) attachClient(cl *Client, listener string) error { expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) s.hooks.OnDisconnect(cl, err, expire) + if s.hooks.Provides(StoredClientByID) { + // Hooks are capable of reloading a persistent client session, so I can forget it + expire = true + } + if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 { cl.ClearInflights() s.UnsubscribeClient(cl) @@ -596,6 +601,42 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { return true // [MQTT-3.2.2-3] } + // Look up a stored client that's not in memory yet: + if s.hooks.Provides(StoredClientByID) { + oldRemote, subs, msgs, err := s.hooks.StoredClientByID(cl.ID, cl.Properties.Username) + if err == nil && oldRemote != "" { + // Instantiate in-flight messages to deliver: + if len(msgs) > 0 { + inf := NewInflights() + for _, msg := range msgs { + inf.Set(msg.ToPacket()) + } + cl.State.Inflight = inf + } + + // Instantiate stored subscriptions: + for _, sub := range subs { + sb := packets.Subscription{ + Filter: sub.Filter, + RetainHandling: sub.RetainHandling, + Qos: sub.Qos, + RetainAsPublished: sub.RetainAsPublished, + NoLocal: sub.NoLocal, + Identifier: sub.Identifier, + } + existed := !s.Topics.Subscribe(cl.ID, sb) // [MQTT-3.8.4-3] + if !existed { + atomic.AddInt64(&s.Info.Subscriptions, 1) + } + cl.State.Subscriptions.Add(sb.Filter, sb) + } + + s.Log.Debug("session taken over (persistent)", "client", cl.ID, "old_remote", oldRemote, "new_remote", cl.Net.Remote) + + return true + } + } + if atomic.LoadInt64(&s.Info.ClientsConnected) > atomic.LoadInt64(&s.Info.ClientsMaximum) { atomic.AddInt64(&s.Info.ClientsMaximum, 1) } @@ -1014,6 +1055,7 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { } } +// publishToClient delivers a published message to a single subscriber client. func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) { if sub.NoLocal && pk.Origin == cl.ID { return pk, nil // [MQTT-3.8.3-3] @@ -1636,24 +1678,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) { // loadClients restores clients from the datastore. func (s *Server) loadClients(v []storage.Client) { for _, c := range v { - cl := s.NewClient(nil, c.Listener, c.ID, false) - cl.Properties.Username = c.Username - cl.Properties.Clean = c.Clean - cl.Properties.ProtocolVersion = c.ProtocolVersion - cl.Properties.Props = packets.Properties{ - SessionExpiryInterval: c.Properties.SessionExpiryInterval, - SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag, - AuthenticationMethod: c.Properties.AuthenticationMethod, - AuthenticationData: c.Properties.AuthenticationData, - RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag, - RequestProblemInfo: c.Properties.RequestProblemInfo, - RequestResponseInfo: c.Properties.RequestResponseInfo, - ReceiveMaximum: c.Properties.ReceiveMaximum, - TopicAliasMaximum: c.Properties.TopicAliasMaximum, - User: c.Properties.User, - MaximumPacketSize: c.Properties.MaximumPacketSize, - } - cl.Properties.Will = Will(c.Will) + cl := s.newClientFromStorage(&c) // cancel the context, update cl.State such as disconnected time and stopCause. cl.Stop(packets.ErrServerShuttingDown) @@ -1669,6 +1694,29 @@ func (s *Server) loadClients(v []storage.Client) { } } +// newClientFromStorage creates a Client from a storage.Client. +func (s *Server) newClientFromStorage(c *storage.Client) *Client { + cl := s.NewClient(nil, c.Listener, c.ID, false) + cl.Properties.Username = c.Username + cl.Properties.Clean = c.Clean + cl.Properties.ProtocolVersion = c.ProtocolVersion + cl.Properties.Props = packets.Properties{ + SessionExpiryInterval: c.Properties.SessionExpiryInterval, + SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag, + AuthenticationMethod: c.Properties.AuthenticationMethod, + AuthenticationData: c.Properties.AuthenticationData, + RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag, + RequestProblemInfo: c.Properties.RequestProblemInfo, + RequestResponseInfo: c.Properties.RequestResponseInfo, + ReceiveMaximum: c.Properties.ReceiveMaximum, + TopicAliasMaximum: c.Properties.TopicAliasMaximum, + User: c.Properties.User, + MaximumPacketSize: c.Properties.MaximumPacketSize, + } + cl.Properties.Will = Will(c.Will) + return cl +} + // loadInflight restores inflight messages from the datastore. func (s *Server) loadInflight(v []storage.Message) { for _, msg := range v {