From d296f7bb4e1ffaf70cc8e323bdd5d23186a2c406 Mon Sep 17 00:00:00 2001 From: Yusuf Musleh Date: Wed, 20 Aug 2025 00:29:33 +0300 Subject: [PATCH 1/3] feat: Add custom dns and cert support mmar client These flags allow users to configure a custom DNS or TLS certificates for the mmar client to use when making request to the dev server. They are mainly used in simulation tests, however they could be utilized in more complicated setups. --- cmd/mmar/main.go | 12 +++++++++++ constants/main.go | 12 +++++++---- internal/client/main.go | 48 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/cmd/mmar/main.go b/cmd/mmar/main.go index 724d9cf..edac2bb 100644 --- a/cmd/mmar/main.go +++ b/cmd/mmar/main.go @@ -45,6 +45,16 @@ func main() { utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HOST, constants.TUNNEL_HOST), constants.TUNNEL_HOST_HELP, ) + clientCustomDns := clientCmd.String( + "custom-dns", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_DNS, ""), + constants.CLIENT_CUSTOM_DNS_HELP, + ) + clientCustomCert := clientCmd.String( + "custom-cert", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_CERT, ""), + constants.CLIENT_CUSTOM_CERT_HELP, + ) versionCmd := flag.NewFlagSet(constants.VERSION_CMD, flag.ExitOnError) versionCmd.Usage = utils.MmarVersionUsage @@ -71,6 +81,8 @@ func main() { TunnelHttpPort: *clientTunnelHttpPort, TunnelTcpPort: *clientTunnelTcpPort, TunnelHost: *clientTunnelHost, + CustomDns: *clientCustomDns, + CustomCert: *clientCustomCert, } client.Run(mmarClientConfig) case constants.VERSION_CMD: diff --git a/constants/main.go b/constants/main.go index 698f9fb..41e2c24 100644 --- a/constants/main.go +++ b/constants/main.go @@ -18,6 +18,8 @@ const ( MMAR_ENV_VAR_TUNNEL_HTTP_PORT = "MMAR__TUNNEL_HTTP_PORT" MMAR_ENV_VAR_TUNNEL_TCP_PORT = "MMAR__TUNNEL_TCP_PORT" MMAR_ENV_VAR_TUNNEL_HOST = "MMAR__TUNNEL_HOST" + MMAR_ENV_VAR_CUSTOM_DNS = "MMAR__CUSTOM_DNS" + MMAR_ENV_VAR_CUSTOM_CERT = "MMAR__CUSTOM_CERT" SERVER_STATS_DEFAULT_USERNAME = "admin" SERVER_STATS_DEFAULT_PASSWORD = "admin" @@ -25,10 +27,12 @@ const ( SERVER_HTTP_PORT_HELP = "Define port where mmar will bind to and run on server for HTTP requests." SERVER_TCP_PORT_HELP = "Define port where mmar will bind to and run on server for TCP connections." - CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar." - CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel." - CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel." - TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to." + CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar." + CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel." + CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel." + TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to." + CLIENT_CUSTOM_DNS_HELP = "Define a custom DNS server that the mmar client should use when accessing your local dev server. (eg: 8.8.8.8:53, defaults to DNS in OS)" + CLIENT_CUSTOM_CERT_HELP = "Define path to file custom TLS certificate containing complete ASN.1 DER content (certificate, signature algorithm and signature). Currently used for testing, but may be used to allow mmar client to work with a dev server using custom TLS certificate setups. (eg: /path/to/cert)" TUNNEL_MESSAGE_PROTOCOL_VERSION = 3 TUNNEL_MESSAGE_DATA_DELIMITER = '\n' diff --git a/internal/client/main.go b/internal/client/main.go index 1f61ca3..68dcc3e 100644 --- a/internal/client/main.go +++ b/internal/client/main.go @@ -4,6 +4,8 @@ import ( "bufio" "bytes" "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -26,6 +28,8 @@ type ConfigOptions struct { TunnelHttpPort string TunnelTcpPort string TunnelHost string + CustomDns string + CustomCert string } type MmarClient struct { @@ -58,6 +62,50 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { }, } + // Use custom DNS if set + if mc.CustomDns != "" { + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial("udp", mc.CustomDns) + }, + } + dialer := &net.Dialer{ + Resolver: r, + } + + tp := &http.Transport{ + DialContext: dialer.DialContext, + } + + fwdClient.Transport = tp + } + + // Use custom TLS certificate if setup + if mc.CustomCert != "" { + certData, certFileErr := os.ReadFile(mc.CustomCert) + if certFileErr != nil { + logger.Log( + constants.RED, + fmt.Sprintf( + "Could not read certificate from file: %v", + certFileErr, + )) + os.Exit(1) + } + + cert, certErr := x509.ParseCertificate(certData) + if certErr != nil { + logger.Log(constants.YELLOW, "Warning: Could not load custom certificate") + } else { + fmt.Println("adding cert dawg..") + fwdClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + RootCAs: x509.NewCertPool(), + } + fwdClient.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert) + } + } + reqReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData)) req, reqErr := http.ReadRequest(reqReader) From c25660ecf799ba4f331b5dd44ebf9ddd8dd0a0de Mon Sep 17 00:00:00 2001 From: Yusuf Musleh Date: Wed, 20 Aug 2025 00:32:34 +0300 Subject: [PATCH 2/3] test: Add mmar client flags + TLS dev server --- simulations/devserver/main.go | 12 ++++++-- simulations/simulation_test.go | 54 +++++++++++++++++++++++++++++----- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/simulations/devserver/main.go b/simulations/devserver/main.go index 59c957b..973d42a 100644 --- a/simulations/devserver/main.go +++ b/simulations/devserver/main.go @@ -25,11 +25,19 @@ type DevServer struct { *httptest.Server } -func NewDevServer() *DevServer { +func NewDevServer(proto string, addr string) *DevServer { mux := setupMux() + var httpServer *httptest.Server + switch proto { + case "https": + httpServer = httptest.NewTLSServer(mux) + case "http": + httpServer = httptest.NewServer(mux) + } + return &DevServer{ - httptest.NewServer(mux), + httpServer, } } diff --git a/simulations/simulation_test.go b/simulations/simulation_test.go index 1fc2173..26f9bb5 100644 --- a/simulations/simulation_test.go +++ b/simulations/simulation_test.go @@ -43,7 +43,15 @@ func StartMmarServer(ctx context.Context) { } } -func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort string) { +func StartMmarClient( + ctx context.Context, + urlCh chan string, + localDevServerPort string, + localDevServerHost string, + localDevServerProto string, + customDns string, + customCert string, +) { cmd := exec.CommandContext( ctx, "./mmar", @@ -54,6 +62,26 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort localDevServerPort, ) + if localDevServerHost != "" { + cmd.Args = append(cmd.Args, "--local-host", localDevServerHost) + } + + if localDevServerProto != "" { + cmd.Args = append(cmd.Args, "--local-proto", localDevServerProto) + } + + if customDns != "" { + cmd.Args = append(cmd.Args, "--custom-dns", customDns) + } + + if customCert != "" { + cmd.Args = append(cmd.Args, "--custom-cert", customCert) + } + + cmd.Args = append(cmd.Args, "") + + cmd.Stdout = os.Stdout + // Pipe Stderr To capture logs for extracting the tunnel url pipe, _ := cmd.StderrPipe() @@ -77,7 +105,6 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort tunnelUrl := extractTunnelURL(line) if tunnelUrl != "" { urlCh <- tunnelUrl - break } line, readErr = stdoutReader.ReadString('\n') } @@ -91,9 +118,9 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort } } -func StartLocalDevServer() *devserver.DevServer { - ds := devserver.NewDevServer() - log.Printf("Started local dev server on: http://localhost:%v", ds.Port()) +func StartLocalDevServer(proto string, addr string) *devserver.DevServer { + ds := devserver.NewDevServer(proto, addr) + log.Printf("Started local dev server on: %v://%v:%v", proto, addr, ds.Port()) return ds } @@ -735,16 +762,29 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu func TestSimulation(t *testing.T) { simulationCtx, simulationCancel := context.WithCancel(context.Background()) - localDevServer := StartLocalDevServer() + // Start a local dev server with http + localDevServer := StartLocalDevServer("http", "localhost") defer localDevServer.Close() + // Start a local dev server with https + localDevTLSServer := StartLocalDevServer("https", "example.com") + defer localDevTLSServer.Close() + + // Write cert to file so we are able to pass it into mmar client + certErr := os.WriteFile("./temp-cert", localDevTLSServer.Certificate().Raw, 0644) // 0644 is file permissions + if certErr != nil { + log.Fatal(certErr) + } + go dnsserver.StartDnsServer() go StartMmarServer(simulationCtx) wait := time.NewTimer(2 * time.Second) <-wait.C clientUrlCh := make(chan string) - go StartMmarClient(simulationCtx, clientUrlCh, localDevServer.Port()) + + // Start a basic mmar client + go StartMmarClient(simulationCtx, clientUrlCh, localDevServer.Port(), "", "", "", "") // Wait for tunnel url tunnelUrl := <-clientUrlCh From 6214a0f2dfbe54db796e51cf1ed8c5ac512288c5 Mon Sep 17 00:00:00 2001 From: Yusuf Musleh Date: Sat, 23 Aug 2025 17:23:11 +0300 Subject: [PATCH 3/3] test: Support testing multiple mmar clients --- internal/client/main.go | 4 +- simulations/simulation_test.go | 79 +++++++++++++++++++++++---------- simulations/simulation_utils.go | 3 +- 3 files changed, 59 insertions(+), 27 deletions(-) diff --git a/internal/client/main.go b/internal/client/main.go index 68dcc3e..bf4fe25 100644 --- a/internal/client/main.go +++ b/internal/client/main.go @@ -148,8 +148,6 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { return } - logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true) - // Writing response to buffer to tunnel it back var responseBuff bytes.Buffer resp.Write(&responseBuff) @@ -158,6 +156,8 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { if err := mc.SendMessage(respMessage); err != nil { log.Fatal(err) } + + logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true) } // Keep attempting to reconnect the existing tunnel until successful diff --git a/simulations/simulation_test.go b/simulations/simulation_test.go index 26f9bb5..174efec 100644 --- a/simulations/simulation_test.go +++ b/simulations/simulation_test.go @@ -142,12 +142,13 @@ func verifyGetRequestSuccess(t *testing.T, client *http.Client, tunnelUrl string resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-get-request-success"}, } @@ -186,12 +187,13 @@ func verifyGetRequestFail(t *testing.T, client *http.Client, tunnelUrl string, w resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-get-request-fail"}, } @@ -237,12 +239,13 @@ func verifyPostRequestSuccess(t *testing.T, client *http.Client, tunnelUrl strin resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-post-request-success"}, "Content-Length": {strconv.Itoa(len(serializedReqBody))}, } @@ -292,12 +295,13 @@ func verifyPostRequestFail(t *testing.T, client *http.Client, tunnelUrl string, resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-post-request-fail"}, "Content-Length": {strconv.Itoa(len(serializedReqBody))}, } @@ -346,12 +350,13 @@ func verifyRedirectsHandled(t *testing.T, client *http.Client, tunnelUrl string, resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-redirect-request"}, "Referer": {tunnelUrl + "/redirect"}, // Include referer header since it redirects } @@ -405,6 +410,7 @@ func verifyInvalidMethodRequestHandled(t *testing.T, client *http.Client, tunnel expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-invalid-method-request"}, } @@ -477,7 +483,6 @@ func verifyInvalidHttpVersionRequestHandled(t *testing.T, tunnelUrl string, wg * if respErr != nil { t.Errorf("%v: Failed to get response %v", "verifyInvalidHttpVersionRequestHandled", respErr) } - if resp.StatusCode != http.StatusBadRequest { t.Errorf( "%v: resp.StatusCode = %v; want %v", @@ -573,7 +578,6 @@ func verifyContentLengthWithNoBodyRequestHandled(t *testing.T, tunnelUrl string, if respErr != nil { t.Errorf("%v: Failed to get response %v", "verifyContentLengthWithNoBodyRequestHandled", respErr) } - expectedBody := constants.READ_BODY_CHUNK_TIMEOUT_ERR_TEXT expectedResp := expectedResponse{ @@ -607,12 +611,13 @@ func verifyRequestWithLargeBody(t *testing.T, client *http.Client, tunnelUrl str resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-large-post-request-success"}, "Content-Length": {strconv.Itoa(len(serializedReqBody))}, } @@ -661,7 +666,11 @@ func verifyRequestWithVeryLargeBody(t *testing.T, client *http.Client, tunnelUrl resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + // Check if connection was closed in the middle of writing, that's also valid behavior + if !strings.Contains(respErr.Error(), "write: connection reset by peer") { + t.Errorf("Failed to get response: %v", respErr) + } + return } expectedBody := constants.MAX_REQ_BODY_SIZE_ERR_TEXT @@ -688,7 +697,7 @@ func verifyDevServerReturningInvalidRespHandled(t *testing.T, client *http.Clien resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedBody := constants.READ_RESP_BODY_ERR_TEXT @@ -715,7 +724,7 @@ func verifyDevServerLongRunningReqHandledGradefully(t *testing.T, client *http.C resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedBody := constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT @@ -742,7 +751,7 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedBody := constants.LOCALHOST_NOT_RUNNING_ERR_TEXT @@ -781,13 +790,26 @@ func TestSimulation(t *testing.T) { go StartMmarServer(simulationCtx) wait := time.NewTimer(2 * time.Second) <-wait.C - clientUrlCh := make(chan string) // Start a basic mmar client - go StartMmarClient(simulationCtx, clientUrlCh, localDevServer.Port(), "", "", "", "") - - // Wait for tunnel url - tunnelUrl := <-clientUrlCh + basicClientUrlCh := make(chan string) + go StartMmarClient(simulationCtx, basicClientUrlCh, localDevServer.Port(), "", "", "", "") + + // Start another basic mmar client + basicClientUrlCh2 := make(chan string) + go StartMmarClient(simulationCtx, basicClientUrlCh2, localDevServer.Port(), "", "", "", "") + + // Wait for all tunnel urls + mmarClientsCount := 2 + tunnelUrls := []string{} + for range mmarClientsCount { + select { + case tunnelUrl := <-basicClientUrlCh: + tunnelUrls = append(tunnelUrls, tunnelUrl) + case tunnelUrl := <-basicClientUrlCh2: + tunnelUrls = append(tunnelUrls, tunnelUrl) + } + } // Initialize http client client := httpClient() @@ -823,18 +845,27 @@ func TestSimulation(t *testing.T) { verifyContentLengthWithNoBodyRequestHandled, } - for _, simTest := range simulationTests { - wg.Add(1) - go simTest(t, client, tunnelUrl, &wg) - } + // Loop through all tunnel urls and run simulation tests + for _, tunnelUrl := range tunnelUrls { + + for _, simTest := range simulationTests { + wg.Add(1) + go simTest(t, client, tunnelUrl, &wg) + } - for _, manualClientSimTest := range manualClientSimulationTests { - wg.Add(1) - go manualClientSimTest(t, tunnelUrl, &wg) + for _, manualClientSimTest := range manualClientSimulationTests { + wg.Add(1) + go manualClientSimTest(t, tunnelUrl, &wg) + } } wg.Wait() + // Delete cert file + if rmErr := os.Remove("./temp-cert"); rmErr != nil { + log.Fatal(rmErr) + } + // Stop simulation tests simulationCancel() diff --git a/simulations/simulation_utils.go b/simulations/simulation_utils.go index d6ae1bf..0a13fab 100644 --- a/simulations/simulation_utils.go +++ b/simulations/simulation_utils.go @@ -95,7 +95,8 @@ func httpClient() *http.Client { dialer := initCustomDialer() tp := &http.Transport{ - DialContext: dialer.DialContext, + DialContext: dialer.DialContext, + DisableKeepAlives: true, } client := &http.Client{Transport: tp} return client