Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gateway/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
ginzap "github.com/gin-contrib/zap"
"github.com/gin-gonic/gin"
"github.com/hoophq/hoop/gateway/proxyproto/ssmproxy"
"github.com/hoophq/hoop/gateway/rdp"
"go.uber.org/zap"

"github.com/hoophq/hoop/common/log"
Expand Down Expand Up @@ -139,6 +140,10 @@ func (a *Api) StartAPI(sentryInit bool) {
ssmInstance := ssmproxy.GetServerInstance()
ssmInstance.AttachHandlers(ssmGroup)

ironRdpGroup := route.Group(baseURL + "/rdpproxy")
ironRdpInstance := rdp.GetIronServerInstance()
ironRdpInstance.AttachHandlers(ironRdpGroup)

rg := route.Group(baseURL + "/api")
if sentryInit {
rg.Use(sentrygin.New(sentrygin.Options{
Expand Down
53 changes: 46 additions & 7 deletions gateway/broker/communicators.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
package broker

import (
"io"
"net"
"time"

"github.com/gorilla/websocket"
)

type ConnectionCommunicator interface {
Send(data []byte) error
Read() (int, []byte, error)
Close()
Close() error
WrapToConnection() net.Conn
}

type agentCommunicator struct{ conn *websocket.Conn }

func NewAgentCommunicator(conn *websocket.Conn) *agentCommunicator {
func NewAgentCommunicator(conn *websocket.Conn) ConnectionCommunicator {
return &agentCommunicator{conn: conn}
}

func (a *agentCommunicator) Send(data []byte) error {
return a.conn.WriteMessage(websocket.BinaryMessage, data)
}

func (a *agentCommunicator) Close() {
a.conn.Close()
func (a *agentCommunicator) Close() error {
return a.conn.Close()
}

func (a *agentCommunicator) Read() (int, []byte, error) {
Expand All @@ -34,9 +37,13 @@ func (a *agentCommunicator) Read() (int, []byte, error) {
return len(message), message, nil
}

func (a *agentCommunicator) WrapToConnection() net.Conn {
return &WSConnWrap{a.conn}
}

type clientCommunicator struct{ conn net.Conn }

func NewClientCommunicator(conn net.Conn) *clientCommunicator {
func NewClientCommunicator(conn net.Conn) ConnectionCommunicator {
return &clientCommunicator{conn: conn}
}

Expand All @@ -54,6 +61,38 @@ func (c *clientCommunicator) Send(data []byte) error {
return err
}

func (c *clientCommunicator) Close() {
c.conn.Close()
func (c *clientCommunicator) Close() error {
return c.conn.Close()
}

func (c *clientCommunicator) WrapToConnection() net.Conn {
return c.conn
}

// WSConnWrap wraps a websocket.Conn to implement the net.Conn interface.
type WSConnWrap struct {
*websocket.Conn
}

func (w *WSConnWrap) Read(b []byte) (n int, err error) {
_, message, err := w.ReadMessage()
if err != nil {
return 0, err
}
if len(message) > len(b) {
return 0, io.ErrShortBuffer
}
return copy(b, message), nil
}

func (w *WSConnWrap) Write(b []byte) (n int, err error) {
err = w.WriteMessage(websocket.BinaryMessage, b)
return len(b), err
}

func (w *WSConnWrap) SetDeadline(t time.Time) error {
if err := w.Conn.SetWriteDeadline(t); err != nil {
return err
}
return w.Conn.SetReadDeadline(t)
}
148 changes: 128 additions & 20 deletions gateway/broker/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package broker
import (
"context"
"io"
"net"
"sync"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
Expand Down Expand Up @@ -74,19 +76,19 @@ func (s *Session) Close() {

// Close consumer connection
if s.ClientCommunicator != nil {
s.ClientCommunicator.Close()
_ = s.ClientCommunicator.Close()
}

// Close agent connection
if s.AgentCommunicator != nil {
s.AgentCommunicator.Close()
_ = s.AgentCommunicator.Close()
}

// Remove from sessions map
// Remove from the sessions map
BrokerInstance.sessions.Delete(s.ID)
}

// forward data from agent to tcp
// ForwardToTCP forward data from agent to tcp
func (s *Session) ForwardToTCP(data []byte) {
s.mu.Lock()
if s.closed || s.dataChannel == nil {
Expand All @@ -96,7 +98,7 @@ func (s *Session) ForwardToTCP(data []byte) {
s.mu.Unlock()

select {
// the data is create with a buffer size buffer size of 1024
// the data is created with a buffer size of 1024
// Up to 1024 messages can be queued without blocking
//If the buffer is full, new data is dropped rather than blocking
case s.dataChannel <- data:
Expand All @@ -107,19 +109,25 @@ func (s *Session) ForwardToTCP(data []byte) {
}
}

// this will spam data from tcp to agent wsconn
func (s *Session) ForwardToAgent(data []byte) error {
// Send first RDP packet using simple header format (not WebSocketMessage)
func (s *Session) SendRawDataToAgent(data []byte) error {
header := &Header{
SID: s.ID,
Len: uint32(len(data)),
}

framedData := append(header.Encode(), data...)

if err := s.SendToAgent(framedData); err != nil {
log.Infof("Failed to send first RDP packet: %v", err)
return err
return s.SendToAgent(framedData)
}

// ForwardToAgent this will spam data from tcp to agent wsconn
func (s *Session) ForwardToAgent(data []byte) error {
if data != nil {
// Send first RDP packet using simple header format (not WebSocketMessage)
if err := s.SendRawDataToAgent(data); err != nil {
log.Infof("Failed to send first RDP packet: %v", err)
return err
}
}

// sending first packet done
Expand All @@ -134,14 +142,7 @@ func (s *Session) ForwardToAgent(data []byte) error {
}

if n > 0 {

header := &Header{
SID: s.ID,
Len: uint32(n),
}
framedData := append(header.Encode(), buffer[:n]...)

if err := s.SendToAgent(framedData); err != nil {
if err = s.SendRawDataToAgent(buffer[:n]); err != nil {
log.Infof("Failed to send RDP data to agent: %v", err)
break
}
Expand All @@ -150,7 +151,7 @@ func (s *Session) ForwardToAgent(data []byte) error {
return nil
}

// this will forward data from agent to tcp
// ForwardToClient this will forward data from agent to tcp
func (s *Session) ForwardToClient() {
for data := range s.dataChannel {

Expand All @@ -162,6 +163,18 @@ func (s *Session) ForwardToClient() {
}
}

// GetTCPDataChannel returns the channel that will be used to send data to the TCP connection
// Warn: do not use this when calling ForwardToClient()
func (s *Session) GetTCPDataChannel() chan []byte {
return s.dataChannel
}

// ToConn returns a net.Conn that can be used to read and write as a normal go connection
// Warn: do not use this when calling ForwardToClient()
func (s *Session) ToConn() net.Conn {
return &sessionConnWrapper{session: s}
}

func CreateAgent(agentID string, ws *websocket.Conn) error {
BrokerInstance.agents.Store(agentID, NewAgentCommunicator(ws))
return nil
Expand Down Expand Up @@ -197,3 +210,98 @@ func GetSessions() map[uuid.UUID]*Session {
})
return sessions
}

var _ net.Conn = (*sessionConnWrapper)(nil)

// sessionConnWrapper makes Session look like a normal net.Conn
type sessionConnWrapper struct {
session *Session
deadline *time.Time
buffer [16384]byte
bufferPos int // current position in buffer
bufferLen int // amount of valid data in a buffer
}

func (s *sessionConnWrapper) Read(b []byte) (n int, err error) {
ctx := context.Background()
cancel := func() {}
if s.deadline != nil {
ctx, cancel = context.WithDeadline(ctx, *s.deadline)
}
defer cancel()
defer func() {
s.deadline = nil
}()
Comment on lines +231 to +234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a single defer statement:

defer func() {
        cancel()
		s.deadline = nil
}()


c := s.session.GetTCPDataChannel()

// First, serve any buffered data
if s.bufferLen > 0 {
n := copy(b, s.buffer[s.bufferPos:s.bufferPos+s.bufferLen])
s.bufferPos += n
s.bufferLen -= n
if s.bufferLen == 0 {
s.bufferPos = 0
}
return n, nil
}

// Wait for data from a channel or context done
select {
case data := <-c:
if data == nil {
// Channel closed
return 0, io.EOF
}

// Copy as much as we can into the provided buffer
remaining := len(b)
if len(data) > remaining {
// Buffer the excess data in the internal buffer
n := copy(b, data[:remaining])

// Store the rest in the internal buffer
s.bufferLen = copy(s.buffer[:], data[remaining:])
s.bufferPos = 0
return n, nil
}

// Data fits entirely in the provided buffer
n := copy(b, data)
return n, nil

case <-ctx.Done():
return 0, ctx.Err()
}
}

func (s *sessionConnWrapper) Write(b []byte) (n int, err error) {
err = s.session.SendRawDataToAgent(b)
return len(b), err
}

func (s *sessionConnWrapper) Close() error {
s.session.Close()
return nil
}

func (s *sessionConnWrapper) LocalAddr() net.Addr {
return nil
}

func (s *sessionConnWrapper) RemoteAddr() net.Addr {
return nil
}

func (s *sessionConnWrapper) SetDeadline(t time.Time) error {
s.deadline = &t
return nil
}

func (s *sessionConnWrapper) SetReadDeadline(t time.Time) error {
return s.SetDeadline(t)
}

func (s *sessionConnWrapper) SetWriteDeadline(t time.Time) error {
return s.SetDeadline(t)
}
Loading