diff --git a/constants/main.go b/constants/main.go index 41e2c24..51ea7c5 100644 --- a/constants/main.go +++ b/constants/main.go @@ -34,7 +34,7 @@ const ( CLIENT_CUSTOM_DNS_HELP = "Define a custom DNS server that the mmar client should use when accessing your local dev server. (eg: 8.8.8.8:53, defaults to DNS in OS)" CLIENT_CUSTOM_CERT_HELP = "Define path to file custom TLS certificate containing complete ASN.1 DER content (certificate, signature algorithm and signature). Currently used for testing, but may be used to allow mmar client to work with a dev server using custom TLS certificate setups. (eg: /path/to/cert)" - TUNNEL_MESSAGE_PROTOCOL_VERSION = 3 + TUNNEL_MESSAGE_PROTOCOL_VERSION = 4 TUNNEL_MESSAGE_DATA_DELIMITER = '\n' ID_CHARSET = "abcdefghijklmnopqrstuvwxyz0123456789" ID_LENGTH = 6 @@ -49,6 +49,7 @@ const ( HEARTBEAT_FROM_CLIENT_TIMEOUT = 2 READ_DEADLINE = 3 MAX_REQ_BODY_SIZE = 10000000 // 10mb + REQUEST_ID_BUFF_SIZE = 4 CLIENT_DISCONNECT_ERR_TEXT = "Tunnel is closed, cannot connect to mmar client." LOCALHOST_NOT_RUNNING_ERR_TEXT = "Tunneled successfully, but nothing is running on localhost." diff --git a/internal/client/main.go b/internal/client/main.go index bf4fe25..2eb8a92 100644 --- a/internal/client/main.go +++ b/internal/client/main.go @@ -98,7 +98,6 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { if certErr != nil { logger.Log(constants.YELLOW, "Warning: Could not load custom certificate") } else { - fmt.Println("adding cert dawg..") fwdClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{ RootCAs: x509.NewCertPool(), } @@ -107,8 +106,20 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { } reqReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData)) - req, reqErr := http.ReadRequest(reqReader) + // Extract RequestId + reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE) + _, err := io.ReadFull(reqReader, reqIdBuff) + if err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse RequestId for request: %v\n", err)) + return + } + + // Include RequestId in tunnel back message + msgData := []byte{} + msgData = append(msgData, reqIdBuff...) + + req, reqErr := http.ReadRequest(reqReader) if reqErr != nil { if errors.Is(reqErr, io.EOF) { logger.Log(constants.DEFAULT_COLOR, "Connection to mmar server closed or disconnected. Exiting...") @@ -128,20 +139,20 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { resp, fwdErr := fwdClient.Do(req) if fwdErr != nil { if errors.Is(fwdErr, syscall.ECONNREFUSED) || errors.Is(fwdErr, io.ErrUnexpectedEOF) || errors.Is(fwdErr, io.EOF) { - localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING} + localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING, MsgData: msgData} if err := mc.SendMessage(localhostNotRunningMsg); err != nil { log.Fatal(err) } return } else if errors.Is(fwdErr, context.DeadlineExceeded) { - destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT} + destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT, MsgData: msgData} if err := mc.SendMessage(destServerTimedoutMsg); err != nil { log.Fatal(err) } return } - invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST} + invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST, MsgData: msgData} if err := mc.SendMessage(invalidRespFromDestMsg); err != nil { log.Fatal(err) } @@ -151,8 +162,8 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { // Writing response to buffer to tunnel it back var responseBuff bytes.Buffer resp.Write(&responseBuff) - - respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} + msgData = append(msgData, responseBuff.Bytes()...) + respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: msgData} if err := mc.SendMessage(respMessage); err != nil { log.Fatal(err) } @@ -178,6 +189,15 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) { continue } mc.Tunnel.Conn = conn + mc.Tunnel.Reader = bufio.NewReader(conn) + + // Try to reclaim the same subdomain + reclaimTunnelMsg := protocol.TunnelMessage{MsgType: protocol.RECLAIM_TUNNEL, MsgData: []byte(mc.subdomain)} + if err := mc.SendMessage(reclaimTunnelMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...") + os.Exit(0) + } + break } } @@ -228,21 +248,9 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) { } switch tunnelMsg.MsgType { - case protocol.CLIENT_CONNECT: + case protocol.TUNNEL_CREATED, protocol.TUNNEL_RECLAIMED: tunnelSubdomain := string(tunnelMsg.MsgData) - // If there is an existing subdomain, that means we are reconnecting with an - // existing mmar client, try to reclaim the same subdomain - if mc.subdomain != "" { - reconnectMsg := protocol.TunnelMessage{MsgType: protocol.CLIENT_RECLAIM_SUBDOMAIN, MsgData: []byte(tunnelSubdomain + ":" + mc.subdomain)} - mc.subdomain = "" - if err := mc.SendMessage(reconnectMsg); err != nil { - logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...") - os.Exit(0) - } - continue - } else { - mc.subdomain = tunnelSubdomain - } + mc.subdomain = tunnelSubdomain logger.LogTunnelCreated(tunnelSubdomain, mc.TunnelHost, mc.TunnelHttpPort, mc.LocalPort) case protocol.CLIENT_TUNNEL_LIMIT: limit := logger.ColorLogStr( @@ -298,7 +306,7 @@ func Run(config ConfigOptions) { } defer conn.Close() mmarClient := MmarClient{ - protocol.Tunnel{Conn: conn}, + protocol.Tunnel{Conn: conn, Reader: bufio.NewReader(conn)}, config, "", } @@ -309,6 +317,12 @@ func Run(config ConfigOptions) { // Process Tunnel Messages coming from mmar server go mmarClient.ProcessTunnelMessages(ctx) + createTunnelMsg := protocol.TunnelMessage{MsgType: protocol.CREATE_TUNNEL} + if err := mmarClient.SendMessage(createTunnelMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, "Failed to create Tunnel. Exiting...") + os.Exit(0) + } + // Wait for an interrupt signal, if received, terminate gracefully <-sigInt diff --git a/internal/protocol/main.go b/internal/protocol/main.go index 8e41559..1287734 100644 --- a/internal/protocol/main.go +++ b/internal/protocol/main.go @@ -19,8 +19,10 @@ import ( const ( REQUEST = uint8(iota + 1) RESPONSE - CLIENT_CONNECT - CLIENT_RECLAIM_SUBDOMAIN + CREATE_TUNNEL + RECLAIM_TUNNEL + TUNNEL_CREATED + TUNNEL_RECLAIMED CLIENT_DISCONNECT CLIENT_TUNNEL_LIMIT LOCALHOST_NOT_RUNNING @@ -76,6 +78,7 @@ type Tunnel struct { Id string Conn net.Conn CreatedOn time.Time + Reader *bufio.Reader } type TunnelInterface interface { @@ -178,6 +181,10 @@ func (tm *TunnelMessage) deserializeMessage(reader *bufio.Reader) error { return nil } +func (t *Tunnel) ReservedSubdomain() bool { + return t.Id != "" +} + func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error { // Serialize tunnel message data serializedMsg, serializeErr := tunnelMsg.serializeMessage() @@ -189,11 +196,9 @@ func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error { } func (t *Tunnel) ReceiveMessage() (TunnelMessage, error) { - msgReader := bufio.NewReader(t.Conn) - // Read and deserialize tunnel message data tunnelMessage := TunnelMessage{} - deserializeErr := tunnelMessage.deserializeMessage(msgReader) + deserializeErr := tunnelMessage.deserializeMessage(t.Reader) return tunnelMessage, deserializeErr } diff --git a/internal/server/main.go b/internal/server/main.go index 8731e1c..594ba88 100644 --- a/internal/server/main.go +++ b/internal/server/main.go @@ -4,19 +4,18 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "encoding/json" "errors" "fmt" "html" "io" "log" - "math/rand" "net" "net/http" "os" "os/signal" "slices" - "strings" "sync" "time" @@ -44,7 +43,6 @@ type IncomingRequest struct { responseWriter http.ResponseWriter request *http.Request cancel context.CancelCauseFunc - serializedReq []byte ctx context.Context } @@ -53,11 +51,14 @@ type OutgoingResponse struct { body []byte } +type RequestId uint32 + // Tunnel to Client type ClientTunnel struct { protocol.Tunnel - incomingChannel chan IncomingRequest - outgoingChannel chan protocol.TunnelMessage + incomingChannel chan IncomingRequest + outgoingChannel chan protocol.TunnelMessage + inflightRequests *sync.Map } func (ct *ClientTunnel) drainChannels() { @@ -119,6 +120,16 @@ func (ct *ClientTunnel) close(graceful bool) { ) } +// Generate unique request id for incoming request for client +func (ct *ClientTunnel) GenerateUniqueRequestID() RequestId { + var generatedReqId RequestId + + for _, exists := ct.inflightRequests.Load(generatedReqId); exists || generatedReqId == 0; { + generatedReqId = RequestId(GenerateRandomUint32()) + } + return generatedReqId +} + // Serves simple stats for mmar server behind Basic Authentication func (ms *MmarServer) handleServerStats(w http.ResponseWriter, r *http.Request) { // Check Basic Authentication @@ -193,19 +204,33 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Create response channel to receive response for tunneled request respChannel := make(chan OutgoingResponse) - // Tunnel the request - clientTunnel.incomingChannel <- IncomingRequest{ + // Add request to client's inflight requests + reqId := clientTunnel.GenerateUniqueRequestID() + incomingReq := IncomingRequest{ responseChannel: respChannel, responseWriter: w, request: r, cancel: cancel, - serializedReq: serializedRequest, ctx: ctx, } + clientTunnel.inflightRequests.Store(reqId, incomingReq) + + // Construct Request message data + reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE) + binary.LittleEndian.PutUint32(reqIdBuff, uint32(reqId)) + reqMsgData := append(reqIdBuff, serializedRequest...) + + // Tunnel the request to mmar client + reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: reqMsgData} + if err := clientTunnel.SendMessage(reqMessage); err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err)) + cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR) + } select { case <-ctx.Done(): // Request is canceled or Tunnel is closed if context is canceled handleCancel(context.Cause(ctx), w) + clientTunnel.inflightRequests.Delete(reqId) return case resp, _ := <-respChannel: // Await response for tunneled request // Add header to close the connection @@ -220,20 +245,15 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (ms *MmarServer) GenerateUniqueId() string { - reservedIDs := []string{"", "admin", "stats"} +func (ms *MmarServer) GenerateUniqueSubdomain() string { + reservedSubdomains := []string{"", "admin", "stats"} - generatedId := "" - for _, exists := ms.clients[generatedId]; exists || slices.Contains(reservedIDs, generatedId); { - var randSeed *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) - b := make([]byte, constants.ID_LENGTH) - for i := range b { - b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))] - } - generatedId = string(b) + generatedSubdomain := "" + for _, exists := ms.clients[generatedSubdomain]; exists || slices.Contains(reservedSubdomains, generatedSubdomain); { + generatedSubdomain = GenerateRandomID() } - return generatedId + return generatedSubdomain } func (ms *MmarServer) TunnelLimitedIP(ip string) bool { @@ -247,31 +267,40 @@ func (ms *MmarServer) TunnelLimitedIP(ip string) bool { return len(tunnels) >= constants.MAX_TUNNELS_PER_IP } -func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) { +func (ms *MmarServer) newClientTunnel(tunnel protocol.Tunnel, subdomain string) (*ClientTunnel, error) { // Acquire lock to create new client tunnel data ms.mu.Lock() - // Generate unique ID for client - uniqueId := ms.GenerateUniqueId() - tunnel := protocol.Tunnel{ - Id: uniqueId, - Conn: conn, - CreatedOn: time.Now(), + var uniqueSubdomain string + var msgType uint8 + if subdomain != "" { + uniqueSubdomain = subdomain + msgType = protocol.TUNNEL_RECLAIMED + } else { + // Generate unique subdomain for client if not passed in + uniqueSubdomain = ms.GenerateUniqueSubdomain() + msgType = protocol.TUNNEL_CREATED } + tunnel.Id = uniqueSubdomain + // Create channels to tunnel requests to and recieve responses from incomingChannel := make(chan IncomingRequest) outgoingChannel := make(chan protocol.TunnelMessage) + // Initialize inflight requests map for client tunnel + var inflightRequests sync.Map + // Create client tunnel clientTunnel := ClientTunnel{ tunnel, incomingChannel, outgoingChannel, + &inflightRequests, } // Check if IP reached max tunnel limit - clientIP := utils.ExtractIP(conn.RemoteAddr().String()) + clientIP := utils.ExtractIP(tunnel.Conn.RemoteAddr().String()) limitedIP := ms.TunnelLimitedIP(clientIP) // If so, send limit message to client and close client tunnel if limitedIP { @@ -286,18 +315,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) { } // Add client tunnel to clients - ms.clients[uniqueId] = clientTunnel + ms.clients[uniqueSubdomain] = clientTunnel // Associate tunnel with client IP - ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueId) + ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueSubdomain) // Release lock once created ms.mu.Unlock() - // Send unique ID to client - connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(uniqueId)} + // Send unique subdomain to client + connMessage := protocol.TunnelMessage{MsgType: msgType, MsgData: []byte(uniqueSubdomain)} if err := clientTunnel.SendMessage(connMessage); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err)) + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique subdomain msg to client: %v", err)) return nil, err } @@ -306,32 +335,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) { func (ms *MmarServer) handleTcpConnection(conn net.Conn) { - clientTunnel, err := ms.newClientTunnel(conn) - - if err != nil { - if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) { - // Close the connection when client max tunnels limit reached - conn.Close() - return - } - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err)) - return + tunnel := protocol.Tunnel{ + Conn: conn, + CreatedOn: time.Now(), + Reader: bufio.NewReader(conn), } - logger.Log( - constants.DEFAULT_COLOR, - fmt.Sprintf( - "[%s] Tunnel created: %s", - clientTunnel.Tunnel.Id, - conn.RemoteAddr().String(), - ), - ) - // Process Tunnel Messages coming from mmar client - go ms.processTunnelMessages(clientTunnel) + go ms.processTunnelMessages(tunnel) +} - // Start goroutine to process tunneled requests - go ms.processTunneledRequestsForClient(clientTunnel) +func (ms *MmarServer) closeTunnel(t *protocol.Tunnel) { + t.Conn.Close() } func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) { @@ -351,190 +366,216 @@ func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) { ct.close(true) } -func (ms *MmarServer) processTunneledRequestsForClient(ct *ClientTunnel) { - for { - // Read requests coming in tunnel channel - incomingReq, ok := <-ct.incomingChannel - if !ok { - // Channel closed, client disconencted, shutdown goroutine - return - } +func (ms *MmarServer) closeClientTunnelOrConn(ct *ClientTunnel, t protocol.Tunnel) { - // Forward the request to mmar client - reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: incomingReq.serializedReq} - if err := ct.SendMessage(reqMessage); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err)) - incomingReq.cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR) - continue - } + // If client has not reserved subdomain, just close the tcp connection + if !ct.ReservedSubdomain() { + ms.closeTunnel(&t) + return + } - // Wait for response for this request to come back from outgoing channel - respTunnelMsg, ok := <-ct.outgoingChannel - if !ok { - // Channel closed, client disconencted, shutdown goroutine - return - } + ms.closeClientTunnel(ct) +} - // Read response for forwarded request - respReader := bufio.NewReader(bytes.NewReader(respTunnelMsg.MsgData)) - resp, respErr := http.ReadResponse(respReader, incomingReq.request) +func (ms *MmarServer) handleResponseMessages(ct *ClientTunnel, tunnelMsg protocol.TunnelMessage) { + respReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData)) - if respErr != nil { - if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) { - incomingReq.cancel(CLIENT_DISCONNECTED_ERR) - ms.closeClientTunnel(ct) - return - } - failedReq := fmt.Sprintf("%s - %s%s", incomingReq.request.Method, html.EscapeString(incomingReq.request.URL.Path), incomingReq.request.URL.RawQuery) - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq)) - incomingReq.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR) - continue - } + // Extract RequestId + reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE) + _, err := io.ReadFull(respReader, reqIdBuff) + if err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] - Failed to parse RequestId for response: %v\n", ct.Tunnel.Id, err)) + return + } - respBody, respBodyErr := io.ReadAll(resp.Body) - if respBodyErr != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr)) - incomingReq.cancel(READ_RESP_BODY_ERR) - continue - } + // Get Inflight Request and remove it from inflight requests + reqId := RequestId(binary.LittleEndian.Uint32(reqIdBuff)) + inflight, loaded := ct.inflightRequests.LoadAndDelete(reqId) + if !loaded { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to identify inflight request: %v", ct.Tunnel.Id, reqId)) + return + } - // Set headers for response - for hKey, hVal := range resp.Header { - incomingReq.responseWriter.Header().Set(hKey, hVal[0]) - // Add remaining values for header if more than than one exists - for i := 1; i < len(hVal); i++ { - incomingReq.responseWriter.Header().Add(hKey, hVal[i]) - } + inflightRequest, ok := inflight.(IncomingRequest) + if !ok { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to parse inflight request: %v", ct.Tunnel.Id, reqId)) + return + } + + // Read response for forwarded request + resp, respErr := http.ReadResponse(respReader, inflightRequest.request) + + if respErr != nil { + if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) { + inflightRequest.cancel(CLIENT_DISCONNECTED_ERR) + ms.closeClientTunnel(ct) + return } + failedReq := fmt.Sprintf("%s - %s%s", inflightRequest.request.Method, html.EscapeString(inflightRequest.request.URL.Path), inflightRequest.request.URL.RawQuery) + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq)) + inflightRequest.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR) + return + } + + respBody, respBodyErr := io.ReadAll(resp.Body) + if respBodyErr != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr)) + inflightRequest.cancel(READ_RESP_BODY_ERR) + return + } - // Close response body - resp.Body.Close() + defer resp.Body.Close() - select { - case <-incomingReq.ctx.Done(): - // Request is canceled, on to the next request - continue - case incomingReq.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}: - // Send response data back + // Set headers for response + for hKey, hVal := range resp.Header { + inflightRequest.responseWriter.Header().Set(hKey, hVal[0]) + // Add remaining values for header if more than than one exists + for i := 1; i < len(hVal); i++ { + inflightRequest.responseWriter.Header().Add(hKey, hVal[i]) } } + + select { + case <-inflightRequest.ctx.Done(): + // Request is canceled, do nothing + return + case inflightRequest.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}: + // Send response data back + } } -func (ms *MmarServer) processTunnelMessages(ct *ClientTunnel) { +func (ms *MmarServer) processTunnelMessages(t protocol.Tunnel) { + var ct *ClientTunnel for { // Send heartbeat if nothing has been read for a while receiveMessageTimeout := time.AfterFunc( constants.HEARTBEAT_FROM_SERVER_TIMEOUT*time.Second, func() { heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_SERVER} - if err := ct.SendMessage(heartbeatMsg); err != nil { + if err := t.SendMessage(heartbeatMsg); err != nil { logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send heartbeat: %v", err)) - ms.closeClientTunnel(ct) + ms.closeClientTunnelOrConn(ct, t) return } - // Set a read timeout, if no response to heartbeat is recieved within that period, + // Set a read timeout, if no response to heartbeat is received within that period, // that means the client has disconnected readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second)) - ct.Tunnel.Conn.SetReadDeadline(readDeadline) + t.Conn.SetReadDeadline(readDeadline) }, ) - tunnelMsg, err := ct.ReceiveMessage() + tunnelMsg, err := t.ReceiveMessage() // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout // as we do not need to send heartbeat or check connection health in this iteration receiveMessageTimeout.Stop() - ct.Tunnel.Conn.SetReadDeadline(time.Time{}) + t.Conn.SetReadDeadline(time.Time{}) if err != nil { logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Receive Message from client tunnel errored: %v", err)) if utils.NetworkError(err) { // If error with connection, stop processing messages - ms.closeClientTunnel(ct) + ms.closeClientTunnelOrConn(ct, t) return } continue } switch tunnelMsg.MsgType { - case protocol.RESPONSE: - ct.outgoingChannel <- tunnelMsg - case protocol.LOCALHOST_NOT_RUNNING: - // Create a response for Tunnel connected but localhost not running - errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING) - responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState) - notRunningMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} - ct.outgoingChannel <- notRunningMsg - case protocol.DEST_REQUEST_TIMEDOUT: - // Create a response for Tunnel connected but localhost took too long to respond - errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT) - responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState) - destTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} - ct.outgoingChannel <- destTimedoutMsg - case protocol.CLIENT_DISCONNECT: - ms.closeClientTunnel(ct) - return - case protocol.HEARTBEAT_FROM_CLIENT: - heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK} - if err := ct.SendMessage(heartbeatAckMsg); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to heartbeat ack to client: %v", err)) - ms.closeClientTunnel(ct) + case protocol.CREATE_TUNNEL: + // mmar client requesting new tunnel + ct, err = ms.newClientTunnel(t, "") + + if err != nil { + if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) { + // Close the connection when client max tunnels limit reached + t.Conn.Close() + return + } + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err)) return } - case protocol.HEARTBEAT_ACK: - // Got a heartbeat ack, that means the connection is healthy, - // we do not need to perform any action - case protocol.CLIENT_RECLAIM_SUBDOMAIN: - newAndExistingIDs := strings.Split(string(tunnelMsg.MsgData), ":") - newId := newAndExistingIDs[0] - existingId := newAndExistingIDs[1] + + logger.Log( + constants.DEFAULT_COLOR, + fmt.Sprintf( + "[%s] Tunnel created: %s", + ct.Tunnel.Id, + t.Conn.RemoteAddr().String(), + ), + ) + case protocol.RECLAIM_TUNNEL: + // mmar client reclaiming a previously created tunnel + existingId := string(tunnelMsg.MsgData) // Check if the subdomain has already been taken _, ok := ms.clients[existingId] if ok { // if so, close the tunnel, so the user can create a new one - ms.closeClientTunnel(ct) + ms.closeClientTunnelOrConn(ct, t) return } - ct.Tunnel.Id = existingId - - // Add existing client tunnel to clients - ms.clients[existingId] = *ct - - // Remove newId tunnel from clients - delete(ms.clients, newId) - - // Update the tunnels for the IP - clientIP := utils.ExtractIP(ct.Conn.RemoteAddr().String()) - newIdIndex := slices.Index(ms.tunnelsPerIP[clientIP], newId) - if newIdIndex == -1 { - ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], existingId) - } else { - ms.tunnelsPerIP[clientIP][newIdIndex] = existingId - } - - connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(existingId)} - if err := ct.SendMessage(connMessage); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err)) - ms.closeClientTunnel(ct) + ct, err = ms.newClientTunnel(t, existingId) + if err != nil { + if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) { + // Close the connection when client max tunnels limit reached + t.Conn.Close() + return + } + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to reclaim ClientTunnel: %v", err)) return } logger.Log( constants.DEFAULT_COLOR, fmt.Sprintf( - "[%s] Tunnel reclaimed: %s -> %s", - newId, - ct.Conn.RemoteAddr().String(), + "[%s] Tunnel reclaimed: %s", existingId, + ct.Conn.RemoteAddr().String(), ), ) + case protocol.RESPONSE: + go ms.handleResponseMessages(ct, tunnelMsg) + case protocol.LOCALHOST_NOT_RUNNING: + // Create a response for Tunnel connected but localhost not running + errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING) + responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState) + notRunningMsg := protocol.TunnelMessage{ + MsgType: protocol.RESPONSE, + MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...), + } + go ms.handleResponseMessages(ct, notRunningMsg) + case protocol.DEST_REQUEST_TIMEDOUT: + // Create a response for Tunnel connected but localhost took too long to respond + errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT) + responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState) + destTimedoutMsg := protocol.TunnelMessage{ + MsgType: protocol.RESPONSE, + MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...), + } + go ms.handleResponseMessages(ct, destTimedoutMsg) + case protocol.CLIENT_DISCONNECT: + ms.closeClientTunnelOrConn(ct, t) + return + case protocol.HEARTBEAT_FROM_CLIENT: + heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK} + if err := t.SendMessage(heartbeatAckMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to heartbeat ack to client: %v", err)) + ms.closeClientTunnelOrConn(ct, t) + return + } + case protocol.HEARTBEAT_ACK: + // Got a heartbeat ack, that means the connection is healthy, + // we do not need to perform any action case protocol.INVALID_RESP_FROM_DEST: // Create a response for receiving invalid response from destination server errState := protocol.TunnelErrState(protocol.INVALID_RESP_FROM_DEST) responseBuff := createSerializedServerResp("500 Internal Server Error", http.StatusInternalServerError, errState) - invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} - ct.outgoingChannel <- invalidRespFromDestMsg + invalidRespFromDestMsg := protocol.TunnelMessage{ + MsgType: protocol.RESPONSE, + MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...), + } + go ms.handleResponseMessages(ct, invalidRespFromDestMsg) } } } diff --git a/internal/server/utils.go b/internal/server/utils.go index 4d91dea..41d9b68 100644 --- a/internal/server/utils.go +++ b/internal/server/utils.go @@ -3,9 +3,12 @@ package server import ( "bytes" "context" + cryptoRand "crypto/rand" + "encoding/binary" "errors" "fmt" "io" + mathRand "math/rand" "net/http" "strconv" "time" @@ -134,3 +137,20 @@ func createSerializedServerResp(status string, statusCode int, body string) byte return responseBuff } + +// Generate a random ID from ID_CHARSET of length ID_LENGTH +func GenerateRandomID() string { + var randSeed *mathRand.Rand = mathRand.New(mathRand.NewSource(time.Now().UnixNano())) + b := make([]byte, constants.ID_LENGTH) + for i := range b { + b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))] + } + return string(b) +} + +// Generate a random 32-bit unsigned integer +func GenerateRandomUint32() uint32 { + var randomUint32 uint32 + binary.Read(cryptoRand.Reader, binary.BigEndian, &randomUint32) + return randomUint32 +}