Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,16 @@ func (s *Session) addConnection(connID int64, conn *connection) {
defer s.Unlock()

s.conns[connID] = conn

// only increment and only on client
if s.clientKey == "client" && s.nextConnID < connID {
s.nextConnID = connID
}

if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}

}

// removeConnection safely removes a connection by ID, returning the connection object
Expand All @@ -111,6 +118,7 @@ func (s *Session) removeConnection(connID int64) *connection {
// The session lock must be held by the caller when calling this method
func (s *Session) removeConnectionLocked(connID int64) *connection {
conn := s.conns[connID]

delete(s.conns, connID)
return conn
}
Expand All @@ -124,7 +132,9 @@ func (s *Session) getConnection(connID int64) *connection {
}

// activeConnectionIDs returns an ordered list of IDs for the currently active connections
func (s *Session) activeConnectionIDs() []int64 {
// it also returns s.nextConnID (a hack to have both the list of conns and nextConnID read in the same lock)
// TODO: seperate latter functionality
func (s *Session) activeConnectionIDs() ([]int64, int64) {
s.RLock()
defer s.RUnlock()

Expand All @@ -133,7 +143,7 @@ func (s *Session) activeConnectionIDs() []int64 {
res = append(res, id)
}
sort.Slice(res, func(i, j int) bool { return res[i] < res[j] })
return res
return res, s.nextConnID
}

// addSessionKey registers a new session key for a given client key
Expand Down Expand Up @@ -190,6 +200,7 @@ func (s *Session) startPings(rootCtx context.Context) {
if err := s.sendSyncConnections(); err != nil {
logrus.WithError(err).Error("Error syncing connections")
}

case <-t.C:
if err := s.sendPing(); err != nil {
logrus.WithError(err).Error("Error writing ping")
Expand Down
4 changes: 2 additions & 2 deletions session_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ func (s *Session) syncConnections(r io.Reader) error {
if err != nil {
return fmt.Errorf("reading message body: %w", err)
}
clientActiveConnections, err := decodeConnectionIDs(payload)
clientActiveConnections, top, err := decodeConnectionIDs(payload)
if err != nil {
return fmt.Errorf("decoding sync connections payload: %w", err)
}

s.compareAndCloseStaleConnections(clientActiveConnections)
s.compareAndCloseStaleConnections(clientActiveConnections, top)
return nil
}

Expand Down
43 changes: 29 additions & 14 deletions session_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,70 @@ import (

var errCloseSyncConnections = errors.New("sync from client")

// encodeConnectionIDs serializes a slice of connection IDs
func encodeConnectionIDs(ids []int64) []byte {
payload := make([]byte, 0, 8*len(ids))
// encodeConnectionIDs serializes a slice of connection IDs and the topmost connection ID seen
func encodeConnectionIDs(top int64, ids []int64) []byte {
payload := make([]byte, 0, 8*(len(ids)+1))

// send top to denote the latest ID this packet was send with knowledge of
payload = binary.LittleEndian.AppendUint64(payload, uint64(top))

for _, id := range ids {
payload = binary.LittleEndian.AppendUint64(payload, uint64(id))
}
return payload
}

// decodeConnectionIDs deserializes a slice of connection IDs
func decodeConnectionIDs(payload []byte) ([]int64, error) {
// decodeConnectionIDs deserializes a slice of connection IDs along with the highest seen
func decodeConnectionIDs(payload []byte) ([]int64, int64, error) {
if len(payload)%8 != 0 {
return nil, fmt.Errorf("incorrect data format")
return nil, 0, fmt.Errorf("incorrect data format")
}
result := make([]int64, 0, len(payload)/8)
for x := 0; x < len(payload); x += 8 {
top := int64(binary.LittleEndian.Uint64(payload[0 : 0+8]))

result := make([]int64, 0, (len(payload)/8)-1)
for x := 8; x < len(payload); x += 8 {
id := binary.LittleEndian.Uint64(payload[x : x+8])
result = append(result, int64(id))
}
return result, nil
return result, top, nil
}

func newSyncConnectionsMessage(connectionIDs []int64) *message {
func newSyncConnectionsMessage(top int64, connectionIDs []int64) *message {
return &message{
id: nextid(),
messageType: SyncConnections,
bytes: encodeConnectionIDs(connectionIDs),
bytes: encodeConnectionIDs(top, connectionIDs),
}
}

// sendSyncConnections sends a binary message of type SyncConnections, whose payload is a list of the active connection IDs for this session
func (s *Session) sendSyncConnections() error {
_, err := s.writeMessage(time.Now().Add(SyncConnectionsTimeout), newSyncConnectionsMessage(s.activeConnectionIDs()))
act, top := s.activeConnectionIDs()

_, err := s.writeMessage(time.Now().Add(SyncConnectionsTimeout), newSyncConnectionsMessage(top, act))
return err
}

// compareAndCloseStaleConnections compares the Session's activeConnectionIDs with the provided list from the client, then closing every connection not present in it
func (s *Session) compareAndCloseStaleConnections(clientIDs []int64) {
serverIDs := s.activeConnectionIDs()
func (s *Session) compareAndCloseStaleConnections(clientIDs []int64, top int64) {
serverIDs, _ := s.activeConnectionIDs()
toClose := diffSortedSetsGetRemoved(serverIDs, clientIDs)
if len(toClose) == 0 {
return
}

s.Lock()
defer s.Unlock()

for _, id := range toClose {
// dont close connection if packet contains id not ever seen by client
if id > top {
break // not continue as toClose is sorted
}

// Connection no longer active in the client, close it server-side
conn := s.removeConnectionLocked(id)

if conn != nil {
// Using doTunnelClose directly instead of tunnelClose, omitting unnecessarily sending an Error message
conn.doTunnelClose(errCloseSyncConnections)
Expand Down
28 changes: 20 additions & 8 deletions session_sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ func Test_encodeConnectionIDs(t *testing.T) {
tt := tests[x]
t.Run(fmt.Sprintf("%d_ids", tt.size), func(t *testing.T) {
t.Parallel()
ids := generateIDs(tt.size)
encoded := encodeConnectionIDs(ids)
decoded, err := decodeConnectionIDs(encoded)
ids, top := generateIDs(tt.size)
encoded := encodeConnectionIDs(top, ids)
decoded, decodedTop, err := decodeConnectionIDs(encoded)
if err != nil {
t.Error(err)
}
if got, want := decodedTop, top; !reflect.DeepEqual(got, want) {
t.Errorf("encoding and decoding differs from original data, got: %v, want: %v", got, want)
}

if got, want := decoded, ids; !reflect.DeepEqual(got, want) {
t.Errorf("encoding and decoding differs from original data, got: %v, want: %v", got, want)
}
Expand Down Expand Up @@ -93,7 +97,7 @@ func TestSession_sendSyncConnections(t *testing.T) {
session := newSession(rand.Int63(), "sync-test", newWSConn(conn))

for _, n := range []int{0, 5, 20} {
ids := generateIDs(n)
ids, _ := generateIDs(n)
for _, id := range ids {
session.conns[id] = nil
}
Expand All @@ -114,20 +118,28 @@ func TestSession_sendSyncConnections(t *testing.T) {
if got, want := message.messageType, SyncConnections; got != want {
t.Errorf("incorrect message type, got: %v, want: %v", got, want)
}
if decoded, err := decodeConnectionIDs(payload); err != nil {

decoded, decodedTop, err := decodeConnectionIDs(payload)
if err != nil {
t.Fatal(err)
} else if got, want := decoded, session.activeConnectionIDs(); !reflect.DeepEqual(got, want) {
return
}
returnedIDs, returnedTop := session.activeConnectionIDs()
if got, want := decodedTop, returnedTop; !reflect.DeepEqual(got, want) {
t.Errorf("incorrect connections IDs, got: %v, want: %v", got, want)
}
if got, want := decoded, returnedIDs; !reflect.DeepEqual(got, want) {
t.Errorf("incorrect connections IDs, got: %v, want: %v", got, want)
}
}
}

func generateIDs(n int) []int64 {
func generateIDs(n int) ([]int64, int64) {
ids := make([]int64, n)
for x := range ids {
ids[x] = rand.Int63()
}
return ids
return ids, rand.Int63()
}

func testServerWS(t *testing.T, data chan<- []byte) *websocket.Conn {
Expand Down
4 changes: 3 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ func TestSession_activeConnectionIDs(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
session := Session{conns: tt.conns}
if got, want := session.activeConnectionIDs(), tt.expected; !reflect.DeepEqual(got, want) {

returnedIDs, _ := session.activeConnectionIDs()
if got, want := returnedIDs, tt.expected; !reflect.DeepEqual(got, want) {
t.Errorf("incorrect result, got: %v, want: %v", got, want)
}
})
Expand Down