Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gateway/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
ginzap "github.com/gin-contrib/zap"
"github.com/gin-gonic/gin"
"github.com/hoophq/hoop/gateway/proxyproto/ssmproxy"
"github.com/hoophq/hoop/gateway/rdp"
"go.uber.org/zap"

"github.com/hoophq/hoop/common/log"
Expand Down Expand Up @@ -139,6 +140,10 @@ func (a *Api) StartAPI(sentryInit bool) {
ssmInstance := ssmproxy.GetServerInstance()
ssmInstance.AttachHandlers(ssmGroup)

ironRdpGroup := route.Group(baseURL + "/rdpproxy")
ironRdpInstance := rdp.GetIronServerInstance()
ironRdpInstance.AttachHandlers(ironRdpGroup)

rg := route.Group(baseURL + "/api")
if sentryInit {
rg.Use(sentrygin.New(sentrygin.Options{
Expand Down
53 changes: 46 additions & 7 deletions gateway/broker/communicators.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
package broker

import (
"io"
"net"
"time"

"github.com/gorilla/websocket"
)

type ConnectionCommunicator interface {
Send(data []byte) error
Read() (int, []byte, error)
Close()
Close() error
WrapToConnection() net.Conn
}

type agentCommunicator struct{ conn *websocket.Conn }

func NewAgentCommunicator(conn *websocket.Conn) *agentCommunicator {
func NewAgentCommunicator(conn *websocket.Conn) ConnectionCommunicator {
return &agentCommunicator{conn: conn}
}

func (a *agentCommunicator) Send(data []byte) error {
return a.conn.WriteMessage(websocket.BinaryMessage, data)
}

func (a *agentCommunicator) Close() {
a.conn.Close()
func (a *agentCommunicator) Close() error {
return a.conn.Close()
}

func (a *agentCommunicator) Read() (int, []byte, error) {
Expand All @@ -34,9 +37,13 @@ func (a *agentCommunicator) Read() (int, []byte, error) {
return len(message), message, nil
}

func (a *agentCommunicator) WrapToConnection() net.Conn {
return &WSConnWrap{a.conn}
}

type clientCommunicator struct{ conn net.Conn }

func NewClientCommunicator(conn net.Conn) *clientCommunicator {
func NewClientCommunicator(conn net.Conn) ConnectionCommunicator {
return &clientCommunicator{conn: conn}
}

Expand All @@ -54,6 +61,38 @@ func (c *clientCommunicator) Send(data []byte) error {
return err
}

func (c *clientCommunicator) Close() {
c.conn.Close()
func (c *clientCommunicator) Close() error {
return c.conn.Close()
}

func (c *clientCommunicator) WrapToConnection() net.Conn {
return c.conn
}

// WSConnWrap wraps a websocket.Conn to implement the net.Conn interface.
type WSConnWrap struct {
*websocket.Conn
}

func (w *WSConnWrap) Read(b []byte) (n int, err error) {
_, message, err := w.ReadMessage()
if err != nil {
return 0, err
}
if len(message) > len(b) {
return 0, io.ErrShortBuffer
}
return copy(b, message), nil
}

func (w *WSConnWrap) Write(b []byte) (n int, err error) {
err = w.WriteMessage(websocket.BinaryMessage, b)
return len(b), err
}

func (w *WSConnWrap) SetDeadline(t time.Time) error {
if err := w.Conn.SetWriteDeadline(t); err != nil {
return err
}
return w.Conn.SetReadDeadline(t)
}
148 changes: 128 additions & 20 deletions gateway/broker/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package broker
import (
"context"
"io"
"net"
"sync"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
Expand Down Expand Up @@ -74,19 +76,19 @@ func (s *Session) Close() {

// Close consumer connection
if s.ClientCommunicator != nil {
s.ClientCommunicator.Close()
_ = s.ClientCommunicator.Close()
}

// Close agent connection
if s.AgentCommunicator != nil {
s.AgentCommunicator.Close()
_ = s.AgentCommunicator.Close()
}

// Remove from sessions map
// Remove from the sessions map
BrokerInstance.sessions.Delete(s.ID)
}

// forward data from agent to tcp
// ForwardToTCP forward data from agent to tcp
func (s *Session) ForwardToTCP(data []byte) {
s.mu.Lock()
if s.closed || s.dataChannel == nil {
Expand All @@ -96,7 +98,7 @@ func (s *Session) ForwardToTCP(data []byte) {
s.mu.Unlock()

select {
// the data is create with a buffer size buffer size of 1024
// the data is created with a buffer size of 1024
// Up to 1024 messages can be queued without blocking
//If the buffer is full, new data is dropped rather than blocking
case s.dataChannel <- data:
Expand All @@ -107,19 +109,25 @@ func (s *Session) ForwardToTCP(data []byte) {
}
}

// this will spam data from tcp to agent wsconn
func (s *Session) ForwardToAgent(data []byte) error {
// Send first RDP packet using simple header format (not WebSocketMessage)
func (s *Session) SendRawDataToAgent(data []byte) error {
header := &Header{
SID: s.ID,
Len: uint32(len(data)),
}

framedData := append(header.Encode(), data...)

if err := s.SendToAgent(framedData); err != nil {
log.Infof("Failed to send first RDP packet: %v", err)
return err
return s.SendToAgent(framedData)
}

// ForwardToAgent this will spam data from tcp to agent wsconn
func (s *Session) ForwardToAgent(data []byte) error {
if data != nil {
// Send first RDP packet using simple header format (not WebSocketMessage)
if err := s.SendRawDataToAgent(data); err != nil {
log.Infof("Failed to send first RDP packet: %v", err)
return err
}
}

// sending first packet done
Expand All @@ -134,14 +142,7 @@ func (s *Session) ForwardToAgent(data []byte) error {
}

if n > 0 {

header := &Header{
SID: s.ID,
Len: uint32(n),
}
framedData := append(header.Encode(), buffer[:n]...)

if err := s.SendToAgent(framedData); err != nil {
if err = s.SendRawDataToAgent(buffer[:n]); err != nil {
log.Infof("Failed to send RDP data to agent: %v", err)
break
}
Expand All @@ -150,7 +151,7 @@ func (s *Session) ForwardToAgent(data []byte) error {
return nil
}

// this will forward data from agent to tcp
// ForwardToClient this will forward data from agent to tcp
func (s *Session) ForwardToClient() {
for data := range s.dataChannel {

Expand All @@ -162,6 +163,18 @@ func (s *Session) ForwardToClient() {
}
}

// GetTCPDataChannel returns the channel that will be used to send data to the TCP connection
// Warn: do not use this when calling ForwardToClient()
func (s *Session) GetTCPDataChannel() chan []byte {
return s.dataChannel
}

// ToConn returns a net.Conn that can be used to read and write as a normal go connection
// Warn: do not use this when calling ForwardToClient()
func (s *Session) ToConn() net.Conn {
return &sessionConnWrapper{session: s}
}

func CreateAgent(agentID string, ws *websocket.Conn) error {
BrokerInstance.agents.Store(agentID, NewAgentCommunicator(ws))
return nil
Expand Down Expand Up @@ -197,3 +210,98 @@ func GetSessions() map[uuid.UUID]*Session {
})
return sessions
}

var _ net.Conn = (*sessionConnWrapper)(nil)

// sessionConnWrapper makes Session look like a normal net.Conn
type sessionConnWrapper struct {
session *Session
deadline *time.Time
buffer [16384]byte
bufferPos int // current position in buffer
bufferLen int // amount of valid data in a buffer
}

func (s *sessionConnWrapper) Read(b []byte) (n int, err error) {
ctx := context.Background()
cancel := func() {}
if s.deadline != nil {
ctx, cancel = context.WithDeadline(ctx, *s.deadline)
}
defer cancel()
defer func() {
s.deadline = nil
}()
Comment on lines +231 to +234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a single defer statement:

defer func() {
        cancel()
		s.deadline = nil
}()


c := s.session.GetTCPDataChannel()

// First, serve any buffered data
if s.bufferLen > 0 {
n := copy(b, s.buffer[s.bufferPos:s.bufferPos+s.bufferLen])
s.bufferPos += n
s.bufferLen -= n
if s.bufferLen == 0 {
s.bufferPos = 0
}
return n, nil
}

// Wait for data from a channel or context done
select {
case data := <-c:
if data == nil {
// Channel closed
return 0, io.EOF
}

// Copy as much as we can into the provided buffer
remaining := len(b)
if len(data) > remaining {
// Buffer the excess data in the internal buffer
n := copy(b, data[:remaining])

// Store the rest in the internal buffer
s.bufferLen = copy(s.buffer[:], data[remaining:])
s.bufferPos = 0
return n, nil
}

// Data fits entirely in the provided buffer
n := copy(b, data)
return n, nil

case <-ctx.Done():
return 0, ctx.Err()
}
}

func (s *sessionConnWrapper) Write(b []byte) (n int, err error) {
err = s.session.SendRawDataToAgent(b)
return len(b), err
}

func (s *sessionConnWrapper) Close() error {
s.session.Close()
return nil
}

func (s *sessionConnWrapper) LocalAddr() net.Addr {
return nil
}

func (s *sessionConnWrapper) RemoteAddr() net.Addr {
return nil
}

func (s *sessionConnWrapper) SetDeadline(t time.Time) error {
s.deadline = &t
return nil
}

func (s *sessionConnWrapper) SetReadDeadline(t time.Time) error {
return s.SetDeadline(t)
}

func (s *sessionConnWrapper) SetWriteDeadline(t time.Time) error {
return s.SetDeadline(t)
}
Loading
Loading