@@ -18,41 +18,48 @@ import (
1818 "bufio"
1919 "encoding/json"
2020 "fmt"
21+ "io"
2122 "net"
2223 "runtime/debug"
2324 "sync"
2425 "sync/atomic"
2526 "time"
2627
27- "github.com/emitter-io/emitter/internal/broker/keygen"
2828 "github.com/emitter-io/emitter/internal/errors"
2929 "github.com/emitter-io/emitter/internal/event"
3030 "github.com/emitter-io/emitter/internal/message"
3131 "github.com/emitter-io/emitter/internal/network/mqtt"
3232 "github.com/emitter-io/emitter/internal/provider/contract"
3333 "github.com/emitter-io/emitter/internal/provider/logging"
3434 "github.com/emitter-io/emitter/internal/security"
35+ "github.com/emitter-io/emitter/internal/service/keygen"
3536 "github.com/emitter-io/stats"
37+ "github.com/kelindar/binary"
3638 "github.com/kelindar/binary/nocopy"
3739 "github.com/kelindar/rate"
3840)
3941
4042const defaultReadRate = 100000
4143
44+ type response interface {
45+ ForRequest (uint16 )
46+ }
47+
4248// Conn represents an incoming connection.
4349type Conn struct {
4450 sync.Mutex
4551 tracked uint32 // Whether the connection was already tracked or not.
4652 socket net.Conn // The transport used to read and write messages.
47- username string // The username provided by the client during MQTT connect.
4853 luid security.ID // The locally unique id of the connection.
4954 guid string // The globally unique id of the connection.
5055 service * Service // The service for this connection.
5156 subs * message.Counters // The subscriptions for this connection.
5257 measurer stats.Measurer // The measurer to use for monitoring.
53- links map [string ]string // The map of all pre-authorized links.
5458 limit * rate.Limiter // The read rate limiter.
55- keys * keygen.Provider // The key generation provider.
59+ keys * keygen.Service // The key generation provider.
60+ connect * event.Connection // The associated connection event.
61+ username string // The username provided by the client during MQTT connect.
62+ links map [string ]string // The map of all pre-authorized links.
5663}
5764
5865// NewConn creates a new connection.
@@ -65,7 +72,7 @@ func (s *Service) newConn(t net.Conn, readRate int) *Conn {
6572 subs : message .NewCounters (),
6673 measurer : s .measurer ,
6774 links : map [string ]string {},
68- keys : s .Keygen ,
75+ keys : s .keygen ,
6976 }
7077
7178 // Generate a globally unique id as well
@@ -86,6 +93,34 @@ func (c *Conn) ID() string {
8693 return c .guid
8794}
8895
96+ // LocalID returns the local connection identifier.
97+ func (c * Conn ) LocalID () security.ID {
98+ return c .luid
99+ }
100+
101+ // Username returns the associated username.
102+ func (c * Conn ) Username () string {
103+ return c .username
104+ }
105+
106+ // GetLink checks if the topic is a registered shortcut and expands it.
107+ func (c * Conn ) GetLink (topic []byte ) []byte {
108+ if len (topic ) <= 2 && c .links != nil {
109+ return []byte (c .links [binary .ToString (& topic )])
110+ }
111+ return topic
112+ }
113+
114+ // AddLink adds a link alias for a channel.
115+ func (c * Conn ) AddLink (alias string , channel * security.Channel ) {
116+ c .links [alias ] = channel .String ()
117+ }
118+
119+ // Links returns a map of all links registered.
120+ func (c * Conn ) Links () map [string ]string {
121+ return c .links
122+ }
123+
89124// Type returns the type of the subscriber
90125func (c * Conn ) Type () message.SubscriberType {
91126 return message .SubscriberDirect
@@ -96,8 +131,8 @@ func (c *Conn) MeasureElapsed(name string, since time.Time) {
96131 c .measurer .MeasureElapsed (name , time .Now ())
97132}
98133
99- // track tracks the connection by adding it to the metering.
100- func (c * Conn ) track (contract contract.Contract ) {
134+ // Track tracks the connection by adding it to the metering.
135+ func (c * Conn ) Track (contract contract.Contract ) {
101136 if atomic .LoadUint32 (& c .tracked ) == 0 {
102137
103138 // We keep only the IP address for fair tracking
@@ -112,6 +147,16 @@ func (c *Conn) track(contract contract.Contract) {
112147 }
113148}
114149
150+ // Increment increments the subscription counter.
151+ func (c * Conn ) Increment (ssid message.Ssid , channel []byte ) bool {
152+ return c .subs .Increment (ssid , channel )
153+ }
154+
155+ // Decrement decrements a subscription counter.
156+ func (c * Conn ) Decrement (ssid message.Ssid ) bool {
157+ return c .subs .Decrement (ssid )
158+ }
159+
115160// Process processes the messages.
116161func (c * Conn ) Process () error {
117162 defer c .Close ()
@@ -166,7 +211,7 @@ func (c *Conn) onReceive(msg mqtt.Message) error {
166211
167212 // Subscribe for each subscription
168213 for _ , sub := range packet .Subscriptions {
169- if err := c .onSubscribe ( sub .Topic ); err != nil {
214+ if err := c .service . pubsub . OnSubscribe ( c , sub .Topic ); err != nil {
170215 ack .Qos = append (ack .Qos , 0x80 ) // 0x80 indicate subscription failure
171216 c .notifyError (err , packet .MessageID )
172217 continue
@@ -188,7 +233,7 @@ func (c *Conn) onReceive(msg mqtt.Message) error {
188233
189234 // Unsubscribe from each subscription
190235 for _ , sub := range packet .Topics {
191- if err := c .onUnsubscribe ( sub .Topic ); err != nil {
236+ if err := c .service . pubsub . OnUnsubscribe ( c , sub .Topic ); err != nil {
192237 c .notifyError (err , packet .MessageID )
193238 }
194239 }
@@ -206,11 +251,11 @@ func (c *Conn) onReceive(msg mqtt.Message) error {
206251 }
207252
208253 case mqtt .TypeOfDisconnect :
209- return nil
254+ return io . EOF
210255
211256 case mqtt .TypeOfPublish :
212257 packet := msg .(* mqtt.Publish )
213- if err := c .onPublish ( packet ); err != nil {
258+ if err := c .service . pubsub . OnPublish ( c , packet ); err != nil {
214259 logging .LogError ("conn" , "publish received" , err )
215260 c .notifyError (err , packet .MessageID )
216261 }
@@ -264,38 +309,41 @@ func (c *Conn) sendResponse(topic string, resp response, requestID uint16) {
264309 return
265310}
266311
267- // Subscribe subscribes to a particular channel.
268- func (c * Conn ) Subscribe (ssid message.Ssid , channel []byte ) {
312+ // CanSubscribe increments the internal counters and checks if the cluster
313+ // needs to be notified.
314+ func (c * Conn ) CanSubscribe (ssid message.Ssid , channel []byte ) bool {
269315 c .Lock ()
270316 defer c .Unlock ()
271-
272- // Add the subscription
273- if first := c .subs .Increment (ssid , channel ); first {
274- c .service .Subscribe (c , & event.Subscription {
275- Peer : c .service .ID (),
276- Conn : c .luid ,
277- User : nocopy .String (c .username ),
278- Ssid : ssid ,
279- Channel : channel ,
280- })
281- }
317+ return c .subs .Increment (ssid , channel )
282318}
283319
284- // Unsubscribe unsubscribes this client from a particular channel.
285- func (c * Conn ) Unsubscribe (ssid message.Ssid , channel []byte ) {
320+ // CanUnsubscribe decrements the internal counters and checks if the cluster
321+ // needs to be notified.
322+ func (c * Conn ) CanUnsubscribe (ssid message.Ssid , channel []byte ) bool {
286323 c .Lock ()
287324 defer c .Unlock ()
325+ return c .subs .Decrement (ssid )
326+ }
288327
289- // Decrement the counter and if there's no more subscriptions, notify everyone.
290- if last := c .subs .Decrement (ssid ); last {
291- c .service .Unsubscribe (c , & event.Subscription {
292- Peer : c .service .ID (),
293- Conn : c .luid ,
294- User : nocopy .String (c .username ),
295- Ssid : ssid ,
296- Channel : channel ,
297- })
328+ // onConnect handles the connection authorization
329+ func (c * Conn ) onConnect (packet * mqtt.Connect ) bool {
330+ c .username = string (packet .Username )
331+ c .connect = & event.Connection {
332+ Peer : c .service .ID (),
333+ Conn : c .luid ,
334+ WillFlag : packet .WillFlag ,
335+ WillRetain : packet .WillRetainFlag ,
336+ WillQoS : packet .WillQOS ,
337+ WillTopic : packet .WillTopic ,
338+ WillMessage : packet .WillMessage ,
339+ ClientID : packet .ClientID ,
340+ Username : packet .Username ,
341+ }
342+
343+ if c .service .cluster != nil {
344+ c .service .cluster .Notify (c .connect , true )
298345 }
346+ return true
299347}
300348
301349// Close terminates the connection.
@@ -308,15 +356,18 @@ func (c *Conn) Close() error {
308356 // Unsubscribe from everything, no need to lock since each Unsubscribe is
309357 // already locked. Locking the 'Close()' would result in a deadlock.
310358 for _ , counter := range c .subs .All () {
311- c .service .Unsubscribe (c , & event.Subscription {
359+ c .service .pubsub . Unsubscribe (c , & event.Subscription {
312360 Peer : c .service .ID (),
313361 Conn : c .luid ,
314- User : nocopy .String (c .username ),
362+ User : nocopy .String (c .Username () ),
315363 Ssid : counter .Ssid ,
316364 Channel : counter .Channel ,
317365 })
318366 }
319367
368+ // Publish last will
369+ c .service .pubsub .OnLastWill (c , c .connect )
370+
320371 //logging.LogTarget("conn", "closed", c.guid)
321372 return c .socket .Close ()
322373}
0 commit comments