Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion constants/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (
CLIENT_CUSTOM_DNS_HELP = "Define a custom DNS server that the mmar client should use when accessing your local dev server. (eg: 8.8.8.8:53, defaults to DNS in OS)"
CLIENT_CUSTOM_CERT_HELP = "Define path to file custom TLS certificate containing complete ASN.1 DER content (certificate, signature algorithm and signature). Currently used for testing, but may be used to allow mmar client to work with a dev server using custom TLS certificate setups. (eg: /path/to/cert)"

TUNNEL_MESSAGE_PROTOCOL_VERSION = 3
TUNNEL_MESSAGE_PROTOCOL_VERSION = 4
TUNNEL_MESSAGE_DATA_DELIMITER = '\n'
ID_CHARSET = "abcdefghijklmnopqrstuvwxyz0123456789"
ID_LENGTH = 6
Expand All @@ -49,6 +49,7 @@ const (
HEARTBEAT_FROM_CLIENT_TIMEOUT = 2
READ_DEADLINE = 3
MAX_REQ_BODY_SIZE = 10000000 // 10mb
REQUEST_ID_BUFF_SIZE = 4

CLIENT_DISCONNECT_ERR_TEXT = "Tunnel is closed, cannot connect to mmar client."
LOCALHOST_NOT_RUNNING_ERR_TEXT = "Tunneled successfully, but nothing is running on localhost."
Expand Down
58 changes: 36 additions & 22 deletions internal/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
if certErr != nil {
logger.Log(constants.YELLOW, "Warning: Could not load custom certificate")
} else {
fmt.Println("adding cert dawg..")
fwdClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{
RootCAs: x509.NewCertPool(),
}
Expand All @@ -107,8 +106,20 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
}

reqReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData))
req, reqErr := http.ReadRequest(reqReader)

// Extract RequestId
reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
_, err := io.ReadFull(reqReader, reqIdBuff)
if err != nil {
logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse RequestId for request: %v\n", err))
return
}

// Include RequestId in tunnel back message
msgData := []byte{}
msgData = append(msgData, reqIdBuff...)

req, reqErr := http.ReadRequest(reqReader)
if reqErr != nil {
if errors.Is(reqErr, io.EOF) {
logger.Log(constants.DEFAULT_COLOR, "Connection to mmar server closed or disconnected. Exiting...")
Expand All @@ -128,20 +139,20 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
resp, fwdErr := fwdClient.Do(req)
if fwdErr != nil {
if errors.Is(fwdErr, syscall.ECONNREFUSED) || errors.Is(fwdErr, io.ErrUnexpectedEOF) || errors.Is(fwdErr, io.EOF) {
localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING}
localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING, MsgData: msgData}
if err := mc.SendMessage(localhostNotRunningMsg); err != nil {
log.Fatal(err)
}
return
} else if errors.Is(fwdErr, context.DeadlineExceeded) {
destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT}
destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT, MsgData: msgData}
if err := mc.SendMessage(destServerTimedoutMsg); err != nil {
log.Fatal(err)
}
return
}

invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST}
invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST, MsgData: msgData}
if err := mc.SendMessage(invalidRespFromDestMsg); err != nil {
log.Fatal(err)
}
Expand All @@ -151,8 +162,8 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
// Writing response to buffer to tunnel it back
var responseBuff bytes.Buffer
resp.Write(&responseBuff)

respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
msgData = append(msgData, responseBuff.Bytes()...)
respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: msgData}
if err := mc.SendMessage(respMessage); err != nil {
log.Fatal(err)
}
Expand All @@ -178,6 +189,15 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) {
continue
}
mc.Tunnel.Conn = conn
mc.Tunnel.Reader = bufio.NewReader(conn)

// Try to reclaim the same subdomain
reclaimTunnelMsg := protocol.TunnelMessage{MsgType: protocol.RECLAIM_TUNNEL, MsgData: []byte(mc.subdomain)}
if err := mc.SendMessage(reclaimTunnelMsg); err != nil {
logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...")
os.Exit(0)
}

break
}
}
Expand Down Expand Up @@ -228,21 +248,9 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
}

switch tunnelMsg.MsgType {
case protocol.CLIENT_CONNECT:
case protocol.TUNNEL_CREATED, protocol.TUNNEL_RECLAIMED:
tunnelSubdomain := string(tunnelMsg.MsgData)
// If there is an existing subdomain, that means we are reconnecting with an
// existing mmar client, try to reclaim the same subdomain
if mc.subdomain != "" {
reconnectMsg := protocol.TunnelMessage{MsgType: protocol.CLIENT_RECLAIM_SUBDOMAIN, MsgData: []byte(tunnelSubdomain + ":" + mc.subdomain)}
mc.subdomain = ""
if err := mc.SendMessage(reconnectMsg); err != nil {
logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...")
os.Exit(0)
}
continue
} else {
mc.subdomain = tunnelSubdomain
}
mc.subdomain = tunnelSubdomain
logger.LogTunnelCreated(tunnelSubdomain, mc.TunnelHost, mc.TunnelHttpPort, mc.LocalPort)
case protocol.CLIENT_TUNNEL_LIMIT:
limit := logger.ColorLogStr(
Expand Down Expand Up @@ -298,7 +306,7 @@ func Run(config ConfigOptions) {
}
defer conn.Close()
mmarClient := MmarClient{
protocol.Tunnel{Conn: conn},
protocol.Tunnel{Conn: conn, Reader: bufio.NewReader(conn)},
config,
"",
}
Expand All @@ -309,6 +317,12 @@ func Run(config ConfigOptions) {
// Process Tunnel Messages coming from mmar server
go mmarClient.ProcessTunnelMessages(ctx)

createTunnelMsg := protocol.TunnelMessage{MsgType: protocol.CREATE_TUNNEL}
if err := mmarClient.SendMessage(createTunnelMsg); err != nil {
logger.Log(constants.DEFAULT_COLOR, "Failed to create Tunnel. Exiting...")
os.Exit(0)
}

// Wait for an interrupt signal, if received, terminate gracefully
<-sigInt

Expand Down
15 changes: 10 additions & 5 deletions internal/protocol/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import (
const (
REQUEST = uint8(iota + 1)
RESPONSE
CLIENT_CONNECT
CLIENT_RECLAIM_SUBDOMAIN
CREATE_TUNNEL
RECLAIM_TUNNEL
TUNNEL_CREATED
TUNNEL_RECLAIMED
CLIENT_DISCONNECT
CLIENT_TUNNEL_LIMIT
LOCALHOST_NOT_RUNNING
Expand Down Expand Up @@ -76,6 +78,7 @@ type Tunnel struct {
Id string
Conn net.Conn
CreatedOn time.Time
Reader *bufio.Reader
}

type TunnelInterface interface {
Expand Down Expand Up @@ -178,6 +181,10 @@ func (tm *TunnelMessage) deserializeMessage(reader *bufio.Reader) error {
return nil
}

func (t *Tunnel) ReservedSubdomain() bool {
return t.Id != ""
}

func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error {
// Serialize tunnel message data
serializedMsg, serializeErr := tunnelMsg.serializeMessage()
Expand All @@ -189,11 +196,9 @@ func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error {
}

func (t *Tunnel) ReceiveMessage() (TunnelMessage, error) {
msgReader := bufio.NewReader(t.Conn)

// Read and deserialize tunnel message data
tunnelMessage := TunnelMessage{}
deserializeErr := tunnelMessage.deserializeMessage(msgReader)
deserializeErr := tunnelMessage.deserializeMessage(t.Reader)

return tunnelMessage, deserializeErr
}
Loading
Loading