diff --git a/.github/workflows/buildandtest.yml b/.github/workflows/buildandtest.yml index 20148a63..a137624c 100644 --- a/.github/workflows/buildandtest.yml +++ b/.github/workflows/buildandtest.yml @@ -27,6 +27,6 @@ jobs: - name: Build run: direnv exec . go build -v ./... - name: Build Examples - run: direnv exec . go build -v ./examples/... + run: direnv exec . go build -C ./examples -v ./... - name: Test run: direnv exec . go test -v ./... diff --git a/CHANGELOG.md b/CHANGELOG.md index 5acf4fa4..bbb0a48f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,27 @@ +## 2.0.0 + +This is a breaking-change release that updates ngrok-go to a new, simplified +API. + +Enhancements: +- Dramatically simplified the API to remove many overlapping options, options + that are now deprecated, unnecessary convenience functions, and more. +- Simplified the API by removing all protocol-specific behaviors (which have + all been moved to Traffic Policy). +- Removed the config package. All of its options are now folded into the + top-level package or removed because they were migrated into Traffic Policy. +- Updates the API to use new ngrok terminology of Agents, Endpoints Upstreams, + and Traffic Policy. +- Removes functionality that is now deprecated (like labeled tunnels). +- Added support for agent-based TLS termination and Mutual TLS termination. +- Added support for full TLS control over forwarding to the upstream. +- Added support for full control over dialing the upstream. +- Removed a bespoke logging interface in favor of `log/slog`. +- Removed the prototype policy package that was not well supported. +- Separated out a concept of an Agent from its Session which were previously + co-mingled. +- Added integration tests. + ## 1.12.1 Fixes: @@ -9,7 +33,6 @@ Fixes: Breaking changes: - Renames pre-release option `WithAllowsPooling` to `WithPoolingEnabled` -- ## 1.11.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 04f37839..85f3a6df 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,14 +16,10 @@ For any larger changes or features, please [open a new issue](https://github.com The library can be compiled with `go build`. -To run tests, `go test`. +To run tests, use `go test ./...`. Tests are split into a number of categories that can be enabled as desired via environment variables. By default, only offline tests run which validate tunnel protocol RPC messages generated from the `config` APIs. The other tests are gated behind the following environment variables: -* `NGROK_TEST_ONLINE`: All online tests require this variable to be set -* `NGROK_TEST_AUTHED`: Enables tests that require an ngrok account and that the authtoken is set in `NGROK_AUTHTOKEN`. -* `NGROK_TEST_PAID`: Enables online, authenticated tests that require access to paid features. If your subscription doesn't support a feature being tested, you should see error messages to that effect. -* `NGROK_TEST_LONG`: Enables online tests that may take longer than most. May also require the `AUTHED` and/or `PAID` groups enabled. -* `NGROK_TEST_FLAKEY`: Enable online tests that may be unreliable. Their success or failure may depend on network conditions, timing, solar flares, ghosts in the machine, etc. +* `NGROK_TEST_ONLINE`: All online tests require this variable to be set and an authtoken in `NGROK_AUTHTOKEN`. Tests that require paid features will fail with appropriate error messages if your subscription doesn't support them. -This list may be incomplete and drift slightly as we add more tests and granularity. See the tests in `online_test.go` for the most accurate list. \ No newline at end of file +This list may be incomplete and drift slightly as we add more tests and granularity. See the tests in `internal/legacy/online_test.go` and `internal/integration_tests/` for the most accurate implementations. \ No newline at end of file diff --git a/README.md b/README.md index 1900665e..dc11987f 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,51 @@ # ngrok-go -[![Go Reference](https://pkg.go.dev/badge/golang.ngrok.com/ngrok.svg)](https://pkg.go.dev/golang.ngrok.com/ngrok) +[![Go Reference](https://pkg.go.dev/badge/golang.ngrok.com/ngrok/v2.svg)](https://pkg.go.dev/golang.ngrok.com/ngrok/v2) [![Go](https://github.com/ngrok/ngrok-go/actions/workflows/buildandtest.yml/badge.svg)](https://github.com/ngrok/ngrok-go/actions/workflows/buildandtest.yml) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/ngrok/ngrok-rust/blob/main/LICENSE-MIT) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/ngrok/ngrok-go/blob/main/LICENSE.txt) -[ngrok](https://ngrok.com) is a simplified API-first ingress-as-a-service that adds connectivity, security, and observability to your apps. +[ngrok](https://ngrok.com) is an API gateway cloud service that forwards to +applications running anywhere. -ngrok-go is an open source and idiomatic library for embedding ngrok networking directly into Go applications. If you’ve used ngrok before, you can think of ngrok-go as the ngrok agent packaged as a Go library. +ngrok-go is an open source and idiomatic Go package for embedding ngrok +networking directly into your Go applications. If you've used ngrok before, you +can think of ngrok-go as the ngrok agent packaged as a Go library. -ngrok-go lets developers serve Go apps on the internet in a single line of code without setting up low-level network primitives like IPs, certificates, load balancers and even ports! Applications using ngrok-go listen on ngrok’s global ingress network but they receive the same interface any Go app would expect (net.Listener) as if it listened on a local port by calling net.Listen(). This makes it effortless to integrate ngrok-go into any application that uses Go's net or net/http packages. - -See [`examples/http/main.go`](/examples/http/main.go) for example usage, or the tests in [`online_test.go`](/online_test.go). +ngrok-go enables you to serve Go apps on the internet in a single line of code +without setting up low-level network primitives like IPs, certificates, load +balancers and even ports! Applications using ngrok-go listen on ngrok's global +cloud service but, they receive connections using the same interface +(net.Listener) that any Go app would expect if it listened on a local port. For working with the [ngrok API](https://ngrok.com/docs/api/), check out the [ngrok Go API Client Library](https://github.com/ngrok/ngrok-api-go). ## Installation -The best way to install the ngrok agent SDK is through `go get`. +Install ngrok-go with `go get`. ```sh -go get golang.ngrok.com/ngrok +go get golang.ngrok.com/ngrok/v2 ``` ## Documentation -A full API reference is included in the [ngrok go sdk documentation on pkg.go.dev](https://pkg.go.dev/golang.ngrok.com/ngrok). Check out the [ngrok Documentation](https://ngrok.com/docs) for more information about what you can do with ngrok. - -For additional information, be sure to also check out the [ngrok-go launch announcement](https://ngrok.com/blog-post/ngrok-go)! +- [ngrok-go API Reference](https://pkg.go.dev/golang.ngrok.com/ngrok/v2) on pkg.go.dev. +- [ngrok Documentation](https://ngrok.com/docs) for what you can do with ngrok. +- [Examples](./examples) are another great way to get started. +- [ngrok-go launch announcement](https://ngrok.com/blog-post/ngrok-go) for more context on why we built it. The examples in the blog post may be out of date for the new API. ## Quickstart -For more examples of using ngrok-go, check out the [/examples](/examples) folder. - -The following example uses ngrok to start an http endpoint with a random url that will route traffic to the handler. The ngrok URL provided when running this example is accessible by anyone with an internet connection. +The following example starts a Go web server that receives traffic from an +endpoint on ngrok's cloud service with a randomly-assigned URL. The ngrok URL +provided when running this example is accessible by anyone with an internet +connection. -The ngrok authtoken is pulled from the `NGROK_AUTHTOKEN` environment variable. You can find your authtoken by logging into the [ngrok dashboard](https://dashboard.ngrok.com/get-started/your-authtoken). +You need an ngrok authtoken to run the following example, which you can get from +the [ngrok dashboard](https://dashboard.ngrok.com/get-started/your-authtoken). -You can run this example with the following command: +Run this example with the following command: ```sh NGROK_AUTHTOKEN=xxxx_xxxx go run examples/http/main.go @@ -52,8 +60,7 @@ import ( "log" "net/http" - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" + "golang.ngrok.com/ngrok/v2" ) func main() { @@ -63,16 +70,16 @@ func main() { } func run(ctx context.Context) error { - ln, err := ngrok.Listen(ctx, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ) + // ngrok.Listen uses ngrok.DefaultAgent which uses the NGROK_AUTHTOKEN + // environment variable for auth + ln, err := ngrok.Listen(ctx) if err != nil { return err } - log.Println("Ingress established at:", ln.URL()) + log.Println("Endpoint online", ln.URL()) + // Serve HTTP traffic on the ngrok endpoint return http.Serve(ln, http.HandlerFunc(handler)) } @@ -81,6 +88,54 @@ func handler(w http.ResponseWriter, r *http.Request) { } ``` +## Traffic Policy + +You can use ngrok's [Traffic Policy](https://ngrok.com/docs/traffic-policy/) +engine to apply API Gateway behaviors at ngrok's cloud service to auth, route, +block and rate-limit the traffic. For example: + +```go +tp := ` +on_http_request: + - name: "rate limit by ip address" + actions: + - type: rate-limit + config: + name: client-ip-rate-limit + algorithm: sliding_window + capacity: 30 + rate: 60s + bucket_key: + - conn.client_ip + - name: "federate to google for auth" + actions: + - type: oauth + config: + provider: google + - name: "block users without an 'example.com' domain" + expressions: + - "!actions.ngrok.oauth.identity.email.endsWith('@example.com')" + actions: + - type: custom-response + config: + status_code: 403 + content: "${actions.ngrok.oauth.identity.name} is not allowed" +` + +ln, err := ngrok.Listen(ctx, ngrok.WithTrafficPolicy(tp)) +if err != nil { + return err +} +``` + +## Examples + +There are many more great examples you can reference to get started: + +- [Creating a TCP endpoint](./examples/tcp/) and handling TCP connections directly. +- [Forwarding to another URL](./examples/forward/) instead of handling connections yourself. +- [Adding Traffic Policy](./examples/traffic-policy/) in front of your app for authentication, rate limiting, etc. + ## Support The best place to get support using ngrok-go is through the [ngrok Slack Community](https://ngrok.com/slack). If you find bugs or would like to contribute code, please follow the instructions in the [contributing guide](/CONTRIBUTING.md). diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..3760c97d --- /dev/null +++ b/TODO.md @@ -0,0 +1,10 @@ +TODO + +- Fix #176 +- Add agent configuration file parsing with an AgentConfig struct +- Wrap `Conn` objects returned by the listener with a type that can be used to determine if they are TLS-terminated or not +- Remove the legacy package by folding all of its logic into the current package +- Add an RPC test +- Implement support for AgentSession.ID() +- Endpoint.ID() should return the endpoint's API resource identifier but right now it just returns a random unique identifier unrelated to the API resource +- Endpoint.Wait() which can be used to wait until an endpoint stops \ No newline at end of file diff --git a/VERSION b/VERSION deleted file mode 100644 index f8f4f03b..00000000 --- a/VERSION +++ /dev/null @@ -1 +0,0 @@ -1.12.1 diff --git a/agent.go b/agent.go new file mode 100644 index 00000000..63bc4fff --- /dev/null +++ b/agent.go @@ -0,0 +1,440 @@ +package ngrok + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "slices" + "sync" + "time" + + "golang.org/x/net/proxy" + + "golang.ngrok.com/ngrok/v2/internal/legacy" + "golang.ngrok.com/ngrok/v2/internal/legacy/config" + "golang.ngrok.com/ngrok/v2/rpc" +) + +// Agent is the main interface for interacting with the ngrok service. +type Agent interface { + // Connect begins a new Session by connecting and authenticating to the ngrok cloud service. + Connect(context.Context) error + + // Disconnect terminates the current Session which disconnects it from the ngrok cloud service. + Disconnect() error + + // Session returns an object describing the connection of the Agent to the ngrok cloud service. + Session() (AgentSession, error) + + // Endpoints returns the list of endpoints created by this Agent from calls to either Listen or Forward. + Endpoints() []Endpoint + + // Listen creates an Endpoint which returns received connections to the caller via an EndpointListener. + Listen(context.Context, ...EndpointOption) (EndpointListener, error) + + // Forward creates an Endpoint which forwards received connections to a target upstream URL. + Forward(context.Context, *Upstream, ...EndpointOption) (EndpointForwarder, error) +} + +// Dialer is an interface that is satisfied by net.Dialer or you can specify your +// own implementation. +type Dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// agent implements the Agent interface. +type agent struct { + mu sync.RWMutex + sess legacy.Session + agentSession *agentSession + opts *agentOpts + endpoints []Endpoint + // Event handlers registered with this agent + eventHandlers []EventHandler + eventMutex sync.RWMutex // Protects eventHandlers +} + +// NewAgent creates a new Agent object. +func NewAgent(agentOpts ...AgentOption) (Agent, error) { + opts := defaultAgentOpts() + for _, opt := range agentOpts { + opt(opts) + } + + return &agent{ + opts: opts, + endpoints: make([]Endpoint, 0), + eventHandlers: opts.eventHandlers, + }, nil +} + +// Connect begins a new Session by connecting and authenticating to the ngrok +// cloud service. +func (a *agent) Connect(ctx context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + + // If we're already connected, return an error + if a.sess != nil && a.agentSession != nil { + return errors.New("agent already connected") + } + + // Add legacy connect handlers for events + legacyOpts := append([]legacy.ConnectOption{}, a.opts.sessionOpts...) + + // Process proxy URL if provided + if a.opts.proxyURL != "" { + parsedURL, err := url.Parse(a.opts.proxyURL) + if err != nil { + return fmt.Errorf("invalid proxy URL: %w", err) + } + + // Determine the base dialer to use for connecting to the proxy + baseDialer := a.opts.dialer + if baseDialer == nil { + // If no custom dialer is provided, use a standard net.Dialer + baseDialer = &net.Dialer{} + } + + // Create a proxy dialer using the base dialer + proxyDialer, err := proxy.FromURL(parsedURL, baseDialer) + if err != nil { + return fmt.Errorf("failed to initialize proxy: %w", err) + } + + // We know FromURL returns a Dialer-compatible type + dialer, ok := proxyDialer.(Dialer) + if !ok { + return fmt.Errorf("proxy dialer is not compatible with ngrok Dialer interface") + } + + // Set the dialer in our options + a.opts.dialer = dialer + // Pass it to the legacy package + legacyOpts = append(legacyOpts, legacy.WithDialer(dialer)) + } + + // Hook up connect event + legacyOpts = append(legacyOpts, legacy.WithConnectHandler(func(_ context.Context, sess legacy.Session) { + a.emitEvent(newAgentConnectSucceeded(a, a.agentSession)) + })) + + // Hook up disconnect event + legacyOpts = append(legacyOpts, legacy.WithDisconnectHandler(func(_ context.Context, sess legacy.Session, err error) { + a.emitEvent(newAgentDisconnected(a, a.agentSession, err)) + })) + + // Hook up heartbeat event + legacyOpts = append(legacyOpts, legacy.WithHeartbeatHandler(func(_ context.Context, sess legacy.Session, latency time.Duration) { + a.emitEvent(newAgentHeartbeatReceived(a, a.agentSession, latency)) + })) + + // If an RPC handler is registered, hook up the command handlers + if a.opts.rpcHandler != nil { + // Register the command handlers that delegate to the RPC handler + legacyOpts = append(legacyOpts, + legacy.WithStopHandler(a.createCommandHandler(rpc.StopAgentMethod)), + legacy.WithRestartHandler(a.createCommandHandler(rpc.RestartAgentMethod)), + legacy.WithUpdateHandler(a.createCommandHandler(rpc.UpdateAgentMethod)), + ) + } + + // Create a new ngrok session + sess, err := legacy.Connect(ctx, legacyOpts...) + if err != nil { + return wrapError(err) + } + + // Create our AgentSession wrapper + a.sess = sess + a.agentSession = &agentSession{ + warnings: sess.Warnings(), + agent: a, + startedAt: time.Now(), + } + + return nil +} + +// Disconnect terminates the current Session which disconnects it from the ngrok +// cloud service. +func (a *agent) Disconnect() error { + // Get what we need under lock + a.mu.Lock() + sess := a.sess + endpoints := a.endpoints + a.sess = nil + a.agentSession = nil + a.endpoints = make([]Endpoint, 0) + a.mu.Unlock() + + if sess == nil { + return nil + } + + // Signal done for all endpoints (not holding the lock) + for _, endpoint := range endpoints { + // Only signal done, don't remove (already cleared the list) + if e, ok := endpoint.(interface{ signalDone() }); ok { + e.signalDone() + } + } + + // Close session (not holding the lock) + err := sess.Close() + return wrapError(err) +} + +// Session returns an object describing the connection of the Agent to the ngrok +// cloud service. +func (a *agent) Session() (AgentSession, error) { + a.mu.RLock() + defer a.mu.RUnlock() + + if a.sess == nil || a.agentSession == nil { + return nil, errors.New("agent not connected") + } + + return a.agentSession, nil +} + +// Endpoints returns the list of endpoints created by this Agent. +func (a *agent) Endpoints() []Endpoint { + a.mu.RLock() + defer a.mu.RUnlock() + + // Return a copy to avoid race conditions + return slices.Clone(a.endpoints) +} + +// createListener creates an endpointListener for internal use +func (a *agent) createListener(ctx context.Context, endpointOpts *endpointOpts) (*endpointListener, error) { + // Get the session + a.mu.RLock() + sess := a.sess + a.mu.RUnlock() + + // Determine URL scheme and configure endpoint + scheme, err := determineURLScheme(endpointOpts.url) + if err != nil { + return nil, err + } + tunnelConfig, err := configureEndpoint(scheme, endpointOpts) + if err != nil { + return nil, err + } + + // Create tunnel and parse URL + tunnel, err := sess.Listen(ctx, tunnelConfig) + if err != nil { + return nil, wrapError(err) + } + tunnelURL, err := url.Parse(tunnel.URL()) + if err != nil { + return nil, fmt.Errorf("failed to parse tunnel URL: %w", err) + } + + // Validate upstream URL format if provided + if endpointOpts.upstreamURL != "" { + _, err = url.Parse(endpointOpts.upstreamURL) + if err != nil { + return nil, fmt.Errorf("invalid upstream URL: %w", err) + } + } + + // Create endpoint listener + endpoint := &endpointListener{ + baseEndpoint: baseEndpoint{ + agent: a, + id: tunnel.ID(), + poolingEnabled: endpointOpts.poolingEnabled, + bindings: endpointOpts.bindings, + description: endpointOpts.description, + metadata: endpointOpts.metadata, + agentTLSConfig: endpointOpts.agentTLSConfig, + trafficPolicy: endpointOpts.trafficPolicy, + endpointURL: *tunnelURL, + doneChannel: make(chan struct{}), + doneOnce: &sync.Once{}, + }, + tunnel: tunnel, + } + + // Add the endpoint to our list + a.mu.Lock() + a.endpoints = append(a.endpoints, endpoint) + a.mu.Unlock() + + return endpoint, nil +} + +// Listen creates an EndpointListener. +func (a *agent) Listen(ctx context.Context, opts ...EndpointOption) (EndpointListener, error) { + // Apply all options + endpointOpts := defaultEndpointOpts() + for _, opt := range opts { + opt(endpointOpts) + } + + // Ensure we're connected + if err := a.ensureConnected(ctx); err != nil { + return nil, err + } + + // Create the listener using the helper method + listener, err := a.createListener(ctx, endpointOpts) + if err != nil { + return nil, err + } + + return listener, nil +} + +// ensureConnected handles automatic connection and verifies connection state +func (a *agent) ensureConnected(ctx context.Context) error { + // First check if we're already connected (with a read lock) + a.mu.RLock() + sessionExists := a.sess != nil + a.mu.RUnlock() + + // Only try to connect if needed and auto-connect is enabled + if !sessionExists && a.opts.autoConnect { + if err := a.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + } + + // Final verification that we're connected + a.mu.RLock() + defer a.mu.RUnlock() + + if a.sess == nil { + return errors.New("agent not connected, call Connect() first") + } + + return nil +} + +// removeEndpoint removes an endpoint from the agent's list +func (a *agent) removeEndpoint(endpoint Endpoint) { + // Remove the endpoint from our list under lock + a.mu.Lock() + for i, e := range a.endpoints { + if e == endpoint { + a.endpoints = append(a.endpoints[:i], a.endpoints[i+1:]...) + break + } + } + a.mu.Unlock() +} + +// emitEvent sends an event to all registered handlers +func (a *agent) emitEvent(evt Event) { + a.eventMutex.RLock() + handlers := make([]EventHandler, len(a.eventHandlers)) + copy(handlers, a.eventHandlers) + a.eventMutex.RUnlock() + + for _, handler := range handlers { + // Call the handler directly + // Note: The handler is responsible for not blocking + handler(evt) + } +} + +// createCommandHandler returns a legacy.ServerCommandHandler that delegates to the RPCHandler +// for the specified RPC method. +func (a *agent) createCommandHandler(method string) legacy.ServerCommandHandler { + return func(ctx context.Context, sess legacy.Session) error { + if a.opts.rpcHandler == nil { + return nil + } + + // Get the current agent session + agentSession, err := a.Session() + if err != nil { + return err + } + + // Create request object with the specified method + req := &rpcRequest{ + method: method, + payload: nil, // No payload for now + } + + // Call the RPC handler + _, err = a.opts.rpcHandler(ctx, agentSession, req) + // Ignore response payload for now + return err + } +} + +// Forward creates an EndpointForwarder that forwards traffic to the specified upstream. +// The upstream parameter is required and must be created using WithUpstream(). +// Additional endpoint options can be provided to configure the endpoint. +func (a *agent) Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (EndpointForwarder, error) { + // Apply all base options first + endpointOpts := defaultEndpointOpts() + + // Set upstream values directly from the Upstream object + endpointOpts.upstreamURL = upstream.addr + endpointOpts.upstreamProtocol = upstream.protocol + endpointOpts.upstreamTLSClientConfig = upstream.tlsClientConfig + + // Convert the proxy protocol to config.ProxyProtoVersion + if upstream.proxyProto != "" { + var proxyVersion config.ProxyProtoVersion + switch upstream.proxyProto { + case ProxyProtoV1: + proxyVersion = config.ProxyProtoV1 + case ProxyProtoV2: + proxyVersion = config.ProxyProtoV2 + default: + return nil, fmt.Errorf("unsupported proxy protocol: %s", upstream.proxyProto) + } + endpointOpts.proxyProtoVersion = proxyVersion + } + + // Apply additional options + for _, opt := range opts { + opt(endpointOpts) + } + + // Ensure we're connected + if err := a.ensureConnected(ctx); err != nil { + return nil, err + } + + // Create the listener using the helper method + listener, err := a.createListener(ctx, endpointOpts) + if err != nil { + return nil, err + } + + // Parse upstream URL - we know it exists and is valid from createListener + upstreamURL, _ := url.Parse(endpointOpts.upstreamURL) + + // Create the forwarder + endpoint := &endpointForwarder{ + baseEndpoint: listener.baseEndpoint, // reuse the baseEndpoint from listener + listener: listener, + upstreamURL: *upstreamURL, + upstreamProtocol: endpointOpts.upstreamProtocol, + upstreamTLSClientConfig: endpointOpts.upstreamTLSClientConfig, + proxyProtocol: upstream.proxyProto, + upstreamDialer: upstream.dialer, + } + + // Start the forwarding process + endpoint.start(ctx) + + // Add the endpoint to our list + a.mu.Lock() + a.endpoints = append(a.endpoints, endpoint) + a.mu.Unlock() + + return endpoint, nil +} diff --git a/agent_options.go b/agent_options.go new file mode 100644 index 00000000..7f0e3b99 --- /dev/null +++ b/agent_options.go @@ -0,0 +1,216 @@ +package ngrok + +import ( + "crypto/tls" + "crypto/x509" + "log/slog" + "time" + + "golang.ngrok.com/ngrok/v2/internal/legacy" +) + +// AgentOption is a functional option used to configure NewAgent. +type AgentOption func(*agentOpts) + +// agentOpts stores configuration for Agent. +type agentOpts struct { + authtoken string + logger *slog.Logger + connectURL string + autoConnect bool + clientInfo clientInfo + dialer Dialer + description string + metadata string + proxyURL string + connectCAs *x509.CertPool + tlsConfig func(*tls.Config) + multiLeg bool + heartbeatInterval time.Duration + heartbeatTolerance time.Duration + // Event handlers registered with the agent + eventHandlers []EventHandler + // RPC handler for server commands + rpcHandler RPCHandler + // Store ngrok SDK options + sessionOpts []legacy.ConnectOption +} + +type clientInfo struct { + clientType string + version string + comments []string +} + +// defaultAgentOpts returns the default options for an agent. +func defaultAgentOpts() *agentOpts { + return &agentOpts{ + autoConnect: true, + sessionOpts: []legacy.ConnectOption{}, + } +} + +// WithAgentConnectCAs defines the CAs used to validate the TLS certificate +// returned by the ngrok service when establishing a session. +// +// See https://ngrok.com/docs/agent/config/v3/#connect_cas +func WithAgentConnectCAs(pool *x509.CertPool) AgentOption { + return func(opts *agentOpts) { + opts.connectCAs = pool + opts.sessionOpts = append(opts.sessionOpts, legacy.WithCA(pool)) + } +} + +// WithAgentConnectURL defines the URL the agent connects to in order to +// establish a connection to the ngrok cloud service. +// +// See https://ngrok.com/docs/agent/config/v3/#connect_url +func WithAgentConnectURL(addr string) AgentOption { + return func(opts *agentOpts) { + opts.connectURL = addr + opts.sessionOpts = append(opts.sessionOpts, legacy.WithServer(addr)) + } +} + +// WithAuthtoken specifies the authtoken to use for authenticating to the +// ngrok cloud service during Connect. +// +// See https://ngrok.com/docs/agent/#authtokens +func WithAuthtoken(token string) AgentOption { + return func(opts *agentOpts) { + opts.authtoken = token + opts.sessionOpts = append(opts.sessionOpts, legacy.WithAuthtoken(token)) + } +} + +// WithAutoConnect controls whether the Agent will automatically call +// Connect(). When enabled, if an endpoint is created via Listen() or Connect() +// and the Agent does not have an active session, it will automatically Connect(). +func WithAutoConnect(auto bool) AgentOption { + return func(opts *agentOpts) { + opts.autoConnect = auto + } +} + +// WithClientInfo provides client information to the ngrok cloud service. +func WithClientInfo(clientType, version string, comments ...string) AgentOption { + return func(opts *agentOpts) { + opts.clientInfo = clientInfo{ + clientType: clientType, + version: version, + comments: comments, + } + opts.sessionOpts = append(opts.sessionOpts, legacy.WithClientInfo(clientType, version, comments...)) + } +} + +// WithDialer customizes how the Agent establishes connections to the ngrok +// cloud service. +func WithDialer(dialer Dialer) AgentOption { + return func(opts *agentOpts) { + opts.dialer = dialer + opts.sessionOpts = append(opts.sessionOpts, legacy.WithDialer(dialer)) + } +} + +// WithAgentDescription sets a human-readable description for the agent session. +func WithAgentDescription(desc string) AgentOption { + return func(opts *agentOpts) { + opts.description = desc + } +} + +// WithHeartbeatInterval sets how often the agent will send heartbeat +// messages to the ngrok service. +// +// See https://ngrok.com/docs/agent/#heartbeats +func WithHeartbeatInterval(interval time.Duration) AgentOption { + return func(opts *agentOpts) { + opts.heartbeatInterval = interval + opts.sessionOpts = append(opts.sessionOpts, legacy.WithHeartbeatInterval(interval)) + } +} + +// WithHeartbeatTolerance sets how long to wait for a heartbeat response +// before assuming the connection is dead. +// +// See https://ngrok.com/docs/agent/#heartbeats +func WithHeartbeatTolerance(tolerance time.Duration) AgentOption { + return func(opts *agentOpts) { + opts.heartbeatTolerance = tolerance + opts.sessionOpts = append(opts.sessionOpts, legacy.WithHeartbeatTolerance(tolerance)) + } +} + +// WithLogger sets the logger to use for the agent. +// Accepts a standard log/slog.Logger from the Go standard library. +func WithLogger(logger *slog.Logger) AgentOption { + return func(opts *agentOpts) { + opts.logger = logger + + // Convert slog logger to log15 for the legacy API + log15Logger := legacy.SlogToLog15(logger) + opts.sessionOpts = append(opts.sessionOpts, legacy.WithLogger(log15Logger)) + } +} + +// WithAgentMetadata sets opaque, machine-readable metadata for the agent session. +// +// See https://ngrok.com/docs/api/resources/tunnel-sessions/#response-1 +func WithAgentMetadata(meta string) AgentOption { + return func(opts *agentOpts) { + opts.metadata = meta + opts.sessionOpts = append(opts.sessionOpts, legacy.WithMetadata(meta)) + } +} + +// WithMultiLeg enables connecting to the ngrok service on secondary legs. This +// option is EXPERIMENTAL and may be removed without a breaking version change. +func WithMultiLeg(enable bool) AgentOption { + return func(opts *agentOpts) { + opts.multiLeg = enable + opts.sessionOpts = append(opts.sessionOpts, legacy.WithMultiLeg(enable)) + } +} + +// WithProxyURL sets the proxy URL to use when connecting to the ngrok service. +// The URL will be parsed and processed during Connect. +// +// If used with WithDialer, the custom dialer will be used to establish the +// connection to the proxy, which will then connect to the ngrok service. +// +// See https://ngrok.com/docs/agent/config/v3/#proxy_url +func WithProxyURL(urlSpec string) AgentOption { + return func(opts *agentOpts) { + opts.proxyURL = urlSpec + } +} + +// WithTLSConfig customizes the TLS configuration for connections to the ngrok +// service. +func WithTLSConfig(tlsCustomizer func(*tls.Config)) AgentOption { + return func(opts *agentOpts) { + opts.tlsConfig = tlsCustomizer + opts.sessionOpts = append(opts.sessionOpts, legacy.WithTLSConfig(tlsCustomizer)) + } +} + +// WithEventHandler registers a callback to receive events from the Agent. If +// called multiple times, each handler will receive callbacks. See +// [EventHandler] for details on correctly authoring handlers. +func WithEventHandler(handler EventHandler) AgentOption { + return func(opts *agentOpts) { + opts.eventHandlers = append(opts.eventHandlers, handler) + } +} + +// WithRPCHandler registers a handler for RPC commands from the ngrok service. +// This handler will be called when the agent receives RPC requests like StopAgent, +// RestartAgent, or UpdateAgent. +func WithRPCHandler(handler RPCHandler) AgentOption { + return func(opts *agentOpts) { + opts.rpcHandler = handler + // Note: Legacy handlers will be registered in the agent.Connect method + // to have access to the agent's session + } +} diff --git a/config/basic_auth.go b/config/basic_auth.go deleted file mode 100644 index 4c24582b..00000000 --- a/config/basic_auth.go +++ /dev/null @@ -1,33 +0,0 @@ -package config - -import "golang.ngrok.com/ngrok/internal/pb" - -// BasicAuth is a set of credentials for basic authentication. -type basicAuth struct { - // The username for basic authentication. - Username string - // The password for basic authentication. - // Must be at least eight characters. - Password string -} - -func (ba basicAuth) toProtoConfig() *pb.MiddlewareConfiguration_BasicAuthCredential { - return &pb.MiddlewareConfiguration_BasicAuthCredential{ - CleartextPassword: ba.Password, - Username: ba.Username, - } -} - -// WithBasicAuth adds the provided credentials to the list of basic -// authentication credentials. -// -// https://ngrok.com/docs/http/basic-auth/ -func WithBasicAuth(username, password string) HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - cfg.BasicAuth = append(cfg.BasicAuth, - basicAuth{ - Username: username, - Password: password, - }) - }) -} diff --git a/config/basic_auth_test.go b/config/basic_auth_test.go deleted file mode 100644 index b4365089..00000000 --- a/config/basic_auth_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestBasicAuth(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "single", - opts: HTTPEndpoint(WithBasicAuth("foo", "bar")), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.NotNil(t, opts.BasicAuth) - require.Len(t, opts.BasicAuth.Credentials, 1) - require.Contains(t, opts.BasicAuth.Credentials, &pb.MiddlewareConfiguration_BasicAuthCredential{ - Username: "foo", - CleartextPassword: "bar", - }) - }, - }, - { - name: "multiple", - opts: HTTPEndpoint( - WithBasicAuth("foo", "bar"), - WithBasicAuth("spam", "eggs"), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.NotNil(t, opts.BasicAuth) - require.Len(t, opts.BasicAuth.Credentials, 2) - require.Contains(t, opts.BasicAuth.Credentials, &pb.MiddlewareConfiguration_BasicAuthCredential{ - Username: "foo", - CleartextPassword: "bar", - }) - require.Contains(t, opts.BasicAuth.Credentials, &pb.MiddlewareConfiguration_BasicAuthCredential{ - Username: "spam", - CleartextPassword: "eggs", - }) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/cidr_restrictions.go b/config/cidr_restrictions.go deleted file mode 100644 index c76e80d4..00000000 --- a/config/cidr_restrictions.go +++ /dev/null @@ -1,101 +0,0 @@ -package config - -import ( - "net" - - "golang.ngrok.com/ngrok/internal/pb" -) - -// Restrictions placed on the origin of incoming connections to the edge. -type cidrRestrictions struct { - // Rejects connections that do not match the given CIDRs - Allowed []string - // Rejects connections that match the given CIDRs and allows all other CIDRs. - Denied []string -} - -// Add the provided CIDRS to the [CIDRRestriction].Allowed list. -// -// https://ngrok.com/docs/http/ip-restrictions/ -func WithAllowCIDRString(cidr ...string) interface { - HTTPEndpointOption - TCPEndpointOption - TLSEndpointOption -} { - return &cidrRestrictions{Allowed: cidr} -} - -// Add the provided [net.IPNet] to the [CIDRRestriction].Allowed list. -// -// https://ngrok.com/docs/http/ip-restrictions/ -func WithAllowCIDR(net ...*net.IPNet) interface { - HTTPEndpointOption - TCPEndpointOption - TLSEndpointOption -} { - cidrStrings := make([]string, 0, len(net)) - for _, n := range net { - cidrStrings = append(cidrStrings, n.String()) - } - return &cidrRestrictions{Allowed: cidrStrings} -} - -// Add the provided CIDRS to the [CIDRRestriction].Denied list. -// -// https://ngrok.com/docs/http/ip-restrictions/ -func WithDenyCIDRString(cidr ...string) interface { - HTTPEndpointOption - TCPEndpointOption - TLSEndpointOption -} { - return cidrRestrictions{Denied: cidr} -} - -// Add the provided [net.IPNet] to the [CIDRRestriction].Denied list. -// -// https://ngrok.com/docs/http/ip-restrictions/ -func WithDenyCIDR(net ...*net.IPNet) interface { - HTTPEndpointOption - TCPEndpointOption - TLSEndpointOption -} { - cidrStrings := make([]string, 0, len(net)) - for _, n := range net { - cidrStrings = append(cidrStrings, n.String()) - } - return cidrRestrictions{Denied: cidrStrings} -} - -func (base *cidrRestrictions) merge(set cidrRestrictions) *cidrRestrictions { - if base == nil { - base = &cidrRestrictions{} - } - - base.Allowed = append(base.Allowed, set.Allowed...) - base.Denied = append(base.Denied, set.Denied...) - - return base -} - -func (ir *cidrRestrictions) toProtoConfig() *pb.MiddlewareConfiguration_IPRestriction { - if ir == nil { - return nil - } - - return &pb.MiddlewareConfiguration_IPRestriction{ - AllowCidrs: ir.Allowed, - DenyCidrs: ir.Denied, - } -} - -func (opt cidrRestrictions) ApplyHTTP(opts *httpOptions) { - opts.CIDRRestrictions = opts.CIDRRestrictions.merge(opt) -} - -func (opt cidrRestrictions) ApplyTCP(opts *tcpOptions) { - opts.CIDRRestrictions = opts.CIDRRestrictions.merge(opt) -} - -func (opt cidrRestrictions) ApplyTLS(opts *tlsOptions) { - opts.CIDRRestrictions = opts.CIDRRestrictions.merge(opt) -} diff --git a/config/cidr_restrictions_test.go b/config/cidr_restrictions_test.go deleted file mode 100644 index fb45aa7d..00000000 --- a/config/cidr_restrictions_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package config - -import ( - "net" - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func mustParseCIDR(cidr string) *net.IPNet { - _, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - panic("TEST BUG: invalid CIDR: " + cidr) - } - return ipnet -} - -func testCIDRRestrictions[T tunnelConfigPrivate, O any, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, - getRestrictions func(*O) *pb.MiddlewareConfiguration_IPRestriction, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - cases := testCases[T, O]{ - { - name: "allow string", - opts: optsFunc(WithAllowCIDRString("127.0.0.0/8")), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.Equal(t, []string{"127.0.0.0/8"}, actual.AllowCidrs) - }, - }, - { - name: "deny string", - opts: optsFunc(WithDenyCIDRString("127.0.0.0/8")), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.Equal(t, []string{"127.0.0.0/8"}, actual.DenyCidrs) - }, - }, - { - name: "allow ipnet", - opts: optsFunc(WithAllowCIDR(mustParseCIDR("127.0.0.0/8"))), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.Equal(t, []string{"127.0.0.0/8"}, actual.AllowCidrs) - }, - }, - { - name: "deny ipnet", - opts: optsFunc(WithDenyCIDR(mustParseCIDR("127.0.0.0/8"))), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.Equal(t, []string{"127.0.0.0/8"}, actual.DenyCidrs) - }, - }, - { - name: "allow multi", - opts: optsFunc( - WithAllowCIDRString("127.0.0.0/8"), - WithAllowCIDRString("10.0.0.0/8"), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.ElementsMatch(t, []string{"127.0.0.0/8", "10.0.0.0/8"}, actual.AllowCidrs) - }, - }, - { - name: "deny multi", - opts: optsFunc( - WithDenyCIDRString("127.0.0.0/8"), - WithDenyCIDRString("10.0.0.0/8"), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.ElementsMatch(t, []string{"127.0.0.0/8", "10.0.0.0/8"}, actual.DenyCidrs) - }, - }, - { - name: "allow and deny multi", - opts: optsFunc( - WithAllowCIDRString("127.0.0.0/8"), - WithAllowCIDRString("10.0.0.0/8"), - WithDenyCIDRString("192.0.0.0/8"), - WithDenyCIDRString("172.0.0.0/8"), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.ElementsMatch(t, []string{"192.0.0.0/8", "172.0.0.0/8"}, actual.DenyCidrs) - require.ElementsMatch(t, []string{"127.0.0.0/8", "10.0.0.0/8"}, actual.AllowCidrs) - }, - }, - { - name: "allow and deny multi ipnet", - opts: optsFunc( - WithAllowCIDR(mustParseCIDR("127.0.0.0/8")), - WithAllowCIDR(mustParseCIDR("10.0.0.0/8")), - WithDenyCIDR(mustParseCIDR("192.0.0.0/8")), - WithDenyCIDR(mustParseCIDR("172.0.0.0/8")), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getRestrictions(opts) - require.NotNil(t, actual) - require.ElementsMatch(t, []string{"192.0.0.0/8", "172.0.0.0/8"}, actual.DenyCidrs) - require.ElementsMatch(t, []string{"127.0.0.0/8", "10.0.0.0/8"}, actual.AllowCidrs) - }, - }, - } - - cases.runAll(t) -} - -func TestCIDRRestrictions(t *testing.T) { - testCIDRRestrictions[*httpOptions](t, HTTPEndpoint, - func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_IPRestriction { - return h.IPRestriction - }) - testCIDRRestrictions[*tcpOptions](t, TCPEndpoint, - func(h *proto.TCPEndpoint) *pb.MiddlewareConfiguration_IPRestriction { - return h.IPRestriction - }) - testCIDRRestrictions[*tlsOptions](t, TLSEndpoint, - func(h *proto.TLSEndpoint) *pb.MiddlewareConfiguration_IPRestriction { - return h.IPRestriction - }) -} diff --git a/config/circuit_breaker.go b/config/circuit_breaker.go deleted file mode 100644 index 625208ce..00000000 --- a/config/circuit_breaker.go +++ /dev/null @@ -1,11 +0,0 @@ -package config - -// WithCircuitBreaker sets the 5XX response ratio at which the ngrok edge will -// stop sending requests to this tunnel. -// -// https://ngrok.com/docs/http/circuit-breaker/ -func WithCircuitBreaker(ratio float64) HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - cfg.CircuitBreaker = ratio - }) -} diff --git a/config/circuit_breaker_test.go b/config/circuit_breaker_test.go deleted file mode 100644 index c4877527..00000000 --- a/config/circuit_breaker_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestCircuitBreaker(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.Nil(t, opts.CircuitBreaker) - }, - }, - { - name: "breakered", - opts: HTTPEndpoint(WithCircuitBreaker(0.5)), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.NotNil(t, opts.CircuitBreaker) - require.Equal(t, opts.CircuitBreaker.ErrorThreshold, 0.5) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/compression.go b/config/compression.go deleted file mode 100644 index 624acd9f..00000000 --- a/config/compression.go +++ /dev/null @@ -1,10 +0,0 @@ -package config - -// WithCompression enables gzip compression. -// -// https://ngrok.com/docs/http/compression/ -func WithCompression() HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - cfg.Compression = true - }) -} diff --git a/config/compression_test.go b/config/compression_test.go deleted file mode 100644 index 6a4b2883..00000000 --- a/config/compression_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestCompression(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.Nil(t, opts.Compression) - }, - }, - { - name: "compressed", - opts: HTTPEndpoint(WithCompression()), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.NotNil(t, opts.Compression) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/domain.go b/config/domain.go deleted file mode 100644 index a45018d9..00000000 --- a/config/domain.go +++ /dev/null @@ -1,61 +0,0 @@ -package config - -type domainOption string - -// WithDomain sets the fully-qualified domain name for this edge. -// -// https://ngrok.com/docs/network-edge/domains-and-tcp-addresses/#domains -func WithDomain(name string) interface { - HTTPEndpointOption - TLSEndpointOption -} { - return domainOption(name) -} - -func (opt domainOption) ApplyHTTP(opts *httpOptions) { - opts.Domain = string(opt) -} - -func (opt domainOption) ApplyTLS(opts *tlsOptions) { - opts.Domain = string(opt) -} - -type hostnameOption string - -// WithHostname sets the hostname for this edge. -// -// Deprecated: use WithDomain instead -func WithHostname(name string) interface { - HTTPEndpointOption - TLSEndpointOption -} { - return hostnameOption(name) -} - -func (opt hostnameOption) ApplyHTTP(opts *httpOptions) { - opts.Hostname = string(opt) -} - -func (opt hostnameOption) ApplyTLS(opts *tlsOptions) { - opts.Hostname = string(opt) -} - -type subdomainOption string - -// WithSubdomain sets the subdomain for this edge. -// -// Deprecated: use WithDomain instead -func WithSubdomain(name string) interface { - HTTPEndpointOption - TLSEndpointOption -} { - return subdomainOption(name) -} - -func (opt subdomainOption) ApplyHTTP(opts *httpOptions) { - opts.Subdomain = string(opt) -} - -func (opt subdomainOption) ApplyTLS(opts *tlsOptions) { - opts.Subdomain = string(opt) -} diff --git a/config/domain_test.go b/config/domain_test.go deleted file mode 100644 index 39277470..00000000 --- a/config/domain_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func testDomain[T tunnelConfigPrivate, O any, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, - getDomain func(*O) string, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - - cases := testCases[T, O]{ - { - name: "absent", - opts: optsFunc(), - expectOpts: func(t *testing.T, opts *O) { - actual := getDomain(opts) - require.Empty(t, actual) - }, - }, - { - name: "with domain", - opts: optsFunc(WithDomain("foo.ngrok.io")), - expectOpts: func(t *testing.T, opts *O) { - actual := getDomain(opts) - require.NotEmpty(t, actual) - require.Equal(t, "foo.ngrok.io", actual) - }, - }, - } - - cases.runAll(t) -} - -func TestDomain(t *testing.T) { - testDomain[*httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) string { - return opts.Domain - }) - testDomain[*tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) string { - return opts.Domain - }) -} diff --git a/config/forwards_to_test.go b/config/forwards_to_test.go deleted file mode 100644 index 429b2ef4..00000000 --- a/config/forwards_to_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package config - -import ( - "testing" -) - -func testForwardsTo[T tunnelConfigPrivate, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - - cases := testCases[T, any]{ - { - name: "absent", - opts: optsFunc(), - expectForwardsTo: ptr(defaultForwardsTo()), - }, - { - name: "with forwardsTo", - opts: optsFunc(WithForwardsTo("localhost:8080")), - expectForwardsTo: ptr("localhost:8080"), - }, - } - - cases.runAll(t) -} - -func TestForwardsTo(t *testing.T) { - testForwardsTo[*httpOptions](t, HTTPEndpoint) - testForwardsTo[*tlsOptions](t, TLSEndpoint) - testForwardsTo[*tcpOptions](t, TCPEndpoint) - testForwardsTo[*labeledOptions](t, LabeledTunnel) -} diff --git a/config/http.go b/config/http.go deleted file mode 100644 index 36f5bb88..00000000 --- a/config/http.go +++ /dev/null @@ -1,181 +0,0 @@ -package config - -import ( - "crypto/x509" - "net/http" - "net/url" - - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -type HTTPEndpointOption interface { - ApplyHTTP(cfg *httpOptions) -} - -type httpOptionFunc func(cfg *httpOptions) - -func (of httpOptionFunc) ApplyHTTP(cfg *httpOptions) { - of(cfg) -} - -// HTTPEndpoint constructs a new set options for a HTTP endpoint. -// -// https://ngrok.com/docs/http/ -func HTTPEndpoint(opts ...HTTPEndpointOption) Tunnel { - cfg := httpOptions{} - for _, opt := range opts { - opt.ApplyHTTP(&cfg) - } - return &cfg -} - -type httpOptions struct { - // Common tunnel configuration options. - commonOpts - - // The scheme that this edge should use. - // Defaults to [SchemeHTTPS]. - Scheme Scheme - - // The fqdn to request for this edge - Domain string - - // Note: these are "the old way", and shouldn't actually be used. Their - // setters are both deprecated. - Hostname string - Subdomain string - - // If non-nil, start a goroutine which runs this http server - // accepting connections from the http tunnel - // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. - httpServer *http.Server - - // Certificates to use for client authentication at the ngrok edge. - MutualTLSCA []*x509.Certificate - // Enable gzip compression for HTTP responses. - Compression bool - // Convert incoming websocket connections to TCP-like streams. - WebsocketTCPConversion bool - // Reject requests when 5XX responses exceed this ratio. - // Disabled when 0. - CircuitBreaker float64 - - // Headers to be added to or removed from all requests at the ngrok edge. - RequestHeaders *headers - // Headers to be added to or removed from all responses at the ngrok edge. - ResponseHeaders *headers - - // Auto-rewrite host header on ListenAndForward? - RewriteHostHeader bool - - // Credentials for basic authentication. - // If empty, basic authentication is disabled. - BasicAuth []basicAuth - // OAuth configuration. - // If nil, OAuth is disabled. - OAuth *oauthOptions - // OIDC configuration. - // If nil, OIDC is disabled. - OIDC *oidcOptions - // WebhookVerification configuration. - // If nil, WebhookVerification is disabled. - WebhookVerification *webhookVerification - // UserAgentFilter configuration - // If nil, UserAgentFilter is disabled - UserAgentFilter *userAgentFilter -} - -func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint { - opts := &proto.HTTPEndpoint{ - URL: cfg.URL, - Domain: cfg.Domain, - Hostname: cfg.Hostname, - Subdomain: cfg.Subdomain, - } - - if cfg.Compression { - opts.Compression = &pb.MiddlewareConfiguration_Compression{} - } - - if cfg.WebsocketTCPConversion { - opts.WebsocketTCPConverter = &pb.MiddlewareConfiguration_WebsocketTCPConverter{} - } - - if cfg.CircuitBreaker != 0 { - opts.CircuitBreaker = &pb.MiddlewareConfiguration_CircuitBreaker{ - ErrorThreshold: cfg.CircuitBreaker, - } - } - - opts.MutualTLSCA = mutualTLSEndpointOption(cfg.MutualTLSCA).toProtoConfig() - - opts.ProxyProto = proto.ProxyProto(cfg.commonOpts.ProxyProto) - - opts.RequestHeaders = cfg.RequestHeaders.toProtoConfig() - opts.ResponseHeaders = cfg.ResponseHeaders.toProtoConfig() - if len(cfg.BasicAuth) > 0 { - opts.BasicAuth = &pb.MiddlewareConfiguration_BasicAuth{} - for _, c := range cfg.BasicAuth { - opts.BasicAuth.Credentials = append(opts.BasicAuth.Credentials, c.toProtoConfig()) - } - } - opts.OAuth = cfg.OAuth.toProtoConfig() - opts.OIDC = cfg.OIDC.toProtoConfig() - opts.WebhookVerification = cfg.WebhookVerification.toProtoConfig() - opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig() - opts.UserAgentFilter = cfg.UserAgentFilter.toProtoConfig() - opts.TrafficPolicy = cfg.TrafficPolicy - - return opts -} - -func (cfg httpOptions) ForwardsProto() string { - return cfg.commonOpts.ForwardsProto -} - -func (cfg httpOptions) ForwardsTo() string { - return cfg.commonOpts.getForwardsTo() -} - -func (cfg *httpOptions) WithForwardsTo(url *url.URL) { - cfg.commonOpts.ForwardsTo = url.Host - if cfg.RewriteHostHeader { - WithRequestHeader("host", url.Host).ApplyHTTP(cfg) - } -} - -func (cfg httpOptions) Extra() proto.BindExtra { - return proto.BindExtra{ - Name: cfg.Name, - Metadata: cfg.Metadata, - Description: cfg.Description, - Bindings: cfg.Bindings, - PoolingEnabled: cfg.PoolingEnabled, - } -} - -func (cfg httpOptions) Proto() string { - if cfg.Scheme == "" { - return string(SchemeHTTPS) - } - return string(cfg.Scheme) -} - -func (cfg httpOptions) Opts() any { - return cfg.toProtoConfig() -} - -func (cfg httpOptions) Labels() map[string]string { - return nil -} - -func (cfg httpOptions) HTTPServer() *http.Server { - return cfg.httpServer -} - -// compile-time check that we're implementing the proper interfaces. -var _ interface { - tunnelConfigPrivate - Tunnel -} = (*httpOptions)(nil) diff --git a/config/http_handler.go b/config/http_handler.go deleted file mode 100644 index 611b91d1..00000000 --- a/config/http_handler.go +++ /dev/null @@ -1,51 +0,0 @@ -package config - -import ( - "net/http" -) - -type httpServerOption struct { - Server *http.Server -} - -type Options interface { - HTTPEndpointOption - TLSEndpointOption - TCPEndpointOption - LabeledTunnelOption - CommonOption -} - -func (opt *httpServerOption) ApplyCommon(cfg *commonOpts) { - -} - -func (opt *httpServerOption) ApplyHTTP(cfg *httpOptions) { - cfg.httpServer = opt.Server -} - -func (opt *httpServerOption) ApplyTCP(cfg *tcpOptions) { - cfg.httpServer = opt.Server -} - -func (opt *httpServerOption) ApplyTLS(cfg *tlsOptions) { - cfg.httpServer = opt.Server -} - -func (opt *httpServerOption) ApplyLabeled(cfg *labeledOptions) { - cfg.httpServer = opt.Server -} - -// WithHTTPHandler adds the provided credentials to the list of basic -// authentication credentials. -// Deprecated: Use session.ListenAndHandleHTTP instead. -func WithHTTPHandler(h http.Handler) Options { - return WithHTTPServer(&http.Server{Handler: h}) -} - -// WithHTTPServer adds the provided credentials to the list of basic -// authentication credentials. -// Deprecated: Use session.ListenAndServeHTTP instead. -func WithHTTPServer(srv *http.Server) Options { - return &httpServerOption{Server: srv} -} diff --git a/config/http_headers.go b/config/http_headers.go deleted file mode 100644 index bf812ab6..00000000 --- a/config/http_headers.go +++ /dev/null @@ -1,112 +0,0 @@ -package config - -import ( - "fmt" - "net/http" - - "golang.ngrok.com/ngrok/internal/pb" -) - -// HTTP Headers to modify at the ngrok edge. -type headers struct { - // Headers to add to requests or responses at the ngrok edge. - Added map[string]string - // Header names to remove from requests or responses at the ngrok edge. - Removed []string -} - -func (h *headers) toProtoConfig() *pb.MiddlewareConfiguration_Headers { - if h == nil { - return nil - } - - headers := &pb.MiddlewareConfiguration_Headers{ - Remove: h.Removed, - } - - for k, v := range h.Added { - headers.Add = append(headers.Add, fmt.Sprintf("%s:%s", k, v)) - } - - return headers -} - -func (h *headers) merge(other headers) *headers { - if h == nil { - h = &headers{ - Added: map[string]string{}, - Removed: []string{}, - } - } - - for k, v := range other.Added { - if existing, ok := h.Added[k]; ok { - v = fmt.Sprintf("%s;%s", existing, v) - } - h.Added[k] = v - } - - h.Removed = append(h.Removed, other.Removed...) - - return h -} - -type requestHeaders headers -type responseHeaders headers - -func (h requestHeaders) ApplyHTTP(cfg *httpOptions) { - cfg.RequestHeaders = cfg.RequestHeaders.merge(headers(h)) - -} - -func (h responseHeaders) ApplyHTTP(cfg *httpOptions) { - cfg.ResponseHeaders = cfg.ResponseHeaders.merge(headers(h)) -} - -// WithHostHeaderRewrite will automatically set the `Host` header to the one in -// the URL passed to `ListenAndForward`. Does nothing if using `Listen`. -// Defaults to `false`. -// -// If you need to set the host header to a specific value, use -// `WithRequestHeader("host", "some.host.com")` instead. -func WithHostHeaderRewrite(rewrite bool) HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - cfg.RewriteHostHeader = rewrite - }) -} - -// WithRequestHeader adds a header to all requests to this edge. -// -// https://ngrok.com/docs/http/request-headers/ -func WithRequestHeader(name, value string) HTTPEndpointOption { - return requestHeaders(headers{ - Added: map[string]string{http.CanonicalHeaderKey(name): value}, - }) -} - -// WithRequestHeader adds a header to all responses coming from this edge. -// -// https://ngrok.com/docs/http/response-headers/ -func WithResponseHeader(name, value string) HTTPEndpointOption { - return responseHeaders(headers{ - Added: map[string]string{http.CanonicalHeaderKey(name): value}, - }) -} - -// WithRemoveRequestHeader removes a header from requests to this edge. -// -// https://ngrok.com/docs/http/request-headers/ -func WithRemoveRequestHeader(name string) HTTPEndpointOption { - return requestHeaders(headers{ - Removed: []string{http.CanonicalHeaderKey(name)}, - }) -} - -// WithRemoveResponseHeader removes a header from responses from this edge. -// -// https://ngrok.com/docs/http/response-headers/ -func WithRemoveResponseHeader(name string) HTTPEndpointOption { - return responseHeaders(headers{ - Removed: []string{http.CanonicalHeaderKey(name)}, - }) -} diff --git a/config/http_headers_test.go b/config/http_headers_test.go deleted file mode 100644 index 28fea7a4..00000000 --- a/config/http_headers_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestHTTPHeaders(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - req := opts.RequestHeaders - resp := opts.RequestHeaders - - require.Nil(t, req) - require.Nil(t, resp) - }, - }, - { - name: "simple request", - opts: HTTPEndpoint( - WithRequestHeader("foo", "bar baz"), - WithRemoveRequestHeader("baz"), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - req := opts.RequestHeaders - resp := opts.ResponseHeaders - - require.NotNil(t, req) - require.Nil(t, resp) - - require.Equal(t, []string{"Foo:bar baz"}, req.Add) - require.Equal(t, []string{"Baz"}, req.Remove) - }, - }, - { - name: "multiple request", - opts: HTTPEndpoint( - WithRequestHeader("foo", "bar"), - WithRequestHeader("foo", "baz"), - WithRequestHeader("spam", "eggs"), - WithRemoveRequestHeader("qas"), - WithRemoveRequestHeader("wex"), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - req := opts.RequestHeaders - resp := opts.ResponseHeaders - - require.NotNil(t, req) - require.Nil(t, resp) - - require.ElementsMatch(t, []string{"Foo:bar;baz", "Spam:eggs"}, req.Add) - require.ElementsMatch(t, []string{"Qas", "Wex"}, req.Remove) - }, - }, - { - name: "simple response", - opts: HTTPEndpoint( - WithResponseHeader("foo", "bar baz"), - WithRemoveResponseHeader("baz"), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - req := opts.RequestHeaders - resp := opts.ResponseHeaders - - require.Nil(t, req) - require.NotNil(t, resp) - - require.Equal(t, []string{"Foo:bar baz"}, resp.Add) - require.Equal(t, []string{"Baz"}, resp.Remove) - }, - }, - { - name: "multiple response", - opts: HTTPEndpoint( - WithResponseHeader("foo", "bar baz"), - WithResponseHeader("spam", "eggs"), - WithRemoveResponseHeader("qas"), - WithRemoveResponseHeader("wex"), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - req := opts.RequestHeaders - resp := opts.ResponseHeaders - - require.Nil(t, req) - require.NotNil(t, resp) - - require.ElementsMatch(t, []string{"Foo:bar baz", "Spam:eggs"}, resp.Add) - require.ElementsMatch(t, []string{"Qas", "Wex"}, resp.Remove) - }, - }, - { - name: "multiple request response", - opts: HTTPEndpoint( - WithRequestHeader("foo", "bar baz"), - WithRequestHeader("spam", "eggs"), - WithRemoveRequestHeader("qas"), - WithRemoveRequestHeader("wex"), - WithResponseHeader("foo", "bar baz"), - WithResponseHeader("spam", "eggs"), - WithRemoveResponseHeader("qas"), - WithRemoveResponseHeader("wex"), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - req := opts.ResponseHeaders - resp := opts.ResponseHeaders - - require.NotNil(t, req) - require.NotNil(t, resp) - - require.ElementsMatch(t, []string{"Spam:eggs", "Foo:bar baz"}, resp.Add) - require.ElementsMatch(t, []string{"Qas", "Wex"}, resp.Remove) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/labeled.go b/config/labeled.go deleted file mode 100644 index 2f77cc7f..00000000 --- a/config/labeled.go +++ /dev/null @@ -1,95 +0,0 @@ -package config - -import ( - "net/http" - "net/url" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -type LabeledTunnelOption interface { - ApplyLabeled(cfg *labeledOptions) -} - -type labeledOptionFunc func(cfg *labeledOptions) - -func (of labeledOptionFunc) ApplyLabeled(cfg *labeledOptions) { - of(cfg) -} - -// LabeledTunnel constructs a new set options for a labeled Edge. -// -// https://ngrok.com/docs/network-edge/edges/#tunnel-group -func LabeledTunnel(opts ...LabeledTunnelOption) Tunnel { - cfg := labeledOptions{} - for _, opt := range opts { - opt.ApplyLabeled(&cfg) - } - return &cfg -} - -// Options for labeled tunnels. -type labeledOptions struct { - // Common tunnel configuration options. - commonOpts - - // A map of label, value pairs for this tunnel. - labels map[string]string - - // An HTTP Server to run traffic on - // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. - httpServer *http.Server -} - -// WithLabel adds a label to this tunnel's set of label, value pairs. -func WithLabel(label, value string) LabeledTunnelOption { - return labeledOptionFunc(func(cfg *labeledOptions) { - if cfg.labels == nil { - cfg.labels = map[string]string{} - } - - cfg.labels[label] = value - }) -} - -func (cfg labeledOptions) ForwardsProto() string { - return cfg.commonOpts.ForwardsProto -} - -func (cfg labeledOptions) ForwardsTo() string { - return cfg.commonOpts.getForwardsTo() -} - -func (cfg *labeledOptions) WithForwardsTo(url *url.URL) { - cfg.commonOpts.ForwardsTo = url.Host -} - -func (cfg labeledOptions) Extra() proto.BindExtra { - return proto.BindExtra{ - Name: cfg.Name, - Metadata: cfg.Metadata, - Description: cfg.Description, - } -} - -func (cfg labeledOptions) Proto() string { - return "" -} - -func (cfg labeledOptions) Opts() any { - return nil -} - -func (cfg labeledOptions) Labels() map[string]string { - return cfg.labels -} - -func (cfg labeledOptions) HTTPServer() *http.Server { - return cfg.httpServer -} - -// compile-time check that we're implementing the proper interfaces. -var _ interface { - tunnelConfigPrivate - Tunnel -} = (*labeledOptions)(nil) diff --git a/config/labeled_test.go b/config/labeled_test.go deleted file mode 100644 index 8fe8ff56..00000000 --- a/config/labeled_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package config - -import ( - "testing" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestLabeled(t *testing.T) { - cases := testCases[*labeledOptions, proto.LabelOptions]{ - { - name: "simple", - opts: LabeledTunnel(WithLabel("foo", "bar")), - expectLabels: ptr(map[string]*string{ - "foo": ptr("bar"), - }), - expectProto: ptr(""), - expectNilOpts: true, - }, - { - name: "multiple", - opts: LabeledTunnel( - WithLabel("foo", "bar"), - WithLabel("spam", "eggs"), - ), - expectProto: ptr(""), - expectLabels: ptr(map[string]*string{ - "foo": ptr("bar"), - "spam": ptr("eggs"), - }), - expectNilOpts: true, - }, - { - name: "withForwardsTo", - opts: LabeledTunnel(WithLabel("foo", "bar"), WithForwardsTo("localhost:8080")), - expectLabels: ptr(map[string]*string{ - "foo": ptr("bar"), - }), - expectForwardsTo: ptr("localhost:8080"), - expectProto: ptr(""), - expectNilOpts: true, - }, - { - name: "withMetadata", - opts: LabeledTunnel(WithLabel("foo", "bar"), WithMetadata("choochoo")), - expectLabels: ptr(map[string]*string{ - "foo": ptr("bar"), - }), - expectExtra: &matchBindExtra{ - Metadata: ptr("choochoo"), - }, - expectProto: ptr(""), - expectNilOpts: true, - }, - } - - cases.runAll(t) -} diff --git a/config/metadata_test.go b/config/metadata_test.go deleted file mode 100644 index f1ceb9a3..00000000 --- a/config/metadata_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package config - -import ( - "testing" -) - -func testMetadata[T tunnelConfigPrivate, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - - cases := testCases[T, any]{ - { - name: "absent", - opts: optsFunc(), - expectExtra: &matchBindExtra{ - Metadata: ptr(""), - }, - }, - { - name: "with metadata", - opts: optsFunc(WithMetadata("Hello, world!")), - expectExtra: &matchBindExtra{ - Metadata: ptr("Hello, world!"), - }, - }, - } - - cases.runAll(t) -} - -func TestMetadata(t *testing.T) { - testMetadata[*httpOptions](t, HTTPEndpoint) - testMetadata[*tlsOptions](t, TLSEndpoint) - testMetadata[*tcpOptions](t, TCPEndpoint) - testMetadata[*labeledOptions](t, LabeledTunnel) -} diff --git a/config/mutual_tls.go b/config/mutual_tls.go deleted file mode 100644 index 49c84338..00000000 --- a/config/mutual_tls.go +++ /dev/null @@ -1,42 +0,0 @@ -package config - -import ( - "crypto/x509" - "encoding/pem" - - "golang.ngrok.com/ngrok/internal/pb" -) - -type mutualTLSEndpointOption []*x509.Certificate - -// WithMutualTLSCA adds a list of [x509.Certificate]'s to use for mutual TLS -// authentication. -// These will be used to authenticate client certificates for requests at the -// ngrok edge. -// -// https://ngrok.com/docs/http/mutual-tls/ -func WithMutualTLSCA(certs ...*x509.Certificate) interface { - HTTPEndpointOption - TLSEndpointOption -} { - return mutualTLSEndpointOption(certs) -} - -func (opt mutualTLSEndpointOption) ApplyHTTP(opts *httpOptions) { - opts.MutualTLSCA = append(opts.MutualTLSCA, opt...) -} - -func (opt mutualTLSEndpointOption) ApplyTLS(opts *tlsOptions) { - opts.MutualTLSCA = append(opts.MutualTLSCA, opt...) -} - -func (cfg mutualTLSEndpointOption) toProtoConfig() *pb.MiddlewareConfiguration_MutualTLS { - if cfg == nil { - return nil - } - opts := &pb.MiddlewareConfiguration_MutualTLS{} - for _, cert := range cfg { - opts.MutualTlsCa = append(opts.MutualTlsCa, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})...) - } - return opts -} diff --git a/config/mutual_tls_test.go b/config/mutual_tls_test.go deleted file mode 100644 index b85e595c..00000000 --- a/config/mutual_tls_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package config - -import ( - "crypto/x509" - "encoding/pem" - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" - - _ "embed" -) - -//go:embed testdata/ngrok.ca.crt -var ngrokCA []byte - -func testMutualTLS[T tunnelConfigPrivate, O any, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, - getMTLS func(*O) *pb.MiddlewareConfiguration_MutualTLS, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - - certDer, _ := pem.Decode(ngrokCA) - cert, err := x509.ParseCertificate(certDer.Bytes) - if err != nil { - panic("failed to parse certificate: " + err.Error()) - } - - cases := testCases[T, O]{ - { - name: "absent", - opts: optsFunc(), - expectOpts: func(t *testing.T, opts *O) { - actual := getMTLS(opts) - require.Nil(t, actual) - }, - }, - { - name: "with mtls", - opts: optsFunc(WithMutualTLSCA(cert)), - expectOpts: func(t *testing.T, opts *O) { - actual := getMTLS(opts) - require.NotNil(t, actual) - require.Equal(t, ngrokCA, actual.MutualTlsCa) - }, - }, - } - - cases.runAll(t) -} - -func TestMutualTLS(t *testing.T) { - testMutualTLS[*httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_MutualTLS { - return opts.MutualTLSCA - }) - testMutualTLS[*tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) *pb.MiddlewareConfiguration_MutualTLS { - return opts.MutualTLSAtEdge - }) -} diff --git a/config/name.go b/config/name.go deleted file mode 100644 index 2db42991..00000000 --- a/config/name.go +++ /dev/null @@ -1,28 +0,0 @@ -package config - -type nameOption string - -func WithName(name string) interface { - HTTPEndpointOption - TCPEndpointOption - TLSEndpointOption - LabeledTunnelOption -} { - return nameOption(name) -} - -func (opt nameOption) ApplyHTTP(opts *httpOptions) { - opts.Name = string(opt) -} - -func (opt nameOption) ApplyTLS(opts *tlsOptions) { - opts.Name = string(opt) -} - -func (opt nameOption) ApplyTCP(opts *tcpOptions) { - opts.Name = string(opt) -} - -func (opt nameOption) ApplyLabeled(opts *labeledOptions) { - opts.Name = string(opt) -} diff --git a/config/oauth.go b/config/oauth.go deleted file mode 100644 index 1ed658b9..00000000 --- a/config/oauth.go +++ /dev/null @@ -1,95 +0,0 @@ -package config - -import ( - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -type OAuthOption func(cfg *oauthOptions) - -// oauthOptions configuration -type oauthOptions struct { - // The OAuth provider to use - Provider string - // Email addresses of users to authorize. - AllowEmails []string - // Email domains of users to authorize. - AllowDomains []string - // OAuth scopes to request from the provider. - Scopes []string - // OAuth custom app ID - ClientID string - // OAuth custom app secret - ClientSecret proto.ObfuscatedString -} - -// Construct a new OAuth provider with the given name. -func oauthProvider(name string) *oauthOptions { - return &oauthOptions{ - Provider: name, - } -} - -// WithOAuthClientID provides a client ID for custom OAuth apps. -func WithOAuthClientID(id string) OAuthOption { - return func(cfg *oauthOptions) { - cfg.ClientID = id - } -} - -// WithOAuthClientSecret provides a client secret for custom OAuth apps. -func WithOAuthClientSecret(secret string) OAuthOption { - return func(cfg *oauthOptions) { - cfg.ClientSecret = proto.ObfuscatedString(secret) - } -} - -// Append email addresses to the list of allowed emails. -func WithAllowOAuthEmail(addr ...string) OAuthOption { - return func(cfg *oauthOptions) { - cfg.AllowEmails = append(cfg.AllowEmails, addr...) - } -} - -// Append email domains to the list of allowed domains. -func WithAllowOAuthDomain(domain ...string) OAuthOption { - return func(cfg *oauthOptions) { - cfg.AllowDomains = append(cfg.AllowDomains, domain...) - } -} - -// Append scopes to the list of scopes to request. -func WithOAuthScope(scope ...string) OAuthOption { - return func(cfg *oauthOptions) { - cfg.Scopes = append(cfg.Scopes, scope...) - } -} - -func (oauth *oauthOptions) toProtoConfig() *pb.MiddlewareConfiguration_OAuth { - if oauth == nil { - return nil - } - - return &pb.MiddlewareConfiguration_OAuth{ - Provider: string(oauth.Provider), - ClientId: oauth.ClientID, - ClientSecret: oauth.ClientSecret.PlainText(), - AllowEmails: oauth.AllowEmails, - AllowDomains: oauth.AllowDomains, - Scopes: oauth.Scopes, - } -} - -// WithOAuth configures this edge with the the given OAuth provider. -// Overwrites any previously-set OAuth configuration. -// -// https://ngrok.com/docs/http/oauth/ -func WithOAuth(provider string, opts ...OAuthOption) HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - oauth := oauthProvider(provider) - for _, opt := range opts { - opt(oauth) - } - cfg.OAuth = oauth - }) -} diff --git a/config/oauth_test.go b/config/oauth_test.go deleted file mode 100644 index 8f181b90..00000000 --- a/config/oauth_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestOAuth(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.OAuth - require.Nil(t, actual) - }, - }, - { - name: "simple", - opts: HTTPEndpoint(WithOAuth("google")), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.OAuth - require.NotNil(t, actual) - require.Equal(t, "google", actual.Provider) - }, - }, - { - name: "with options", - opts: HTTPEndpoint( - WithOAuth("google", - WithOAuthScope("foo"), - WithOAuthScope("bar", "baz"), - WithAllowOAuthDomain("ngrok.com", "google.com"), - WithAllowOAuthDomain("github.com", "facebook.com"), - WithAllowOAuthEmail("user1@gmail.com", "user2@gmail.com"), - WithAllowOAuthEmail("user3@gmail.com"), - ), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.OAuth - require.NotNil(t, actual) - require.Equal(t, "google", actual.Provider) - require.ElementsMatch(t, []string{"foo", "bar", "baz"}, actual.Scopes) - require.ElementsMatch(t, []string{"user1@gmail.com", "user2@gmail.com", "user3@gmail.com"}, actual.AllowEmails) - require.ElementsMatch(t, []string{"ngrok.com", "google.com", "github.com", "facebook.com"}, actual.AllowDomains) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/oidc.go b/config/oidc.go deleted file mode 100644 index 9eff9d2d..00000000 --- a/config/oidc.go +++ /dev/null @@ -1,73 +0,0 @@ -package config - -import ( - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -type OIDCOption func(cfg *oidcOptions) - -type oidcOptions struct { - IssuerURL string - ClientID string - ClientSecret proto.ObfuscatedString - AllowEmails []string - AllowDomains []string - Scopes []string -} - -func (oidc *oidcOptions) toProtoConfig() *pb.MiddlewareConfiguration_OIDC { - if oidc == nil { - return nil - } - - return &pb.MiddlewareConfiguration_OIDC{ - IssuerUrl: oidc.IssuerURL, - ClientId: oidc.ClientID, - ClientSecret: oidc.ClientSecret.PlainText(), - AllowEmails: oidc.AllowEmails, - AllowDomains: oidc.AllowDomains, - Scopes: oidc.Scopes, - } -} - -// WithOIDC configures this edge with the the given OIDC provider. -// Overwrites any previously-set OIDC configuration. -// -// https://ngrok.com/docs/http/openid-connect/ -func WithOIDC(issuerURL string, clientID string, clientSecret string, opts ...OIDCOption) HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - oidc := &oidcOptions{ - IssuerURL: issuerURL, - ClientID: clientID, - ClientSecret: proto.ObfuscatedString(clientSecret), - } - - for _, opt := range opts { - opt(oidc) - } - - cfg.OIDC = oidc - }) -} - -// Append email addresses to the list of allowed emails. -func WithAllowOIDCEmail(addr ...string) OIDCOption { - return func(cfg *oidcOptions) { - cfg.AllowEmails = append(cfg.AllowEmails, addr...) - } -} - -// Append email domains to the list of allowed domains. -func WithAllowOIDCDomain(domain ...string) OIDCOption { - return func(cfg *oidcOptions) { - cfg.AllowDomains = append(cfg.AllowDomains, domain...) - } -} - -// Append scopes to the list of scopes to request. -func WithOIDCScope(scope ...string) OIDCOption { - return func(cfg *oidcOptions) { - cfg.Scopes = append(cfg.Scopes, scope...) - } -} diff --git a/config/oidc_test.go b/config/oidc_test.go deleted file mode 100644 index c796ca2e..00000000 --- a/config/oidc_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestOIDC(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.OAuth - require.Nil(t, actual) - }, - }, - { - name: "simple", - opts: HTTPEndpoint(WithOIDC("https://google.com", "foo", "bar")), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.OIDC - require.NotNil(t, actual) - require.Equal(t, "https://google.com", actual.IssuerUrl) - require.Equal(t, "foo", actual.ClientId) - require.Equal(t, "bar", actual.ClientSecret) - }, - }, - { - name: "with options", - opts: HTTPEndpoint( - WithOIDC("google", "foo", "bar", - WithOIDCScope("foo"), - WithOIDCScope("bar", "baz"), - WithAllowOIDCDomain("ngrok.com", "google.com"), - WithAllowOIDCDomain("github.com"), - WithAllowOIDCEmail("user1@gmail.com", "user2@gmail.com"), - WithAllowOIDCEmail("user3@gmail.com"), - ), - ), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.OIDC - require.NotNil(t, actual) - require.ElementsMatch(t, []string{"foo", "bar", "baz"}, actual.Scopes) - require.ElementsMatch(t, []string{"user1@gmail.com", "user2@gmail.com", "user3@gmail.com"}, actual.AllowEmails) - require.ElementsMatch(t, []string{"ngrok.com", "google.com", "github.com"}, actual.AllowDomains) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/policy.go b/config/policy.go deleted file mode 100644 index b6756b96..00000000 --- a/config/policy.go +++ /dev/null @@ -1,96 +0,0 @@ -package config - -import ( - "encoding/json" - "errors" - "fmt" - - "gopkg.in/yaml.v3" - - po "golang.ngrok.com/ngrok/policy" -) - -type policy po.Policy -type rule po.Rule -type action po.Action -type trafficPolicy string - -// WithTrafficPolicy configures this edge with the provided policy configuration -// passed as a json or yaml string and overwrites any previously-set traffic policy. -// https://ngrok.com/docs/http/traffic-policy -func WithTrafficPolicy(str string) interface { - HTTPEndpointOption - TLSEndpointOption - TCPEndpointOption -} { - if !isJsonString(str) && !isYamlStr(str) { - panic(errors.New("provided string is neither valid JSON nor valid YAML")) - } - return trafficPolicy(str) -} - -// WithPolicyString is deprecated, use WithTrafficPolicy instead. -// https://ngrok.com/docs/http/traffic-policy/ -func WithPolicyString(str string) interface { - HTTPEndpointOption - TLSEndpointOption - TCPEndpointOption -} { - return WithTrafficPolicy(str) -} - -func (p trafficPolicy) ApplyTLS(opts *tlsOptions) { - opts.TrafficPolicy = string(p) -} - -func (p trafficPolicy) ApplyHTTP(opts *httpOptions) { - opts.TrafficPolicy = string(p) -} - -func (p trafficPolicy) ApplyTCP(opts *tcpOptions) { - opts.TrafficPolicy = string(p) -} - -func isJsonString(jsonStr string) bool { - var js json.RawMessage - return json.Unmarshal([]byte(jsonStr), &js) == nil -} - -func isYamlStr(yamlStr string) bool { - var yml map[string]any - return yaml.Unmarshal([]byte(yamlStr), &yml) == nil -} - -// WithPolicy is deprecated, use WithTrafficPolicy instead. -// https://ngrok.com/docs/http/traffic-policy/ -func WithPolicy(p po.Policy) interface { - HTTPEndpointOption - TLSEndpointOption - TCPEndpointOption -} { - ret := policy(p) - - return &ret -} - -func (p *policy) ApplyTLS(opts *tlsOptions) { - opts.TrafficPolicy = policyToString(p) -} - -func (p *policy) ApplyHTTP(opts *httpOptions) { - opts.TrafficPolicy = policyToString(p) -} - -func (p *policy) ApplyTCP(opts *tcpOptions) { - opts.TrafficPolicy = policyToString(p) -} - -// policyToString converts the policy into a JSON string representation. This is to help remap Policy to TrafficPolicy. -func policyToString(p *policy) string { - val, err := json.Marshal(p) - if err != nil { - panic(errors.New(fmt.Sprintf("failed to parse action configuration due to error: %s", err.Error()))) - } - - return string(val) -} diff --git a/config/policy_test.go b/config/policy_test.go deleted file mode 100644 index 5ef3ffd1..00000000 --- a/config/policy_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" - po "golang.ngrok.com/ngrok/policy" -) - -func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, - getPolicies func(*O) string, -) { - - // putting yaml string up here as the formatting makes the test - // cases messy - yamlPolicy := `--- -inbound: - - name: DenyAll - actions: - - type: deny - config: - status_code: 446 -` - - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - - cases := testCases[T, O]{ - { - name: "absent", - opts: optsFunc(), - expectOpts: func(t *testing.T, opts *O) { - actual := getPolicies(opts) - require.Empty(t, actual) - }, - }, - { - name: "with policy", - opts: optsFunc( - WithPolicy( - po.Policy{ - Inbound: []po.Rule{ - { - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []po.Action{ - {Type: "deny"}, - }, - }, - { - Name: "logFooHeader", - Expressions: []string{"'foo' in req.Headers"}, - Actions: []po.Action{ - { - Type: "log", - Config: map[string]any{"metadata": map[string]any{"key": "val"}}, - }, - }, - }, - }, - Outbound: []po.Rule{ - { - Name: "InternalErrorWhenFailed", - Expressions: []string{ - "res.StatusCode <= '0'", - "res.StatusCode >= '300'", - }, - Actions: []po.Action{ - { - Type: "custom-response", - Config: map[string]any{"status_code": 500}, - }, - }, - }, - }, - }, - ), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getPolicies(opts) - require.NotEmpty(t, actual) - require.Equal(t, actual, "{\"inbound\":[{\"name\":\"denyPUT\",\"expressions\":[\"req.Method == 'PUT'\"],\"actions\":[{\"type\":\"deny\"}]},{\"name\":\"logFooHeader\",\"expressions\":[\"'foo' in req.Headers\"],\"actions\":[{\"type\":\"log\",\"config\":{\"metadata\":{\"key\":\"val\"}}}]}],\"outbound\":[{\"name\":\"InternalErrorWhenFailed\",\"expressions\":[\"res.StatusCode \\u003c= '0'\",\"res.StatusCode \\u003e= '300'\"],\"actions\":[{\"type\":\"custom-response\",\"config\":{\"status_code\":500}}]}]}") - }, - }, - { - name: "with valid JSON policy string", - opts: optsFunc( - WithTrafficPolicy(` - { - "inbound":[ - { - "name":"denyPut", - "expressions":["req.Method == 'PUT'"], - "actions":[{"type":"deny"}] - }, - { - "name":"logFooHeader", - "expressions":["'foo' in req.Headers"], - "actions":[ - {"type":"log","config":{"metadata":{"key":"val"}}} - ] - } - ], - "outbound":[ - { - "name":"500ForFailures", - "expressions":["res.StatusCode <= 0", "res.StatusCode >= 300"], - "actions":[{"type":"custom-response", "config":{"status_code":500}}] - } - ] - }`)), - expectOpts: func(t *testing.T, opts *O) { - actual := getPolicies(opts) - require.NotEmpty(t, actual) - require.Equal(t, actual, ` - { - "inbound":[ - { - "name":"denyPut", - "expressions":["req.Method == 'PUT'"], - "actions":[{"type":"deny"}] - }, - { - "name":"logFooHeader", - "expressions":["'foo' in req.Headers"], - "actions":[ - {"type":"log","config":{"metadata":{"key":"val"}}} - ] - } - ], - "outbound":[ - { - "name":"500ForFailures", - "expressions":["res.StatusCode <= 0", "res.StatusCode >= 300"], - "actions":[{"type":"custom-response", "config":{"status_code":500}}] - } - ] - }`) - }, - }, - { - name: "with valid YAML policy string", - opts: optsFunc( - WithTrafficPolicy(yamlPolicy)), - expectOpts: func(t *testing.T, opts *O) { - actual := getPolicies(opts) - require.NotEmpty(t, actual) - require.Equal(t, actual, yamlPolicy) - }, - }, - } - - cases.runAll(t) -} - -func TestPolicy(t *testing.T) { - testPolicy[*httpOptions](t, HTTPEndpoint, - func(h *proto.HTTPEndpoint) string { - return h.TrafficPolicy - }) - testPolicy[*tcpOptions](t, TCPEndpoint, - func(h *proto.TCPEndpoint) string { - return h.TrafficPolicy - }) - testPolicy[*tlsOptions](t, TLSEndpoint, - func(h *proto.TLSEndpoint) string { - return h.TrafficPolicy - }) -} diff --git a/config/tls_termination.go b/config/tls_termination.go deleted file mode 100644 index 6707688b..00000000 --- a/config/tls_termination.go +++ /dev/null @@ -1,85 +0,0 @@ -package config - -type TLSTerminationLocation int - -const ( - // Terminate TLS at the ngrok edge. The backend will receive a plaintext - // stream. - TLSAtEdge TLSTerminationLocation = iota - // Terminate TLS in the ngrok library. The library will receive the - // handshake and perform TLS termination, and the backend will receive the - // plaintext stream. - // TODO: export this once implemented - tlsAtLibrary -) - -type tlsTermination struct { - location TLSTerminationLocation - key []byte - cert []byte -} - -func (tt tlsTermination) ApplyTLS(cfg *tlsOptions) { - switch tt.location { - case tlsAtLibrary: - cfg.KeyPEM = nil - cfg.CertPEM = nil - // TODO: implement this in the tunnel `Accept` call. - panic("automatic tls termination in-app is not yet supported") - case TLSAtEdge: - cfg.terminateAtEdge = true - cfg.KeyPEM = tt.key - cfg.CertPEM = tt.cert - return - } -} - -type TLSTerminationOption func(tt *tlsTermination) - -// WithTLSTermination arranges for incoming TLS connections to be automatically terminated. -// The backend will then receive plaintext streams, rather than raw TLS connections. -// Defaults to terminating TLS at the ngrok edge with an automatically-provisioned keypair. -// -// https://ngrok.com/docs/tls/tls-termination/ -func WithTLSTermination(opts ...TLSTerminationOption) TLSEndpointOption { - tt := tlsTermination{ - location: TLSAtEdge, - key: []byte{}, - cert: []byte{}, - } - for _, opt := range opts { - opt(&tt) - } - return tt -} - -// WithTermination sets the key and certificate in PEM format for TLS termination at the ngrok -// edge. -// -// Deprecated: Use WithCustomEdgeTermination instead. -func WithTermination(certPEM, keyPEM []byte) TLSEndpointOption { - return tlsOptionFunc(func(cfg *tlsOptions) { - cfg.terminateAtEdge = true - cfg.CertPEM = certPEM - cfg.KeyPEM = keyPEM - }) -} - -// WithTLSTerminationAt determines where TLS termination should occur. -// Currently, only `TLSAtEdge` is supported. -func WithTLSTerminationAt(location TLSTerminationLocation) TLSTerminationOption { - return TLSTerminationOption(func(cfg *tlsTermination) { - cfg.location = location - }) -} - -// WithTLSTerminationKeyPair sets a custom key and certificate in PEM format for -// TLS termination. -// If terminating at the ngrok edge, this uploads the private key and -// certificate to the ngrok servers. -func WithTLSTerminationKeyPair(certPEM, keyPEM []byte) TLSTerminationOption { - return TLSTerminationOption(func(cfg *tlsTermination) { - cfg.cert = certPEM - cfg.key = keyPEM - }) -} diff --git a/config/tls_termination_test.go b/config/tls_termination_test.go deleted file mode 100644 index 6cbb510c..00000000 --- a/config/tls_termination_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestTLSTermination(t *testing.T) { - cases := testCases[*tlsOptions, proto.TLSEndpoint]{ - { - name: "absent", - opts: TLSEndpoint(), - expectOpts: func(t *testing.T, opts *proto.TLSEndpoint) { - require.Nil(t, opts.TLSTermination) - }, - }, - { - name: "with termination", - opts: TLSEndpoint(WithTermination([]byte("cert"), []byte("key"))), - expectOpts: func(t *testing.T, opts *proto.TLSEndpoint) { - actual := opts.TLSTermination - require.NotNil(t, actual) - require.Equal(t, []byte("cert"), actual.Cert) - require.Equal(t, []byte("key"), actual.Key) - }, - }, - { - name: "with new termination", - opts: TLSEndpoint(WithTLSTermination()), - expectOpts: func(t *testing.T, opts *proto.TLSEndpoint) { - actual := opts.TLSTermination - require.NotNil(t, actual) - require.Equal(t, []byte{}, actual.Cert) - require.Equal(t, []byte{}, actual.Key) - }, - }, - { - name: "with new nil termination", - opts: TLSEndpoint(WithTLSTermination(WithTLSTerminationKeyPair(nil, nil))), - expectOpts: func(t *testing.T, opts *proto.TLSEndpoint) { - actual := opts.TLSTermination - require.NotNil(t, actual) - require.Equal(t, []byte(nil), actual.Cert) - require.Equal(t, []byte(nil), actual.Key) - }, - }, - { - name: "with new custom termination", - opts: TLSEndpoint(WithTLSTermination(WithTLSTerminationKeyPair([]byte("cert"), []byte("key")))), - expectOpts: func(t *testing.T, opts *proto.TLSEndpoint) { - actual := opts.TLSTermination - require.NotNil(t, actual) - require.Equal(t, []byte("cert"), actual.Cert) - require.Equal(t, []byte("key"), actual.Key) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/user_agent_filter.go b/config/user_agent_filter.go deleted file mode 100644 index a3253683..00000000 --- a/config/user_agent_filter.go +++ /dev/null @@ -1,88 +0,0 @@ -package config - -import ( - "golang.ngrok.com/ngrok/internal/pb" -) - -// UserAgentFilter is a pair of strings slices that allow/deny traffic to an endpoint -type userAgentFilter struct { - // slice of regex strings for allowed user agents - Allow []string - // slice of regex strings for denied user agents - Deny []string -} - -// WithAllowUserAgentFilter is a deprecated alias for [WithAllowUserAgent]. -// -// Deprecated: use [WithAllowUserAgent] instead. -func WithAllowUserAgentFilter(allow ...string) HTTPEndpointOption { - return WithAllowUserAgent(allow...) -} - -// WithDenyUserAgentFilter is a deprecated alias for [WithDenyUserAgent]. -// -// Deprecated: use [WithDenyUserAgent] instead. -func WithDenyUserAgentFilter(allow ...string) HTTPEndpointOption { - return WithDenyUserAgent(allow...) -} - -// WithAllowUserAgent adds user agent filtering to the endpoint. -// -// The allow argument is a regular expressions for the user-agent -// header to allow -// -// Any invalid regular expression will result in an error when creating the tunnel. -// -// https://ngrok.com/docs/http/user-agent-filter/ -// ERR_NGROK_2090 for invalid allow/deny on connect. -// ERR_NGROK_3211 The server does not authorize requests from your user-agent -// ERR_NGROK_9022 Your account is not authorized to use user agent filtering. -func WithAllowUserAgent(allow ...string) HTTPEndpointOption { - return &userAgentFilter{ - // slice of regex strings for allowed user agents - Allow: allow, - } -} - -// WithDenyUserAgent adds user agent filtering to the endpoint. -// -// The deny argument is a regular expressions to -// deny, with allows taking precedence over denies. -// -// Any invalid regular expression will result in an error when creating the tunnel. -// -// https://ngrok.com/docs/http/user-agent-filter/ -// ERR_NGROK_2090 for invalid allow/deny on connect. -// ERR_NGROK_3211 The server does not authorize requests from your user-agent -// ERR_NGROK_9022 Your account is not authorized to use user agent filtering. -func WithDenyUserAgent(deny ...string) HTTPEndpointOption { - return &userAgentFilter{ - // slice of regex strings for denied user agents - Deny: deny, - } -} - -func (b *userAgentFilter) toProtoConfig() *pb.MiddlewareConfiguration_UserAgentFilter { - if b == nil { - return nil - } - return &pb.MiddlewareConfiguration_UserAgentFilter{ - Allow: b.Allow, - Deny: b.Deny, - } -} - -func (b *userAgentFilter) merge(set userAgentFilter) *userAgentFilter { - if b == nil { - b = &userAgentFilter{} - } - - b.Allow = append(b.Allow, set.Allow...) - b.Deny = append(b.Deny, set.Deny...) - - return b -} - -func (b userAgentFilter) ApplyHTTP(opts *httpOptions) { - opts.UserAgentFilter = opts.UserAgentFilter.merge(b) -} diff --git a/config/user_agent_filter_test.go b/config/user_agent_filter_test.go deleted file mode 100644 index 9976ccfe..00000000 --- a/config/user_agent_filter_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func testUserAgentFilter[T tunnelConfigPrivate, O any, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, - getUserAgentFilter func(*O) *pb.MiddlewareConfiguration_UserAgentFilter, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - cases := testCases[T, O]{ - { - name: "test empty", - opts: optsFunc(), - expectOpts: func(t *testing.T, opts *O) { - actual := getUserAgentFilter(opts) - require.Nil(t, actual) - }, - }, - { - name: "test allow", - opts: optsFunc( - WithAllowUserAgent(`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getUserAgentFilter(opts) - require.NotNil(t, actual) - require.Nil(t, actual.Deny) - require.Empty(t, actual.Deny) - require.Equal(t, []string{`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`}, actual.Allow) - }, - }, - { - name: "test deny", - opts: optsFunc( - WithDenyUserAgent(`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getUserAgentFilter(opts) - require.NotNil(t, actual) - require.Nil(t, actual.Allow) - require.Equal(t, []string{`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`}, actual.Deny) - }, - }, - { - name: "test allow and deny", - opts: optsFunc( - WithAllowUserAgent(`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`), - WithDenyUserAgent(`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getUserAgentFilter(opts) - require.NotNil(t, actual) - require.Equal(t, []string{`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`}, actual.Allow) - require.Equal(t, []string{`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`}, actual.Deny) - }, - }, - { - name: "test multiple", - opts: optsFunc( - WithAllowUserAgent(`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`), - WithDenyUserAgent(`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`), - WithAllowUserAgent(`(Pingdom2\.com_bot_version_)(\d+)\.(\d+)`), - WithDenyUserAgent(`(Pingdom2\.com_bot_version_)(\d+)\.(\d+)`), - ), - expectOpts: func(t *testing.T, opts *O) { - actual := getUserAgentFilter(opts) - require.NotNil(t, actual) - require.Equal(t, []string{`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`, `(Pingdom2\.com_bot_version_)(\d+)\.(\d+)`}, actual.Allow) - require.Equal(t, []string{`(Pingdom\.com_bot_version_)(\d+)\.(\d+)`, `(Pingdom2\.com_bot_version_)(\d+)\.(\d+)`}, actual.Deny) - }, - }, - } - - cases.runAll(t) -} - -func TestUserAgentFilter(t *testing.T) { - testUserAgentFilter[*httpOptions](t, HTTPEndpoint, - func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_UserAgentFilter { - return h.UserAgentFilter - }) -} diff --git a/config/webhook_verification.go b/config/webhook_verification.go deleted file mode 100644 index 2a4f166a..00000000 --- a/config/webhook_verification.go +++ /dev/null @@ -1,36 +0,0 @@ -package config - -import ( - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -// Configuration for webhook verification. -type webhookVerification struct { - // The webhook provider - Provider string - // The secret for verifying webhooks from this provider. - Secret proto.ObfuscatedString -} - -func (wv *webhookVerification) toProtoConfig() *pb.MiddlewareConfiguration_WebhookVerification { - if wv == nil { - return nil - } - return &pb.MiddlewareConfiguration_WebhookVerification{ - Provider: wv.Provider, - Secret: wv.Secret.PlainText(), - } -} - -// WithWebhookVerification configures webhook verification for this edge. -// -// https://ngrok.com/docs/http/webhook-verification/ -func WithWebhookVerification(provider string, secret string) HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - cfg.WebhookVerification = &webhookVerification{ - Provider: provider, - Secret: proto.ObfuscatedString(secret), - } - }) -} diff --git a/config/webhook_verification_test.go b/config/webhook_verification_test.go deleted file mode 100644 index de8ce441..00000000 --- a/config/webhook_verification_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestWebhookVerification(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.WebhookVerification - require.Nil(t, actual) - }, - }, - { - name: "single", - opts: HTTPEndpoint(WithWebhookVerification("google", "domoarigato")), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - actual := opts.WebhookVerification - require.NotNil(t, actual) - require.Equal(t, "google", actual.Provider) - require.Equal(t, "domoarigato", actual.Secret) - }, - }, - } - - cases.runAll(t) -} diff --git a/config/websocket_tcp_conversion.go b/config/websocket_tcp_conversion.go deleted file mode 100644 index 2e764c36..00000000 --- a/config/websocket_tcp_conversion.go +++ /dev/null @@ -1,10 +0,0 @@ -package config - -// WithWebsocketTCPConversion enables the websocket-to-tcp converter. -// -// https://ngrok.com/docs/http/websocket-tcp-converter/ -func WithWebsocketTCPConversion() HTTPEndpointOption { - return httpOptionFunc(func(cfg *httpOptions) { - cfg.WebsocketTCPConversion = true - }) -} diff --git a/config/websocket_tcp_conversion_test.go b/config/websocket_tcp_conversion_test.go deleted file mode 100644 index 4c1092b1..00000000 --- a/config/websocket_tcp_conversion_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "golang.ngrok.com/ngrok/internal/tunnel/proto" -) - -func TestWebsocketTCPConversion(t *testing.T) { - cases := testCases[*httpOptions, proto.HTTPEndpoint]{ - { - name: "absent", - opts: HTTPEndpoint(), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.Nil(t, opts.WebsocketTCPConverter) - }, - }, - { - name: "converted", - opts: HTTPEndpoint(WithWebsocketTCPConversion()), - expectOpts: func(t *testing.T, opts *proto.HTTPEndpoint) { - require.NotNil(t, opts.WebsocketTCPConverter) - }, - }, - } - - cases.runAll(t) -} diff --git a/defaults.go b/defaults.go new file mode 100644 index 00000000..8045b1a1 --- /dev/null +++ b/defaults.go @@ -0,0 +1,21 @@ +package ngrok + +import ( + "context" + "os" +) + +// A default Agent instance to use when you don't need a custom one. +var DefaultAgent, _ = NewAgent( + WithAuthtoken(os.Getenv("NGROK_AUTHTOKEN")), +) + +// Listen is equivalent to DefaultAgent.Listen(). +func Listen(ctx context.Context, opts ...EndpointOption) (EndpointListener, error) { + return DefaultAgent.Listen(ctx, opts...) +} + +// Forward is sugar for DefaultAgent.Forward(). +func Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (EndpointForwarder, error) { + return DefaultAgent.Forward(ctx, upstream, opts...) +} diff --git a/endpoint.go b/endpoint.go new file mode 100644 index 00000000..f3d9aa6e --- /dev/null +++ b/endpoint.go @@ -0,0 +1,118 @@ +package ngrok + +import ( + "context" + "crypto/tls" + "net/url" + "sync" +) + +// Endpoint is the interface implemented by both EndpointListener and +// EndpointForwarder. +type Endpoint interface { + // Agent returns the Agent that created this Endpoint. + Agent() Agent + + // PoolingEnabled returns whether the endpoint supports pooling set by WithPoolingEnabled. + PoolingEnabled() bool + + // Bindings returns the endpoint's bindings set by WithBindings + Bindings() []string + + // Close() is equivalent to for CloseWithContext(context.Background()) + Close() error + + // CloseWithContext closes the endpoint with the provided context. + CloseWithContext(context.Context) error + + // Description returns the endpoint's human-readable description set by WithDescription. + Description() string + + // Done returns a channel that is closed when the endpoint stops. + Done() <-chan struct{} + + // ID returns the unique endpoint identifier assigned by the ngrok cloud service. + ID() string + + // Metadata returns the endpoint's opaque user-defined metadata set by WithMetadata. + Metadata() string + + // Protocol is sugar for URL().Scheme + Protocol() string + + // AgentTLSTermination returns the TLS config that the agent uses to terminate TLS connections. + AgentTLSTermination() *tls.Config + + // TrafficPolicy returns the traffic policy for the endpoint. + TrafficPolicy() string + + // URL returns the Endpoint's URL + URL() *url.URL +} + +// baseEndpoint implements the common functionality for both EndpointListener and +// EndpointForwarder. +type baseEndpoint struct { + agent Agent + poolingEnabled bool + bindings []string + description string + id string + metadata string + agentTLSConfig *tls.Config // TLS config for termination + trafficPolicy string + endpointURL url.URL + doneChannel chan struct{} + doneOnce *sync.Once +} + +func (e *baseEndpoint) Agent() Agent { + return e.agent +} + +func (e *baseEndpoint) PoolingEnabled() bool { + return e.poolingEnabled +} + +func (e *baseEndpoint) Bindings() []string { + return e.bindings +} + +func (e *baseEndpoint) Description() string { + return e.description +} + +func (e *baseEndpoint) Done() <-chan struct{} { + return e.doneChannel +} + +func (e *baseEndpoint) ID() string { + return e.id +} + +func (e *baseEndpoint) Metadata() string { + return e.metadata +} + +func (e *baseEndpoint) Protocol() string { + return e.endpointURL.Scheme +} + +func (e *baseEndpoint) AgentTLSTermination() *tls.Config { + return e.agentTLSConfig +} + +func (e *baseEndpoint) TrafficPolicy() string { + return e.trafficPolicy +} + +func (e *baseEndpoint) URL() *url.URL { + return &e.endpointURL +} + +// signalDone safely closes the done channel using sync.Once +func (e *baseEndpoint) signalDone() { + e.doneOnce.Do(func() { + close(e.doneChannel) + }) +} diff --git a/endpoint_options.go b/endpoint_options.go new file mode 100644 index 00000000..1449136f --- /dev/null +++ b/endpoint_options.go @@ -0,0 +1,286 @@ +package ngrok + +import ( + "crypto/tls" + "fmt" + "net/url" + + "golang.ngrok.com/ngrok/v2/internal/legacy/config" +) + +// EndpointOption is a functional option used to configure endpoints. +type EndpointOption func(*endpointOpts) + +// endpointOpts stores configuration for endpoints. +type endpointOpts struct { + poolingEnabled bool + bindings []string + description string + metadata string + agentTLSConfig *tls.Config + trafficPolicy string + url string + upstreamProtocol string + upstreamURL string + upstreamTLSClientConfig *tls.Config + proxyProtoVersion config.ProxyProtoVersion +} + +// defaultEndpointOpts returns the default options for an endpoint. +func defaultEndpointOpts() *endpointOpts { + return &endpointOpts{} +} + +// WithPoolingEnabled controls whether the endpoint supports connection pooling. +// +// See https://ngrok.com/docs/universal-gateway/endpoint-pooling/ +func WithPoolingEnabled(pool bool) EndpointOption { + return func(opts *endpointOpts) { + opts.poolingEnabled = pool + } +} + +// WithBindings sets the endpoint's bindings. +// +// See https://ngrok.com/docs/universal-gateway/bindings/ +func WithBindings(bindings ...string) EndpointOption { + return func(opts *endpointOpts) { + opts.bindings = bindings + } +} + +// WithDescription sets a human-readable description for the endpoint. +func WithDescription(desc string) EndpointOption { + return func(opts *endpointOpts) { + opts.description = desc + } +} + +// WithMetadata sets opaque, machine-readable metadata for the endpoint. +func WithMetadata(meta string) EndpointOption { + return func(opts *endpointOpts) { + opts.metadata = meta + } +} + +// WithAgentTLSTermination sets a TLS configuration that the agent will use to +// terminate connections received on the Endpoint. +// +// See https://ngrok.com/docs/agent/agent-tls-termination/ +func WithAgentTLSTermination(config *tls.Config) EndpointOption { + return func(opts *endpointOpts) { + opts.agentTLSConfig = config + } +} + +// WithTrafficPolicy defines the Endpoint's Traffic Policy. +// +// See https://ngrok.com/docs/traffic-policy/ +func WithTrafficPolicy(policy string) EndpointOption { + return func(opts *endpointOpts) { + opts.trafficPolicy = policy + } +} + +// WithURL defines the Endpoint's URL. +func WithURL(urlSpec string) EndpointOption { + return func(opts *endpointOpts) { + opts.url = urlSpec + } +} + +// temporary while we're wrapping the legacy api. remove this after we no longer +// call it so that new schemes +type endpointURLScheme string + +const ( + httpScheme endpointURLScheme = "http" + httpsScheme endpointURLScheme = "https" + tcpScheme endpointURLScheme = "tcp" + tlsScheme endpointURLScheme = "tls" +) + +// configureEndpoint creates the appropriate tunnel configuration based on the URL +// scheme and options +func configureEndpoint(scheme endpointURLScheme, endpointOpts *endpointOpts) (config.Tunnel, error) { + switch scheme { + case httpScheme, httpsScheme: + return configureHTTPEndpoint(endpointOpts) + case tcpScheme: + return configureTCPEndpoint(endpointOpts) + case tlsScheme: + return configureTLSEndpoint(endpointOpts) + default: + return nil, fmt.Errorf("unsupported endpoint URL scheme: %s", scheme) + } +} + +// configureHTTPEndpoint configures an HTTP/HTTPS endpoint with options +func configureHTTPEndpoint(endpointOpts *endpointOpts) (config.Tunnel, error) { + configOpts := []config.HTTPEndpointOption{} + + // Set URL and scheme if specified + if endpointOpts.url != "" { + configOpts = append(configOpts, config.WithURL(endpointOpts.url)) + + // Parse the URL and always set scheme explicitly + if parsedURL, err := url.Parse(endpointOpts.url); err == nil { + // Determine scheme - default to HTTPS if not specified or is https + scheme := config.SchemeHTTPS + if parsedURL.Scheme == "http" { + scheme = config.SchemeHTTP + } + configOpts = append(configOpts, config.WithScheme(scheme)) + } + } + + // Set pooling if enabled + if endpointOpts.poolingEnabled { + configOpts = append(configOpts, config.WithPoolingEnabled(endpointOpts.poolingEnabled)) + } + + // Add bindings if specified + if len(endpointOpts.bindings) > 0 { + configOpts = append(configOpts, config.WithBindings(endpointOpts.bindings...)) + } + + // Add metadata if specified + if len(endpointOpts.metadata) > 0 { + configOpts = append(configOpts, config.WithMetadata(endpointOpts.metadata)) + } + + // Add description if specified + if len(endpointOpts.description) > 0 { + configOpts = append(configOpts, config.WithDescription(endpointOpts.description)) + } + + // Set traffic policy if specified + if len(endpointOpts.trafficPolicy) > 0 { + configOpts = append(configOpts, config.WithTrafficPolicy(endpointOpts.trafficPolicy)) + } + + // Set proxy protocol if specified + if endpointOpts.proxyProtoVersion != config.ProxyProtoNone { + configOpts = append(configOpts, config.WithProxyProto(endpointOpts.proxyProtoVersion)) + } + + // Set upstream protocol if specified (maps to AppProtocol in the SDK) + if endpointOpts.upstreamProtocol != "" { + configOpts = append(configOpts, config.WithAppProtocol(endpointOpts.upstreamProtocol)) + } + + // Note: upstreamVerifyTLSCAs is not currently supported in the legacy SDK + // We'll need to implement this in a future version + + return config.HTTPEndpoint(configOpts...), nil +} + +// configureTCPEndpoint configures a TCP endpoint with options +func configureTCPEndpoint(endpointOpts *endpointOpts) (config.Tunnel, error) { + configOpts := []config.TCPEndpointOption{} + + // Set URL if specified + if endpointOpts.url != "" { + configOpts = append(configOpts, config.WithURL(endpointOpts.url)) + } + + // Set pooling if enabled + if endpointOpts.poolingEnabled { + configOpts = append(configOpts, config.WithPoolingEnabled(endpointOpts.poolingEnabled)) + } + + // Add bindings if specified + if len(endpointOpts.bindings) > 0 { + configOpts = append(configOpts, config.WithBindings(endpointOpts.bindings...)) + } + + // Add metadata if specified + if len(endpointOpts.metadata) > 0 { + configOpts = append(configOpts, config.WithMetadata(endpointOpts.metadata)) + } + + // Add description if specified + if len(endpointOpts.description) > 0 { + configOpts = append(configOpts, config.WithDescription(endpointOpts.description)) + } + + // Set traffic policy if specified + if len(endpointOpts.trafficPolicy) > 0 { + configOpts = append(configOpts, config.WithTrafficPolicy(endpointOpts.trafficPolicy)) + } + + // Set proxy protocol if specified + if endpointOpts.proxyProtoVersion != config.ProxyProtoNone { + configOpts = append(configOpts, config.WithProxyProto(endpointOpts.proxyProtoVersion)) + } + + return config.TCPEndpoint(configOpts...), nil +} + +// configureTLSEndpoint configures a TLS endpoint with options +func configureTLSEndpoint(endpointOpts *endpointOpts) (config.Tunnel, error) { + configOpts := []config.TLSEndpointOption{} + + // Set URL if specified + if endpointOpts.url != "" { + configOpts = append(configOpts, config.WithURL(endpointOpts.url)) + } + + // Set pooling if enabled + if endpointOpts.poolingEnabled { + configOpts = append(configOpts, config.WithPoolingEnabled(endpointOpts.poolingEnabled)) + } + + // Add bindings if specified + if len(endpointOpts.bindings) > 0 { + configOpts = append(configOpts, config.WithBindings(endpointOpts.bindings...)) + } + + // Add metadata if specified + if len(endpointOpts.metadata) > 0 { + configOpts = append(configOpts, config.WithMetadata(endpointOpts.metadata)) + } + + // Add description if specified + if len(endpointOpts.description) > 0 { + configOpts = append(configOpts, config.WithDescription(endpointOpts.description)) + } + + // Set traffic policy if specified + if len(endpointOpts.trafficPolicy) > 0 { + configOpts = append(configOpts, config.WithTrafficPolicy(endpointOpts.trafficPolicy)) + } + + // Set proxy protocol if specified + if endpointOpts.proxyProtoVersion != config.ProxyProtoNone { + configOpts = append(configOpts, config.WithProxyProto(endpointOpts.proxyProtoVersion)) + } + + return config.TLSEndpoint(configOpts...), nil +} + +// determineURLScheme examines the URL to determine what scheme to use +func determineURLScheme(urlStr string) (endpointURLScheme, error) { + if urlStr == "" { + // Default to HTTPS if no URL specified + return "https", nil + } + + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", fmt.Errorf("invalid URL format: %w", err) + } + + // If no scheme is specified, default to HTTPS + if parsedURL.Scheme == "" { + return "https", nil + } + + // Validate supported schemes + switch parsedURL.Scheme { + case "http", "https", "tcp", "tls": + return endpointURLScheme(parsedURL.Scheme), nil + default: + return "", fmt.Errorf("unsupported endpoint URL scheme: %s", parsedURL.Scheme) + } +} diff --git a/errors.go b/errors.go index 066fdab9..8452fc79 100644 --- a/errors.go +++ b/errors.go @@ -1,164 +1,43 @@ package ngrok import ( - "fmt" - "net/url" - "strings" + "errors" + + "golang.ngrok.com/ngrok/v2/internal/legacy" ) -// Error is an error enriched with a specific ErrorCode. -// All ngrok error codes are documented at https://ngrok.com/docs/errors. -// -// An [Error] can be extracted from a generic error using [errors.As]. -// -// Example: -// -// var nerr ngrok.Error -// if errors.As(err, &nerr) { -// fmt.Printf("%s: %s\n", nerr.ErrorCode(), nerr.Msg()) -// } +// Error is a custom error type that includes a unique ngrok error code. +// All errors that are returned by the ngrok cloud service are of this type. type Error interface { error - // Msg returns the error string without the error code. - Msg() string - // ErrorCode returns the ngrok error code, if one exists. - ErrorCode() string -} - -// Errors arising from authentication failure. -type errAuthFailed struct { - // Whether the error was generated by the remote server, or in the sending - // of the authentication request. - Remote bool - // The underlying error. - Inner error -} - -func (e errAuthFailed) Error() string { - var msg string - if e.Remote { - msg = "authentication failed" - } else { - msg = "failed to send authentication request" - } - - return fmt.Sprintf("%s: %v", msg, e.Inner) -} - -func (e errAuthFailed) Unwrap() error { - return e.Inner -} - -func (e errAuthFailed) Is(target error) bool { - _, ok := target.(errAuthFailed) - return ok -} - -// The error returned by [Tunnel]'s [net.Listener.Accept] method. -type errAcceptFailed struct { - // The underlying error. - Inner error -} - -func (e errAcceptFailed) Error() string { - return fmt.Sprintf("failed to accept connection: %v", e.Inner) -} - -func (e errAcceptFailed) Unwrap() error { - return e.Inner -} - -func (e errAcceptFailed) Is(target error) bool { - _, ok := target.(errAcceptFailed) - return ok -} - -// Errors arising from a failure to start a tunnel. -type errListen struct { - // The underlying error. - Inner error -} - -func (e errListen) Error() string { - return fmt.Sprintf("failed to start tunnel: %v", e.Inner) -} - -func (e errListen) Unwrap() error { - return e.Inner -} - -func (e errListen) Is(target error) bool { - _, ok := target.(errListen) - return ok -} - -// Errors arising from a failure to construct a [golang.org/x/net/proxy.Dialer] from a [url.URL]. -type errProxyInit struct { - // The provided proxy URL. - URL *url.URL - // The underlying error. - Inner error -} - -func (e errProxyInit) Error() string { - return fmt.Sprintf("failed to construct proxy dialer from \"%s\": %v", e.URL.String(), e.Inner) -} - -func (e errProxyInit) Unwrap() error { - return e.Inner + // The unique ngrok error code + Code() string } -func (e errProxyInit) Is(target error) bool { - _, ok := target.(errProxyInit) - return ok +// errorAdapter implements our Error interface by wrapping ngrok.Error +type errorAdapter struct { + ngrokErr legacy.Error } -// Error arising from a failure to dial the ngrok server. -type errSessionDial struct { - // The address to which a connection was attempted. - Addr string - // The underlying error. - Inner error +func (e *errorAdapter) Code() string { + return e.ngrokErr.ErrorCode() } -func (e errSessionDial) Error() string { - return fmt.Sprintf("failed to dial ngrok server with address \"%s\": %v", e.Addr, e.Inner) +func (e *errorAdapter) Error() string { + return e.ngrokErr.Error() } -func (e errSessionDial) Unwrap() error { - return e.Inner -} - -func (e errSessionDial) Is(target error) bool { - _, ok := target.(errSessionDial) - return ok -} - -// Generic ngrok error that requires no parsing -type ngrokError struct { - Message string - ErrCode string -} - -const errUrl = "https://ngrok.com/docs/errors" +// wrapError returns the original error or wraps it if it's a ngrok.Error -func (m ngrokError) Error() string { - out := m.Message - if m.ErrCode != "" { - out = fmt.Sprintf("%s\n\n%s/%s", out, errUrl, strings.ToLower(m.ErrCode)) +func wrapError(err error) error { + if err == nil { + return nil } - return out -} - -func (m ngrokError) Msg() string { - return m.Message -} -func (m ngrokError) ErrorCode() string { - return m.ErrCode -} + var ngrokErr legacy.Error + if errors.As(err, &ngrokErr) { + return &errorAdapter{ngrokErr: ngrokErr} + } -func (e ngrokError) Is(target error) bool { - _, ok := target.(ngrokError) - return ok + return err } diff --git a/events.go b/events.go new file mode 100644 index 00000000..8910552b --- /dev/null +++ b/events.go @@ -0,0 +1,102 @@ +package ngrok + +import "time" + +// EventType represents the type of event that occurred +type EventType int + +const ( + EventTypeAgentConnectSucceeded EventType = iota + EventTypeAgentDisconnected + EventTypeAgentHeartbeatReceived +) + +func (t EventType) String() string { + return [...]string{ + "AgentConnectSucceeded", + "AgentDisconnected", + "AgentHeartbeatReceived", + }[t] +} + +// Event is the interface implemented by all event types +type Event interface { + EventType() EventType + Timestamp() time.Time +} + +// baseEvent provides common functionality for all events +type baseEvent struct { + Type EventType + OccurredAt time.Time +} + +func (e baseEvent) EventType() EventType { return e.Type } +func (e baseEvent) Timestamp() time.Time { return e.OccurredAt } + +// EventHandler is the function type for event callbacks. EventHandlers must not +// block. If you would like to run operations on an event that will block or +// fail, instead write your handler to either non-blockingly push the Event onto +// a channel or spin up a goroutine. +type EventHandler func(Event) + +// EventAgentConnectSucceeded is emitted when an agent successfully establishes a connection +type EventAgentConnectSucceeded struct { + baseEvent + Agent Agent + Session AgentSession +} + +// EventAgentDisconnected is emitted when an agent disconnects +type EventAgentDisconnected struct { + baseEvent + Agent Agent + Session AgentSession + Error error +} + +// EventAgentHeartbeatReceived is emitted when a heartbeat is successful +type EventAgentHeartbeatReceived struct { + baseEvent + Agent Agent + Session AgentSession + Latency time.Duration +} + +// newAgentConnectSucceeded creates a new EventAgentConnectSucceeded event +func newAgentConnectSucceeded(agent Agent, session AgentSession) *EventAgentConnectSucceeded { + return &EventAgentConnectSucceeded{ + baseEvent: baseEvent{ + Type: EventTypeAgentConnectSucceeded, + OccurredAt: time.Now(), + }, + Agent: agent, + Session: session, + } +} + +// newAgentDisconnected creates a new EventAgentDisconnected event +func newAgentDisconnected(agent Agent, session AgentSession, err error) *EventAgentDisconnected { + return &EventAgentDisconnected{ + baseEvent: baseEvent{ + Type: EventTypeAgentDisconnected, + OccurredAt: time.Now(), + }, + Agent: agent, + Session: session, + Error: err, + } +} + +// newAgentHeartbeatReceived creates a new EventAgentHeartbeatReceived event +func newAgentHeartbeatReceived(agent Agent, session AgentSession, latency time.Duration) *EventAgentHeartbeatReceived { + return &EventAgentHeartbeatReceived{ + baseEvent: baseEvent{ + Type: EventTypeAgentHeartbeatReceived, + OccurredAt: time.Now(), + }, + Agent: agent, + Session: session, + Latency: latency, + } +} diff --git a/events_test.go b/events_test.go new file mode 100644 index 00000000..e16a861f --- /dev/null +++ b/events_test.go @@ -0,0 +1,64 @@ +package ngrok + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestEventTypeString(t *testing.T) { + tests := []struct { + eventType EventType + expected string + }{ + {EventTypeAgentConnectSucceeded, "AgentConnectSucceeded"}, + {EventTypeAgentDisconnected, "AgentDisconnected"}, + {EventTypeAgentHeartbeatReceived, "AgentHeartbeatReceived"}, + } + + for _, test := range tests { + assert.Equal(t, test.expected, test.eventType.String()) + } +} + +func TestBaseEvent(t *testing.T) { + now := time.Now() + be := baseEvent{ + Type: EventTypeAgentConnectSucceeded, + OccurredAt: now, + } + + assert.Equal(t, EventTypeAgentConnectSucceeded, be.EventType()) + assert.Equal(t, now, be.Timestamp()) +} + +func TestEventCreation(t *testing.T) { + // Create a mock agent and session for testing + agent := &agent{} + session := &agentSession{} + + // Test EventAgentConnectSucceeded creation + connectEvent := newAgentConnectSucceeded(agent, session) + assert.Equal(t, EventTypeAgentConnectSucceeded, connectEvent.EventType()) + assert.NotZero(t, connectEvent.Timestamp()) + assert.Equal(t, agent, connectEvent.Agent) + assert.Equal(t, session, connectEvent.Session) + + // Test EventAgentDisconnected creation + expectedErr := assert.AnError + disconnectEvent := newAgentDisconnected(agent, session, expectedErr) + assert.Equal(t, EventTypeAgentDisconnected, disconnectEvent.EventType()) + assert.NotZero(t, disconnectEvent.Timestamp()) + assert.Equal(t, agent, disconnectEvent.Agent) + assert.Equal(t, session, disconnectEvent.Session) + assert.Equal(t, expectedErr, disconnectEvent.Error) + + // Test EventAgentHeartbeatReceived creation + heartbeatEvent := newAgentHeartbeatReceived(agent, session, 100*time.Millisecond) + assert.Equal(t, EventTypeAgentHeartbeatReceived, heartbeatEvent.EventType()) + assert.NotZero(t, heartbeatEvent.Timestamp()) + assert.Equal(t, agent, heartbeatEvent.Agent) + assert.Equal(t, session, heartbeatEvent.Session) + assert.Equal(t, 100*time.Millisecond, heartbeatEvent.Latency) +} diff --git a/examples/fasthttp/main.go b/examples/fasthttp/main.go index 7df92a41..73d22db2 100644 --- a/examples/fasthttp/main.go +++ b/examples/fasthttp/main.go @@ -7,8 +7,7 @@ import ( "github.com/valyala/fasthttp" - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" + "golang.ngrok.com/ngrok/v2" ) func main() { @@ -18,14 +17,11 @@ func main() { } func run(ctx context.Context) error { - tun, err := ngrok.Listen(ctx, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ) + ln, err := ngrok.Listen(ctx) if err != nil { return err } - log.Println("tunnel created:", tun.URL()) + log.Println("endpoint online", ln.URL()) var serv fasthttp.Server @@ -33,7 +29,7 @@ func run(ctx context.Context) error { fmt.Fprintf(ctx, "Hello! You're requesting %q", ctx.RequestURI()) } - err = serv.Serve(tun) + err = serv.Serve(ln) if err != nil { return err } diff --git a/examples/forward/main.go b/examples/forward/main.go new file mode 100644 index 00000000..0c2877a3 --- /dev/null +++ b/examples/forward/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + "golang.ngrok.com/ngrok/v2" +) + +// how to invoke this example: +// ./forward https://mydomain.ngrok.app http://localhost:8080 +func main() { + if err := run(context.Background(), os.Args[1], os.Args[2]); err != nil { + log.Fatal(err) + } +} + +func run(ctx context.Context, from string, to string) error { + // Forward using the agent's Forward method + fwd, err := ngrok.Forward(ctx, + ngrok.WithUpstream(to), + ngrok.WithURL(from), + ) + if err != nil { + return err + } + + fmt.Println("endpoint online: forwarding from", fwd.URL(), "to", to) + + // forwarding lasts indefinitely unless you explicitly stop it so this will + // never return unless there's an unrecoverable error like running out of + // memory or file descriptors + // Wait for the forwarding to complete + <-fwd.Done() + return nil +} diff --git a/examples/go.mod b/examples/go.mod index a38b223f..01be45ec 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -4,8 +4,8 @@ go 1.21 require ( github.com/valyala/fasthttp v1.56.0 - golang.ngrok.com/ngrok v1.11.0 - golang.ngrok.com/ngrok/log/slog v0.0.0-20241014162652-57e91a614efd + golang.ngrok.com/ngrok/v2 v2.0.0 + golang.ngrok.com/ngrok/v2/log/slog v0.0.0-20241014162652-57e91a614efd ) require ( @@ -22,16 +22,14 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.ngrok.com/muxado/v2 v2.0.1 // indirect golang.org/x/net v0.30.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect google.golang.org/protobuf v1.35.1 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) replace ( - golang.ngrok.com/ngrok => ../ - golang.ngrok.com/ngrok/log/log15 => ../log/log15 - golang.ngrok.com/ngrok/log/slog => ../log/slog + golang.ngrok.com/ngrok/v2 => ../ + golang.ngrok.com/ngrok/v2/log/log15 => ../log/log15 + golang.ngrok.com/ngrok/v2/log/slog => ../log/slog + golang.ngrok.com/ngrok/v2/rpc => ../rpc ) diff --git a/examples/go.sum b/examples/go.sum index 8a720eaf..05b309df 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -39,8 +39,6 @@ golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= @@ -49,9 +47,5 @@ golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/http-full/main.go b/examples/http-full/main.go deleted file mode 100644 index 2834ddba..00000000 --- a/examples/http-full/main.go +++ /dev/null @@ -1,92 +0,0 @@ -package main - -// This example demonstrates how to create a secure HTTPS connection with all -// available configuration options illustrated. - -import ( - "context" - "fmt" - "log" - "net/http" - "time" - - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" -) - -func main() { - if err := run(context.Background()); err != nil { - log.Fatal(err) - } -} - -func run(ctx context.Context) error { - secureConnection, err := ngrok.Listen(ctx, - // secure connection configuration - config.HTTPEndpoint( - config.WithAllowCIDRString("0.0.0.0/0"), - config.WithAllowUserAgent("Mozilla/5.0.*"), - // config.WithBasicAuth("ngrok", "online1line"), - config.WithCircuitBreaker(0.5), - config.WithCompression(), - config.WithDenyCIDRString("10.1.1.1/32"), - config.WithDenyUserAgent("EvilCorp.*"), - // config.WithDomain(".ngrok.io"), - config.WithMetadata("example secure connection metadata from golang"), - // config.WithMutualTLSCA(), - // config.WithOAuth("google", - // config.WithAllowOAuthEmail("@"), - // config.WithAllowOAuthDomain(""), - // config.WithOAuthScope(""), - // ), - // config.WithOIDC("", "", "", - // config.WithAllowOIDCEmail("@"), - // config.WithAllowOIDCDomain(""), - // config.WithOIDCScope(""), - // ), - config.WithProxyProto(config.ProxyProtoNone), - config.WithRemoveRequestHeader("X-Req-Nope"), - config.WithRemoveResponseHeader("X-Res-Nope"), - config.WithRequestHeader("X-Req-Yup", "true"), - config.WithResponseHeader("X-Res-Yup", "true"), - config.WithScheme(config.SchemeHTTPS), - // config.WithWebsocketTCPConversion(), - // config.WithWebhookVerification("twilio", "asdf"), - ), - - // session configuration - // ngrok.WithAuthtoken(""), - ngrok.WithAuthtokenFromEnv(), - ngrok.WithClientInfo("go-example-full", "0.0.1"), - ngrok.WithDisconnectHandler(func(ctx context.Context, sess ngrok.Session, err error) { - log.Println("session disconnect:", sess, "error:", err) - }), - ngrok.WithHeartbeatHandler(func(ctx context.Context, sess ngrok.Session, latency time.Duration) { - log.Println("session heartbeat:", sess, "latency:", latency) - }), - ngrok.WithMetadata("go-example-full"), - ngrok.WithRestartHandler(func(ctx context.Context, sess ngrok.Session) error { - log.Println("session restart:", sess) - return nil - }), - ngrok.WithStopHandler(func(ctx context.Context, sess ngrok.Session) error { - log.Println("session stop:", sess) - return nil - }), - ngrok.WithUpdateHandler(func(ctx context.Context, sess ngrok.Session) error { - log.Println("session update:", sess) - return nil - }), - ) - if err != nil { - return err - } - - log.Println("secure connection created:", secureConnection.URL()) - - return http.Serve(secureConnection, http.HandlerFunc(handler)) -} - -func handler(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello from ngrok-go!\n\nThe time is now: ", time.Now().String()) -} diff --git a/examples/http/main.go b/examples/http/main.go index 4c14a289..e622928e 100644 --- a/examples/http/main.go +++ b/examples/http/main.go @@ -1,5 +1,3 @@ -// A simple HTTP service. - package main import ( @@ -8,8 +6,7 @@ import ( "log" "net/http" - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" + "golang.ngrok.com/ngrok/v2" ) func main() { @@ -19,16 +16,14 @@ func main() { } func run(ctx context.Context) error { - ln, err := ngrok.Listen(ctx, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ) + ln, err := ngrok.Listen(ctx) if err != nil { return err } - log.Println("Ingress established at:", ln.URL()) + log.Println("endpoint online", ln.URL()) + // Serve HTTP traffic on the ngrok endpoint return http.Serve(ln, http.HandlerFunc(handler)) } diff --git a/examples/logging/main.go b/examples/logging/main.go deleted file mode 100644 index d674daa1..00000000 --- a/examples/logging/main.go +++ /dev/null @@ -1,66 +0,0 @@ -// Setting up a custom logger. -// Takes the desired log level as the first CLI argument. - -package main - -import ( - "context" - "fmt" - "log" - "net/http" - "os" - - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" - ngrok_log "golang.ngrok.com/ngrok/log" -) - -func usage(bin string) { - log.Fatalf("Usage: %s ", bin) -} - -func main() { - if len(os.Args) != 2 { - usage(os.Args[0]) - } - if err := run(context.Background(), os.Args[1]); err != nil { - log.Fatal(err) - } -} - -// Simple logger that forwards to the Go standard logger. -type logger struct { - lvl ngrok_log.LogLevel -} - -func (l *logger) Log(ctx context.Context, lvl ngrok_log.LogLevel, msg string, data map[string]interface{}) { - if lvl > l.lvl { - return - } - lvlName, _ := ngrok_log.StringFromLogLevel(lvl) - log.Printf("[%s] %s %v", lvlName, msg, data) -} - -func run(ctx context.Context, lvlName string) error { - lvl, err := ngrok_log.LogLevelFromString(lvlName) - if err != nil { - return err - } - - ln, err := ngrok.Listen(ctx, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ngrok.WithLogger(&logger{lvl}), - ) - if err != nil { - return err - } - - log.Println("Ingress established at:", ln.URL()) - - return http.Serve(ln, http.HandlerFunc(handler)) -} - -func handler(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello from ngrok-go!") -} diff --git a/examples/multiple-endpoints/main.go b/examples/multiple-endpoints/main.go new file mode 100644 index 00000000..d93391c8 --- /dev/null +++ b/examples/multiple-endpoints/main.go @@ -0,0 +1,135 @@ +package main + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "log" + "net" + "net/http" + + "golang.ngrok.com/ngrok/v2" +) + +func main() { + if err := run(context.Background()); err != nil { + log.Fatal(err) + } +} + +func run(ctx context.Context) error { + // Create a single ngrok agent + agent, err := ngrok.NewAgent() + if err != nil { + return err + } + defer agent.Disconnect() + + // Start HTTPS endpoint + httpsListener, err := agent.Listen(ctx, ngrok.WithURL("https://")) + if err != nil { + return fmt.Errorf("HTTPS endpoint error: %v", err) + } + log.Println("HTTPS endpoint online:", httpsListener.URL()) + + // Start TLS endpoint + tlsListener, err := agent.Listen(ctx, ngrok.WithURL("tls://")) + if err != nil { + return fmt.Errorf("TLS endpoint error: %v", err) + } + log.Println("TLS endpoint online:", tlsListener.URL()) + + // Start TCP endpoint + tcpListener, err := agent.Listen(ctx, ngrok.WithURL("tcp://")) + if err != nil { + return fmt.Errorf("TCP endpoint error: %v", err) + } + log.Println("TCP endpoint online:", tcpListener.URL()) + + // Start HTTP server in a goroutine + go serveHTTP(httpsListener) + + // Start TLS server in a goroutine + go serveTLS(tlsListener) + + // Start TCP server in a goroutine + go serveTCP(ctx, tcpListener) + + // Display summary of all endpoints + log.Println("All endpoints are now online:") + for _, endpoint := range agent.Endpoints() { + log.Printf("- %s", endpoint.URL()) + } + + // Block forever + select {} +} + +func serveHTTP(listener net.Listener) { + log.Println("HTTP server exited:", http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello from ngrok-go HTTPS endpoint!") + }))) +} + +func serveTLS(listener net.Listener) { + config := &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // This would normally return a real certificate + // For this example, we'll let the ngrok edge handle TLS termination + return &tls.Certificate{}, nil + }, + } + + // Wrap listener with TLS + tlsListener := tls.NewListener(listener, config) + log.Println("TLS server started") + + // Accept connections + for { + conn, err := tlsListener.Accept() + if err != nil { + log.Println("TLS accept error:", err) + break + } + + log.Println("accepted TLS connection from", conn.RemoteAddr()) + + go func(c net.Conn) { + defer c.Close() + io.WriteString(c, "Hello from ngrok-go TLS endpoint!\n") + buf := make([]byte, 1024) + io.ReadFull(c, buf) + }(conn) + } +} + +func serveTCP(ctx context.Context, listener net.Listener) { + log.Println("TCP server started") + for { + conn, err := listener.Accept() + if err != nil { + log.Println("TCP accept error:", err) + break + } + + log.Println("accepted TCP connection from", conn.RemoteAddr()) + + go func(c net.Conn) { + defer c.Close() + + // Echo back to the client + _, err := fmt.Fprintln(c, "Hello from ngrok-go TCP endpoint!") + if err != nil { + log.Println("TCP write error:", err) + return + } + + // Copy data back to the client (echo server) + _, err = io.Copy(c, c) + if err != nil { + log.Println("TCP copy error:", err) + } + }(conn) + } +} diff --git a/examples/ngrok-forward-lite/main.go b/examples/ngrok-forward-lite/main.go deleted file mode 100644 index 4baa62c2..00000000 --- a/examples/ngrok-forward-lite/main.go +++ /dev/null @@ -1,88 +0,0 @@ -// Naïve ngrok agent implementation. -// Sets up a single listener and forwards it to another service. - -package main - -import ( - "context" - "fmt" - "log" - "net/url" - "os" - "strings" - - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" - ngrok_log "golang.ngrok.com/ngrok/log" -) - -func usage(bin string) { - log.Fatalf("Usage: %s ", bin) -} - -// Simple logger that forwards to the Go standard logger. -type logger struct { - lvl ngrok_log.LogLevel -} - -func (l *logger) Log(ctx context.Context, lvl ngrok_log.LogLevel, msg string, data map[string]interface{}) { - if lvl > l.lvl { - return - } - lvlName, _ := ngrok_log.StringFromLogLevel(lvl) - log.Printf("[%s] %s %v", lvlName, msg, data) -} - -var l *logger = &logger{ - lvl: ngrok_log.LogLevelDebug, -} - -func main() { - if len(os.Args) != 2 { - usage(os.Args[0]) - } - backend := os.Args[1] - if !strings.Contains(backend, "://") { - backend = fmt.Sprintf("tcp://%s", backend) - } - - backendUrl, err := url.Parse(backend) - if err != nil { - usage(os.Args[0]) - } - - if err := run(context.Background(), backendUrl); err != nil { - log.Fatal(err) - } -} - -func run(ctx context.Context, backend *url.URL) error { - sess, err := ngrok.Connect(ctx, - ngrok.WithAuthtokenFromEnv(), - ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}), - ) - if err != nil { - return err - } - - for { - fwd, err := sess.ListenAndForward(ctx, - backend, - config.HTTPEndpoint(), - ) - if err != nil { - return err - } - - l.Log(ctx, ngrok_log.LogLevelInfo, "ingress established", map[string]any{ - "url": fwd.URL(), - }) - - err = fwd.Wait() - if err == nil { - return nil - } - l.Log(ctx, ngrok_log.LogLevelWarn, "accept error. now setting up a new forwarder.", - map[string]any{"err": err}) - } -} diff --git a/examples/ngrok-http-lite/main.go b/examples/ngrok-http-lite/main.go deleted file mode 100644 index 20060ba2..00000000 --- a/examples/ngrok-http-lite/main.go +++ /dev/null @@ -1,67 +0,0 @@ -// Naïve ngrok agent implementation. -// Sets up a single listener and connects to an arbitrary HTTP server. - -package main - -import ( - "context" - "fmt" - "log" - "net/http" - "time" - - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" - ngrok_log "golang.ngrok.com/ngrok/log" -) - -// Simple logger that forwards to the Go standard logger. -type logger struct { - lvl ngrok_log.LogLevel -} - -func (l *logger) Log(ctx context.Context, lvl ngrok_log.LogLevel, msg string, data map[string]interface{}) { - if lvl > l.lvl { - return - } - lvlName, _ := ngrok_log.StringFromLogLevel(lvl) - log.Printf("[%s] %s %v", lvlName, msg, data) -} - -var l *logger = &logger{ - lvl: ngrok_log.LogLevelDebug, -} - -func main() { - // Spin up a simple HTTP server - server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello from ngrok-go!") - })} - - // Serve with listener backend - if err := run(context.Background(), server); err != nil { - log.Fatal(err) - } - - // Sleep main thread - for { - time.Sleep(5 * time.Second) - } -} - -func run(ctx context.Context, server *http.Server) error { - ln, err := ngrok.ListenAndServeHTTP(ctx, - server, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}), - ) - - if err == nil { - l.Log(ctx, ngrok_log.LogLevelInfo, "ingress established", map[string]any{ - "url": ln.URL(), - }) - } - - return err -} diff --git a/examples/rpc/main.go b/examples/rpc/main.go new file mode 100644 index 00000000..f11e5005 --- /dev/null +++ b/examples/rpc/main.go @@ -0,0 +1,79 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "sync" + "time" + + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/rpc" +) + +func main() { + var stopOnce sync.Once + stopChan := make(chan struct{}, 1) + + // Create an ngrok agent with RPC handler + agent, err := ngrok.NewAgent( + ngrok.WithAuthtoken(os.Getenv("NGROK_AUTHTOKEN")), + ngrok.WithRPCHandler(func(ctx context.Context, session ngrok.AgentSession, req rpc.Request) ([]byte, error) { + // Handle different RPC methods + switch req.Method() { + case rpc.StopAgentMethod: + stopOnce.Do(func() { + go func() { + // wait a second to ensure that the process has time to + // respond to the server before shutting down + time.Sleep(time.Second) + close(stopChan) + }() + }) + // In a real application, you might want to do some cleanup + // Return nil error to acknowledge the command + return nil, nil + + case rpc.RestartAgentMethod: + log.Println("Received restart command") + // Typically you'd want to implement restart logic here + return nil, nil + + case rpc.UpdateAgentMethod: + log.Println("Received update command") + // Implement your update logic here + return nil, nil + + default: + err := fmt.Errorf("unsupported method: %s", req.Method()) + log.Println(err) + return nil, err + } + }), + ) + if err != nil { + log.Fatalf("Error creating agent: %v", err) + } + + err = agent.Connect(context.Background()) + if err != nil { + log.Fatalf("Error connecting: %v", err) + } + + log.Printf("Agent connected and ready to handle RPC commands") + + // Create an endpoint for demonstration + listener, err := agent.Listen(context.Background()) + if err != nil { + log.Fatalf("Error creating endpoint: %v", err) + } + log.Printf("Endpoint created: %s", listener.URL()) + + // wait for a stop RPC + <-stopChan + + // Disconnect agent when done + log.Println("Shutting down...") + agent.Disconnect() +} diff --git a/examples/slog/main.go b/examples/slog/main.go deleted file mode 100644 index 0a151d93..00000000 --- a/examples/slog/main.go +++ /dev/null @@ -1,70 +0,0 @@ -// Setting up a slog logger. -// Takes the desired log level as the first CLI argument. - -package main - -import ( - "context" - "fmt" - "net/http" - "os" - "strings" - - "log/slog" - - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" - slogadapter "golang.ngrok.com/ngrok/log/slog" -) - -func usage(bin string) { - fmt.Printf("Usage: %s \n", bin) - os.Exit(1) -} - -func main() { - if len(os.Args) != 2 { - usage(os.Args[0]) - } - if err := run(context.Background(), os.Args[1]); err != nil { - slog.Error("exited with error", err) - os.Exit(1) - } -} - -var programLevel = new(slog.LevelVar) // Info by default - -func run(ctx context.Context, lvlName string) error { - switch strings.ToUpper(lvlName) { - case "DEBUG": - programLevel.Set(slog.LevelDebug) - case "INFO": - programLevel.Set(slog.LevelInfo) - case "WARN": - programLevel.Set(slog.LevelWarn) - case "ERROR": - programLevel.Set(slog.LevelError) - default: - return fmt.Errorf("invalid log level: %s", lvlName) - } - opts := &slog.HandlerOptions{Level: programLevel} - logger := slog.New(slog.NewTextHandler(os.Stdout, opts)) - slog.SetDefault(logger) - - ln, err := ngrok.Listen(ctx, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ngrok.WithLogger(slogadapter.NewLogger(slog.Default())), - ) - if err != nil { - return err - } - - slog.Info("Ingress established", "url", ln.URL()) - - return http.Serve(ln, http.HandlerFunc(handler)) -} - -func handler(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello from ngrok-go!") -} diff --git a/examples/tcp/main.go b/examples/tcp/main.go index 32b21d0c..e757288d 100644 --- a/examples/tcp/main.go +++ b/examples/tcp/main.go @@ -9,8 +9,7 @@ import ( "log" "net" - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" + "golang.ngrok.com/ngrok/v2" ) func main() { @@ -20,16 +19,11 @@ func main() { } func run(ctx context.Context) error { - ln, err := ngrok.Listen(ctx, - config.TCPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ) + ln, err := ngrok.Listen(ctx, ngrok.WithURL("tcp://")) if err != nil { return err } - - log.Println("Ingress established at:", ln.URL()) - + log.Println("Endpoint online", ln.URL()) return runListener(ctx, ln) } diff --git a/examples/traffic-policy/main.go b/examples/traffic-policy/main.go new file mode 100644 index 00000000..ecfc50ee --- /dev/null +++ b/examples/traffic-policy/main.go @@ -0,0 +1,59 @@ +// Example demonstrating how to use ngrok's traffic policy feature +// to implement rate limiting on an HTTP endpoint. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "golang.ngrok.com/ngrok/v2" +) + +func main() { + if err := run(context.Background()); err != nil { + log.Fatal(err) + } +} + +const trafficPolicy = ` +on_http_request: + - actions: + - type: rate-limit + config: + name: client-ip-limit + algorithm: sliding_window + capacity: 5 + rate: "60s" + bucket_key: + - conn.client_ip + - type: basic-auth + config: + credentials: + - username1:some-secret-1 + - username2:some-secret-2 + - type: add-headers + config: + headers: + authenticated-user: "${actions.ngrok.basic_auth.credentials.username}" +` + +func run(ctx context.Context) error { + // Create an HTTP listener with the traffic policy + ln, err := ngrok.Listen(ctx, + ngrok.WithTrafficPolicy(trafficPolicy), + ngrok.WithDescription("traffic policy example"), + ) + if err != nil { + return err + } + // Serve HTTP traffic on the ngrok endpoint + log.Println("Endpoint online", ln.URL()) + return http.Serve(ln, http.HandlerFunc(handler)) +} + +func handler(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %s\n", r.Header.Get("authenticated-user")) +} diff --git a/forward.go b/forward.go deleted file mode 100644 index 200fdcdb..00000000 --- a/forward.go +++ /dev/null @@ -1,202 +0,0 @@ -package ngrok - -import ( - "bytes" - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - "sync" - - "github.com/inconshreveable/log15/v3" - "golang.org/x/sync/errgroup" -) - -// Forwarder is a tunnel that has every connection forwarded to some URL. -type Forwarder interface { - // Information about the tunnel being forwarded - TunnelInfo - - // Close is a convenience method for calling Tunnel.CloseWithContext - // with a context that has a timeout of 5 seconds. This also allows the - // Tunnel to satisfy the io.Closer interface. - Close() error - - // CloseWithContext closes the Tunnel. Closing a tunnel is an operation - // that involves sending a "close" message over the parent session. - // Since this is a network operation, it is most correct to provide a - // context with a timeout. - CloseWithContext(context.Context) error - - // Session returns the tunnel's parent Session object that it - // was started on. - Session() Session - - // Wait blocks until the forwarding task exits (usually due to tunnel - // close), or the `context.Context` that it was started with is canceled. - Wait() error -} - -type forwarder struct { - Tunnel - mainGroup *errgroup.Group -} - -func (fwd *forwarder) Wait() error { - return fwd.mainGroup.Wait() -} - -// compile-time check that we're implementing the proper interface -var _ Forwarder = (*forwarder)(nil) - -func join(logger log15.Logger, left, right net.Conn) { - g := &sync.WaitGroup{} - g.Add(2) - go func() { - defer g.Done() - defer left.Close() - n, err := io.Copy(left, right) - logger.Debug("left join finished", "err", err, "bytes", n) - }() - go func() { - defer g.Done() - defer right.Close() - n, err := io.Copy(right, left) - logger.Debug("right join finished", "err", err, "bytes", n) - }() - g.Wait() -} - -func forwardTunnel(ctx context.Context, tun Tunnel, url *url.URL) Forwarder { - mainGroup, ctx := errgroup.WithContext(ctx) - fwdTasks := &sync.WaitGroup{} - - sess := tun.Session() - sessImpl := sess.(*sessionImpl) - logger := sessImpl.inner().Logger.New("task", "forward", "toUrl", url, "tunnelUrl", tun.URL()) - - mainGroup.Go(func() error { - for { - if ctxErr := ctx.Err(); ctxErr != nil { - return ctxErr - } - - conn, err := tun.Accept() - if err != nil { - return err - } - logger.Debug("accept connection from", "address", conn.RemoteAddr()) - fwdTasks.Add(1) - - go func() { - ngrokConn := conn.(Conn) - - backend, err := openBackend(ctx, logger, tun, ngrokConn, url) - if err != nil { - defer ngrokConn.Close() - logger.Warn("failed to connect to backend url", "error", err) - fwdTasks.Done() - return - } - - join(logger.New("url", url), ngrokConn, backend) - fwdTasks.Done() - }() - } - }) - - return &forwarder{ - Tunnel: tun, - mainGroup: mainGroup, - } -} - -// TODO: use an actual reverse proxy for http/s tunnels so that the host header gets set? -func openBackend(ctx context.Context, logger log15.Logger, tun Tunnel, tunnelConn Conn, url *url.URL) (net.Conn, error) { - host := url.Hostname() - port := url.Port() - if port == "" { - switch { - case usesTLS(url.Scheme): - port = "443" - case isHTTP(url.Scheme): - port = "80" - default: - return nil, fmt.Errorf("no default tcp port available for %s", url.Scheme) - } - logger.Debug("set default port", "port", port) - } - var appProto string - if fwdProto, ok := tun.(interface{ ForwardsProto() string }); ok { - appProto = fwdProto.ForwardsProto() - } - - // Create TLS config if necessary - var tlsConfig *tls.Config - if usesTLS(url.Scheme) { - tlsConfig = &tls.Config{ - ServerName: url.Hostname(), - Renegotiation: tls.RenegotiateOnceAsClient, - } - // If the backend is TLS and we've requested HTTP2, we'll need to - // make the backend aware of that via ALPN. - if appProto == "http2" { - logger.Debug("negotiating http/2 via alpn") - tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2", "http/1.1") - } - } - - dialer := &net.Dialer{} - address := fmt.Sprintf("%s:%s", host, port) - logger.Debug("dial backend tcp", "address", address) - - conn, err := dialer.DialContext(ctx, "tcp", address) - if err != nil { - defer tunnelConn.Close() - - // TODO: this http error is only valid for http/1.1. If the edge is - // expecting http/2, it'll end up being a proxy error instead. - // We should probably find a better way to do this that doesn't involve - // understanding http here. - if isHTTP(tunnelConn.Proto()) && appProto != "http2" { - _ = writeHTTPError(tunnelConn, err) - } - return nil, err - } - - if usesTLS(url.Scheme) && !tunnelConn.PassthroughTLS() { - logger.Debug("establishing TLS connection with backend") - return tls.Client(conn, tlsConfig), nil - } - - return conn, nil -} - -func writeHTTPError(w io.Writer, err error) error { - resp := &http.Response{} - resp.StatusCode = http.StatusBadGateway - resp.Body = io.NopCloser(bytes.NewBufferString(fmt.Sprintf("failed to connect to backend: %s", err.Error()))) - return resp.Write(w) -} - -func usesTLS(scheme string) bool { - switch strings.ToLower(scheme) { - case "https", "tls": - return true - default: - return false - } -} - -func isHTTP(scheme string) bool { - switch strings.ToLower(scheme) { - case "https", "http": - return true - default: - return false - } -} diff --git a/forward_test.go b/forward_test.go deleted file mode 100644 index b4cf8938..00000000 --- a/forward_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package ngrok - -import ( - "errors" - "io" - "net" - "testing" - - "github.com/inconshreveable/log15/v3" - "github.com/stretchr/testify/require" -) - -func TestHalfCloseJoin(t *testing.T) { - srv, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - waitSrvConn := make(chan net.Conn) - go func() { - srvConn, err := srv.Accept() - if err != nil { - panic(err) - } - waitSrvConn <- srvConn - }() - - browser, ngrokEndpoint := net.Pipe() - agent, userService := net.Pipe() - - waitJoinDone := make(chan struct{}) - go func() { - defer close(waitJoinDone) - join(log15.New(), ngrokEndpoint, agent) - }() - - _, err = browser.Write([]byte("hello world")) - require.NoError(t, err) - var b [len("hello world")]byte - _, err = userService.Read(b[:]) - require.NoError(t, err) - require.Equal(t, []byte("hello world"), b[:]) - browser.Close() - _, err = userService.Read(b[:]) - require.Truef(t, errors.Is(err, io.EOF), "io.EOF expected, got %v", err) - - <-waitJoinDone -} diff --git a/forwarder.go b/forwarder.go new file mode 100644 index 00000000..2605fcff --- /dev/null +++ b/forwarder.go @@ -0,0 +1,214 @@ +package ngrok + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/url" + "strings" + "sync" + "time" +) + +// EndpointForwarder is an Endpoint that forwards traffic to an upstream service. +type EndpointForwarder interface { + Endpoint + + // UpstreamProtocol returns the protocol used to communicate with the upstream server. + // This differs from UpstreamURL().Scheme if http2 is used. + UpstreamProtocol() string + + // UpstreamURL returns the URL that the endpoint forwards its traffic to. + UpstreamURL() url.URL + + // UpstreamTLSClientConfig returns the TLS client configuration used for upstream connections. + UpstreamTLSClientConfig() *tls.Config + + // ProxyProtocol returns the PROXY protocol version used for the endpoint. + // Returns a ProxyProtoVersion or empty string if not enabled. + ProxyProtocol() ProxyProtoVersion +} + +// endpointForwarder implements the EndpointForwarder interface. +type endpointForwarder struct { + baseEndpoint + listener *endpointListener + upstreamURL url.URL + upstreamTLSClientConfig *tls.Config + upstreamProtocol string + proxyProtocol ProxyProtoVersion + upstreamDialer Dialer +} + +// Start begins forwarding connections from the listener to the upstream URL +func (e *endpointForwarder) start(ctx context.Context) { + go e.forwardLoop(ctx) +} + +// forwardLoop is the main loop that forwards connections +func (e *endpointForwarder) forwardLoop(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + for { + select { + case <-ctx.Done(): + // Context cancelled, exit the loop + return + default: + // Accept connection with TLS termination already handled by the listener + conn, err := e.listener.Accept() + if err != nil { + // Signal done if accept fails + e.signalDone() + return + } + + // Handle the connection in a goroutine + go func() { + e.handleConnection(ctx, conn) + }() + } + } +} + +// handleConnection processes a single connection +func (e *endpointForwarder) handleConnection(ctx context.Context, conn net.Conn) { + defer conn.Close() + + // Connect to the backend server + backend, err := e.connectToBackend(ctx) + if err != nil { + // Could log the error here + return + } + defer backend.Close() + + // Copy data bidirectionally + e.join(conn, backend) +} + +// connectToBackend establishes a connection to the upstream URL +func (e *endpointForwarder) connectToBackend(ctx context.Context) (net.Conn, error) { + // Parse host and port from URL + host := e.upstreamURL.Hostname() + port := e.upstreamURL.Port() + if port == "" { + // Default ports based on scheme + switch { + case usesTLS(e.upstreamURL.Scheme): + port = "443" + case strings.ToLower(e.upstreamURL.Scheme) == "http": + port = "80" + default: + port = "80" // Default fallback + } + } + if host == "" { + host = "localhost" + } + + // Connect to the backend + address := net.JoinHostPort(host, port) + + // Use custom dialer if provided, otherwise use default dialer + dialer := e.upstreamDialer + if dialer == nil { + dialer = &net.Dialer{ + Timeout: 3 * time.Second, + } + } + + conn, err := dialer.DialContext(ctx, "tcp", address) + if err != nil { + return nil, err + } + + // For HTTPS/TLS upstreams, establish TLS + if usesTLS(e.upstreamURL.Scheme) { + config := &tls.Config{ + ServerName: e.upstreamURL.Hostname(), + } + + // Use custom TLS client config if provided + if e.upstreamTLSClientConfig != nil { + // Use the provided config as a base, but ensure ServerName is set + config = e.upstreamTLSClientConfig.Clone() + if config.ServerName == "" { + config.ServerName = e.upstreamURL.Hostname() + } + } + + // Add HTTP/2 support via ALPN if requested + if e.upstreamProtocol == "http2" { + config.NextProtos = append(config.NextProtos, "h2", "http/1.1") + } + + return tls.Client(conn, config), nil + } + + return conn, nil +} + +// join copies data bidirectionally between the two connections +func (e *endpointForwarder) join(left, right net.Conn) { + wg := &sync.WaitGroup{} + wg.Add(2) + + // Copy from left to right + go func() { + defer wg.Done() + defer right.Close() + _, _ = io.Copy(right, left) + }() + + // Copy from right to left + go func() { + defer wg.Done() + defer left.Close() + _, _ = io.Copy(left, right) + }() + + wg.Wait() +} + +func (e *endpointForwarder) Close() error { + return e.CloseWithContext(context.Background()) +} + +func (e *endpointForwarder) CloseWithContext(ctx context.Context) error { + // Close via the listener + err := e.listener.CloseWithContext(ctx) + + return wrapError(err) +} + +// UpstreamProtocol returns the protocol used to communicate with the upstream server. +func (e *endpointForwarder) UpstreamProtocol() string { + return e.upstreamProtocol +} + +// UpstreamURL returns the URL that the endpoint forwards its traffic to. +func (e *endpointForwarder) UpstreamURL() url.URL { + return e.upstreamURL +} + +// UpstreamTLSClientConfig returns the TLS client configuration used for upstream connections. +func (e *endpointForwarder) UpstreamTLSClientConfig() *tls.Config { + return e.upstreamTLSClientConfig +} + +// ProxyProtocol returns the PROXY protocol version used for the endpoint. +func (e *endpointForwarder) ProxyProtocol() ProxyProtoVersion { + return e.proxyProtocol +} + +// usesTLS checks if the provided scheme uses TLS +func usesTLS(scheme string) bool { + switch strings.ToLower(scheme) { + case "https", "tls": + return true + default: + return false + } +} diff --git a/go.mod b/go.mod index 250013a4..c3aee94d 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module golang.ngrok.com/ngrok +module golang.ngrok.com/ngrok/v2 go 1.21 @@ -9,10 +9,7 @@ require ( go.uber.org/multierr v1.11.0 golang.ngrok.com/muxado/v2 v2.0.1 golang.org/x/net v0.30.0 - golang.org/x/sync v0.8.0 google.golang.org/protobuf v1.35.1 - gopkg.in/yaml.v2 v2.4.0 - gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -24,4 +21,5 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 56a4a9d7..69c9d6f2 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,6 @@ golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= @@ -43,7 +41,5 @@ google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFyt google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go.work b/go.work deleted file mode 100644 index 1a1a1a2d..00000000 --- a/go.work +++ /dev/null @@ -1,22 +0,0 @@ -// General workspace notes: -// This is mostly here so that gopls/vscode will quit yelling at me. The modules -// seem to work just fine without it. The replace's are duplicated everywhere -// because `go mod tidy` & co seem to ignore the ones at the workspace level. -// See: https://github.com/golang/go/issues/50750. - -go 1.21 - -use ( - . - ./examples - ./log/log15 - ./log/logrus - ./log/slog - ./log/zap -) - -replace ( - golang.ngrok.com/ngrok v0.0.0 => ./ - golang.ngrok.com/ngrok/log/log15adapter v0.0.0 => ./log/log15adapter - golang.ngrok.com/ngrok/log/pgxadapter v0.0.0 => ./log/pgxadapter -) diff --git a/internal/integration_tests/agent_tls_termination_test.go b/internal/integration_tests/agent_tls_termination_test.go new file mode 100644 index 00000000..dcbe26e7 --- /dev/null +++ b/internal/integration_tests/agent_tls_termination_test.go @@ -0,0 +1,143 @@ +package integration_tests + +import ( + "crypto/tls" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestAgentTLSTerminationIntegration tests agent-based TLS termination with custom certificates +func TestAgentTLSTerminationIntegration(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + + // Generate test certificate + cert := CreateTestCertificate(t) + + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Setup synchronization primitives + handlerReady := testutil.NewSyncPoint() + requestComplete := testutil.NewSyncPoint() + messageChan := make(chan string, 1) + done := make(chan struct{}) + + // Create a TLS listener with TLS termination + config := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + listener, err := agent.Listen(ctx, ngrok.WithURL("tls://"), ngrok.WithAgentTLSTermination(config)) + require.NoError(t, err, "Failed to create listener with TLS termination") + defer listener.Close() + + // Verify the agent is configured with our TLS config + require.NotNil(t, listener.AgentTLSTermination(), "AgentTLSTermination should return our TLS config") + + // Log the endpoint URL + endpointURL := listener.URL().String() + t.Logf("TLS endpoint URL: %s", endpointURL) + + // Start a goroutine to handle incoming connections + go func() { + defer close(done) + + // Signal that we're ready to accept connections + handlerReady.Signal() + + // Accept a connection - this should be already TLS terminated + conn, err := listener.Accept() + assert.NoError(t, err, "Failed to accept connection") + if err != nil { + return + } + defer conn.Close() + t.Log("Connection accepted") + + // Handle the TCP connection using our utility function + message, err := HandleTCPConnection(t, conn) + assert.NoError(t, err, "Failed to handle TCP connection") + if err != nil { + return + } + t.Logf("Received data from client: %q", message) + messageChan <- message + + // Note: The HandleTCPConnection function has already sent a response + + // Signal that the request is complete + requestComplete.Signal() + }() + + // Wait for the handler to be ready to accept connections + handlerReady.Wait(t) + + // Expected message + expectedMessage := "TLS test payload" + + // Parse the URL to get the host and port + u, err := url.Parse(endpointURL) + require.NoError(t, err, "Failed to parse URL") + + // Make sure we have a port (default to 443 for HTTPS) + host := u.Host + if !strings.Contains(host, ":") { + host = host + ":443" + } + + // Connect to the endpoint using TLS + clientConfig := &tls.Config{ + InsecureSkipVerify: true, // Skip verification for integration test + } + + t.Logf("Connecting to TLS endpoint: %s (host: %s)", endpointURL, host) + conn, err := tls.Dial("tcp", host, clientConfig) + require.NoError(t, err, "Failed to connect to TLS endpoint") + defer conn.Close() + + // Send the test message + _, err = conn.Write([]byte(expectedMessage)) + require.NoError(t, err, "Failed to send data") + + // Read the response + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err, "Failed to read response") + respBody := string(buf[:n]) + + // Wait for the message to be received with timeout + var actualMessage string + select { + case actualMessage = <-messageChan: + // Received the message + case <-time.After(1 * time.Second): + require.Fail(t, "Timed out waiting for request processing") + } + + // Check that the message received matches what was sent + assert.Equal(t, expectedMessage, actualMessage, "Message should match what was sent") + + // Wait for the request to complete + requestComplete.Wait(t) + + // Verify the response body + expectedResponse := "Message received" + assert.Equal(t, expectedResponse, respBody, "Response body should match expected") + + // Wait for the goroutine to finish with timeout + select { + case <-done: + // Handler finished + case <-time.After(1 * time.Second): + require.Fail(t, "Timed out waiting for handler to finish") + } +} diff --git a/internal/integration_tests/endpoint_closing_test.go b/internal/integration_tests/endpoint_closing_test.go new file mode 100644 index 00000000..e2b991d4 --- /dev/null +++ b/internal/integration_tests/endpoint_closing_test.go @@ -0,0 +1,65 @@ +package integration_tests + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEndpointClosingIntegration tests closing an endpoint while an agent session is live +// and verifies that cleanup works correctly (Done channel triggered, removed from endpoint list) +func TestEndpointClosingIntegration(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Create a listener endpoint + listener, err := agent.Listen(ctx) + require.NoError(t, err, "Failed to create listener") + + // Store the endpoint URL for verification + endpointURL := listener.URL().String() + t.Logf("Created endpoint: %s", endpointURL) + + // Verify the endpoint appears in the agent's endpoints list + endpoints := agent.Endpoints() + endpointFound := false + for _, ep := range endpoints { + if ep.URL().String() == endpointURL { + endpointFound = true + break + } + } + require.True(t, endpointFound, "Endpoint should be found in agent's endpoints list") + + // Create a channel to monitor the endpoint's Done channel + endpointClosed := make(chan struct{}) + go func() { + <-listener.Done() + close(endpointClosed) + }() + + // Close the endpoint + t.Log("Closing endpoint...") + listener.Close() + + // Wait for the Done channel to be triggered with timeout + select { + case <-endpointClosed: + t.Log("Endpoint Done channel was triggered successfully") + case <-time.After(1 * time.Second): + require.Fail(t, "Timeout waiting for endpoint Done channel to be triggered") + } + + // Verify the endpoint is removed from the agent's endpoints list + endpoints = agent.Endpoints() + for _, ep := range endpoints { + assert.NotEqual(t, endpointURL, ep.URL().String(), "Endpoint should not be found in agent's endpoints list after closing") + } +} diff --git a/internal/integration_tests/error_code_test.go b/internal/integration_tests/error_code_test.go new file mode 100644 index 00000000..4d231a7d --- /dev/null +++ b/internal/integration_tests/error_code_test.go @@ -0,0 +1,30 @@ +package integration_tests + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" +) + +// TestErrorCode tests that unique ngrok error codes are properly returned +func TestErrorCode(t *testing.T) { + SkipIfOffline(t) + t.Parallel() + + agent, ctx, cancel := SetupAgent(t) + defer cancel() + + // Create an endpoint with an invalid character ('@') in its URL + _, err := agent.Listen(ctx, + ngrok.WithURL("https://invalid@domain.com"), + ) + require.Error(t, err, "Expected an error when using invalid URL") + + var ngrokErr ngrok.Error + require.True(t, errors.As(err, &ngrokErr), "Expected error to be of type ngrok.Error, got %T", err) + + errCode := ngrokErr.Code() + require.Equal(t, "ERR_NGROK_9037", errCode, "Expected error code ERR_NGROK_9037 for invalid URL") +} diff --git a/internal/integration_tests/event_handling_test.go b/internal/integration_tests/event_handling_test.go new file mode 100644 index 00000000..631f62ff --- /dev/null +++ b/internal/integration_tests/event_handling_test.go @@ -0,0 +1,120 @@ +package integration_tests + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" +) + +// TestEventHandlingIntegration tests that events are properly emitted and received +// during real connection workflows, focusing on connect and disconnect events. +func TestEventHandlingIntegration(t *testing.T) { + // Skip if not running online tests + SkipIfOffline(t) + + // Mark this test for parallel execution + t.Parallel() + + // Create channels for capturing events + connectEventCh := make(chan *ngrok.EventAgentConnectSucceeded, 1) + disconnectEventCh := make(chan *ngrok.EventAgentDisconnected, 1) + + // Create a handler that categorizes events by type + handler := func(evt ngrok.Event) { + t.Logf("Received event: %s at %v", evt.EventType(), evt.Timestamp()) + + switch e := evt.(type) { + case *ngrok.EventAgentConnectSucceeded: + select { + case connectEventCh <- e: + // Successfully sent event + default: + t.Logf("Channel full, dropping connect event") + } + case *ngrok.EventAgentDisconnected: + select { + case disconnectEventCh <- e: + // Successfully sent event + default: + t.Logf("Channel full, dropping disconnect event") + } + default: + // Log other events but don't process them + t.Logf("Received other event type: %T", evt) + } + } + + // Get authentication token from environment + authToken := os.Getenv("NGROK_AUTHTOKEN") + require.NotEmpty(t, authToken, "NGROK_AUTHTOKEN environment variable is required but not set") + + // Create and connect an agent with the event handler + agent, err := ngrok.NewAgent( + ngrok.WithAuthtoken(authToken), + ngrok.WithEventHandler(handler)) + require.NoError(t, err, "Failed to create agent") + + // Create a context with timeout for the test + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Connect the agent (should trigger a connect event) + t.Log("Connecting agent...") + err = agent.Connect(ctx) + require.NoError(t, err, "Failed to connect agent") + + // Ensure agent is disconnected at end of test + defer func() { + t.Log("Disconnecting agent...") + _ = agent.Disconnect() + }() + + // Wait for the connect event with timeout + t.Log("Waiting for connect event...") + var connectEvent *ngrok.EventAgentConnectSucceeded + select { + case connectEvent = <-connectEventCh: + t.Log("Received connect event") + case <-time.After(5 * time.Second): + require.Fail(t, "Timeout waiting for connect event") + } + + // Verify the connect event details + // Note: We can't directly compare Session objects as they have different references + // but may represent the same session + assert.Equal(t, agent, connectEvent.Agent, "Connect event should have the correct agent") + assert.False(t, connectEvent.Timestamp().IsZero(), "Connect event should have non-zero timestamp") + + // Explicitly disconnect the agent to trigger disconnect event + t.Log("Disconnecting agent...") + err = agent.Disconnect() + assert.NoError(t, err, "Agent should disconnect without error") + + // Wait for the disconnect event with timeout + t.Log("Waiting for disconnect event...") + var disconnectEvent *ngrok.EventAgentDisconnected + select { + case disconnectEvent = <-disconnectEventCh: + t.Log("Received disconnect event") + case <-time.After(5 * time.Second): + require.Fail(t, "Timeout waiting for disconnect event") + } + + // Verify the disconnect event details + // Note: We can't directly compare Session objects as they have different references + assert.Equal(t, agent, disconnectEvent.Agent, "Disconnect event should have the correct agent") + assert.False(t, disconnectEvent.Timestamp().IsZero(), "Disconnect event should have non-zero timestamp") + // For client-triggered disconnect, the error may indicate "not reconnecting, session closed" + // which is expected behavior + t.Logf("Disconnect error: %v", disconnectEvent.Error) + + // Verify events are in chronological order + assert.True(t, connectEvent.Timestamp().Before(disconnectEvent.Timestamp()), + "Connect event (%v) should be before disconnect event (%v)", + connectEvent.Timestamp(), disconnectEvent.Timestamp()) +} diff --git a/internal/integration_tests/forward_test.go b/internal/integration_tests/forward_test.go new file mode 100644 index 00000000..7dcd9cab --- /dev/null +++ b/internal/integration_tests/forward_test.go @@ -0,0 +1,83 @@ +package integration_tests + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestForward tests forwarding to a local web server +func TestForward(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Create a channel to signal when the server is ready + serverReady := testutil.NewSyncPoint() + + // Start a local HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read the request body + body, err := io.ReadAll(r.Body) + if err != nil { + assert.NoError(t, err, "Server failed to read body") + http.Error(w, "Failed to read body", http.StatusInternalServerError) + return + } + + // Echo back what was received + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("X-Received", string(body)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fmt.Sprintf("Received: %s", string(body)))) + })) + defer server.Close() + + // Signal that the server is ready + serverReady.Signal() + + // Create a channel to signal when forwarding is ready + forwarderReady := testutil.NewSyncPoint() + + // Forward to the local server + forwarder, err := agent.Forward(ctx, ngrok.WithUpstream(server.URL)) + require.NoError(t, err, "Failed to create forwarder") + defer forwarder.Close() + + // Get the ngrok URL + ngrokURL := forwarder.URL().String() + t.Logf("Forwarder URL: %s", ngrokURL) + + // Signal that forwarding is ready + forwarderReady.Signal() + + // Send a request to the ngrok URL + expectedMessage := "Hello from forward test!" + resp := MakeHTTPRequest(t, ctx, ngrokURL, expectedMessage) + defer resp.Body.Close() + + // Check the status code + assert.Equal(t, http.StatusOK, resp.StatusCode, "HTTP status should be 200 OK") + + // Check the received header + receivedHeader := resp.Header.Get("X-Received") + assert.Equal(t, expectedMessage, receivedHeader, "Header X-Received should match the message sent") + + // Read the response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + + // Verify the response body contains the expected message + expectedResponsePrefix := "Received: " + expectedMessage + assert.Contains(t, string(body), expectedResponsePrefix, "Response body should contain the expected message") +} diff --git a/internal/integration_tests/http2_test.go b/internal/integration_tests/http2_test.go new file mode 100644 index 00000000..a3eaf4ed --- /dev/null +++ b/internal/integration_tests/http2_test.go @@ -0,0 +1,159 @@ +package integration_tests + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" +) + +// TestUpstreamProtocolHTTP2 tests the WithUpstreamProtocol option +// to verify HTTP/2 connections to the upstream +func TestUpstreamProtocolHTTP2(t *testing.T) { + t.Parallel() + + // Test 1: Without specifying protocol - should default to HTTP/1.1 + t.Run("Without protocol specified", func(t *testing.T) { + t.Parallel() + + // Setup agent for this test + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Set up a test HTTP/2 server + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Report the protocol used + protoVer := "HTTP/1.1" + if r.ProtoMajor == 2 { + protoVer = "HTTP/2.0" + } + + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("X-Protocol-Version", protoVer) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fmt.Sprintf("Server used %s", protoVer))) + })) + + // Configure TLS with HTTP/2 support + srv.TLS = &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + } + + // Start the server with TLS and HTTP/2 enabled + srv.StartTLS() + defer srv.Close() + + // Create a forwarder without specifying protocol and skip cert verification + tlsPool := x509.NewCertPool() + tlsPool.AddCert(srv.Certificate()) + config := &tls.Config{ + RootCAs: tlsPool, + } + + forwarder, err := agent.Forward(ctx, + ngrok.WithUpstream(srv.URL, ngrok.WithUpstreamTLSClientConfig(config)), + ) + require.NoError(t, err, "Failed to create forwarder") + defer forwarder.Close() + + // Get the ngrok URL + ngrokURL := forwarder.URL().String() + t.Logf("Forwarder URL: %s", ngrokURL) + + // Send a request to the ngrok URL + message := "Testing HTTP version" + resp := MakeHTTPRequest(t, ctx, ngrokURL, message) + defer resp.Body.Close() + + // Check the status code + assert.Equal(t, http.StatusOK, resp.StatusCode, "HTTP status should be 200 OK") + + // Read the response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + + // Check protocol version - should be HTTP/1.1 when not specified + protoHeader := resp.Header.Get("X-Protocol-Version") + assert.Equal(t, "HTTP/1.1", protoHeader, "Protocol should be HTTP/1.1 when not specified") + + t.Logf("Response: %s", string(body)) + }) + + // Test 2: With HTTP/2 protocol specified - should use HTTP/2 + t.Run("With HTTP/2 protocol specified", func(t *testing.T) { + t.Parallel() + + // Setup agent for this test + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Set up a test HTTP/2 server + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Report the protocol used + protoVer := "HTTP/1.1" + if r.ProtoMajor == 2 { + protoVer = "HTTP/2.0" + } + + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("X-Protocol-Version", protoVer) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fmt.Sprintf("Server used %s", protoVer))) + })) + + // Configure TLS with HTTP/2 support + srv.TLS = &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + } + + // Start the server with TLS and HTTP/2 enabled + srv.StartTLS() + defer srv.Close() + + // Create a forwarder with HTTP/2 protocol and skip cert verification + tlsPool := x509.NewCertPool() + tlsPool.AddCert(srv.Certificate()) + config := &tls.Config{ + RootCAs: tlsPool, + } + + forwarder, err := agent.Forward(ctx, + ngrok.WithUpstream(srv.URL, + ngrok.WithUpstreamProtocol("http2"), + ngrok.WithUpstreamTLSClientConfig(config)), + ) + require.NoError(t, err, "Failed to create forwarder") + defer forwarder.Close() + + // Get the ngrok URL + ngrokURL := forwarder.URL().String() + t.Logf("Forwarder URL: %s", ngrokURL) + + // Send a request to the ngrok URL + message := "Testing HTTP2" + resp := MakeHTTPRequest(t, ctx, ngrokURL, message) + defer resp.Body.Close() + + // Check the status code + assert.Equal(t, http.StatusOK, resp.StatusCode, "HTTP status should be 200 OK") + + // Read the response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + + // Check protocol version - should be HTTP/2.0 when specified + protoHeader := resp.Header.Get("X-Protocol-Version") + assert.Equal(t, "HTTP/2.0", protoHeader, "Protocol should be HTTP/2.0 when specified") + + t.Logf("Response: %s", string(body)) + }) +} diff --git a/internal/integration_tests/listen_http_test.go b/internal/integration_tests/listen_http_test.go new file mode 100644 index 00000000..cdd1149f --- /dev/null +++ b/internal/integration_tests/listen_http_test.go @@ -0,0 +1,96 @@ +package integration_tests + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestListenAndHTTPRequest tests the basic functionality of listening for HTTP requests +func TestListenAndHTTPRequest(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Setup listener + listener := SetupListener(t, agent, ctx) + defer listener.Close() + + // Expected message + expectedMessage := "Hello, ngrok!" + + // Create synchronization points + handlerReady := testutil.NewSyncPoint() + requestComplete := testutil.NewSyncPoint() + messageChan := make(chan string, 1) + done := make(chan struct{}) + + // Start a goroutine to handle a single request + go func() { + defer close(done) + + // Accept a connection + t.Log("Waiting for connection...") + // Signal that we're ready to accept connections + handlerReady.Signal() + + conn, err := listener.Accept() + if err != nil { + assert.NoError(t, err, "Failed to accept connection") + return + } + defer conn.Close() + t.Log("Connection accepted") + + // Handle the HTTP request using the utility function + message, err := HandleHTTPRequest(t, conn) + assert.NoError(t, err, "Failed to handle HTTP request") + if err != nil { + return + } + messageChan <- message + + // Signal that the request is complete + requestComplete.Signal() + }() + + // Wait for the handler to be ready to accept connections + handlerReady.Wait(t) + + // Make HTTP request + resp := MakeHTTPRequest(t, ctx, listener.URL().String(), expectedMessage) + defer resp.Body.Close() + + // Wait for the message to be received with timeout + var actualMessage string + select { + case actualMessage = <-messageChan: + // Received the message + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for request processing") + } + + // Check that the message received matches what was sent + assert.Equal(t, expectedMessage, actualMessage, "Message should match what was sent") + + // Verify response status + assert.Equal(t, http.StatusOK, resp.StatusCode, "HTTP status should be 200 OK") + + // Wait for the request to complete + requestComplete.Wait(t) + + // Wait for the goroutine to finish with timeout + select { + case <-done: + // Handler finished + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for handler to finish") + } +} diff --git a/internal/integration_tests/listen_http_url_test.go b/internal/integration_tests/listen_http_url_test.go new file mode 100644 index 00000000..372329bc --- /dev/null +++ b/internal/integration_tests/listen_http_url_test.go @@ -0,0 +1,95 @@ +package integration_tests + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestListenWithHTTPURL tests using WithURL to specify an http URL +func TestListenWithHTTPURL(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Setup listener with HTTP URL + httpURL := "http://test-http.ngrok.io" + listener := SetupListener(t, agent, ctx, ngrok.WithURL(httpURL)) + defer listener.Close() + + // Verify the URL scheme is http + assert.Equal(t, "http", listener.URL().Scheme, "URL scheme should be http") + + // Expected message + expectedMessage := "HTTP Test" + + // Create synchronization points + handlerReady := testutil.NewSyncPoint() + requestComplete := testutil.NewSyncPoint() + messageChan := make(chan string, 1) + done := make(chan struct{}) + + // Start a goroutine to handle a single request + go func() { + defer close(done) + + // Accept a connection + t.Log("Waiting for connection...") + // Signal that we're ready to accept connections + handlerReady.Signal() + + conn, err := listener.Accept() + require.NoError(t, err, "Failed to accept connection") + defer conn.Close() + t.Log("Connection accepted") + + // Handle the HTTP request using the utility function + message, err := HandleHTTPRequest(t, conn) + require.NoError(t, err, "Failed to handle HTTP request") + messageChan <- message + + // Signal that the request is complete + requestComplete.Signal() + }() + + // Wait for the handler to be ready to accept connections + handlerReady.Wait(t) + + // Make HTTP request + resp := MakeHTTPRequest(t, ctx, listener.URL().String(), expectedMessage) + defer resp.Body.Close() + + // Wait for the message to be received with timeout + var actualMessage string + select { + case actualMessage = <-messageChan: + // Received the message + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for request processing") + } + + // Check that the message received matches what was sent + assert.Equal(t, expectedMessage, actualMessage, "Message should match what was sent") + + // Verify response status + assert.Equal(t, http.StatusOK, resp.StatusCode, "HTTP status should be 200 OK") + + // Wait for the request to complete + requestComplete.Wait(t) + + // Wait for the goroutine to finish with timeout + select { + case <-done: + // Handler finished + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for handler to finish") + } +} diff --git a/internal/integration_tests/listen_https_test.go b/internal/integration_tests/listen_https_test.go new file mode 100644 index 00000000..610713bd --- /dev/null +++ b/internal/integration_tests/listen_https_test.go @@ -0,0 +1,101 @@ +package integration_tests + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestListenWithHTTPSURL tests using WithURL to specify an https URL +func TestListenWithHTTPSURL(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Setup listener with HTTPS URL + httpsURL := "https://test-https.ngrok.io" + listener := SetupListener(t, agent, ctx, ngrok.WithURL(httpsURL)) + defer listener.Close() + + // Verify the URL scheme is https + assert.Equal(t, "https", listener.URL().Scheme, "URL scheme should be https") + + // Expected message + expectedMessage := "HTTPS Test" + + // Create synchronization points + handlerReady := testutil.NewSyncPoint() + requestComplete := testutil.NewSyncPoint() + messageChan := make(chan string, 1) + done := make(chan struct{}) + + // Start a goroutine to handle a single request + go func() { + defer close(done) + + // Accept a connection + t.Log("Waiting for connection...") + // Signal that we're ready to accept connections + handlerReady.Signal() + + conn, err := listener.Accept() + assert.NoError(t, err, "Failed to accept connection") + if err != nil { + return + } + defer conn.Close() + t.Log("Connection accepted") + + // Handle the HTTP request using the utility function + message, err := HandleHTTPRequest(t, conn) + assert.NoError(t, err, "Failed to handle HTTP request") + if err != nil { + return + } + messageChan <- message + + // Signal that the request is complete + requestComplete.Signal() + }() + + // Wait for the handler to be ready to accept connections + handlerReady.Wait(t) + + // Make HTTP request + resp := MakeHTTPRequest(t, ctx, listener.URL().String(), expectedMessage) + defer resp.Body.Close() + + // Wait for the message to be received with timeout + var actualMessage string + select { + case actualMessage = <-messageChan: + // Received the message + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for request processing") + } + + // Check that the message received matches what was sent + assert.Equal(t, expectedMessage, actualMessage, "Message should match what was sent") + + // Verify response status + assert.Equal(t, http.StatusOK, resp.StatusCode, "HTTP status should be 200 OK") + + // Wait for the request to complete + requestComplete.Wait(t) + + // Wait for the goroutine to finish with timeout + select { + case <-done: + // Handler finished + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for handler to finish") + } +} diff --git a/internal/integration_tests/listen_tcp_test.go b/internal/integration_tests/listen_tcp_test.go new file mode 100644 index 00000000..ac0be742 --- /dev/null +++ b/internal/integration_tests/listen_tcp_test.go @@ -0,0 +1,110 @@ +package integration_tests + +import ( + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestListenAndTCPConnection tests the basic functionality of listening for TCP connections +func TestListenAndTCPConnection(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Setup TCP listener using the TCP scheme + listener := SetupListener(t, agent, ctx, ngrok.WithURL("tcp://")) + defer listener.Close() + + // Expected message + expectedMessage := "Hello, TCP!" + + // Create synchronization points + handlerReady := testutil.NewSyncPoint() + requestComplete := testutil.NewSyncPoint() + messageChan := make(chan string, 1) + done := make(chan struct{}) + + // Start a goroutine to handle a single connection + go func() { + defer close(done) + + // Accept a connection + t.Log("Waiting for connection...") + // Signal that we're ready to accept connections + handlerReady.Signal() + + conn, err := listener.Accept() + assert.NoError(t, err, "Failed to accept connection") + if err != nil { + return + } + defer conn.Close() + t.Log("Connection accepted") + + // Handle TCP connection using utility function + message, err := HandleTCPConnection(t, conn) + assert.NoError(t, err, "Failed to handle TCP connection") + if err != nil { + return + } + messageChan <- message + + // Signal that the request is complete + requestComplete.Signal() + }() + + // Wait for the handler to be ready to accept connections + handlerReady.Wait(t) + + // Make TCP connection and send data + // Extract host and port from the URL + hostPort := listener.URL().Host + t.Logf("Connecting to TCP endpoint: %s", hostPort) + conn, err := MakeTCPConnection(t, ctx, hostPort) + require.NoError(t, err, "Failed to connect to TCP endpoint") + defer conn.Close() + + // Send test message + _, err = conn.Write([]byte(expectedMessage)) + require.NoError(t, err, "Failed to send data") + + // Wait for the message to be received with timeout + var actualMessage string + select { + case actualMessage = <-messageChan: + // Received the message + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for message processing") + } + + // Check that the message received matches what was sent + assert.Equal(t, expectedMessage, actualMessage, "Message should match what was sent") + + // Read response + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.True(t, err == nil || err == io.EOF, "Failed to read response: %v", err) + response := string(buf[:n]) + expectedResponse := "Message received" + assert.Equal(t, expectedResponse, response, "Response should match expected") + + // Wait for the request to complete + requestComplete.Wait(t) + + // Wait for the goroutine to finish with timeout + select { + case <-done: + // Handler finished + case <-time.After(500 * time.Millisecond): + require.Fail(t, "Timed out waiting for handler to finish") + } +} diff --git a/internal/integration_tests/proxy_proto_test.go b/internal/integration_tests/proxy_proto_test.go new file mode 100644 index 00000000..0fdb40bd --- /dev/null +++ b/internal/integration_tests/proxy_proto_test.go @@ -0,0 +1,432 @@ +package integration_tests + +import ( + "bufio" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// parseProxyProtocolHeader extracts client and server information from a PROXY protocol header. +func parseProxyProtocolHeader(reader *bufio.Reader) (srcAddr, dstAddr net.Addr, err error) { + // Read the first line from the connection + header, err := reader.ReadString('\n') + if err != nil { + return nil, nil, fmt.Errorf("failed to read PROXY header: %v", err) + } + + // Trim trailing newline + header = strings.TrimSuffix(header, "\r\n") + header = strings.TrimSuffix(header, "\n") + + // Split the header into parts + parts := strings.Split(header, " ") + if len(parts) < 6 || parts[0] != "PROXY" { + return nil, nil, fmt.Errorf("invalid PROXY protocol header: %s", header) + } + + // Extract information + proto := parts[1] // TCP4 or TCP6 + srcIP := parts[2] // Source IP + dstIP := parts[3] // Destination IP + srcPort := parts[4] // Source port + dstPort := parts[5] // Destination port + + // Parse ports + srcPortInt, err := strconv.Atoi(srcPort) + if err != nil { + return nil, nil, fmt.Errorf("invalid source port: %v", err) + } + dstPortInt, err := strconv.Atoi(dstPort) + if err != nil { + return nil, nil, fmt.Errorf("invalid destination port: %v", err) + } + + // Create addresses + if proto == "TCP4" { + srcAddr = &net.TCPAddr{IP: net.ParseIP(srcIP), Port: srcPortInt} + dstAddr = &net.TCPAddr{IP: net.ParseIP(dstIP), Port: dstPortInt} + } else if proto == "TCP6" { + srcAddr = &net.TCPAddr{IP: net.ParseIP(srcIP), Port: srcPortInt} + dstAddr = &net.TCPAddr{IP: net.ParseIP(dstIP), Port: dstPortInt} + } else { + return nil, nil, fmt.Errorf("unsupported protocol: %s", proto) + } + + return srcAddr, dstAddr, nil +} + +// bufferedConn combines a net.Conn with a bufio.Reader to implement net.Conn interface. +type bufferedConn struct { + r *bufio.Reader + c net.Conn +} + +func (b *bufferedConn) Read(p []byte) (int, error) { + return b.r.Read(p) +} + +func (b *bufferedConn) Write(p []byte) (int, error) { + return b.c.Write(p) +} + +func (b *bufferedConn) Close() error { + return b.c.Close() +} + +// Required net.Conn interface methods that delegate to the underlying connection +func (b *bufferedConn) LocalAddr() net.Addr { + return b.c.LocalAddr() +} + +func (b *bufferedConn) RemoteAddr() net.Addr { + return b.c.RemoteAddr() +} + +func (b *bufferedConn) SetDeadline(t time.Time) error { + return b.c.SetDeadline(t) +} + +func (b *bufferedConn) SetReadDeadline(t time.Time) error { + return b.c.SetReadDeadline(t) +} + +func (b *bufferedConn) SetWriteDeadline(t time.Time) error { + return b.c.SetWriteDeadline(t) +} + +// verifyClientAddr checks that the client address received via PROXY protocol is valid. +func verifyClientAddr(t *testing.T, clientAddr net.Addr) { + require.NotNil(t, clientAddr, "Client address should not be nil") + + t.Logf("Received client address via PROXY protocol: %s", clientAddr.String()) + + // We can't verify exact IP matches in a public test environment, + // but we can verify that something reasonable came through + tcpAddr, ok := clientAddr.(*net.TCPAddr) + assert.True(t, ok, "Expected TCP address, got %T", clientAddr) + if !ok { + return + } + + // Log the client IP for manual verification + t.Logf("Client IP via PROXY protocol: %s", tcpAddr.IP.String()) + + // If we're testing locally, this might be a loopback, but in CI it should be a public IP + // For this test, we just verify we got something non-nil + assert.False(t, tcpAddr.IP == nil || tcpAddr.IP.String() == "", "Expected non-empty IP address") +} + +// handleTLSConnection handles a TLS connection with PROXY protocol already read +func handleTLSConnection(t *testing.T, conn net.Conn, reader *bufio.Reader, srcAddr net.Addr) { + // Create a server TLS certificate for the handshake + servCert := CreateTestCertificate(t) + + // Create TLS configuration for server + config := &tls.Config{ + Certificates: []tls.Certificate{*servCert}, + } + + // Use the remaining buffer as the source for the TLS connection + // Create a TLS server connection + tlsConn := tls.Server(&bufferedConn{reader, conn}, config) + defer tlsConn.Close() + + // Perform TLS handshake + if err := tlsConn.Handshake(); err != nil { + t.Logf("TLS handshake failed: %v", err) + return + } + + // Read data from the TLS connection + buffer := make([]byte, 1024) + n, err := tlsConn.Read(buffer) + if err != nil && err != io.EOF { + assert.NoError(t, err, "Error reading from TLS connection") + return + } + + if n > 0 { + clientMsg := string(buffer[:n]) + t.Logf("Received client message over TLS: %s", clientMsg) + + // Send a response back to the client over TLS + response := fmt.Sprintf("Received data with PROXY protocol from %s over TLS", srcAddr) + _, err := tlsConn.Write([]byte(response)) + assert.NoError(t, err, "Failed to write TLS response") + } +} + +// handleHTTPConnection handles an HTTP connection with PROXY protocol already read +func handleHTTPConnection(t *testing.T, conn net.Conn, reader *bufio.Reader, srcAddr net.Addr) { + // Read plaintext after the PROXY header + buffer := make([]byte, 1024) + n, err := reader.Read(buffer) + if err != nil && err != io.EOF { + assert.NoError(t, err, "Error reading from connection") + return + } + + if n > 0 { + clientMsg := string(buffer[:n]) + t.Logf("Received client message: %s", clientMsg) + + // For HTTP/HTTPS endpoints, send a proper HTTP response + message := fmt.Sprintf("Received data with PROXY protocol from %s", srcAddr) + response := fmt.Sprintf( + "HTTP/1.1 200 OK\r\n"+ + "Content-Type: text/plain\r\n"+ + "Content-Length: %d\r\n"+ + "Connection: close\r\n"+ + "\r\n"+ + "%s", + len(message), message) + _, err := conn.Write([]byte(response)) + assert.NoError(t, err, "Failed to write HTTP response") + t.Logf("Sent HTTP response with status 200 OK") + } +} + +// handleTCPConnection handles a TCP connection with PROXY protocol already read +func handleTCPConnection(t *testing.T, conn net.Conn, reader *bufio.Reader, srcAddr net.Addr) { + // Read plaintext after the PROXY header + buffer := make([]byte, 1024) + n, err := reader.Read(buffer) + if err != nil && err != io.EOF { + assert.NoError(t, err, "Error reading from connection") + return + } + + if n > 0 { + clientMsg := string(buffer[:n]) + t.Logf("Received client message: %s", clientMsg) + + // For TCP endpoints, send plain text response + response := fmt.Sprintf("Received data with PROXY protocol from %s", srcAddr) + _, err := conn.Write([]byte(response)) + assert.NoError(t, err, "Failed to write response") + } +} + +// connectHTTPSClient connects to an HTTPS endpoint using an HTTP client +func connectHTTPSClient(t *testing.T, endpointURL string) { + // Use MakeHTTPRequest to send test message + message := "Test message for PROXY protocol" + resp := MakeHTTPRequest(t, context.Background(), endpointURL, message) + defer resp.Body.Close() + + // Read the response + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + t.Logf("Response: %s", string(respBody)) +} + +// connectTCPClient connects to a TCP endpoint using a direct TCP connection +func connectTCPClient(t *testing.T, endpointURL string) { + // For TCP, use direct TCP connection + u, err := url.Parse(endpointURL) + require.NoError(t, err, "Failed to parse URL") + + // Connect to the endpoint using MakeTCPConnection + clientConn, err := MakeTCPConnection(t, context.Background(), u.Host) + require.NoError(t, err, "Failed to connect to TCP endpoint") + defer clientConn.Close() + + // Send test message + testMessage := "Test message for PROXY protocol" + _, err = clientConn.Write([]byte(testMessage)) + require.NoError(t, err, "Failed to send data") + + // Read response + buffer := make([]byte, 1024) + n, err := clientConn.Read(buffer) + require.NoError(t, err, "Failed to read response") + response := string(buffer[:n]) + t.Logf("Received response: %s", response) +} + +// connectTLSClient connects to a TLS endpoint using a TLS client +func connectTLSClient(t *testing.T, endpointURL string) { + // For TLS, use TLS client + u, err := url.Parse(endpointURL) + require.NoError(t, err, "Failed to parse URL") + + // Make sure we have a port + host := u.Host + if !strings.Contains(host, ":") { + host = host + ":443" + } + + // Connect using TLS as required for TLS endpoints + config := &tls.Config{ + InsecureSkipVerify: true, // Skip verification for testing + } + + // Establish a proper TLS connection + clientConn, err := tls.Dial("tcp", host, config) + require.NoError(t, err, "Failed to connect with TLS") + defer clientConn.Close() + + // Send test message over the TLS connection + testMessage := "Test message for PROXY protocol TLS endpoint" + _, err = clientConn.Write([]byte(testMessage)) + require.NoError(t, err, "Failed to send data over TLS") + + // Read response + buffer := make([]byte, 1024) + n, err := clientConn.Read(buffer) + require.NoError(t, err, "Failed to read response from TLS connection") + response := string(buffer[:n]) + t.Logf("Received response from TLS endpoint: %s", response) +} + +// TestProxyProtoIntegration tests PROXY protocol functionality with each supported protocol +func TestProxyProtoIntegration(t *testing.T) { + // Skip if not running online tests + SkipIfOffline(t) + + // Define the schemes to test + // Test TCP, TLS, and HTTPS with PROXY protocol + // Note: For HTTPS endpoints, ngrok terminates TLS at their edge, so our listener receives plain HTTP + schemes := []string{"tcp://", "tls://", "https://"} + + for _, s := range schemes { + // Create a subtest for each scheme + scheme := s // Local copy to avoid loop variable capture + t.Run(scheme, func(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Create synchronization points + handlerReady := testutil.NewSyncPoint() + clientConnected := testutil.NewSyncPoint() + requestComplete := testutil.NewSyncPoint() + + // Channel to pass client address information + clientAddrChan := make(chan net.Addr, 1) + + // Create a server listener + serverListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "Failed to create server listener") + defer serverListener.Close() + + // Extract server address to use for upstream + serverAddr := serverListener.Addr().String() + t.Logf("Local server running at: %s", serverAddr) + + // Start a goroutine to handle incoming connections on the local server + go func() { + // Signal that we're ready to accept connections + handlerReady.Signal() + + // Accept a connection + conn, err := serverListener.Accept() + assert.NoError(t, err, "Failed to accept connection") + if err != nil { + return + } + defer conn.Close() + t.Log("Connection accepted by local server") + + // Wait for client to connect before parsing PROXY header + clientConnected.Wait(t) + + // Create a buffered reader for the connection + reader := bufio.NewReader(conn) + + // Parse the PROXY protocol header + srcAddr, dstAddr, err := parseProxyProtocolHeader(reader) + assert.NoError(t, err, "Error parsing PROXY protocol header") + if err != nil { + return + } + + // Log header details + t.Logf("PROXY header parsed: src=%s, dst=%s", srcAddr, dstAddr) + + // Send the source address to the channel for verification + clientAddrChan <- srcAddr + + // Handle connection based on endpoint type + switch { + case strings.HasPrefix(scheme, "tls"): + handleTLSConnection(t, conn, reader, srcAddr) + case strings.HasPrefix(scheme, "https"): + handleHTTPConnection(t, conn, reader, srcAddr) + default: // TCP + handleTCPConnection(t, conn, reader, srcAddr) + } + + // Signal that the request processing is complete + requestComplete.Signal() + }() + + // Wait for the handler to be ready to accept connections + handlerReady.Wait(t) + + // Create a forwarder with PROXY protocol version 1 enabled + // Format the upstream URL properly + upstreamURL := fmt.Sprintf("tcp://%s", serverAddr) + upstream := ngrok.WithUpstream(upstreamURL, + ngrok.WithUpstreamProxyProto(ngrok.ProxyProtoV1), // Version 1 (text format) + ) + forwarder, err := agent.Forward(ctx, upstream, + ngrok.WithURL(scheme), + ) + require.NoError(t, err, "Failed to create forwarder with PROXY protocol") + defer forwarder.Close() + + // Verify the forwarder has PROXY protocol enabled + proxyProto := forwarder.ProxyProtocol() + require.Equal(t, ngrok.ProxyProtoV1, proxyProto, "ProxyProtocol should be ProxyProtoV1") + t.Logf("Proxy protocol enabled: %s", proxyProto) + + // Log the endpoint URL + endpointURL := forwarder.URL().String() + t.Logf("Endpoint URL: %s", endpointURL) + + // Signal that the client is about to connect + clientConnected.Signal() + + // Connect to the endpoint with appropriate client based on scheme + switch { + case strings.HasPrefix(scheme, "https"): + connectHTTPSClient(t, endpointURL) + case strings.HasPrefix(scheme, "tls"): + connectTLSClient(t, endpointURL) + default: // TCP + connectTCPClient(t, endpointURL) + } + + // Wait for the client address with timeout + var clientAddr net.Addr + select { + case clientAddr = <-clientAddrChan: + // Verify the client address + verifyClientAddr(t, clientAddr) + case <-time.After(2 * time.Second): + require.Fail(t, "Timed out waiting for client address") + } + + // Wait for request completion + requestComplete.Wait(t) + }) + } +} diff --git a/internal/integration_tests/test_utils.go b/internal/integration_tests/test_utils.go new file mode 100644 index 00000000..01985156 --- /dev/null +++ b/internal/integration_tests/test_utils.go @@ -0,0 +1,243 @@ +package integration_tests + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "io" + "math/big" + "net" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" +) + +// SkipIfOffline skips the test if NGROK_TEST_ONLINE environment variable is not set +func SkipIfOffline(t *testing.T) { + if os.Getenv("NGROK_TEST_ONLINE") == "" { + t.Skip("Skipping online test because NGROK_TEST_ONLINE is not set") + } +} + +// SetupAgent creates and connects a new agent for testing +func SetupAgent(t *testing.T) (ngrok.Agent, context.Context, context.CancelFunc) { + // Skip if not running online tests + SkipIfOffline(t) + + // Get authentication token from environment + authToken := os.Getenv("NGROK_AUTHTOKEN") + require.NotEmpty(t, authToken, "NGROK_AUTHTOKEN environment variable is required but not set") + + // Create a new agent for each test + agent, err := ngrok.NewAgent( + ngrok.WithAuthtoken(authToken), + ) + require.NoError(t, err, "Failed to create agent") + + // Start a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + + // Connect the agent + err = agent.Connect(ctx) + require.NoError(t, err, "Failed to connect agent") + + return agent, ctx, cancel +} + +// SetupListener sets up an ngrok listener with the specified options +func SetupListener(t *testing.T, agent ngrok.Agent, ctx context.Context, opts ...ngrok.EndpointOption) ngrok.EndpointListener { + // Create a listener endpoint + listener, err := agent.Listen(ctx, opts...) + require.NoError(t, err, "Failed to create listener") + + // Get the URL of the endpoint + endpointURL := listener.URL().String() + t.Logf("Endpoint URL: %s", endpointURL) + + return listener +} + +// MakeHTTPRequest makes an HTTP request to the specified URL with the given message +func MakeHTTPRequest(t *testing.T, ctx context.Context, url string, message string) *http.Response { + // Create a custom transport that doesn't reuse connections + transport := &http.Transport{ + DisableKeepAlives: true, + } + + // Create a client with the custom transport + client := &http.Client{Transport: transport} + + // Make the request + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(message)) + require.NoError(t, err, "Failed to create request") + + t.Logf("Making HTTP request to %s", url) + resp, err := client.Do(req) + require.NoError(t, err, "Failed to send request") + + return resp +} + +// WaitForForwarderReady polls the forwarder endpoint until it responds or times out +func WaitForForwarderReady(t *testing.T, url string) { + client := &http.Client{Timeout: 100 * time.Millisecond} + for start := time.Now(); time.Since(start) < 500*time.Millisecond; { + resp, err := client.Get(url) + if err == nil { + resp.Body.Close() + return + } + time.Sleep(10 * time.Millisecond) + } + t.Logf("Forwarder endpoint didn't become ready in expected time, continuing anyway") +} + +// CreateTestCertificate creates a certificate for testing +func CreateTestCertificate(t *testing.T) *tls.Certificate { + // Generate a self-signed certificate for testing + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err, "Failed to generate private key") + + templ := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &templ, &templ, &privKey.PublicKey, privKey) + require.NoError(t, err, "Failed to create certificate") + + cert := tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: privKey, + } + + return &cert +} + +// MakeTCPConnection establishes a TCP connection to the given address +func MakeTCPConnection(t *testing.T, ctx context.Context, address string) (io.ReadWriteCloser, error) { + t.Helper() + // Use a simple net.Dialer to connect to the TCP address + dialer := &net.Dialer{ + Timeout: 500 * time.Millisecond, + } + conn, err := dialer.DialContext(ctx, "tcp", address) + if err != nil { + return nil, err + } + return conn, nil +} + +// HandleHTTPRequest processes an HTTP request from a connection and sends a response +func HandleHTTPRequest(t *testing.T, conn net.Conn) (string, error) { + t.Helper() + // Create a buffered reader for the connection + reader := bufio.NewReader(conn) + + // Read the HTTP request + request, err := http.ReadRequest(reader) + if err != nil { + return "", fmt.Errorf("failed to read HTTP request: %w", err) + } + + // Read the request body + body, err := io.ReadAll(request.Body) + if err != nil { + return "", fmt.Errorf("failed to read request body: %w", err) + } + message := string(body) + + // Send a response + response := http.Response{ + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + response.Header.Set("Content-Type", "text/plain") + response.Body = io.NopCloser(strings.NewReader("Request received")) + + if err := response.Write(conn); err != nil { + return message, fmt.Errorf("failed to write response: %w", err) + } + + return message, nil +} + +// HandleTCPConnection reads data from a TCP connection and sends a response +func HandleTCPConnection(t *testing.T, conn io.ReadWriteCloser) (string, error) { + t.Helper() + // Read data from the connection + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + return "", fmt.Errorf("failed to read data: %w", err) + } + + message := string(buf[:n]) + + // Send a response + response := "Message received" + if _, err := conn.Write([]byte(response)); err != nil { + return message, fmt.Errorf("failed to write response: %w", err) + } + + return message, nil +} + +// HandleTLSConnection handles a TLS server connection +func HandleTLSConnection(t *testing.T, conn net.Conn, cert *tls.Certificate) (string, error) { + t.Helper() + // Create TLS configuration for server + config := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + + // Create a TLS server connection + tlsConn := tls.Server(conn, config) + defer tlsConn.Close() + + // Perform TLS handshake + if err := tlsConn.Handshake(); err != nil { + return "", fmt.Errorf("TLS handshake failed: %w", err) + } + + // Read data from the TLS connection + buffer := make([]byte, 1024) + n, err := tlsConn.Read(buffer) + if err != nil && err != io.EOF { + return "", fmt.Errorf("error reading from TLS connection: %w", err) + } + + message := "" + if n > 0 { + message = string(buffer[:n]) + + // Send a response back to the client over TLS + response := "TLS message received" + if _, err := tlsConn.Write([]byte(response)); err != nil { + return message, fmt.Errorf("failed to write TLS response: %w", err) + } + } + + return message, nil +} diff --git a/internal/integration_tests/upstream_dialer_test.go b/internal/integration_tests/upstream_dialer_test.go new file mode 100644 index 00000000..75379271 --- /dev/null +++ b/internal/integration_tests/upstream_dialer_test.go @@ -0,0 +1,88 @@ +package integration_tests + +import ( + "context" + "errors" + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// erroringDialer implements the ngrok.Dialer interface for testing +// It returns an error and signals when it's called +type erroringDialer struct { + syncPoint *testutil.SyncPoint // Synchronization using testutil +} + +// newErroringDialer creates a new erroringDialer with synchronization +func newErroringDialer() *erroringDialer { + return &erroringDialer{ + syncPoint: testutil.NewSyncPoint(), + } +} + +// Dial implements the ngrok.Dialer interface +func (d *erroringDialer) Dial(network, address string) (net.Conn, error) { + // Signal that the dialer was called + d.syncPoint.Signal() + return nil, errors.New("custom dialer test error") +} + +// DialContext implements the ngrok.Dialer interface +func (d *erroringDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + // Signal that the dialer was called + d.syncPoint.Signal() + return nil, errors.New("custom dialer test error") +} + +// WaitForCall waits for the dialer to be called with a specified timeout +func (d *erroringDialer) WaitForCall(t testing.TB, timeout time.Duration) { + success := d.syncPoint.WaitTimeout(t, timeout) + require.True(t, success, "Timed out waiting for dialer to be called") +} + +// TestUpstreamDialer tests the WithUpstreamDialer functionality +func TestUpstreamDialer(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Create a custom dialer that returns an error and has synchronization + customDialer := newErroringDialer() + + // Use any arbitrary URL, the dialer will be called but will fail + // We're only testing that our dialer gets invoked + forwarder, err := agent.Forward(ctx, + ngrok.WithUpstream("http://example.com", ngrok.WithUpstreamDialer(customDialer)), + ) + require.NoError(t, err, "Failed to create forwarder") + defer forwarder.Close() + + // Get the ngrok URL + ngrokURL := forwarder.URL().String() + t.Logf("Forwarder URL: %s", ngrokURL) + + // Now make a request to trigger the dialer + t.Logf("Making request to trigger upstream connection...") + // The request will fail, but we'll ignore that since we expect it to fail + // We're just triggering the ngrok service to use our dialer + go func() { + _, _ = http.Get(ngrokURL) + }() + + // Wait for our dialer to be called with a timeout + t.Log("Waiting for dialer to be called...") + customDialer.WaitForCall(t, 3*time.Second) + + // If we got here, the test passed (WaitForCall would have failed if the dialer wasn't called) + t.Log("Custom dialer was successfully called") +} diff --git a/internal/integration_tests/url_pooling_test.go b/internal/integration_tests/url_pooling_test.go new file mode 100644 index 00000000..c8d85cb0 --- /dev/null +++ b/internal/integration_tests/url_pooling_test.go @@ -0,0 +1,320 @@ +package integration_tests + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testutil" +) + +// TestListenWithURLAndPooling tests load balancing across two endpoints with the same URL +func TestListenWithURLAndPooling(t *testing.T) { + // Mark this test for parallel execution + t.Parallel() + + // Setup agent + agent, ctx, cancel := SetupAgent(t) + defer cancel() + defer func() { _ = agent.Disconnect() }() + + // Common URL for both endpoints - IMPORTANT: the exact same string must be used for both listeners + sharedURL := "https://test-lb.ngrok.io" + + // Create sync points for coordination + listenersReady := testutil.NewSyncPoint() + requestedFinished := testutil.NewSyncPoint() + + // Setup first listener with pooling enabled + listener1 := SetupListener(t, agent, ctx, ngrok.WithURL(sharedURL), ngrok.WithPoolingEnabled(true)) + defer listener1.Close() + + // Setup second listener with the same URL and pooling enabled + listener2 := SetupListener(t, agent, ctx, ngrok.WithURL(sharedURL), ngrok.WithPoolingEnabled(true)) + defer listener2.Close() + + // Log URLs for debugging + t.Logf("Listener1 URL: %s, Pooling: %v", listener1.URL().String(), listener1.PoolingEnabled()) + t.Logf("Listener2 URL: %s, Pooling: %v", listener2.URL().String(), listener2.PoolingEnabled()) + + // Verify both have the same URL - but note that load balancing can work even with different returned URLs + // since the WithURL value is what's important for ngrok's backend pooling, not the returned URL + if listener1.URL().String() != listener2.URL().String() { + t.Logf("Warning: URLs don't match exactly, but load balancing may still work: %s and %s", + listener1.URL().String(), listener2.URL().String()) + } + + // Track which endpoint receives each request + var ( + mu sync.Mutex + endpoint1Requests int + endpoint2Requests int + wg sync.WaitGroup + endpoint1Ready = testutil.NewSyncPoint() + endpoint2Ready = testutil.NewSyncPoint() + processingDone = make(chan struct{}) + testFinished = make(chan struct{}) // Signal that test is finished so goroutines can ignore errors + ) + + // Start handlers for both listeners + // Handler for first listener + wg.Add(1) + go func() { + defer wg.Done() + defer close(processingDone) + + // Signal that we're ready to accept connections + endpoint1Ready.Signal() + + for i := 0; i < 5; i++ { // Handle up to 5 connections + conn, err := listener1.Accept() + if err != nil { + // Check if test is already finished before reporting errors + select { + case <-testFinished: + // Test is done, just return silently + return + default: + // Test still running, check the error type + if strings.Contains(err.Error(), "listener closed") { + t.Log("Listener1 closed") + return + } + t.Logf("Listener1 accept error: %v", err) + return + } + } + + // Process in a new goroutine + go func(conn net.Conn) { + defer conn.Close() + + // Track this request for listener1 + mu.Lock() + endpoint1Requests++ + mu.Unlock() + + // Handle the HTTP request + request, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + t.Errorf("Failed to read HTTP request: %v", err) + return + } + + // Read the request body + _, err = io.ReadAll(request.Body) + if err != nil { + t.Errorf("Failed to read request body: %v", err) + return + } + + // Send a response with endpoint identifier + response := http.Response{ + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + response.Header.Set("Content-Type", "text/plain") + response.Header.Set("X-Endpoint", "endpoint1") + response.Body = io.NopCloser(strings.NewReader("Response from endpoint 1")) + + if err := response.Write(conn); err != nil { + t.Errorf("Failed to write response: %v", err) + } + }(conn) + } + }() + + // Handler for second listener + wg.Add(1) + go func() { + defer wg.Done() + + // Signal that we're ready to accept connections + endpoint2Ready.Signal() + + for i := 0; i < 5; i++ { // Handle up to 5 connections + conn, err := listener2.Accept() + if err != nil { + // Check if test is already finished before reporting errors + select { + case <-testFinished: + // Test is done, just return silently + return + default: + // Test still running, check the error type + if strings.Contains(err.Error(), "listener closed") { + t.Log("Listener2 closed") + return + } + t.Logf("Listener2 accept error: %v", err) + return + } + } + + // Process in a new goroutine + go func(conn net.Conn) { + defer conn.Close() + + // Track this request for listener2 + mu.Lock() + endpoint2Requests++ + mu.Unlock() + + // Handle the HTTP request + request, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + t.Errorf("Failed to read HTTP request: %v", err) + return + } + + // Read the request body + _, err = io.ReadAll(request.Body) + if err != nil { + t.Errorf("Failed to read request body: %v", err) + return + } + + // Send a response with endpoint identifier + response := http.Response{ + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + response.Header.Set("Content-Type", "text/plain") + response.Header.Set("X-Endpoint", "endpoint2") + response.Body = io.NopCloser(strings.NewReader("Response from endpoint 2")) + + if err := response.Write(conn); err != nil { + t.Errorf("Failed to write response: %v", err) + } + }(conn) + } + }() + + // Wait for both endpoints to be ready to accept connections + endpoint1Ready.Wait(t) + endpoint2Ready.Wait(t) + + // Signal that listeners are ready + listenersReady.Signal() + + // Create a channel to signal when both endpoints have received at least one request + bothEndpointsHit := make(chan struct{}) + maxRequests := 20 // Safety limit to prevent infinite loop + requestCount := 0 + url := listener1.URL().String() // Both listeners have the same URL + + // Start a goroutine to monitor when both endpoints have been hit + go func() { + for { + mu.Lock() + ep1Hit := endpoint1Requests > 0 + ep2Hit := endpoint2Requests > 0 + mu.Unlock() + + if ep1Hit && ep2Hit { + close(bothEndpointsHit) + return + } + time.Sleep(50 * time.Millisecond) + } + }() + + // Send requests until both endpoints have been hit or we reach max requests + for requestCount < maxRequests { + select { + case <-bothEndpointsHit: + // Both endpoints have received at least one request + t.Log("Both endpoints have received requests") + goto testComplete + default: + // Send another request + requestCount++ + message := fmt.Sprintf("Request %d", requestCount) + + // Make HTTP request with a new connection each time + resp := MakeHTTPRequest(t, ctx, url, message) + + // Read the response to see which endpoint responded + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Failed to read response body: %v", err) + } + t.Logf("Response %d: %s, Header: %s", requestCount, string(body), resp.Header.Get("X-Endpoint")) + + resp.Body.Close() + time.Sleep(50 * time.Millisecond) // Small delay between requests + } + } + + // If we reach here, we hit the max requests without both endpoints receiving traffic + require.Fail(t, fmt.Sprintf("Sent %d requests but both endpoints weren't hit", maxRequests)) + +testComplete: + + // Signal that all requests are finished + requestedFinished.Signal() + + // Signal that we're about to close listeners - this will help handle error reporting + doneProcessing := make(chan struct{}) + go func() { + // Close the listeners to stop the handler goroutines + listener1.Close() + listener2.Close() + close(doneProcessing) + }() + + // Wait for processing to finish with timeout + select { + case <-processingDone: + // Processing completed + case <-doneProcessing: + // Listeners closed + case <-time.After(500 * time.Millisecond): + t.Log("Processing timeout - continuing with verification") + } + + // Verify that both endpoints received requests + mu.Lock() + endpoint1Count := endpoint1Requests + endpoint2Count := endpoint2Requests + mu.Unlock() + + t.Logf("Endpoint 1 received %d requests", endpoint1Count) + t.Logf("Endpoint 2 received %d requests", endpoint2Count) + + // Both endpoints should have received at least one request + assert.NotZero(t, endpoint1Count, "Endpoint 1 should receive at least one request") + assert.NotZero(t, endpoint2Count, "Endpoint 2 should receive at least one request") + + // Wait for handlers to finish with timeout + c := make(chan struct{}) + go func() { + wg.Wait() + close(c) + }() + + select { + case <-c: + // Handlers finished + case <-time.After(500 * time.Millisecond): + t.Log("Timed out waiting for handlers to finish") + } + + // Signal that the test is completely finished so goroutines can clean up + close(testFinished) +} diff --git a/internal/legacy/VERSION b/internal/legacy/VERSION new file mode 100644 index 00000000..227cea21 --- /dev/null +++ b/internal/legacy/VERSION @@ -0,0 +1 @@ +2.0.0 diff --git a/config/app_protocol.go b/internal/legacy/config/app_protocol.go similarity index 70% rename from config/app_protocol.go rename to internal/legacy/config/app_protocol.go index a0714875..c6223b79 100644 --- a/config/app_protocol.go +++ b/internal/legacy/config/app_protocol.go @@ -6,19 +6,12 @@ func (ap appProtocol) ApplyHTTP(cfg *httpOptions) { cfg.commonOpts.ForwardsProto = string(ap) } -func (ap appProtocol) ApplyLabeled(cfg *labeledOptions) { - cfg.commonOpts.ForwardsProto = string(ap) -} - // WithAppProtocol declares the protocol that the upstream service speaks. // This may be used by the ngrok edge to make decisions regarding protocol // upgrades or downgrades. // // Currently, `http2` is the only valid string, and will cause connections // received from HTTP endpoints to always use HTTP/2. -func WithAppProtocol(proto string) interface { - HTTPEndpointOption - LabeledTunnelOption -} { +func WithAppProtocol(proto string) HTTPEndpointOption { return appProtocol(proto) } diff --git a/config/bindings.go b/internal/legacy/config/bindings.go similarity index 100% rename from config/bindings.go rename to internal/legacy/config/bindings.go diff --git a/config/bindings_test.go b/internal/legacy/config/bindings_test.go similarity index 100% rename from config/bindings_test.go rename to internal/legacy/config/bindings_test.go diff --git a/config/common.go b/internal/legacy/config/common.go similarity index 89% rename from config/common.go rename to internal/legacy/config/common.go index ee0ba3bb..f363d856 100644 --- a/config/common.go +++ b/internal/legacy/config/common.go @@ -1,8 +1,7 @@ package config type commonOpts struct { - // Restrictions placed on the origin of incoming connections to the edge. - CIDRRestrictions *cidrRestrictions + // The version of PROXY protocol to use with this tunnel, zero if not // using. ProxyProto ProxyProtoVersion @@ -28,8 +27,6 @@ type commonOpts struct { // change-of-protocol happening at our edge. ForwardsProto string - // DEPRECATED: use TrafficPolicy instead. - Policy *policy // Policy that define rules that should be applied to incoming or outgoing // connections to the edge. TrafficPolicy string diff --git a/config/config_test.go b/internal/legacy/config/config_test.go similarity index 97% rename from config/config_test.go rename to internal/legacy/config/config_test.go index c547b1eb..ccc874a5 100644 --- a/config/config_test.go +++ b/internal/legacy/config/config_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) // Helper to assert a whole slice to a different type. diff --git a/config/description.go b/internal/legacy/config/description.go similarity index 79% rename from config/description.go rename to internal/legacy/config/description.go index c8742718..eb0b4a8b 100644 --- a/config/description.go +++ b/internal/legacy/config/description.go @@ -6,7 +6,6 @@ func WithDescription(name string) interface { HTTPEndpointOption TCPEndpointOption TLSEndpointOption - LabeledTunnelOption } { return descriptionOption(name) } @@ -22,7 +21,3 @@ func (opt descriptionOption) ApplyTLS(opts *tlsOptions) { func (opt descriptionOption) ApplyTCP(opts *tcpOptions) { opts.Description = string(opt) } - -func (opt descriptionOption) ApplyLabeled(opts *labeledOptions) { - opts.Description = string(opt) -} diff --git a/config/forwards_to.go b/internal/legacy/config/forwards_to.go similarity index 92% rename from config/forwards_to.go rename to internal/legacy/config/forwards_to.go index d4ae12bd..227cd74b 100644 --- a/config/forwards_to.go +++ b/internal/legacy/config/forwards_to.go @@ -36,10 +36,6 @@ func (fwd forwardsToOption) ApplyTLS(cfg *tlsOptions) { fwd.ApplyCommon(&cfg.commonOpts) } -func (fwd forwardsToOption) ApplyLabeled(cfg *labeledOptions) { - fwd.ApplyCommon(&cfg.commonOpts) -} - func defaultForwardsTo() string { hostname, err := os.Hostname() if err != nil { diff --git a/internal/legacy/config/http.go b/internal/legacy/config/http.go new file mode 100644 index 00000000..ecaa1cfb --- /dev/null +++ b/internal/legacy/config/http.go @@ -0,0 +1,105 @@ +package config + +import ( + "net/http" + "net/url" + + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" +) + +type HTTPEndpointOption interface { + ApplyHTTP(cfg *httpOptions) +} + +type httpOptionFunc func(cfg *httpOptions) + +func (of httpOptionFunc) ApplyHTTP(cfg *httpOptions) { + of(cfg) +} + +// HTTPEndpoint constructs a new set options for a HTTP endpoint. +// +// https://ngrok.com/docs/http/ +func HTTPEndpoint(opts ...HTTPEndpointOption) Tunnel { + cfg := httpOptions{} + for _, opt := range opts { + opt.ApplyHTTP(&cfg) + } + return &cfg +} + +type httpOptions struct { + // Common tunnel configuration options. + commonOpts + + // The scheme that this edge should use. + // Defaults to [SchemeHTTPS]. + Scheme Scheme + + // If non-nil, start a goroutine which runs this http server + // accepting connections from the http tunnel + // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. + httpServer *http.Server + + // Auto-rewrite host header on ListenAndForward? + RewriteHostHeader bool +} + +func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint { + opts := &proto.HTTPEndpoint{ + URL: cfg.URL, + } + + opts.ProxyProto = proto.ProxyProto(cfg.commonOpts.ProxyProto) + + opts.TrafficPolicy = cfg.TrafficPolicy + + return opts +} + +func (cfg httpOptions) ForwardsProto() string { + return cfg.commonOpts.ForwardsProto +} + +func (cfg httpOptions) ForwardsTo() string { + return cfg.commonOpts.getForwardsTo() +} + +func (cfg *httpOptions) WithForwardsTo(url *url.URL) { + cfg.commonOpts.ForwardsTo = url.Host +} + +func (cfg httpOptions) Extra() proto.BindExtra { + return proto.BindExtra{ + Name: cfg.Name, + Metadata: cfg.Metadata, + Description: cfg.Description, + Bindings: cfg.Bindings, + PoolingEnabled: cfg.PoolingEnabled, + } +} + +func (cfg httpOptions) Proto() string { + if cfg.Scheme == "" { + return string(SchemeHTTPS) + } + return string(cfg.Scheme) +} + +func (cfg httpOptions) Opts() any { + return cfg.toProtoConfig() +} + +func (cfg httpOptions) Labels() map[string]string { + return nil +} + +func (cfg httpOptions) HTTPServer() *http.Server { + return cfg.httpServer +} + +// compile-time check that we're implementing the proper interfaces. +var _ interface { + tunnelConfigPrivate + Tunnel +} = (*httpOptions)(nil) diff --git a/internal/legacy/config/http_handler.go b/internal/legacy/config/http_handler.go new file mode 100644 index 00000000..b7635fd9 --- /dev/null +++ b/internal/legacy/config/http_handler.go @@ -0,0 +1,8 @@ +package config + +type Options interface { + HTTPEndpointOption + TLSEndpointOption + TCPEndpointOption + CommonOption +} diff --git a/config/http_test.go b/internal/legacy/config/http_test.go similarity index 88% rename from config/http_test.go rename to internal/legacy/config/http_test.go index 5e2cf3fe..08fad7c6 100644 --- a/config/http_test.go +++ b/internal/legacy/config/http_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) func TestHTTP(t *testing.T) { diff --git a/config/metadata.go b/internal/legacy/config/metadata.go similarity index 81% rename from config/metadata.go rename to internal/legacy/config/metadata.go index b8d8a972..f42798ab 100644 --- a/config/metadata.go +++ b/internal/legacy/config/metadata.go @@ -5,7 +5,6 @@ func WithMetadata(meta string) interface { HTTPEndpointOption TCPEndpointOption TLSEndpointOption - LabeledTunnelOption } { return metadataOption(meta) } @@ -23,7 +22,3 @@ func (meta metadataOption) ApplyTCP(cfg *tcpOptions) { func (meta metadataOption) ApplyTLS(cfg *tlsOptions) { cfg.Metadata = string(meta) } - -func (meta metadataOption) ApplyLabeled(cfg *labeledOptions) { - cfg.Metadata = string(meta) -} diff --git a/internal/legacy/config/policy.go b/internal/legacy/config/policy.go new file mode 100644 index 00000000..6360ee79 --- /dev/null +++ b/internal/legacy/config/policy.go @@ -0,0 +1,28 @@ +package config + +// No imports needed + +type trafficPolicy string + +// WithTrafficPolicy configures this edge with the provided policy configuration +// passed as a json or yaml string and overwrites any previously-set traffic policy. +// https://ngrok.com/docs/http/traffic-policy +func WithTrafficPolicy(str string) interface { + HTTPEndpointOption + TLSEndpointOption + TCPEndpointOption +} { + return trafficPolicy(str) +} + +func (p trafficPolicy) ApplyTLS(opts *tlsOptions) { + opts.TrafficPolicy = string(p) +} + +func (p trafficPolicy) ApplyHTTP(opts *httpOptions) { + opts.TrafficPolicy = string(p) +} + +func (p trafficPolicy) ApplyTCP(opts *tcpOptions) { + opts.TrafficPolicy = string(p) +} diff --git a/config/pooling_enabled.go b/internal/legacy/config/pooling_enabled.go similarity index 100% rename from config/pooling_enabled.go rename to internal/legacy/config/pooling_enabled.go diff --git a/config/proxy_proto.go b/internal/legacy/config/proxy_proto.go similarity index 100% rename from config/proxy_proto.go rename to internal/legacy/config/proxy_proto.go diff --git a/config/proxy_proto_test.go b/internal/legacy/config/proxy_proto_test.go similarity index 95% rename from config/proxy_proto_test.go rename to internal/legacy/config/proxy_proto_test.go index 9eb55b10..12033181 100644 --- a/config/proxy_proto_test.go +++ b/internal/legacy/config/proxy_proto_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" _ "embed" ) diff --git a/config/scheme.go b/internal/legacy/config/scheme.go similarity index 100% rename from config/scheme.go rename to internal/legacy/config/scheme.go diff --git a/config/scheme_test.go b/internal/legacy/config/scheme_test.go similarity index 91% rename from config/scheme_test.go rename to internal/legacy/config/scheme_test.go index c12b610c..8e812d9b 100644 --- a/config/scheme_test.go +++ b/internal/legacy/config/scheme_test.go @@ -3,7 +3,7 @@ package config import ( "testing" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) func TestScheme(t *testing.T) { diff --git a/config/tcp.go b/internal/legacy/config/tcp.go similarity index 82% rename from config/tcp.go rename to internal/legacy/config/tcp.go index f422964b..7af935a6 100644 --- a/config/tcp.go +++ b/internal/legacy/config/tcp.go @@ -4,7 +4,7 @@ import ( "net/http" "net/url" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) type TCPEndpointOption interface { @@ -33,25 +33,14 @@ type tcpOptions struct { // Common tunnel configuration options. commonOpts - // The TCP address to request for this edge. - RemoteAddr string // An HTTP Server to run traffic on // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. httpServer *http.Server } -// Set the TCP address to request for this edge. -func WithRemoteAddr(addr string) TCPEndpointOption { - return tcpOptionFunc(func(cfg *tcpOptions) { - cfg.RemoteAddr = addr - }) -} - func (cfg *tcpOptions) toProtoConfig() *proto.TCPEndpoint { return &proto.TCPEndpoint{ URL: cfg.URL, - Addr: cfg.RemoteAddr, - IPRestriction: cfg.commonOpts.CIDRRestrictions.toProtoConfig(), ProxyProto: proto.ProxyProto(cfg.commonOpts.ProxyProto), TrafficPolicy: cfg.commonOpts.TrafficPolicy, } diff --git a/config/tcp_test.go b/internal/legacy/config/tcp_test.go similarity index 77% rename from config/tcp_test.go rename to internal/legacy/config/tcp_test.go index 934856bd..857032c6 100644 --- a/config/tcp_test.go +++ b/internal/legacy/config/tcp_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) func TestTCP(t *testing.T) { @@ -22,12 +22,12 @@ func TestTCP(t *testing.T) { }, { name: "remote addr", - opts: TCPEndpoint(WithRemoteAddr("0.tcp.ngrok.io:1234")), + opts: TCPEndpoint(WithURL("tcp://0.tcp.ngrok.io:1234")), expectProto: ptr("tcp"), expectLabels: nil, expectOpts: func(t *testing.T, opts *proto.TCPEndpoint) { require.NotNil(t, opts) - require.Equal(t, "0.tcp.ngrok.io:1234", opts.Addr) + require.Equal(t, "tcp://0.tcp.ngrok.io:1234", opts.URL) }, }, } diff --git a/assets/ngrok.ca.crt b/internal/legacy/config/testdata/ngrok.ca.crt similarity index 100% rename from assets/ngrok.ca.crt rename to internal/legacy/config/testdata/ngrok.ca.crt diff --git a/config/tls.go b/internal/legacy/config/tls.go similarity index 79% rename from config/tls.go rename to internal/legacy/config/tls.go index 7e91bdec..4a5c4796 100644 --- a/config/tls.go +++ b/internal/legacy/config/tls.go @@ -1,12 +1,11 @@ package config import ( - "crypto/x509" "net/http" "net/url" - "golang.ngrok.com/ngrok/internal/pb" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/pb" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) type TLSEndpointOption interface { @@ -35,17 +34,6 @@ type tlsOptions struct { // Common tunnel options commonOpts - // The fqdn to request for this edge. - Domain string - - // Note: these are "the old way", and shouldn't actually be used. Their - // setters are both deprecated. - Hostname string - Subdomain string - - // Certificates to use for client authentication at the ngrok edge. - MutualTLSCA []*x509.Certificate - // True if the TLS connection should be terminated at the ngrok edge. terminateAtEdge bool // The key to use for TLS termination at the ngrok edge in PEM format. @@ -62,17 +50,11 @@ type tlsOptions struct { func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint { opts := &proto.TLSEndpoint{ URL: cfg.URL, - Domain: cfg.Domain, ProxyProto: proto.ProxyProto(cfg.ProxyProto), - Subdomain: cfg.Subdomain, - Hostname: cfg.Hostname, } - opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig() opts.TrafficPolicy = cfg.commonOpts.TrafficPolicy - opts.MutualTLSAtEdge = mutualTLSEndpointOption(cfg.MutualTLSCA).toProtoConfig() - // When terminate-at-edge is set the TLSTermination must be sent even if the key and cert are nil, // this will default to the ngrok edge's automatically provisioned keypair. if cfg.terminateAtEdge { diff --git a/config/tls_test.go b/internal/legacy/config/tls_test.go similarity index 83% rename from config/tls_test.go rename to internal/legacy/config/tls_test.go index 0c88fd96..2f06c36f 100644 --- a/config/tls_test.go +++ b/internal/legacy/config/tls_test.go @@ -3,7 +3,7 @@ package config import ( "testing" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) func TestTLS(t *testing.T) { diff --git a/config/tunnel_config.go b/internal/legacy/config/tunnel_config.go similarity index 93% rename from config/tunnel_config.go rename to internal/legacy/config/tunnel_config.go index 02a3682d..57331101 100644 --- a/config/tunnel_config.go +++ b/internal/legacy/config/tunnel_config.go @@ -3,7 +3,7 @@ package config import ( "net/url" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) // Tunnel is a marker interface for options that can be used to start diff --git a/config/url.go b/internal/legacy/config/url.go similarity index 100% rename from config/url.go rename to internal/legacy/config/url.go diff --git a/internal/legacy/errors.go b/internal/legacy/errors.go new file mode 100644 index 00000000..8a6c3da6 --- /dev/null +++ b/internal/legacy/errors.go @@ -0,0 +1,164 @@ +package legacy + +import ( + "fmt" + "net/url" + "strings" +) + +// Error is an error enriched with a specific ErrorCode. +// All ngrok error codes are documented at https://ngrok.com/docs/errors. +// +// An [Error] can be extracted from a generic error using [errors.As]. +// +// Example: +// +// var nerr ngrok.Error +// if errors.As(err, &nerr) { +// fmt.Printf("%s: %s\n", nerr.ErrorCode(), nerr.Msg()) +// } +type Error interface { + error + // Msg returns the error string without the error code. + Msg() string + // ErrorCode returns the ngrok error code, if one exists. + ErrorCode() string +} + +// Errors arising from authentication failure. +type errAuthFailed struct { + // Whether the error was generated by the remote server, or in the sending + // of the authentication request. + Remote bool + // The underlying error. + Inner error +} + +func (e errAuthFailed) Error() string { + var msg string + if e.Remote { + msg = "authentication failed" + } else { + msg = "failed to send authentication request" + } + + return fmt.Sprintf("%s: %v", msg, e.Inner) +} + +func (e errAuthFailed) Unwrap() error { + return e.Inner +} + +func (e errAuthFailed) Is(target error) bool { + _, ok := target.(errAuthFailed) + return ok +} + +// The error returned by [Tunnel]'s [net.Listener.Accept] method. +type errAcceptFailed struct { + // The underlying error. + Inner error +} + +func (e errAcceptFailed) Error() string { + return fmt.Sprintf("failed to accept connection: %v", e.Inner) +} + +func (e errAcceptFailed) Unwrap() error { + return e.Inner +} + +func (e errAcceptFailed) Is(target error) bool { + _, ok := target.(errAcceptFailed) + return ok +} + +// Errors arising from a failure to start a tunnel. +type errListen struct { + // The underlying error. + Inner error +} + +func (e errListen) Error() string { + return fmt.Sprintf("failed to start tunnel: %v", e.Inner) +} + +func (e errListen) Unwrap() error { + return e.Inner +} + +func (e errListen) Is(target error) bool { + _, ok := target.(errListen) + return ok +} + +// Errors arising from a failure to construct a [golang.org/x/net/proxy.Dialer] from a [url.URL]. +type errProxyInit struct { + // The provided proxy URL. + URL *url.URL + // The underlying error. + Inner error +} + +func (e errProxyInit) Error() string { + return fmt.Sprintf("failed to construct proxy dialer from \"%s\": %v", e.URL.String(), e.Inner) +} + +func (e errProxyInit) Unwrap() error { + return e.Inner +} + +func (e errProxyInit) Is(target error) bool { + _, ok := target.(errProxyInit) + return ok +} + +// Error arising from a failure to dial the ngrok server. +type errSessionDial struct { + // The address to which a connection was attempted. + Addr string + // The underlying error. + Inner error +} + +func (e errSessionDial) Error() string { + return fmt.Sprintf("failed to dial ngrok server with address \"%s\": %v", e.Addr, e.Inner) +} + +func (e errSessionDial) Unwrap() error { + return e.Inner +} + +func (e errSessionDial) Is(target error) bool { + _, ok := target.(errSessionDial) + return ok +} + +// Generic ngrok error that requires no parsing +type ngrokError struct { + Message string + ErrCode string +} + +const errUrl = "https://ngrok.com/docs/errors" + +func (m ngrokError) Error() string { + out := m.Message + if m.ErrCode != "" { + out = fmt.Sprintf("%s\n\n%s/%s", out, errUrl, strings.ToLower(m.ErrCode)) + } + return out +} + +func (m ngrokError) Msg() string { + return m.Message +} + +func (m ngrokError) ErrorCode() string { + return m.ErrCode +} + +func (e ngrokError) Is(target error) bool { + _, ok := target.(ngrokError) + return ok +} diff --git a/errors_test.go b/internal/legacy/errors_test.go similarity index 95% rename from errors_test.go rename to internal/legacy/errors_test.go index 1a562f8b..bad6fd3d 100644 --- a/errors_test.go +++ b/internal/legacy/errors_test.go @@ -1,4 +1,4 @@ -package ngrok +package legacy import ( "errors" @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) // Sanity check for the approach to error construction/wrapping diff --git a/internal/legacy/logging.go b/internal/legacy/logging.go new file mode 100644 index 00000000..23ed5fa2 --- /dev/null +++ b/internal/legacy/logging.go @@ -0,0 +1,56 @@ +package legacy + +import ( + "github.com/inconshreveable/log15/v3" + "log/slog" +) + +// SlogToLog15 converts a slog.Logger to a log15.Logger for use with the legacy package +func SlogToLog15(slogger *slog.Logger) log15.Logger { + logger := log15.New() + logger.SetHandler(&slogHandler{logger: slogger}) + return logger +} + +// slogHandler implements log15.Handler interface to adapt slog.Logger +type slogHandler struct { + logger *slog.Logger +} + +// Log implements log15.Handler interface +func (h *slogHandler) Log(r log15.Record) error { + // Convert log15 context to slog attributes + attrs := make([]any, 0, len(r.Ctx)) + for i := 0; i < len(r.Ctx); i += 2 { + if i+1 < len(r.Ctx) { + key, ok := r.Ctx[i].(string) + if !ok { + key = "unknown_key" + } + attrs = append(attrs, key, r.Ctx[i+1]) + } + } + + // Map log15 levels to slog levels + switch r.Lvl { + case log15.LvlCrit, log15.LvlError: + h.logger.Error(r.Msg, attrs...) + case log15.LvlWarn: + h.logger.Warn(r.Msg, attrs...) + case log15.LvlInfo: + h.logger.Info(r.Msg, attrs...) + case log15.LvlDebug, log15.LvlDebug + 1: // Handle trace level too + h.logger.Debug(r.Msg, attrs...) + default: + h.logger.Info(r.Msg, append(attrs, "original_level", r.Lvl)...) + } + + return nil +} + +// defaultLogger returns a no-op logger that discards all messages +func defaultLogger() log15.Logger { + logger := log15.New() + logger.SetHandler(log15.DiscardHandler()) + return logger +} diff --git a/config/testdata/ngrok.ca.crt b/internal/legacy/ngrok.ca.crt similarity index 100% rename from config/testdata/ngrok.ca.crt rename to internal/legacy/ngrok.ca.crt diff --git a/internal/legacy/online_test.go b/internal/legacy/online_test.go new file mode 100644 index 00000000..05e35179 --- /dev/null +++ b/internal/legacy/online_test.go @@ -0,0 +1,353 @@ +package legacy + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "testing" + "time" + + "github.com/inconshreveable/log15/v3" + "github.com/stretchr/testify/require" + "golang.org/x/net/websocket" + + "golang.ngrok.com/ngrok/v2/internal/legacy/config" +) + +// testLogger is a simple wrapper around log15.Logger that logs to the test's output +func newTestLogger(t *testing.T) log15.Logger { + logger := log15.New() + // Create a custom handler that writes to test output + handler := log15.FuncHandler(func(r log15.Record) error { + // Add test name to context + ctx := append(r.Ctx, "test", t.Name()) + + // Format and log to test output + t.Logf("%s [%s] %s %v", + time.Now().Format(time.RFC3339), + r.Lvl, + r.Msg, + fmtLogContext(ctx)) + return nil + }) + + logger.SetHandler(handler) + return logger +} + +// Helper to format log context as a map for display +func fmtLogContext(ctx []interface{}) map[string]interface{} { + result := make(map[string]interface{}) + for i := 0; i < len(ctx); i += 2 { + if i+1 < len(ctx) { + key, ok := ctx[i].(string) + if !ok { + key = fmt.Sprintf("%v", ctx[i]) + } + result[key] = ctx[i+1] + } + } + return result +} + +func expectChanError(t *testing.T, ch <-chan error, timeout time.Duration) { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case err := <-ch: + require.Error(t, err) + case <-timer.C: + t.Error("timeout while waiting on error channel") + } +} + +func skipUnless(t *testing.T, varname string, message ...any) { + if os.Getenv(varname) == "" && os.Getenv("NGROK_TEST_ALL") == "" { + t.Skip(message...) + } +} + +func onlineTest(t *testing.T) { + skipUnless(t, "NGROK_TEST_ONLINE", "Skipping online test") + // This is an annoying quirk of the free account limitations. It looks like + // the tests run quickly enough in series that they trigger simultaneous + // session errors for free accounts. "Something something eventual + // consistency" most likely. + require.NotEmpty(t, os.Getenv("NGROK_AUTHTOKEN"), "Online tests require an authtoken.") +} + +func setupSession(ctx context.Context, t *testing.T, opts ...ConnectOption) Session { + onlineTest(t) + opts = append(opts, WithAuthtoken(os.Getenv("NGROK_AUTHTOKEN")), WithLogger(newTestLogger(t))) + sess, err := Connect(ctx, opts...) + require.NoError(t, err, "Session Connect") + return sess +} + +func startTunnel(ctx context.Context, t *testing.T, sess Session, opts config.Tunnel) Tunnel { + onlineTest(t) + tun, err := sess.Listen(ctx, opts) + require.NoError(t, err, "Listen") + return tun +} + +var helloHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + _ = r.Body.Close() + _, _ = fmt.Fprintln(rw, "Hello, world!") +}) + +func serveHTTP(ctx context.Context, t *testing.T, connectOpts []ConnectOption, opts config.Tunnel, handler http.Handler) (Tunnel, <-chan error) { + sess := setupSession(ctx, t, connectOpts...) + + tun := startTunnel(ctx, t, sess, opts) + exited := make(chan error) + + go func() { + exited <- http.Serve(tun, handler) + sess.Close() + }() + return tun, exited +} + +func TestTunnel(t *testing.T) { + ctx := context.Background() + sess := setupSession(ctx, t) + + tun := startTunnel(ctx, t, sess, config.HTTPEndpoint( + config.WithMetadata("Hello, world!"), + config.WithForwardsTo("some application"))) + + require.NotEmpty(t, tun.URL(), "Tunnel URL") + require.Equal(t, "Hello, world!", tun.Metadata()) + require.Equal(t, "some application", tun.ForwardsTo()) + tun.Close() + sess.Close() +} + +func TestTunnelConnMetadata(t *testing.T) { + ctx := context.Background() + sess := setupSession(ctx, t) + + tun := startTunnel(ctx, t, sess, config.HTTPEndpoint()) + + go func() { + resp, _ := http.Get(tun.URL()) + if resp != nil { + _ = resp.Body.Close() + } + }() + + conn, err := tun.Accept() + require.NoError(t, err) + + proxyconn, ok := conn.(Conn) + require.True(t, ok, "conn doesn't implement proxy conn interface") + + require.Equal(t, "https", proxyconn.Proto()) + tun.Close() + sess.Close() +} + +// *testing.T wrapper to force `require` to Fail() then panic() rather than +// FailNow(). Permits better flow control in test functions. +type failPanic struct { + t *testing.T +} + +func (f failPanic) Errorf(format string, args ...interface{}) { + f.t.Errorf(format, args...) +} + +func (f failPanic) FailNow() { + f.t.Fail() + panic("test failed") +} + +func TestTCP(t *testing.T) { + onlineTest(t) + ctx := context.Background() + + opts := config.TCPEndpoint() + + // Easier to test by pretending it's HTTP on this end. + tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) + + url, err := url.Parse(tun.URL()) + require.NoError(t, err) + url.Scheme = "http" + resp, err := http.Get(url.String()) + require.NoError(t, err, "GET tunnel url") + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Read response body") + + require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") + + require.NoError(t, tun.CloseWithContext(ctx)) + expectChanError(t, exited, 5*time.Second) +} + +func TestConnectionCallbacks(t *testing.T) { + // Don't run this one by default - it's timing-sensitive and prone to flakes + skipUnless(t, "NGROK_TEST_FLAKEY", "Skipping flakey network test") + + ctx := context.Background() + connects := 0 + disconnectErrs := 0 + disconnectNils := 0 + sess := setupSession(ctx, t, + WithConnectHandler(func(ctx context.Context, sess Session) { + connects++ + }), + WithDisconnectHandler(func(ctx context.Context, sess Session, err error) { + if err == nil { + disconnectNils++ + } else { + disconnectErrs++ + } + }), + WithDialer(&sketchyDialer{1 * time.Second})) + + time.Sleep(2*time.Second + 500*time.Millisecond) + + _ = sess.Close() + + time.Sleep(2 * time.Second) + + require.Equal(t, 3, connects, "should've seen some connect events") + require.Equal(t, 3, disconnectErrs, "should've seen some errors from disconnecting") + require.Equal(t, 1, disconnectNils, "should've seen a final nil from disconnecting") +} + +type sketchyDialer struct { + limit time.Duration +} + +func (sd *sketchyDialer) Dial(network, addr string) (net.Conn, error) { + return sd.DialContext(context.Background(), network, addr) +} + +func (sd *sketchyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + go func() { + time.Sleep(sd.limit) + conn.Close() + }() + return conn, err +} + +func TestHeartbeatCallback(t *testing.T) { + // Don't run this one by default - it's long + skipUnless(t, "NGROK_TEST_LONG", "Skipping long network test") + + ctx := context.Background() + heartbeats := 0 + sess := setupSession(ctx, t, + WithHeartbeatHandler(func(ctx context.Context, sess Session, latency time.Duration) { + heartbeats++ + }), + WithHeartbeatInterval(10*time.Second)) + + time.Sleep(20*time.Second + 500*time.Millisecond) + + _ = sess.Close() + + require.Equal(t, 2, heartbeats, "should've seen some heartbeats") +} + +func TestPermanentErrors(t *testing.T) { + onlineTest(t) + var err error + ctx := context.Background() + token := os.Getenv("NGROK_AUTHTOKEN") + + sess, err := Connect(ctx, WithAuthtoken(token)) + require.NoError(t, err) + sess.Close() + + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + _, err = Connect(timeoutCtx, WithServer("127.0.0.234:123"), WithAuthtoken(token)) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestRetryableErrors(t *testing.T) { + onlineTest(t) + var err error + // Set global context with a longer timeout just to prevent test from hanging forever + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + // Create a custom dialer with short timeout for invalid addresses + dialer := &net.Dialer{Timeout: 500 * time.Millisecond} + + // give up on connecting after first attempt + disconnect := WithDisconnectHandler(func(_ context.Context, sess Session, disconnectErr error) { + sess.Close() + }) + connect := WithConnectHandler(func(_ context.Context, sess Session) { + sess.Close() + }) + + _, err = Connect(ctx, WithServer("127.0.0.234:123"), WithDialer(dialer), connect, disconnect) + var dialErr errSessionDial + require.ErrorIs(t, err, dialErr) + require.ErrorAs(t, err, &dialErr) + + _, err = Connect(ctx, WithAuthtoken("invalid-token"), WithDialer(dialer), connect, disconnect) + var authErr errAuthFailed + require.ErrorIs(t, err, authErr) + require.ErrorAs(t, err, &authErr) + require.True(t, authErr.Remote) +} + +func TestNonExported(t *testing.T) { + ctx := context.Background() + + sess := setupSession(ctx, t) + + require.NotEmpty(t, sess.(interface{ Region() string }).Region()) +} + +func echo(ws *websocket.Conn) { + _, _ = io.Copy(ws, ws) +} + +func TestWebsockets(t *testing.T) { + onlineTest(t) + + ctx := context.Background() + + srv := &http.ServeMux{} + srv.Handle("/", helloHandler) + srv.Handle("/ws", websocket.Handler(echo)) + + tun, errCh := serveHTTP(ctx, t, nil, config.HTTPEndpoint(config.WithScheme(config.SchemeHTTPS)), srv) + + tunnelURL, err := url.Parse(tun.URL()) + require.NoError(t, err) + + conn, err := websocket.Dial(fmt.Sprintf("wss://%s/ws", tunnelURL.Hostname()), "", tunnelURL.String()) + require.NoError(t, err) + + go func() { + _, _ = fmt.Fprintln(conn, "Hello, world!") + }() + + bufConn := bufio.NewReader(conn) + out, err := bufConn.ReadString('\n') + require.NoError(t, err) + require.Equal(t, "Hello, world!\n", out) + + conn.Close() + tun.Close() + + require.Error(t, <-errCh) +} diff --git a/internal/legacy/session.go b/internal/legacy/session.go new file mode 100644 index 00000000..f3c51b3e --- /dev/null +++ b/internal/legacy/session.go @@ -0,0 +1,872 @@ +package legacy + +import ( + "context" + "crypto/tls" + "crypto/x509" + _ "embed" // nolint + "errors" + "fmt" + "net" + "net/http" + "net/url" + "regexp" + "runtime" + "strings" + "sync/atomic" + "time" + + "github.com/inconshreveable/log15/v3" + "go.uber.org/multierr" + "golang.org/x/net/proxy" + + "golang.ngrok.com/ngrok/v2/internal/legacy/config" + + "golang.ngrok.com/muxado/v2" + tunnel_client "golang.ngrok.com/ngrok/v2/internal/tunnel/client" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" +) + +// The ngrok library version. +// +//go:embed VERSION +var libraryAgentVersion string + +// Session encapsulates an established session with the ngrok service. Sessions +// recover from network failures by automatically reconnecting. +type Session interface { + // Listen creates a new Tunnel which will listen for new inbound + // connections. The returned Tunnel object is a net.Listener. + Listen(ctx context.Context, cfg config.Tunnel) (Tunnel, error) + + // Warnings returns a list of warnings generated for the session on connect/auth + Warnings() []error + + // Close ends the ngrok session. All Tunnel objects created by Listen + // on this session will be closed. + Close() error +} + +//go:embed ngrok.ca.crt +var defaultCACert []byte + +const defaultServer = "connect.ngrok-agent.com:443" + +var leastLatencyServer = regexp.MustCompile(`^connect\.([a-z]+?-)?ngrok-agent\.com(\.lan)?:443`) + +// Dialer is the interface a custom connection dialer must implement for use +// with the [WithDialer] option. +type Dialer interface { + // Connect to an address on the named network. + // See the documentation for net.Dial. + Dial(network, address string) (net.Conn, error) + // Connect to an address on the named network with the provided + // context. + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// SessionConnectHandler is the callback type for [WithConnectHandler] +type SessionConnectHandler func(ctx context.Context, sess Session) + +// SessionDisconnectHandler is the callback type for [WithDisconnectHandler] +type SessionDisconnectHandler func(ctx context.Context, sess Session, err error) + +// SessionHeartbeatHandler is the callback type for [WithHearbeatHandler] +type SessionHeartbeatHandler func(ctx context.Context, sess Session, latency time.Duration) + +// ServerCommandHandler is the callback type for [WithStopHandler] +type ServerCommandHandler func(ctx context.Context, sess Session) error + +// ConnectOption is passed to [Connect] to customize session connection and establishment. +type ConnectOption func(*connectConfig) + +type clientInfo struct { + Type string + Version string + Comments []string +} + +var bannedUAchar = regexp.MustCompile("[^!#$%&'*+-.^_`|~0-9a-zA-Z]") + +// Formats client info as a well-formed user agent string +func (c *clientInfo) ToUserAgent() string { + comment := "" + if len(c.Comments) > 0 { + comment = fmt.Sprintf(" (%s)", strings.Join(c.Comments, "; ")) + } + return sanitizeUserAgentString(c.Type) + "/" + sanitizeUserAgentString(c.Version) + comment +} + +func sanitizeUserAgentString(s string) string { + s = strings.ReplaceAll(s, "/", "-") + s = bannedUAchar.ReplaceAllString(s, "#") + return s +} + +// version, type, user-agent +func generateUserAgent(cs []clientInfo) string { + var uas []string + + for _, c := range cs { + uas = append(uas, c.ToUserAgent()) + } + + return strings.Join(uas, " ") +} + +// Options to use when establishing the ngrok session. +type connectConfig struct { + // Your ngrok Authtoken. + Authtoken proto.ObfuscatedString + // The address of the ngrok server to connect to. + // Defaults to `connect.ngrok-agent.com:443` + ServerAddr string + // The optional addresses of the additional ngrok servers to connect to. + AdditionalServerAddrs []string + // Enable using multiple session legs + EnableMultiLeg bool + // The [tls.Config] used when connecting to the ngrok server + TLSConfigCustomizer func(*tls.Config) + // The [x509.CertPool] used to authenticate the ngrok server certificate. + CAPool *x509.CertPool + + // The [Dialer] used to establish the initial TCP connection to the ngrok + // server. + // If set, takes precedence over the ProxyURL setting. + // If not set, defaults to a [net.Dialer]. + Dialer Dialer + + // The URL of a proxy to use when making the TCP connection to the ngrok + // server. + // Any proxy supported by [golang.org/x/net/proxy] may be used. + ProxyURL *url.URL + + // Opaque metadata string to be associated with the session. + // Viewable from the ngrok dashboard or API. + Metadata string + + // Child client types and versions used to identify specific applications + // using this library to the ngrok service. + ClientInfo []clientInfo + + // HeartbeatInterval determines how often we send application level + // heartbeats to the server go check connection liveness. + HeartbeatInterval time.Duration + // HeartbeatTolerance is the duration after which an unacknowledged + // heartbeat is determined to mean the connection is dead. + HeartbeatTolerance time.Duration + + ConnectHandler SessionConnectHandler + DisconnectHandler SessionDisconnectHandler + HeartbeatHandler SessionHeartbeatHandler + + StopHandler ServerCommandHandler + RestartHandler ServerCommandHandler + UpdateHandler ServerCommandHandler + + remoteStopErr *string + remoteRestartErr *string + remoteUpdateErr *string + + // The logger for the session to use. + Logger log15.Logger +} + +// WithMetadata configures the opaque, machine-readable metadata string for this +// session. Metadata is made available to you in the ngrok dashboard and the +// Agents API resource. It is a useful way to allow you to uniquely identify +// sessions. We suggest encoding the value in a structured format like JSON. +// +// See the [metadata parameter in the ngrok docs] for additional details. +// +// [metadata parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#metadata +func WithMetadata(meta string) ConnectOption { + return func(cfg *connectConfig) { + cfg.Metadata = meta + } +} + +// WithClientInfo configures client type and version information for applications +// built on this library. This is a way for consumers of this library to identify +// themselves to the ngrok service. +// +// This will add a new entry to the `User-Agent` field in the "most significant" +// (first) position. +func WithClientInfo(clientType, version string, comments ...string) ConnectOption { + return func(cfg *connectConfig) { + cfg.ClientInfo = append([]clientInfo{{clientType, version, comments}}, cfg.ClientInfo...) + } +} + +// WithDialer configures the session to use the provided [Dialer] when +// establishing a connection to the ngrok service. This option will cause +// [WithProxyURL] to be ignored. +func WithDialer(dialer Dialer) ConnectOption { + return func(cfg *connectConfig) { + cfg.Dialer = dialer + } +} + +// WithAuthtoken configures the session to authenticate with the provided +// authtoken. You can [find your existing authtoken] or [create a new one] in the ngrok dashboard. +// +// See the [authtoken parameter in the ngrok docs] for additional details. +// +// [find your existing authtoken]: https://dashboard.ngrok.com/get-started/your-authtoken +// [create a new one]: https://dashboard.ngrok.com/tunnels/authtokens +// [authtoken parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#authtoken +func WithAuthtoken(token string) ConnectOption { + return func(cfg *connectConfig) { + cfg.Authtoken = proto.ObfuscatedString(token) + } +} + +// WithServer configures the network address to dial to connect to the ngrok +// service. Use this option only if you are connecting to a custom agent +// ingress. +// +// See the [server_addr parameter in the ngrok docs] for additional details. +// +// [server_addr parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#server_addr +func WithServer(addr string) ConnectOption { + return func(cfg *connectConfig) { + cfg.ServerAddr = addr + } +} + +// WithMultiLeg as true allows connecting to the ngrok service on secondary legs. +func WithMultiLeg(enable bool) ConnectOption { + return func(cfg *connectConfig) { + cfg.EnableMultiLeg = enable + } +} + +// WithTLSConfig allows customization of the TLS connection made from the agent +// to the ngrok service. Customization is applied after the [WithServer] and +// [WithCA] options are applied. +func WithTLSConfig(tlsCustomizer func(*tls.Config)) ConnectOption { + return func(cfg *connectConfig) { + cfg.TLSConfigCustomizer = tlsCustomizer + } +} + +// WithCA configures the CAs used to validate the TLS certificate returned by +// the ngrok service while establishing the session. Use this option only if +// you are connecting through a man-in-the-middle or deep packet inspection +// proxy. +// +// See the [root_cas parameter in the ngrok docs] for additional details. +// +// [root_cas parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#root_cas +func WithCA(pool *x509.CertPool) ConnectOption { + return func(cfg *connectConfig) { + cfg.CAPool = pool + } +} + +// WithHeartbeatTolerance configures the duration to wait for a response to a heartbeat +// before assuming the session connection is dead and attempting to reconnect. +// +// See the [heartbeat_tolerance parameter in the ngrok docs] for additional details. +// +// [heartbeat_tolerance parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#heartbeat_tolerance +func WithHeartbeatTolerance(tolerance time.Duration) ConnectOption { + return func(cfg *connectConfig) { + cfg.HeartbeatTolerance = tolerance + } +} + +// WithHeartbeatInterval configures how often the session will send heartbeat +// messages to the ngrok service to check session liveness. +// +// See the [heartbeat_interval parameter in the ngrok docs] for additional details. +// +// [heartbeat_interval parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#heartbeat_interval +func WithHeartbeatInterval(interval time.Duration) ConnectOption { + return func(cfg *connectConfig) { + cfg.HeartbeatInterval = interval + } +} + +// WithLogger configures a logger to receive log messages from the [Session]. +// Accepts a log15.Logger directly. +func WithLogger(logger log15.Logger) ConnectOption { + return func(cfg *connectConfig) { + cfg.Logger = logger + } +} + +// WithConnectHandler configures a function which is called each time the ngrok +// [Session] successfully connects to the ngrok service. Use this option to +// receive events when the [Session] successfully connects or reconnects after +// a disconnection due to network failure. +func WithConnectHandler(handler SessionConnectHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.ConnectHandler = handler + } +} + +// WithDisconnectHandler configures a function which is called each time the +// ngrok [Session] disconnects from the ngrok service. Use this option to detect +// when the ngrok session has gone temporarily offline. +// +// This handler will be called every time the [Session] encounters an error during +// or after connection. It may be called multiple times in a row; it may be +// called before any Connect handler is called and before [Connect] returns. +// +// If this function is called with a nil error, the [Session] has stopped and will +// not reconnect, usually due to [Session.Close] being called. +func WithDisconnectHandler(handler SessionDisconnectHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.DisconnectHandler = handler + } +} + +// WithHeartbeatHandler configures a function which is called each time the +// [Session] successfully heartbeats the ngrok service. The callback receives +// the latency of the round trip time from initiating the heartbeat to +// receiving an acknowledgement back from the ngrok service. +func WithHeartbeatHandler(handler SessionHeartbeatHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.HeartbeatHandler = handler + } +} + +// WithStopHandler configures a function which is called when the ngrok service +// requests that this [Session] stops. Your application may choose to interpret +// this callback as a request to terminate the [Session] or the entire process. +// +// Errors returned by this function will be visible to the ngrok dashboard or +// API as the response to the Stop operation. +// +// Do not block inside this callback. It will cause the Dashboard or API Stop +// operation to hang. Do not call [Session].Close or [os.Exit] inside this +// callback, it will also cause the operation to hang. +// +// Instead, either return an error or if you intend to Stop, spawn a goroutine +// to asynchronously call [Session].Close or [os.Exit]. +func WithStopHandler(handler ServerCommandHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.StopHandler = handler + } +} + +// WithRestartHandler configures a function which is called when the ngrok service +// requests that this [Session] restarts. Your application may choose to interpret +// this callback as a request to reconnect the [Session] or restart the entire process. +// +// Errors returned by this function will be visible to the ngrok dashboard or +// API as the response to the Restart operation. +// +// Do not block inside this callback. It will cause the Dashboard or API Restart +// operation to hang. Do not call [Session].Close or [os.Exit] inside this +// callback, it will also cause the operation to hang. +// +// Instead, either spawn a goroutine to asynchronously restart, or return an error. +func WithRestartHandler(handler ServerCommandHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.RestartHandler = handler + } +} + +// WithUpdateHandler configures a function which is called when the ngrok service +// requests that the application running this [Session] updates. Your application +// may use this callback to trigger a check for a newer version followed by an update +// and restart if one exists. +// +// Errors returned by this function will be visible to the ngrok dashboard or +// API as the response to the Update operation. +// +// Do not block inside this callback. It will cause the Dashboard or API Update +// operation to hang. Do not call [Session].Close or [os.Exit] inside this +// callback, it will also cause the operation to hang. +// +// Instead, spawn a goroutine to asynchronously handle the update process +// or return an error if there is no newer version to update to. +func WithUpdateHandler(handler ServerCommandHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.UpdateHandler = handler + } +} + +// Connect begins a new ngrok [Session] by connecting to the ngrok service, +// retrying transient failures if they occur. +// +// Connect blocks until the session is successfully established or fails with +// an error that will not be retried. Customize session connection behavior +// with [ConnectOption] arguments. +func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) { + logger := defaultLogger() + + cfg := connectConfig{} + for _, o := range opts { + o(&cfg) + } + + if cfg.Logger != nil { + logger = cfg.Logger + } + + if cfg.CAPool == nil { + cfg.CAPool = x509.NewCertPool() + cfg.CAPool.AppendCertsFromPEM(defaultCACert) + } + + if cfg.ServerAddr == "" { + cfg.ServerAddr = defaultServer + } + + var dialer Dialer + + if cfg.Dialer != nil { + dialer = cfg.Dialer + } else { + netDialer := &net.Dialer{} + + if cfg.ProxyURL != nil { + proxied, err := proxy.FromURL(cfg.ProxyURL, netDialer) + if err != nil { + return nil, errProxyInit{cfg.ProxyURL, err} + } + dialer = proxied.(Dialer) + } else { + dialer = netDialer + } + } + + heartbeatConfig := muxado.NewHeartbeatConfig() + if cfg.HeartbeatTolerance != 0 { + heartbeatConfig.Tolerance = cfg.HeartbeatTolerance + } + if cfg.HeartbeatInterval != 0 { + heartbeatConfig.Interval = cfg.HeartbeatInterval + } + + session := new(sessionImpl) + + stateChanges := make(chan error, 32) + + callbackHandler := remoteCallbackHandler{ + Logger: logger, + sess: session, + stopHandler: cfg.StopHandler, + restartHandler: cfg.RestartHandler, + updateHandler: cfg.UpdateHandler, + } + + rawDialer := func(legNumber uint32) (tunnel_client.RawSession, error) { + serverAddr := cfg.ServerAddr + if legNumber > 0 && len(cfg.AdditionalServerAddrs) >= int(legNumber) { + serverAddr = cfg.AdditionalServerAddrs[legNumber-1] + } + tlsConfig := &tls.Config{ + RootCAs: cfg.CAPool, + ServerName: strings.Split(serverAddr, ":")[0], + MinVersion: tls.VersionTLS12, + } + if cfg.TLSConfigCustomizer != nil { + cfg.TLSConfigCustomizer(tlsConfig) + } + + conn, err := dialer.DialContext(ctx, "tcp", serverAddr) + if err != nil { + return nil, errSessionDial{serverAddr, err} + } + + conn = tls.Client(conn, tlsConfig) + + sess := muxado.Client(conn, &muxado.Config{}) + return tunnel_client.NewRawSession(logger, sess, heartbeatConfig, callbackHandler), nil + } + + empty := "" + notImplemented := "the agent has not defined a callback for this operation" + + if cfg.StopHandler != nil { + cfg.remoteStopErr = &empty + } + if cfg.RestartHandler != nil { + cfg.remoteRestartErr = &empty + } + if cfg.UpdateHandler != nil { + cfg.remoteUpdateErr = &empty + } + + if cfg.remoteStopErr == nil { + cfg.remoteStopErr = ¬Implemented + } + if cfg.remoteRestartErr == nil { + cfg.remoteRestartErr = ¬Implemented + } + if cfg.remoteUpdateErr == nil { + cfg.remoteUpdateErr = ¬Implemented + } + + cfg.ClientInfo = append( + cfg.ClientInfo, + clientInfo{Type: string(proto.LibraryOfficialGo), Version: strings.TrimSpace(libraryAgentVersion)}, + ) + + userAgent := generateUserAgent(cfg.ClientInfo) + + auth := proto.AuthExtra{ + Version: cfg.ClientInfo[0].Version, + ClientType: proto.ClientType(cfg.ClientInfo[0].Type), + UserAgent: userAgent, + Authtoken: proto.ObfuscatedString(cfg.Authtoken), + Metadata: cfg.Metadata, + OS: runtime.GOOS, + Arch: runtime.GOARCH, + HeartbeatInterval: int64(heartbeatConfig.Interval), + HeartbeatTolerance: int64(heartbeatConfig.Tolerance), + + RestartUnsupportedError: cfg.remoteRestartErr, + StopUnsupportedError: cfg.remoteStopErr, + UpdateUnsupportedError: cfg.remoteUpdateErr, + } + + reconnect := func(sess tunnel_client.Session, raw tunnel_client.RawSession, legNumber uint32) (int, error) { + auth.LegNumber = legNumber + resp, err := sess.Auth(auth) + if err != nil { + remote := false + if resp.Error != "" { + remote = true + } + return 0, errAuthFailed{remote, err} + } + + if resp.Extra.DeprecationWarning != nil { + warning := resp.Extra.DeprecationWarning + vars := make([]any, 0, 3) + if warning.NextMin != "" { + vars = append(vars, "min_version", warning.NextMin) + } + if !warning.NextDate.IsZero() { + vars = append(vars, "deadline", warning.NextDate) + } + if warning.Msg != "" { + vars = append(vars, "extra", warning.Msg) + } + logger.Warn(warning.Error(), vars...) + } + + sessionInner := &sessionInner{ + Session: sess, + Region: resp.Extra.Region, + ProtoVersion: resp.Version, + ServerVersion: resp.Extra.Version, + ClientID: resp.Extra.Region, + AccountName: resp.Extra.AccountName, + PlanName: resp.Extra.PlanName, + Banner: resp.Extra.Banner, + SessionDuration: resp.Extra.SessionDuration, + DeprecationWarning: resp.Extra.DeprecationWarning, + ConnectAddresses: resp.Extra.ConnectAddresses, + Logger: logger, + } + + if legNumber == 0 { + session.setInner(sessionInner) + } + + if cfg.HeartbeatHandler != nil { + // plumb a session with the proper region to the heartbeatHandler + heartbeatSession := new(sessionImpl) + heartbeatSession.setInner(sessionInner) + go func() { + // use the raw latency channel in case this is a multi-leg session + beats := raw.Latency() + for { + select { + case <-ctx.Done(): + return + case latency, ok := <-beats: + if !ok { + return + } + cfg.HeartbeatHandler(ctx, heartbeatSession, latency) + } + } + }() + } + + auth.Cookie = resp.Extra.Cookie + + // store any connect server addresses for use in subsequent legs + if cfg.EnableMultiLeg && legNumber == 0 && len(resp.Extra.ConnectAddresses) > 1 { + overrideAdditionalServers := len(cfg.AdditionalServerAddrs) == 0 + for i, ca := range resp.Extra.ConnectAddresses { + if i == 0 { + if leastLatencyServer.MatchString(cfg.ServerAddr) { + // lock in the leg 0 region + logger.Debug("first leg using region", "region", resp.Extra.Region, "server", ca.ServerAddr) + cfg.ServerAddr = ca.ServerAddr + } + } else if overrideAdditionalServers { + cfg.AdditionalServerAddrs = append(cfg.AdditionalServerAddrs, ca.ServerAddr) + } + } + } + + // if we are using multi-leg, we need to know how many legs to connect + desiredLegs := 1 + if cfg.EnableMultiLeg { + desiredLegs = 1 + len(cfg.AdditionalServerAddrs) + } + return desiredLegs, nil + } + + sess := tunnel_client.NewReconnectingSession(logger, rawDialer, stateChanges, reconnect) + // allow consumers to .Close() the session before a successful connect + session.setInner(&sessionInner{ + Session: sess, + }) + + // performs one "pump" of the session update channel + // returns true if there are more updates to handle + runSessionHandlers := func() (bool, error) { + select { + case <-ctx.Done(): + if cfg.DisconnectHandler != nil { + cfg.DisconnectHandler(ctx, session, ctx.Err()) + logger.Info("no more state changes") + cfg.DisconnectHandler(ctx, session, nil) + } + sess.Close() + return false, ctx.Err() + case err, ok := <-stateChanges: + switch { + case !ok: // session has given up on reconnecting + if cfg.DisconnectHandler != nil { + logger.Info("no more state changes") + cfg.DisconnectHandler(ctx, session, nil) + } + sess.Close() + return false, nil + case err != nil: // session encountered an error + if cfg.DisconnectHandler != nil { + cfg.DisconnectHandler(ctx, session, err) + } + return true, err + case err == nil: // session connected successfully + if cfg.ConnectHandler != nil { + cfg.ConnectHandler(ctx, session) + } + return true, nil + } + } + + panic("inexhaustive case match when handling session state change") + } + + var errs error + for again := true; again; { + var err error + again, err = runSessionHandlers() + switch { + case again && err == nil: // successfully connected, move to goroutine and return + again = false + case again && err != nil: // error on reconnect + errs = multierr.Append(errs, err) + case !again: // gave up trying to reconnect + errs = multierr.Append(errs, err) + return nil, errs + } + } + + go func() { + for again := true; again; again, _ = runSessionHandlers() { + } + }() + + return session, nil +} + +type sessionImpl struct { + raw atomic.Pointer[sessionInner] +} + +type sessionInner struct { + tunnel_client.Session + + Region string + ProtoVersion string + ServerVersion string + ClientID string + AccountName string + PlanName string + Banner string + SessionDuration int64 + DeprecationWarning *proto.AgentVersionDeprecated + ConnectAddresses []proto.ConnectAddress + + Logger log15.Logger +} + +func (s *sessionImpl) inner() *sessionInner { + return s.raw.Load() +} + +func (s *sessionImpl) setInner(raw *sessionInner) { + s.raw.Store(raw) +} + +func (s *sessionImpl) closeTunnel(clientID string, err error) error { + return s.inner().CloseTunnel(clientID, err) +} + +func (s *sessionImpl) Close() error { + return s.inner().Close() +} + +func (s *sessionImpl) Warnings() []error { + deprecated := s.inner().DeprecationWarning + if deprecated != nil { + return []error{deprecated} + } + return nil +} + +func (s *sessionImpl) Listen(ctx context.Context, cfg config.Tunnel) (Tunnel, error) { + var ( + tunnel tunnel_client.Tunnel + err error + ) + tunnelCfg, ok := cfg.(tunnelConfigPrivate) + if !ok { + return nil, errors.New("invalid tunnel config") + } + + extra := tunnelCfg.Extra() + if tunnelCfg.Proto() != "" { + tunnel, err = s.inner().Listen(tunnelCfg.Proto(), tunnelCfg.Opts(), extra, tunnelCfg.ForwardsTo(), tunnelCfg.ForwardsProto()) + } else { + tunnel, err = s.inner().ListenLabel(tunnelCfg.Labels(), extra.Metadata, tunnelCfg.ForwardsTo(), tunnelCfg.ForwardsProto()) + } + + impl := &tunnelImpl{ + Sess: s, + Tunnel: tunnel, + } + + // Legacy support for passing HTTP server via config options. + // TODO: Remove this after we feel HTTP options via config have been deprecated. + if serverCfg, ok := cfg.(interface{ HTTPServer() *http.Server }); ok { + server := serverCfg.HTTPServer() + if server != nil { + go func() { _ = server.Serve(impl) }() + impl.server = server + } + } + + if err == nil { + return impl, nil + } + return nil, errListen{err} +} + +// The rest of the `sessionImpl` methods are non-public, but can be +// interface-asserted if they're *really* needed. These are exempt from any +// stability guarantees and subject to change without notice. + +func (s *sessionImpl) ProtoVersion() string { + return s.inner().ProtoVersion +} +func (s *sessionImpl) ServerVersion() string { + return s.inner().ServerVersion +} +func (s *sessionImpl) ClientID() string { + return s.inner().ClientID +} +func (s *sessionImpl) AccountName() string { + return s.inner().AccountName +} +func (s *sessionImpl) PlanName() string { + return s.inner().PlanName +} +func (s *sessionImpl) Banner() string { + return s.inner().Banner +} +func (s *sessionImpl) SessionDuration() int64 { + return s.inner().SessionDuration +} +func (s *sessionImpl) Region() string { + return s.inner().Region +} +func (s *sessionImpl) Heartbeat() (time.Duration, error) { + return s.inner().Heartbeat() +} +func (s *sessionImpl) Latency() <-chan time.Duration { + return s.inner().Latency() +} +func (s *sessionImpl) ConnectAddresses() []struct{ Region, ServerAddr string } { + connectAddresses := make([]struct{ Region, ServerAddr string }, len(s.inner().ConnectAddresses)) + for i, addr := range s.inner().ConnectAddresses { + connectAddresses[i] = struct{ Region, ServerAddr string }{addr.Region, addr.ServerAddr} + } + return connectAddresses +} + +type remoteCallbackHandler struct { + log15.Logger + sess *sessionImpl + stopHandler ServerCommandHandler + restartHandler ServerCommandHandler + updateHandler ServerCommandHandler +} + +func (rc remoteCallbackHandler) OnStop(_ *proto.Stop, respond tunnel_client.HandlerRespFunc) { + if rc.stopHandler != nil { + resp := new(proto.StopResp) + close := true + if err := rc.stopHandler(context.TODO(), rc.sess); err != nil { + close = false + resp.Error = err.Error() + } + if err := respond(resp); err != nil { + rc.Warn("error responding to stop request", "error", err) + } + if close { + _ = rc.sess.Close() + } + } +} + +func (rc remoteCallbackHandler) OnRestart(_ *proto.Restart, respond tunnel_client.HandlerRespFunc) { + if rc.restartHandler != nil { + resp := new(proto.RestartResp) + close := true + if err := rc.restartHandler(context.TODO(), rc.sess); err != nil { + close = false + resp.Error = err.Error() + } + if err := respond(resp); err != nil { + rc.Warn("error responding to restart request", "error", err) + } + if close { + _ = rc.sess.Close() + } + } +} + +func (rc remoteCallbackHandler) OnUpdate(_ *proto.Update, respond tunnel_client.HandlerRespFunc) { + if rc.updateHandler != nil { + resp := new(proto.UpdateResp) + if err := rc.updateHandler(context.TODO(), rc.sess); err != nil { + resp.Error = err.Error() + } + if err := respond(resp); err != nil { + rc.Warn("error responding to restart request", "error", err) + } + } +} + +func (rc remoteCallbackHandler) OnStopTunnel(stopTunnel *proto.StopTunnel, respond tunnel_client.HandlerRespFunc) { + ngrokErr := &ngrokError{Message: stopTunnel.Message, ErrCode: stopTunnel.ErrorCode} + // close the tunnel and maintain the session + err := rc.sess.closeTunnel(stopTunnel.ClientID, ngrokErr) + if err != nil { + rc.Warn("error closing tunnel", "error", err) + } +} diff --git a/session_test.go b/internal/legacy/session_test.go similarity index 97% rename from session_test.go rename to internal/legacy/session_test.go index 409c6dcf..6d450eae 100644 --- a/session_test.go +++ b/internal/legacy/session_test.go @@ -1,4 +1,4 @@ -package ngrok +package legacy import ( "testing" diff --git a/tunnel.go b/internal/legacy/tunnel.go similarity index 54% rename from tunnel.go rename to internal/legacy/tunnel.go index ca7a6387..6e37a182 100644 --- a/tunnel.go +++ b/internal/legacy/tunnel.go @@ -1,15 +1,12 @@ -package ngrok +package legacy import ( "context" "net" "net/http" - "net/url" "time" - "golang.ngrok.com/ngrok/config" - tunnel_client "golang.ngrok.com/ngrok/internal/tunnel/client" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + tunnel_client "golang.ngrok.com/ngrok/v2/internal/tunnel/client" ) // Tunnel is a [net.Listener] created by a call to [Listen] or @@ -60,86 +57,6 @@ type TunnelInfo interface { URL() string } -// Listen creates a new [Tunnel] after connecting a new [Session]. This is a -// shortcut for calling [Connect] then [Session].Listen. -// -// Access to the underlying [Session] that was started automatically can be -// accessed via [Tunnel].Session. -// -// If an error is encountered during [Session].Listen, the [Session] object that -// was created will be closed automatically. -func Listen(ctx context.Context, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Tunnel, error) { - sess, err := Connect(ctx, connectOpts...) - if err != nil { - return nil, err - } - tunnel, err := sess.Listen(ctx, tunnelConfig) - if err != nil { - _ = sess.Close() - return nil, err - } - return tunnel, nil -} - -// ListenAndForward creates a new [Forwarder] after connecting a new [Session], and -// then forwards all connections to the provided URL. -// This is a shortcut for calling [Connect] then [Session].ListenAndForward. -// -// Access to the underlying [Session] that was started automatically can be -// accessed via [Forwarder].Session. -// -// If an error is encountered during [Session].ListenAndForward, the [Session] -// object that was created will be closed automatically. -func ListenAndForward(ctx context.Context, backend *url.URL, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Forwarder, error) { - sess, err := Connect(ctx, connectOpts...) - if err != nil { - return nil, err - } - fwd, err := sess.ListenAndForward(ctx, backend, tunnelConfig) - if err != nil { - _ = sess.Close() - return nil, err - } - - return fwd, nil -} - -// ListenAndServeHTTP creates a new [Forwarder] after connecting a new [Session], and -// then forwards all connections to the provided HTTP server. -// This is a shortcut for calling [Connect] then [Session].ListenAndForward. -// -// Access to the underlying [Session] that was started automatically can be -// accessed via [Tunnel].Session. -// -// If an error is encountered during [Session].ListenAndServeHTTP, the [Session] -// object that was created will be closed automatically. -func ListenAndServeHTTP(ctx context.Context, server *http.Server, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Forwarder, error) { - sess, err := Connect(ctx, connectOpts...) - if err != nil { - return nil, err - } - - forwarder, err := sess.ListenAndServeHTTP(ctx, tunnelConfig, server) - if err != nil { - _ = sess.Close() - return nil, err - } - - return forwarder, nil -} - -// ListenAndHandleHTTP creates a new [Forwarder] after connecting a new [Session], and -// then forwards all connections to a new HTTP server and handles them with the provided HTTP handler. -// -// Access to the underlying [Session] that was started automatically can be -// accessed via [Tunnel].Session. -// -// If an error is encountered during [Session].ListenAndHandleHTTP, the [Session] -// object that was created will be closed automatically. -func ListenAndHandleHTTP(ctx context.Context, handler *http.Handler, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Forwarder, error) { - return ListenAndServeHTTP(ctx, &http.Server{Handler: *handler}, tunnelConfig, connectOpts...) -} - type tunnelImpl struct { Sess Session Tunnel tunnel_client.Tunnel @@ -231,24 +148,12 @@ type Conn interface { net.Conn // Proto returns the tunnel protocol (http, https, tls, or tcp) for this connection. Proto() string - // EdgeType returns the type of the edge (https, tls, or tcp) that matched this tunnel. - EdgeType() EdgeType + // PassthroughTLS returns whether this connection contains an end-to-end tls // connection. PassthroughTLS() bool } -// EdgeType is the type of the edge (https, tls, or tcp) for this tunnel. -type EdgeType proto.EdgeType - -// All possible edge types. Currently only https, tls, and tcp are supported. -const ( - EdgeTypeUndefined EdgeType = 0 - EdgeTypeTCP EdgeType = 1 - EdgeTypeTLS EdgeType = 2 - EdgeTypeHTTPS EdgeType = 3 -) - type connImpl struct { net.Conn Proxy *tunnel_client.ProxyConn @@ -265,11 +170,6 @@ func (c *connImpl) Proto() string { return c.Proxy.Header.Proto } -func (c *connImpl) EdgeType() EdgeType { - et, _ := proto.ParseEdgeType(c.Proxy.Header.EdgeType) - return EdgeType(et) -} - func (c *connImpl) PassthroughTLS() bool { return c.Proxy.Header.PassthroughTLS } diff --git a/tunnel_config.go b/internal/legacy/tunnel_config.go similarity index 89% rename from tunnel_config.go rename to internal/legacy/tunnel_config.go index 2ed8c51a..dca09686 100644 --- a/tunnel_config.go +++ b/internal/legacy/tunnel_config.go @@ -1,9 +1,9 @@ -package ngrok +package legacy import ( "net/url" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) // This is the internal-only interface that all config.Tunnel implementations diff --git a/internal/pb/middleware.proto b/internal/pb/middleware.proto index 65f8c3ab..1fabb6ec 100644 --- a/internal/pb/middleware.proto +++ b/internal/pb/middleware.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package agent_internal; -option go_package = "golang.ngrok.com/ngrok/internal/pb"; +option go_package = "golang.ngrok.com/ngrok/v2/internal/pb"; message MiddlewareConfiguration { Compression compression = 1; diff --git a/internal/testutil/sync.go b/internal/testutil/sync.go new file mode 100644 index 00000000..840716f1 --- /dev/null +++ b/internal/testutil/sync.go @@ -0,0 +1,95 @@ +package testutil + +import ( + "sync" + "testing" + "time" +) + +// SyncPoint coordinates test execution points +type SyncPoint struct { + ch chan struct{} + called bool + mu sync.Mutex +} + +// NewSyncPoint creates a new synchronization point +func NewSyncPoint() *SyncPoint { + return &SyncPoint{ + ch: make(chan struct{}), + } +} + +// Signal marks the sync point as reached +func (s *SyncPoint) Signal() { + s.mu.Lock() + defer s.mu.Unlock() + if !s.called { + close(s.ch) + s.called = true + } +} + +// Wait blocks until the sync point is signaled or times out +func (s *SyncPoint) Wait(t testing.TB) { + t.Helper() + select { + case <-s.ch: + return + case <-time.After(5 * time.Second): // Safety timeout + t.Fatal("timeout waiting for sync point") + } +} + +// WaitTimeout waits for the sync point with a custom timeout +func (s *SyncPoint) WaitTimeout(t testing.TB, timeout time.Duration) bool { + t.Helper() + select { + case <-s.ch: + return true + case <-time.After(timeout): + return false + } +} + +// WaitGroup is a wrapper around sync.WaitGroup with timeouts +type WaitGroup struct { + wg sync.WaitGroup + done chan struct{} + started bool +} + +// NewWaitGroup creates a new wait group with timeout capability +func NewWaitGroup() *WaitGroup { + return &WaitGroup{ + done: make(chan struct{}), + } +} + +// Add adds delta to the WaitGroup counter +func (w *WaitGroup) Add(delta int) { + w.wg.Add(delta) + if !w.started { + w.started = true + go func() { + w.wg.Wait() + close(w.done) + }() + } +} + +// Done decrements the WaitGroup counter +func (w *WaitGroup) Done() { + w.wg.Done() +} + +// Wait waits for the WaitGroup counter to be zero +func (w *WaitGroup) Wait(t testing.TB) { + t.Helper() + select { + case <-w.done: + return + case <-time.After(10 * time.Second): // Safety timeout + t.Fatal("timeout waiting for wait group") + } +} diff --git a/internal/tunnel/client/raw_session.go b/internal/tunnel/client/raw_session.go index 7499d989..6df1478e 100644 --- a/internal/tunnel/client/raw_session.go +++ b/internal/tunnel/client/raw_session.go @@ -10,8 +10,8 @@ import ( "time" "golang.ngrok.com/muxado/v2" - "golang.ngrok.com/ngrok/internal/tunnel/netx" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/netx" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" log "github.com/inconshreveable/log15/v3" logext "github.com/inconshreveable/log15/v3/ext" diff --git a/internal/tunnel/client/raw_session_test.go b/internal/tunnel/client/raw_session_test.go index 8344161b..c753ee7f 100644 --- a/internal/tunnel/client/raw_session_test.go +++ b/internal/tunnel/client/raw_session_test.go @@ -46,7 +46,9 @@ testloop: } ctx, cancel := context.WithCancel(ctx) - r := NewRawSession(log15.New(), muxado.Client(&dummyStream{}, nil), nil, nil) + logger := log15.New() + logger.SetHandler(log15.LvlFilterHandler(log15.LvlError, log15.StdoutHandler)) + r := NewRawSession(logger, muxado.Client(&dummyStream{}, nil), nil, nil) wg := sync.WaitGroup{} wg.Add(1) diff --git a/internal/tunnel/client/reconnecting.go b/internal/tunnel/client/reconnecting.go index d2c45a2c..f1cf812f 100644 --- a/internal/tunnel/client/reconnecting.go +++ b/internal/tunnel/client/reconnecting.go @@ -9,8 +9,8 @@ import ( log "github.com/inconshreveable/log15/v3" "github.com/jpillora/backoff" - "golang.ngrok.com/ngrok/internal/tunnel/netx" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/netx" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) var ErrSessionNotReady = errors.New("an ngrok tunnel session has not yet been established") diff --git a/internal/tunnel/client/session.go b/internal/tunnel/client/session.go index d44231f2..b9e7b615 100644 --- a/internal/tunnel/client/session.go +++ b/internal/tunnel/client/session.go @@ -9,8 +9,8 @@ import ( "sync" "time" - "golang.ngrok.com/ngrok/internal/tunnel/netx" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/netx" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" log "github.com/inconshreveable/log15/v3" diff --git a/internal/tunnel/client/tunnel.go b/internal/tunnel/client/tunnel.go index 6e5ccfbf..729f3094 100644 --- a/internal/tunnel/client/tunnel.go +++ b/internal/tunnel/client/tunnel.go @@ -6,7 +6,7 @@ import ( "net/url" "sync/atomic" - "golang.ngrok.com/ngrok/internal/tunnel/proto" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) type Tunnel interface { diff --git a/internal/tunnel/proto/msg.go b/internal/tunnel/proto/msg.go index 1b70056b..4373ec76 100644 --- a/internal/tunnel/proto/msg.go +++ b/internal/tunnel/proto/msg.go @@ -7,7 +7,7 @@ import ( "time" "golang.ngrok.com/muxado/v2" - "golang.ngrok.com/ngrok/internal/pb" + "golang.ngrok.com/ngrok/v2/internal/pb" ) type ReqType muxado.StreamType diff --git a/listener.go b/listener.go new file mode 100644 index 00000000..97651db3 --- /dev/null +++ b/listener.go @@ -0,0 +1,75 @@ +package ngrok + +import ( + "context" + "crypto/tls" + "net" + + "golang.ngrok.com/ngrok/v2/internal/legacy" +) + +// EndpointListener is an endpoint that you may treat as a net.Listener. +type EndpointListener interface { + Endpoint + + // Accept returns the next connection received the Endpoint. + Accept() (net.Conn, error) + + // Addr() returns where the Endpoint is listening. + Addr() net.Addr +} + +// endpointListener implements the EndpointListener interface. +type endpointListener struct { + baseEndpoint + tunnel legacy.Tunnel +} + +// wrapConnWithTLS is a wrapper around a net.Conn that performs TLS termination +// without immediately performing the handshake +func wrapConnWithTLS(conn net.Conn, tlsConfig *tls.Config) net.Conn { + if tlsConfig == nil { + return conn + } + + // Create a TLS server connection without performing handshake + // The handshake will happen when the client first reads or writes + return tls.Server(conn, tlsConfig) +} + +func (e *endpointListener) Accept() (net.Conn, error) { + // Accept connection from the tunnel + conn, err := e.tunnel.Accept() + if err != nil { + return nil, wrapError(err) + } + + // Apply TLS termination if a config is provided + if e.agentTLSConfig != nil { + // Wrap the connection with TLS without performing handshake + return wrapConnWithTLS(conn, e.agentTLSConfig), nil + } + + // Return the raw connection if no TLS certificate is provided + return conn, nil +} + +func (e *endpointListener) Addr() net.Addr { + return e.tunnel.Addr() +} + +func (e *endpointListener) Close() error { + return e.CloseWithContext(context.Background()) +} + +func (e *endpointListener) CloseWithContext(ctx context.Context) error { + err := e.tunnel.CloseWithContext(ctx) + e.signalDone() + + // Remove from agent + if a, ok := e.agent.(*agent); ok { + a.removeEndpoint(e) + } + + return wrapError(err) +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 00000000..9db4823c --- /dev/null +++ b/listener_test.go @@ -0,0 +1,114 @@ +package ngrok + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "testing" + "time" +) + +// createTestCertificate creates a self-signed certificate for testing +func createTestCertificate(t *testing.T) *tls.Certificate { + // Generate a private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate private key: %v", err) + } + + // Create a certificate template + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), // Valid for 24 hours + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Create a self-signed certificate + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("Failed to create certificate: %v", err) + } + + // Create a TLS certificate + cert := &tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: privateKey, + Leaf: &template, + } + + return cert +} + +// TestWrapConnWithTLS tests the TLS connection wrapper +func TestWrapConnWithTLS(t *testing.T) { + // Create a test certificate + cert := createTestCertificate(t) + + // Create a pipe for testing + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + // Apply TLS in a separate goroutine + go func() { + // Configure TLS client + config := &tls.Config{ + InsecureSkipVerify: true, // Skip verification for test + } + + // Create TLS client connection + clientTLS := tls.Client(clientConn, config) + + // Perform handshake - this will happen when we first use the connection + _, err := clientTLS.Write([]byte("hello")) + if err != nil { + t.Errorf("Failed to write to TLS connection: %v", err) + return + } + + // Close the client connection + clientTLS.Close() + }() + + // Wrap the server connection with TLS + config := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + serverTLS := wrapConnWithTLS(serverConn, config) + + // Read the test data - this will trigger the handshake on first use + buf := make([]byte, 10) + n, err := serverTLS.Read(buf) + if err != nil { + t.Fatalf("Failed to read from TLS connection: %v", err) + } + + // Verify the data + if string(buf[:n]) != "hello" { + t.Fatalf("Unexpected data: %s", string(buf[:n])) + } +} + +// TestWrapConnWithTLSNil tests that wrapConnWithTLS returns the original connection when no certificate is provided +func TestWrapConnWithTLSNil(t *testing.T) { + // Create a pipe for testing + conn1, conn2 := net.Pipe() + defer conn1.Close() + defer conn2.Close() + + // Call wrapConnWithTLS with nil config + result := wrapConnWithTLS(conn1, nil) + + // Verify that the original connection was returned + if result != conn1 { + t.Fatalf("wrapConnWithTLS didn't return the original connection") + } +} diff --git a/log/interface.go b/log/interface.go deleted file mode 100644 index 0efc93b8..00000000 --- a/log/interface.go +++ /dev/null @@ -1,74 +0,0 @@ -package log - -import ( - "context" - "fmt" -) - -type ErrInvalidLogLevel struct { - Level any -} - -func (e ErrInvalidLogLevel) Error() string { - return fmt.Sprintf("invalid log level: %v", e.Level) -} - -type LogLevel = int - -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -// Logging interface, heavily inspired by github.com/jackc/pgx's logger. -// The primary difference is that `LogLevel` is a type alias rather than a -// newtype. This makes it easier for other libraries to support the interface (in theory), -// as they don't need to depend on this package directly. -// -// Adapters are provided for pgx and log15. -type Logger interface { - // Log a message at the given level with data key/value pairs. data may be nil. - Log(context context.Context, level LogLevel, msg string, data map[string]interface{}) -} - -func StringFromLogLevel(lvl LogLevel) (string, error) { - switch lvl { - case LogLevelTrace: - return "trace", nil - case LogLevelDebug: - return "debug", nil - case LogLevelInfo: - return "info", nil - case LogLevelWarn: - return "warn", nil - case LogLevelError: - return "error", nil - case LogLevelNone: - return "none", nil - default: - return "invalid", ErrInvalidLogLevel{lvl} - } -} - -func LogLevelFromString(s string) (LogLevel, error) { - switch s { - case "trace": - return LogLevelTrace, nil - case "debug": - return LogLevelDebug, nil - case "info": - return LogLevelInfo, nil - case "warn": - return LogLevelWarn, nil - case "error": - return LogLevelError, nil - case "none": - return LogLevelNone, nil - default: - return 0, ErrInvalidLogLevel{s} - } -} diff --git a/log/log15/LICENSE.txt b/log/log15/LICENSE.txt deleted file mode 100644 index 7fcd822a..00000000 --- a/log/log15/LICENSE.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright 2022 ngrok, Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/log/log15/adapter.go b/log/log15/adapter.go deleted file mode 100644 index 5d27f5df..00000000 --- a/log/log15/adapter.go +++ /dev/null @@ -1,57 +0,0 @@ -// Package log15 provides a logger that writes to a -// github.com/inconshreveable/log15.Logger and implements the -// golang.ngrok.com/ngrok/log.Logger interface. -// -// Adapted from the github.com/jackc/pgx log15 adapter. -package log15 - -import ( - "context" - - "github.com/inconshreveable/log15/v3" -) - -type LogLevel = int - -// Log level constants matching the ones in golang.ngrok.com/ngrok/log -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -// Wrapper for a log15.Logger to add the ngrok logging interface. -// Also exposes the log15.Logger interface directly so that it can be downcast -// to the log15.Logger. -type Logger struct { - log15.Logger -} - -func NewLogger(l log15.Logger) *Logger { - return &Logger{l} -} - -func (l *Logger) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, len(data)) - for k, v := range data { - logArgs = append(logArgs, k, v) - } - - switch level { - case LogLevelTrace: - l.Debug(msg, append(logArgs, "LOG_LEVEL", level)...) - case LogLevelDebug: - l.Debug(msg, logArgs...) - case LogLevelInfo: - l.Info(msg, logArgs...) - case LogLevelWarn: - l.Warn(msg, logArgs...) - case LogLevelError: - l.Error(msg, logArgs...) - default: - l.Error(msg, append(logArgs, "INVALID_LOG_LEVEL", level)...) - } -} diff --git a/log/log15/go.mod b/log/log15/go.mod deleted file mode 100644 index 3be1db48..00000000 --- a/log/log15/go.mod +++ /dev/null @@ -1,12 +0,0 @@ -module golang.ngrok.com/ngrok/log/log15 - -go 1.21 - -require github.com/inconshreveable/log15/v3 v3.0.0-testing.1 - -require ( - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.16 // indirect - golang.org/x/sys v0.2.0 // indirect - golang.org/x/term v0.2.0 // indirect -) diff --git a/log/log15/go.sum b/log/log15/go.sum deleted file mode 100644 index 039c9564..00000000 --- a/log/log15/go.sum +++ /dev/null @@ -1,11 +0,0 @@ -github.com/inconshreveable/log15/v3 v3.0.0-testing.1 h1:z0j/cq0uyipqKQU6RxLmK0yLhebUG91QcQ/yXxlurRI= -github.com/inconshreveable/log15/v3 v3.0.0-testing.1/go.mod h1:6ilUJWAP6U5r72HEvjhxyHiEx7OD5EwLXEDrOahAzIE= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.2.0 h1:z85xZCsEl7bi/KwbNADeBYoOP0++7W1ipu+aGnpwzRM= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= diff --git a/log/logrus/LICENSE.txt b/log/logrus/LICENSE.txt deleted file mode 100644 index 7fcd822a..00000000 --- a/log/logrus/LICENSE.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright 2022 ngrok, Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/log/logrus/adapter.go b/log/logrus/adapter.go deleted file mode 100644 index 679287ef..00000000 --- a/log/logrus/adapter.go +++ /dev/null @@ -1,56 +0,0 @@ -// Package logrus provides a logger that writes to a -// github.com/sirupsen/logrus.Logger and implements the -// golang.ngrok.com/ngrok/log.Logger interface. -// -// Adapted from the github.com/jackc/pgx logrus adapter. -package logrus - -import ( - "context" - - "github.com/sirupsen/logrus" -) - -type LogLevel = int - -// Log level constants matching the ones in golang.ngrok.com/ngrok/log -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -type Logger struct { - l logrus.FieldLogger -} - -func NewLogger(l logrus.FieldLogger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { - var logger logrus.FieldLogger - if data != nil { - logger = l.l.WithFields(data) - } else { - logger = l.l - } - - switch level { - case LogLevelTrace: - logger.WithField("LOG_LEVEL", level).Debug(msg) - case LogLevelDebug: - logger.Debug(msg) - case LogLevelInfo: - logger.Info(msg) - case LogLevelWarn: - logger.Warn(msg) - case LogLevelError: - logger.Error(msg) - default: - logger.WithField("INVALID_LOG_LEVEL", level).Error(msg) - } -} diff --git a/log/logrus/go.mod b/log/logrus/go.mod deleted file mode 100644 index 2b773716..00000000 --- a/log/logrus/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module golang.ngrok.com/ngrok/log/logrus - -go 1.21 - -require github.com/sirupsen/logrus v1.9.0 - -require ( - github.com/stretchr/testify v1.8.0 // indirect - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect -) diff --git a/log/logrus/go.sum b/log/logrus/go.sum deleted file mode 100644 index b13c28b8..00000000 --- a/log/logrus/go.sum +++ /dev/null @@ -1,20 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/log/slog/LICENSE.txt b/log/slog/LICENSE.txt deleted file mode 100644 index 7fcd822a..00000000 --- a/log/slog/LICENSE.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright 2022 ngrok, Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/log/slog/adapter.go b/log/slog/adapter.go deleted file mode 100644 index 76cc395a..00000000 --- a/log/slog/adapter.go +++ /dev/null @@ -1,55 +0,0 @@ -// Package slog provides a logger that writes -// to a log/slog.Logger and implements the -// golang.ngrok.com/ngrok/log.Logger interface. -package slog - -import ( - "context" - - "log/slog" -) - -type LogLevel = int - -// Log level constants matching the ones in golang.ngrok.com/ngrok/log -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -// Wrapper for a slog.Logger to add the ngrok logging interface. -// Also exposes the slog.Logger interface directly so that it can be downcast -// to the slog.Logger. -type Logger struct { - inner *slog.Logger -} - -func NewLogger(l *slog.Logger) *Logger { - return &Logger{l} -} - -func (l *Logger) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, len(data)) - for k, v := range data { - logArgs = append(logArgs, k, v) - } - - switch level { - case LogLevelTrace: - l.inner.Debug(msg, append(logArgs, "LOG_LEVEL", level)...) - case LogLevelDebug: - l.inner.Debug(msg, logArgs...) - case LogLevelInfo: - l.inner.Info(msg, logArgs...) - case LogLevelWarn: - l.inner.Warn(msg, logArgs...) - case LogLevelError: - l.inner.Error(msg, logArgs...) - default: - l.inner.Error(msg, append(logArgs, "INVALID_LOG_LEVEL", level)...) - } -} diff --git a/log/slog/go.mod b/log/slog/go.mod deleted file mode 100644 index d29c593b..00000000 --- a/log/slog/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module golang.ngrok.com/ngrok/log/slog - -go 1.21 diff --git a/log/slog/go.sum b/log/slog/go.sum deleted file mode 100644 index e69de29b..00000000 diff --git a/log/zap/LICENSE.txt b/log/zap/LICENSE.txt deleted file mode 100644 index 7fcd822a..00000000 --- a/log/zap/LICENSE.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright 2022 ngrok, Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/log/zap/adapter.go b/log/zap/adapter.go deleted file mode 100644 index b61002a4..00000000 --- a/log/zap/adapter.go +++ /dev/null @@ -1,56 +0,0 @@ -// Package zap provides a logger that writes to a go.uber.org/zap.Logger and -// implements the golang.ngrok.com/ngrok/log.Logger interface. -// -// Adapted from the github.com/jackc/pgx zap adapter. -package zap - -import ( - "context" - - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type LogLevel = int - -// Log level constants matching the ones in golang.ngrok.com/ngrok/log -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -type Logger struct { - logger *zap.Logger -} - -func NewLogger(logger *zap.Logger) *Logger { - return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} -} - -func (pl *Logger) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { - fields := make([]zapcore.Field, len(data)) - i := 0 - for k, v := range data { - fields[i] = zap.Any(k, v) - i++ - } - - switch level { - case LogLevelTrace: - pl.logger.Debug(msg, append(fields, zap.Any("LOG_LEVEL", level))...) - case LogLevelDebug: - pl.logger.Debug(msg, fields...) - case LogLevelInfo: - pl.logger.Info(msg, fields...) - case LogLevelWarn: - pl.logger.Warn(msg, fields...) - case LogLevelError: - pl.logger.Error(msg, fields...) - default: - pl.logger.Error(msg, append(fields, zap.Any("INVALID_LOG_LEVEL", level))...) - } -} diff --git a/log/zap/go.mod b/log/zap/go.mod deleted file mode 100644 index b74fefdd..00000000 --- a/log/zap/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module golang.ngrok.com/ngrok/log/zap - -go 1.21 - -require go.uber.org/zap v1.23.0 - -require ( - go.uber.org/atomic v1.7.0 // indirect - go.uber.org/multierr v1.6.0 // indirect -) diff --git a/log/zap/go.sum b/log/zap/go.sum deleted file mode 100644 index 1772f3d4..00000000 --- a/log/zap/go.sum +++ /dev/null @@ -1,18 +0,0 @@ -github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= -go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= -go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/logging.go b/logging.go deleted file mode 100644 index ceeb0081..00000000 --- a/logging.go +++ /dev/null @@ -1,78 +0,0 @@ -package ngrok - -import ( - "context" - "fmt" - - "github.com/inconshreveable/log15/v3" - - "golang.ngrok.com/ngrok/log" -) - -type log15Handler struct { - log.Logger -} - -// The internals all use log15, so we need to convert the public logging -// interface to log15. -// If the provided Logger also implements the log15 interface, downcast and use -// that instead of wrapping again. This is the case for the Logger constructed -// by the log15adapter module. -// Otherwise, a new log15.Logger is constructed and the provided Logger used as -// its Handler. -func toLog15(l log.Logger) log15.Logger { - if logger, ok := l.(log15.Logger); ok { - return logger - } - - logger := log15.New() - logger.SetHandler(&log15Handler{l}) - - return logger -} - -func (l *log15Handler) Log(r log15.Record) error { - lvl := log.LogLevelNone - switch r.Lvl { - case log15.LvlCrit: - lvl = log.LogLevelError - case log15.LvlError: - lvl = log.LogLevelError - case log15.LvlWarn: - lvl = log.LogLevelWarn - case log15.LvlInfo: - lvl = log.LogLevelInfo - case log15.LvlDebug: - lvl = log.LogLevelDebug - case log15.LvlDebug + 1: - // Also support trace, if someone happens to hack - // it in. - lvl = log.LogLevelTrace - } - - data := make(map[string]interface{}, len(r.Ctx)/2) - for i := 0; i < len(r.Ctx); i += 2 { - var ( - k string - ok bool - v interface{} - ) - // The default upstream log15 formatter chooses to treat non-strings as - // errors. We'll be a bit nicer and Sprint it instead if we find one. - k, ok = r.Ctx[i].(string) - if !ok { - k = fmt.Sprint(r.Ctx[i]) - } - // I think log15 guarantees an even number of context values, but just - // in case. - if len(r.Ctx) > i+1 { - v = r.Ctx[i+1] - } else { - v = "MISSING_VALUE" - } - data[k] = v - } - - l.Logger.Log(context.Background(), lvl, r.Msg, data) - return nil -} diff --git a/online_test.go b/online_test.go deleted file mode 100644 index 13590c3d..00000000 --- a/online_test.go +++ /dev/null @@ -1,798 +0,0 @@ -package ngrok - -import ( - "bufio" - "compress/gzip" - "context" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "math/rand" - "net" - "net/http" - "net/url" - "os" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - "golang.org/x/net/websocket" - - "golang.ngrok.com/ngrok/config" - "golang.ngrok.com/ngrok/log" -) - -type testLogger struct { - t *testing.T - testName string -} - -func newTestLogger(t *testing.T) *testLogger { - return &testLogger{ - t: t, - testName: t.Name(), - } -} - -func (tl *testLogger) Log(context context.Context, level log.LogLevel, msg string, data map[string]interface{}) { - cpy := map[string]any{} - for k, v := range data { - cpy[k] = v - } - cpy["test"] = tl.testName - lvl, err := log.StringFromLogLevel(level) - if err != nil { - lvl = "UKWN" - } - lvl = strings.ToUpper(lvl) - tl.t.Logf("%s [%s] %s %v", time.Now().Format(time.RFC3339), lvl, msg, cpy) -} - -func expectChanError(t *testing.T, ch <-chan error, timeout time.Duration) { - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case err := <-ch: - require.Error(t, err) - case <-timer.C: - t.Error("timeout while waiting on error channel") - } -} - -func skipUnless(t *testing.T, varname string, message ...any) { - if os.Getenv(varname) == "" && os.Getenv("NGROK_TEST_ALL") == "" { - t.Skip(message...) - } -} - -func onlineTest(t *testing.T) { - skipUnless(t, "NGROK_TEST_ONLINE", "Skipping online test") - // This is an annoying quirk of the free account limitations. It looks like - // the tests run quickly enough in series that they trigger simultaneous - // session errors for free accounts. "Something something eventual - // consistency" most likely. - require.NotEmpty(t, os.Getenv("NGROK_AUTHTOKEN"), "Online tests require an authtoken.") -} - -func setupSession(ctx context.Context, t *testing.T, opts ...ConnectOption) Session { - onlineTest(t) - opts = append(opts, WithAuthtokenFromEnv(), WithLogger(newTestLogger(t))) - sess, err := Connect(ctx, opts...) - require.NoError(t, err, "Session Connect") - return sess -} - -func startTunnel(ctx context.Context, t *testing.T, sess Session, opts config.Tunnel) Tunnel { - onlineTest(t) - tun, err := sess.Listen(ctx, opts) - require.NoError(t, err, "Listen") - return tun -} - -var helloHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, _ = io.ReadAll(r.Body) - _ = r.Body.Close() - _, _ = fmt.Fprintln(rw, "Hello, world!") -}) - -func serveHTTP(ctx context.Context, t *testing.T, connectOpts []ConnectOption, opts config.Tunnel, handler http.Handler) (Tunnel, <-chan error) { - sess := setupSession(ctx, t, connectOpts...) - - tun := startTunnel(ctx, t, sess, opts) - exited := make(chan error) - - go func() { - exited <- http.Serve(tun, handler) - sess.Close() - }() - return tun, exited -} - -func TestListen(t *testing.T) { - onlineTest(t) - tun, err := Listen(context.Background(), - config.HTTPEndpoint(), - WithAuthtokenFromEnv(), - WithLogger(newTestLogger(t)), - ) - require.NoError(t, err, "Session Connect") - tun.Close() -} - -func TestTunnel(t *testing.T) { - ctx := context.Background() - sess := setupSession(ctx, t) - - tun := startTunnel(ctx, t, sess, config.HTTPEndpoint( - config.WithMetadata("Hello, world!"), - config.WithForwardsTo("some application"))) - - require.NotEmpty(t, tun.URL(), "Tunnel URL") - require.Equal(t, "Hello, world!", tun.Metadata()) - require.Equal(t, "some application", tun.ForwardsTo()) - tun.Close() - sess.Close() -} - -func TestTunnelConnMetadata(t *testing.T) { - ctx := context.Background() - sess := setupSession(ctx, t) - - tun := startTunnel(ctx, t, sess, config.HTTPEndpoint()) - - go func() { - resp, _ := http.Get(tun.URL()) - if resp != nil { - _ = resp.Body.Close() - } - }() - - conn, err := tun.Accept() - require.NoError(t, err) - - proxyconn, ok := conn.(Conn) - require.True(t, ok, "conn doesn't implement proxy conn interface") - - require.Equal(t, "https", proxyconn.Proto()) - require.Equal(t, EdgeTypeUndefined, proxyconn.EdgeType()) - tun.Close() - sess.Close() -} - -func TestWithHTTPHandler(t *testing.T) { - ctx := context.Background() - tun, _ := serveHTTP(ctx, t, nil, config.HTTPEndpoint( - config.WithMetadata("Hello, world!"), - config.WithForwardsTo("some application"), - ), helloHandler) - - resp, err := http.Get(tun.URL()) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NotNil(t, resp.TLS, "TLS established") - - // Closing the tunnel should be fine - require.NoError(t, tun.CloseWithContext(ctx)) -} - -func TestHTTPS(t *testing.T) { - ctx := context.Background() - tun, exited := serveHTTP(ctx, t, nil, - config.HTTPEndpoint(), - helloHandler, - ) - - resp, err := http.Get(tun.URL()) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NotNil(t, resp.TLS, "TLS established") - - // Closing the tunnel should be fine - require.NoError(t, tun.CloseWithContext(ctx)) - - // The http server should exit with a "closed" error - expectChanError(t, exited, 5*time.Second) -} - -func TestHTTP(t *testing.T) { - ctx := context.Background() - tun, exited := serveHTTP(ctx, t, nil, - config.HTTPEndpoint( - config.WithScheme(config.SchemeHTTP)), - helloHandler, - ) - - resp, err := http.Get(tun.URL()) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.Nil(t, resp.TLS, "No TLS") - - // Closing the tunnel should be fine - require.NoError(t, tun.CloseWithContext(ctx)) - - // The http server should exit with a "closed" error - expectChanError(t, exited, 5*time.Second) -} - -func TestHTTPCompression(t *testing.T) { - onlineTest(t) - ctx := context.Background() - opts := config.HTTPEndpoint(config.WithCompression()) - tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) - - req, err := http.NewRequest(http.MethodGet, tun.URL(), nil) - require.NoError(t, err, "Create request") - req.Header.Add("Accept-Encoding", "gzip") - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - require.Equal(t, http.StatusOK, resp.StatusCode) - - gzReader, err := gzip.NewReader(resp.Body) - require.NoError(t, err, "gzip reader") - - body, err := io.ReadAll(gzReader) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -// *testing.T wrapper to force `require` to Fail() then panic() rather than -// FailNow(). Permits better flow control in test functions. -type failPanic struct { - t *testing.T -} - -func (f failPanic) Errorf(format string, args ...interface{}) { - f.t.Errorf(format, args...) -} - -func (f failPanic) FailNow() { - f.t.Fail() - panic("test failed") -} - -func TestHTTPHeaders(t *testing.T) { - onlineTest(t) - ctx := context.Background() - opts := config.HTTPEndpoint( - config.WithRequestHeader("foo", "bar"), - config.WithRemoveRequestHeader("baz"), - config.WithResponseHeader("spam", "eggs"), - config.WithRemoveResponseHeader("python")) - - tun, exited := serveHTTP(ctx, t, nil, opts, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - defer func() { _ = recover() }() - t := failPanic{t} - - require.NotContains(t, r.Header, "Baz", "Baz Removed") - require.Contains(t, r.Header, "Foo", "Foo added") - require.Equal(t, "bar", r.Header.Get("Foo"), "Foo=bar") - - rw.Header().Add("Python", "bad header") - _, _ = fmt.Fprintln(rw, "Hello, world!") - })) - - req, err := http.NewRequest(http.MethodGet, tun.URL(), nil) - require.NoError(t, err, "Create request") - req.Header.Add("Baz", "bad header") - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - require.Equal(t, http.StatusOK, resp.StatusCode) - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NotContains(t, resp.Header, "Python", "Python removed") - require.Contains(t, resp.Header, "Spam", "Spam added") - require.Equal(t, "eggs", resp.Header.Get("Spam"), "Spam=eggs") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestBasicAuth(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - opts := config.HTTPEndpoint(config.WithBasicAuth("user", "foobarbaz")) - - tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) - - req, err := http.NewRequest(http.MethodGet, tun.URL(), nil) - require.NoError(t, err, "Create request") - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err, "GET tunnel url") - - require.Equal(t, http.StatusUnauthorized, resp.StatusCode) - _ = resp.Body.Close() - - req.SetBasicAuth("user", "foobarbaz") - - resp, err = http.DefaultClient.Do(req) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - require.Equal(t, http.StatusOK, resp.StatusCode) - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestCircuitBreaker(t *testing.T) { - // Don't run this one by default - it has to make ~50 requests. - skipUnless(t, "NGROK_TEST_LONG", "Skipping long circuit breaker test") - onlineTest(t) - ctx := context.Background() - - opts := config.HTTPEndpoint(config.WithCircuitBreaker(0.1)) - - n := 0 - tun, exited := serveHTTP(ctx, t, nil, opts, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n = n + 1 - w.WriteHeader(http.StatusServiceUnavailable) - })) - - var ( - resp *http.Response - err error - ) - - for i := 0; i < 50; i++ { - resp, err = http.Get(tun.URL()) - require.NoError(t, err) - _ = resp.Body.Close() - } - - // Should see fewer than 50 requests come through. - require.Less(t, n, 50) - - require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestProxyProto(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - type testCase struct { - name string - optsFunc func(config.ProxyProtoVersion) config.Tunnel - reqFunc func(*testing.T, string) - version config.ProxyProtoVersion - shouldContain string - } - - base := []testCase{ - { - version: config.ProxyProtoV1, - shouldContain: "PROXY TCP4", - }, - { - version: config.ProxyProtoV2, - shouldContain: "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A", - }, - } - - var cases []testCase - - for _, c := range base { - cases = append(cases, - testCase{ - name: fmt.Sprintf("HTTP/Version%d", c.version), - optsFunc: func(v config.ProxyProtoVersion) config.Tunnel { - return config.HTTPEndpoint(config.WithProxyProto(v)) - }, - reqFunc: func(t *testing.T, url string) { - resp, _ := http.Get(url) - if resp != nil { - _ = resp.Body.Close() - } - }, - version: c.version, - shouldContain: c.shouldContain, - }, - testCase{ - name: fmt.Sprintf("TCP/Version%d", c.version), - optsFunc: func(v config.ProxyProtoVersion) config.Tunnel { - return config.TCPEndpoint(config.WithProxyProto(v)) - }, - reqFunc: func(t *testing.T, u string) { - url, err := url.Parse(u) - require.NoError(t, err) - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", url.Hostname(), url.Port())) - require.NoError(t, err) - _, _ = fmt.Fprint(conn, "Hello, world!") - }, - version: c.version, - shouldContain: c.shouldContain, - }, - ) - } - - for _, tcase := range cases { - t.Run(tcase.name, func(t *testing.T) { - sess := setupSession(ctx, t) - tun := startTunnel(ctx, t, sess, tcase.optsFunc(tcase.version)) - - go tcase.reqFunc(t, tun.URL()) - - conn, err := tun.Accept() - require.NoError(t, err, "Accept connection") - - buf := make([]byte, 12) - _, err = io.ReadAtLeast(conn, buf, 12) - require.NoError(t, err, "Read connection contents") - - conn.Close() - - require.Contains(t, string(buf), tcase.shouldContain) - }) - } -} - -func TestSubdomain(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, rand.Uint64()) - - subdomain := hex.EncodeToString(buf) - - tun, exited := serveHTTP(ctx, t, nil, - config.HTTPEndpoint(config.WithDomain(subdomain+".ngrok.io")), - helloHandler, - ) - - require.Contains(t, tun.URL(), subdomain) - - resp, err := http.Get(tun.URL()) - require.NoError(t, err) - defer resp.Body.Close() - - content, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "Hello, world!\n", string(content)) - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestOAuth(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - opts := config.HTTPEndpoint(config.WithOAuth("google")) - - tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) - - resp, err := http.Get(tun.URL()) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - content, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.NotContains(t, string(content), "Hello, world!") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestHTTPIPRestriction(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - _, cidr, err := net.ParseCIDR("0.0.0.0/0") - require.NoError(t, err) - - opts := config.HTTPEndpoint( - config.WithAllowCIDRString("127.0.0.1/32"), - config.WithDenyCIDR(cidr)) - - tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) - - resp, err := http.Get(tun.URL()) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - require.Equal(t, http.StatusForbidden, resp.StatusCode) - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestTCP(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - opts := config.TCPEndpoint() - - // Easier to test by pretending it's HTTP on this end. - tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) - - url, err := url.Parse(tun.URL()) - require.NoError(t, err) - url.Scheme = "http" - resp, err := http.Get(url.String()) - require.NoError(t, err, "GET tunnel url") - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestTCPIPRestriction(t *testing.T) { - onlineTest(t) - ctx := context.Background() - - _, cidr, err := net.ParseCIDR("127.0.0.1/32") - require.NoError(t, err) - - opts := config.TCPEndpoint( - config.WithAllowCIDR(cidr), - config.WithDenyCIDRString("0.0.0.0/0")) - - // Easier to test by pretending it's HTTP on this end. - tun, exited := serveHTTP(ctx, t, nil, opts, helloHandler) - - url, err := url.Parse(tun.URL()) - require.NoError(t, err) - url.Scheme = "http" - resp, err := http.Get(url.String()) //nolint:bodyclose // resp is expected to be nil - - // Rather than layer-7 error, we should see it at the connection level - require.Nil(t, resp) - require.Error(t, err, "GET Tunnel URL") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestWebsocketConversion(t *testing.T) { - onlineTest(t) - ctx := context.Background() - sess := setupSession(ctx, t) - tun := startTunnel(ctx, t, sess, - config.HTTPEndpoint( - config.WithWebsocketTCPConversion()), - ) - - // HTTP over websockets? suuuure lol - exited := make(chan error) - go func() { - exited <- http.Serve(tun, helloHandler) - }() - - resp, err := http.Get(tun.URL()) - require.NoError(t, err) - - require.Equal(t, http.StatusBadRequest, resp.StatusCode, "Normal http should be rejected") - _ = resp.Body.Close() - - url, err := url.Parse(tun.URL()) - require.NoError(t, err) - - url.Scheme = "wss" - - client := http.Client{ - Transport: &http.Transport{ - Dial: func(network, addr string) (net.Conn, error) { - return websocket.Dial(url.String(), "", tun.URL()) - }, - }, - } - - resp, err = client.Get("http://example.com") - require.NoError(t, err) - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Read response body") - - require.Equal(t, "Hello, world!\n", string(body), "HTTP Body Contents") - - require.NoError(t, tun.CloseWithContext(ctx)) - expectChanError(t, exited, 5*time.Second) -} - -func TestConnectionCallbacks(t *testing.T) { - // Don't run this one by default - it's timing-sensitive and prone to flakes - skipUnless(t, "NGROK_TEST_FLAKEY", "Skipping flakey network test") - - ctx := context.Background() - connects := 0 - disconnectErrs := 0 - disconnectNils := 0 - sess := setupSession(ctx, t, - WithConnectHandler(func(ctx context.Context, sess Session) { - connects++ - }), - WithDisconnectHandler(func(ctx context.Context, sess Session, err error) { - if err == nil { - disconnectNils++ - } else { - disconnectErrs++ - } - }), - WithDialer(&sketchyDialer{1 * time.Second})) - - time.Sleep(2*time.Second + 500*time.Millisecond) - - _ = sess.Close() - - time.Sleep(2 * time.Second) - - require.Equal(t, 3, connects, "should've seen some connect events") - require.Equal(t, 3, disconnectErrs, "should've seen some errors from disconnecting") - require.Equal(t, 1, disconnectNils, "should've seen a final nil from disconnecting") -} - -type sketchyDialer struct { - limit time.Duration -} - -func (sd *sketchyDialer) Dial(network, addr string) (net.Conn, error) { - return sd.DialContext(context.Background(), network, addr) -} - -func (sd *sketchyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := net.Dial(network, addr) - go func() { - time.Sleep(sd.limit) - conn.Close() - }() - return conn, err -} - -func TestHeartbeatCallback(t *testing.T) { - // Don't run this one by default - it's long - skipUnless(t, "NGROK_TEST_LONG", "Skipping long network test") - - ctx := context.Background() - heartbeats := 0 - sess := setupSession(ctx, t, - WithHeartbeatHandler(func(ctx context.Context, sess Session, latency time.Duration) { - heartbeats++ - }), - WithHeartbeatInterval(10*time.Second)) - - time.Sleep(20*time.Second + 500*time.Millisecond) - - _ = sess.Close() - - require.Equal(t, 2, heartbeats, "should've seen some heartbeats") -} - -func TestPermanentErrors(t *testing.T) { - onlineTest(t) - var err error - ctx := context.Background() - u, _ := url.Parse("notarealscheme://example.com") - - _, err = Connect(ctx, WithProxyURL(u), WithAuthtokenFromEnv()) - var proxyErr errProxyInit - require.ErrorIs(t, err, proxyErr) - require.ErrorAs(t, err, &proxyErr) - - sess, err := Connect(ctx, WithAuthtokenFromEnv()) - require.NoError(t, err) - sess.Close() - - timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - _, err = Connect(timeoutCtx, WithServer("127.0.0.234:123"), WithAuthtokenFromEnv()) - require.ErrorIs(t, err, context.DeadlineExceeded) -} - -func TestRetryableErrors(t *testing.T) { - onlineTest(t) - var err error - ctx := context.Background() - - // give up on connecting after first attempt - disconnect := WithDisconnectHandler(func(_ context.Context, sess Session, disconnectErr error) { - sess.Close() - }) - connect := WithConnectHandler(func(_ context.Context, sess Session) { - sess.Close() - }) - - _, err = Connect(ctx, WithServer("127.0.0.234:123"), connect, disconnect) - var dialErr errSessionDial - require.ErrorIs(t, err, dialErr) - require.ErrorAs(t, err, &dialErr) - - _, err = Connect(ctx, WithAuthtoken("lolnope"), connect, disconnect) - var authErr errAuthFailed - require.ErrorIs(t, err, authErr) - require.ErrorAs(t, err, &authErr) - require.True(t, authErr.Remote) -} - -func TestNonExported(t *testing.T) { - ctx := context.Background() - - sess := setupSession(ctx, t) - - require.NotEmpty(t, sess.(interface{ Region() string }).Region()) -} - -func echo(ws *websocket.Conn) { - _, _ = io.Copy(ws, ws) -} - -func TestWebsockets(t *testing.T) { - onlineTest(t) - - ctx := context.Background() - - srv := &http.ServeMux{} - srv.Handle("/", helloHandler) - srv.Handle("/ws", websocket.Handler(echo)) - - tun, errCh := serveHTTP(ctx, t, nil, config.HTTPEndpoint(config.WithScheme(config.SchemeHTTPS)), srv) - - tunnelURL, err := url.Parse(tun.URL()) - require.NoError(t, err) - - conn, err := websocket.Dial(fmt.Sprintf("wss://%s/ws", tunnelURL.Hostname()), "", tunnelURL.String()) - require.NoError(t, err) - - go func() { - _, _ = fmt.Fprintln(conn, "Hello, world!") - }() - - bufConn := bufio.NewReader(conn) - out, err := bufConn.ReadString('\n') - require.NoError(t, err) - require.Equal(t, "Hello, world!\n", out) - - conn.Close() - tun.Close() - - require.Error(t, <-errCh) -} diff --git a/policy/policy.go b/policy/policy.go deleted file mode 100644 index 7de5af54..00000000 --- a/policy/policy.go +++ /dev/null @@ -1,124 +0,0 @@ -package policy - -import ( - "bytes" - "encoding/json" - "fmt" - - "gopkg.in/yaml.v2" -) - -type Policy struct { - // the ordered set of rules that apply to inbound traffic - Inbound []Rule `json:"inbound,omitempty" yaml:"inbound,omitempty"` - // the ordered set of rules that apply to outbound traffic - Outbound []Rule `json:"outbound,omitempty" yaml:"outbound,omitempty"` -} - -type Rule struct { - // the name of the traffic policy rule - Name string `json:"name,omitempty" yaml:"name,omitempty"` - // the set of CEL expressions used to determine if this rule is applicable - Expressions []string `json:"expressions,omitempty" yaml:"expressions,omitempty"` - // the ordered set of actions that should take effect against the traffic - Actions []Action `json:"actions" yaml:"actions"` -} - -type Action struct { - // the type of action that should be used - Type string `json:"type" yaml:"type"` - // the configuration for the specified action type written as a json string - Config map[string]any `json:"config,omitempty" yaml:"config,omitempty"` -} - -// converts the policy to a json string -func (p Policy) JSON() (string, error) { - return marshalJSON(p) -} - -// converts the policy to a yaml string -func (p Policy) YAML() (string, error) { - return marshalYAML(p) -} - -// creates a rule from the specified string in json or yaml format -func NewRuleFromString(input string) (Rule, error) { - r := Rule{} - err := unmarshal(input, &r) - - return r, err -} - -// creates a rule from the specified string in json or yaml format and panics if invalid -func MustRuleFromString(input string) Rule { - r := Rule{} - if err := unmarshal(input, &r); err != nil { - panic(fmt.Sprintf("failed to create rule from specified string due to error: %s", err.Error())) - } - - return r -} - -// converts the rule to a json string -func (p Rule) JSON() (string, error) { - return marshalJSON(p) -} - -// converts the rule to a yaml string -func (p Rule) YAML() (string, error) { - return marshalYAML(p) -} - -// creates an action from the specified string in json or yaml format -func NewActionFromString(input string) (Action, error) { - a := Action{} - err := unmarshal(input, &a) - - return a, err -} - -// creates an action from the specified string in json or yaml format and panics if invalid -func MustActionFromString(input string) Action { - a := Action{} - if err := unmarshal(input, &a); err != nil { - panic(fmt.Sprintf("failed to create action from specified string due to error: %s", err.Error())) - } - - return a -} - -// converts the action to a json string -func (p Action) JSON() (string, error) { - return marshalJSON(p) -} - -// converts the action to a yaml string -func (p Action) YAML() (string, error) { - return marshalYAML(p) -} - -func marshalJSON(o any) (string, error) { - b := new(bytes.Buffer) - enc := json.NewEncoder(b) - enc.SetEscapeHTML(false) - - if err := enc.Encode(o); err != nil { - return "", err - } - - return b.String(), nil -} - -func marshalYAML(o any) (string, error) { - bytes, err := yaml.Marshal(o) - - if err != nil { - return "", err - } - - return string(bytes), nil -} - -func unmarshal(input string, typ any) error { - return yaml.UnmarshalStrict([]byte(input), typ) -} diff --git a/policy/policy_test.go b/policy/policy_test.go deleted file mode 100644 index f6638e27..00000000 --- a/policy/policy_test.go +++ /dev/null @@ -1,347 +0,0 @@ -package policy - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestPolicyToJSON(t *testing.T) { - t.Run("Convert whole policy to json", func(t *testing.T) { - expected := ` - { - "inbound": [ - { - "name":"denyPUT", - "expressions":["req.Method == 'PUT'"], - "actions":[ - { - "type":"deny" - } - ] - }, - { - "name":"logFooHeader", - "expressions":["'foo' in req.Headers"], - "actions":[ - { - "type":"log", - "config":{ - "metadata":{ - "key":"val" - } - } - } - ] - } - ], - "outbound": [ - { - "name":"InternalErrorWhenFailed", - "expressions":["res.StatusCode <= '0'", "res.StatusCode >= '300'"], - "actions":[ - { - "type":"custom-response", - "config":{ - "status_code":500 - } - } - ] - } - ] - }` - cfg := Policy{ - Inbound: []Rule{ - { - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []Action{ - {Type: "deny"}, - }, - }, - { - Name: "logFooHeader", - Expressions: []string{"'foo' in req.Headers"}, - Actions: []Action{ - { - Type: "log", - Config: map[string]any{"metadata": map[string]any{"key": "val"}}, - }, - }, - }, - }, - Outbound: []Rule{ - { - Name: "InternalErrorWhenFailed", - Expressions: []string{ - "res.StatusCode <= '0'", - "res.StatusCode >= '300'", - }, - Actions: []Action{ - { - Type: "custom-response", - Config: map[string]any{"status_code": 500}, - }, - }, - }, - }, - } - - json, err := cfg.JSON() - require.NoError(t, err) - require.JSONEq(t, expected, json) - }) - - t.Run("convert policy rule to json", func(t *testing.T) { - expected := `{"name":"denyPUT","expressions":["req.Method == 'PUT'"],"actions":[{"type":"deny","config":{"status_code":401}}]}` - - policy := Rule{ - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []Action{ - { - Type: "deny", - Config: map[string]any{"status_code": 401}, - }, - }, - } - - result, err := policy.JSON() - require.NoError(t, err) - require.JSONEq(t, expected, result) - }) - - t.Run("convert action to json", func(t *testing.T) { - expected := `{"type":"deny","config":{"status_code":401}}` - action := Action{ - - Type: "deny", - Config: map[string]any{"status_code": 401}, - } - - result, err := action.JSON() - require.NoError(t, err) - require.JSONEq(t, expected, result) - }) -} - -func TestPolicyToYAML(t *testing.T) { - t.Run("Convert whole policy to yaml", func(t *testing.T) { - expected := ` - inbound: - - name: "denyPUT" - expressions: ["req.Method == 'PUT'"] - actions: - - type: "deny" - - name: "logFooHeader" - expressions: ["'foo' in req.Headers"] - actions: - - type: "log" - config: - metadata: - key: "val" - outbound: - - name: "InternalErrorWhenFailed" - expressions: - - "res.StatusCode <= '0'" - - "res.StatusCode >= '300'" - actions: - - type: "custom-response" - config: - status_code: 500` - cfg := Policy{ - Inbound: []Rule{ - { - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []Action{ - {Type: "deny"}, - }, - }, - { - Name: "logFooHeader", - Expressions: []string{"'foo' in req.Headers"}, - Actions: []Action{ - { - Type: "log", - Config: map[string]any{"metadata": map[string]any{"key": "val"}}, - }, - }, - }, - }, - Outbound: []Rule{ - { - Name: "InternalErrorWhenFailed", - Expressions: []string{ - "res.StatusCode <= '0'", - "res.StatusCode >= '300'", - }, - Actions: []Action{ - { - Type: "custom-response", - Config: map[string]any{"status_code": 500}, - }, - }, - }, - }, - } - - yaml, err := cfg.YAML() - require.NoError(t, err) - require.YAMLEq(t, expected, yaml) - }) - - t.Run("convert policy rule to json", func(t *testing.T) { - expected := ` - name: "denyPUT" - expressions: ["req.Method == 'PUT'"] - actions: - - type: "deny" - config: - status_code: 401` - - policy := Rule{ - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []Action{ - { - Type: "deny", - Config: map[string]any{"status_code": 401}, - }, - }, - } - - result, err := policy.YAML() - require.NoError(t, err) - require.YAMLEq(t, expected, result) - }) - - t.Run("convert action to json", func(t *testing.T) { - expected := ` - type: "deny" - config: - status_code: 401` - action := Action{ - - Type: "deny", - Config: map[string]any{"status_code": 401}, - } - - result, err := action.YAML() - require.NoError(t, err) - require.YAMLEq(t, expected, result) - }) -} - -func TestFromString(t *testing.T) { - t.Run("rule from json", func(t *testing.T) { - input := `{"name":"denyPUT","expressions":["req.Method == 'PUT'"],"actions":[{"type":"deny","config":{"status_code":401}}]}` - expected := Rule{ - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []Action{ - { - Type: "deny", - Config: map[string]any{"status_code": 401}, - }, - }, - } - - result, err := NewRuleFromString(input) - - require.NoError(t, err) - require.Equal(t, expected, result) - }) - - t.Run("new rule from yaml", func(t *testing.T) { - input := ` - name: "denyPUT" - expressions: ["req.Method == 'PUT'"] - actions: - - type: "deny" - config: - status_code: 401` - - expected := Rule{ - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []Action{ - { - Type: "deny", - Config: map[string]any{"status_code": 401}, - }, - }, - } - - result, err := NewRuleFromString(input) - - require.NoError(t, err) - require.Equal(t, expected, result) - }) - - t.Run("convert action to json", func(t *testing.T) { - input := `{"type":"deny","config":{"status_code":401}}` - expected := Action{ - - Type: "deny", - Config: map[string]any{"status_code": 401}, - } - - result, err := NewActionFromString(input) - - require.NoError(t, err) - require.Equal(t, expected, result) - }) - - t.Run("action from yaml", func(t *testing.T) { - input := ` - type: "deny" - config: - status_code: 401` - expected := Action{ - - Type: "deny", - Config: map[string]any{"status_code": 401}, - } - - result, err := NewActionFromString(input) - - require.NoError(t, err) - require.Equal(t, expected, result) - }) - - t.Run("must action to json", func(t *testing.T) { - input := `{"type":"deny","config":{"status_code":401}}` - expected := Action{ - - Type: "deny", - Config: map[string]any{"status_code": 401}, - } - - result := MustActionFromString(input) - - require.Equal(t, expected, result) - }) - - t.Run("must action from yaml", func(t *testing.T) { - input := ` - type: "deny" - config: - status_code: 401` - expected := Action{ - - Type: "deny", - Config: map[string]any{"status_code": 401}, - } - - result := MustActionFromString(input) - - require.Equal(t, expected, result) - }) - - t.Run("must action from invalid", func(t *testing.T) { - input := `invalid: val` - - require.Panics(t, func() { MustActionFromString(input) }) - }) -} diff --git a/rpc/request.go b/rpc/request.go new file mode 100644 index 00000000..20837d3f --- /dev/null +++ b/rpc/request.go @@ -0,0 +1,15 @@ +package rpc + +// Method constants defining standard RPC methods +const ( + StopAgentMethod = "StopAgent" + RestartAgentMethod = "RestartAgent" + UpdateAgentMethod = "UpdateAgent" +) + +// Request defines an interface of RPC messages received from the ngrok cloud +// service. +type Request interface { + // Method returns the RPC method name being called. + Method() string +} diff --git a/rpc_handler.go b/rpc_handler.go new file mode 100644 index 00000000..88af77bb --- /dev/null +++ b/rpc_handler.go @@ -0,0 +1,20 @@ +package ngrok + +import ( + "context" + + "golang.ngrok.com/ngrok/v2/rpc" +) + +// RPCHandler is a function that processes RPC requests from the ngrok service. +// It receives the context, agent session, and request, and returns an optional +// response payload and error. +type RPCHandler func(context.Context, AgentSession, rpc.Request) ([]byte, error) + +// Private request implementation that satisfies the rpc.Request interface +type rpcRequest struct { + method string + payload []byte +} + +func (r *rpcRequest) Method() string { return r.method } diff --git a/scripts/Makefile b/scripts/Makefile index b21acf4d..0df2bfd1 100644 --- a/scripts/Makefile +++ b/scripts/Makefile @@ -1,7 +1,7 @@ .PHONY: fmt fmt: go install golang.org/x/tools/cmd/goimports@v0.3.0 - goimports -format-only -w -local golang.ngrok.com *.go config examples + goimports -format-only -w -local golang.ngrok.com *.go examples .PHONY: lint lint: diff --git a/session.go b/session.go index 5688a9bc..8c45e458 100644 --- a/session.go +++ b/session.go @@ -1,1040 +1,43 @@ package ngrok import ( - "context" - "crypto/tls" - "crypto/x509" - _ "embed" // nolint - "errors" - "fmt" - "net" - "net/http" - "net/url" - "os" - "regexp" - "runtime" - "strings" - "sync/atomic" "time" - - "github.com/inconshreveable/log15/v3" - "go.uber.org/multierr" - "golang.org/x/net/proxy" - "golang.org/x/sync/errgroup" - - "golang.ngrok.com/ngrok/config" - - "golang.ngrok.com/muxado/v2" - tunnel_client "golang.ngrok.com/ngrok/internal/tunnel/client" - "golang.ngrok.com/ngrok/internal/tunnel/proto" - "golang.ngrok.com/ngrok/log" ) -// The ngrok library version. -// -//go:embed VERSION -var libraryAgentVersion string - -// AgentVersionDeprecated is a type wrapper for [proto.AgentVersionDeprecated] -type AgentVersionDeprecated proto.AgentVersionDeprecated - -func (avd *AgentVersionDeprecated) Error() string { - return (*proto.AgentVersionDeprecated)(avd).Error() -} - -// Session encapsulates an established session with the ngrok service. Sessions -// recover from network failures by automatically reconnecting. -type Session interface { - // Listen creates a new Tunnel which will listen for new inbound - // connections. The returned Tunnel object is a net.Listener. - Listen(ctx context.Context, cfg config.Tunnel) (Tunnel, error) - - // Warnings returns a list of warnings generated for the session on connect/auth +// AgentSession represents an active connection from an Agent to the ngrok cloud +// service. +type AgentSession interface { + // ID returns the server-assigned ID of the agent session + // TODO(alan): implement when the server begins setting this value + // ID() string + // Warnings is a list of warnings returned by the ngrok cloud service after the Agent has connected Warnings() []error - - // ListenAndForward creates a new Tunnel which will listen for new inbound - // connections. Connections on this tunnel are automatically forwarded to - // the provided URL. - ListenAndForward(ctx context.Context, backend *url.URL, cfg config.Tunnel) (Forwarder, error) - - // ListenAndServeHTTP creates a new Tunnel to serve as a backend for an HTTP server. Connections will be - // forwarded to the provided HTTP server. - ListenAndServeHTTP(ctx context.Context, cfg config.Tunnel, server *http.Server) (Forwarder, error) - - // ListenAndHandleHTTP creates a new Tunnel to serve as a backend for an HTTP handler. Connections will be - // forwarded to a new HTTP server and handled by the provided HTTP handler. - ListenAndHandleHTTP(ctx context.Context, cfg config.Tunnel, handler *http.Handler) (Forwarder, error) - - // Close ends the ngrok session. All Tunnel objects created by Listen - // on this session will be closed. - Close() error -} - -//go:embed assets/ngrok.ca.crt -var defaultCACert []byte - -const defaultServer = "connect.ngrok-agent.com:443" - -var leastLatencyServer = regexp.MustCompile(`^connect\.([a-z]+?-)?ngrok-agent\.com(\.lan)?:443`) - -// Dialer is the interface a custom connection dialer must implement for use -// with the [WithDialer] option. -type Dialer interface { - // Connect to an address on the named network. - // See the documentation for net.Dial. - Dial(network, address string) (net.Conn, error) - // Connect to an address on the named network with the provided - // context. - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - -// SessionConnectHandler is the callback type for [WithConnectHandler] -type SessionConnectHandler func(ctx context.Context, sess Session) - -// SessionDisconnectHandler is the callback type for [WithDisconnectHandler] -type SessionDisconnectHandler func(ctx context.Context, sess Session, err error) - -// SessionHeartbeatHandler is the callback type for [WithHearbeatHandler] -type SessionHeartbeatHandler func(ctx context.Context, sess Session, latency time.Duration) - -// ServerCommandHandler is the callback type for [WithStopHandler] -type ServerCommandHandler func(ctx context.Context, sess Session) error - -// ConnectOption is passed to [Connect] to customize session connection and establishment. -type ConnectOption func(*connectConfig) - -type clientInfo struct { - Type string - Version string - Comments []string -} - -var bannedUAchar = regexp.MustCompile("[^!#$%&'*+-.^_`|~0-9a-zA-Z]") - -// Formats client info as a well-formed user agent string -func (c *clientInfo) ToUserAgent() string { - comment := "" - if len(c.Comments) > 0 { - comment = fmt.Sprintf(" (%s)", strings.Join(c.Comments, "; ")) - } - return sanitizeUserAgentString(c.Type) + "/" + sanitizeUserAgentString(c.Version) + comment -} - -func sanitizeUserAgentString(s string) string { - s = strings.ReplaceAll(s, "/", "-") - s = bannedUAchar.ReplaceAllString(s, "#") - return s -} - -// version, type, user-agent -func generateUserAgent(cs []clientInfo) string { - var uas []string - - for _, c := range cs { - uas = append(uas, c.ToUserAgent()) - } - - return strings.Join(uas, " ") -} - -// Options to use when establishing the ngrok session. -type connectConfig struct { - // Your ngrok Authtoken. - Authtoken proto.ObfuscatedString - // The address of the ngrok server to connect to. - // Defaults to `connect.ngrok-agent.com:443` - ServerAddr string - // The optional addresses of the additional ngrok servers to connect to. - AdditionalServerAddrs []string - // Enable using multiple session legs - EnableMultiLeg bool - // The [tls.Config] used when connecting to the ngrok server - TLSConfigCustomizer func(*tls.Config) - // The [x509.CertPool] used to authenticate the ngrok server certificate. - CAPool *x509.CertPool - - // The [Dialer] used to establish the initial TCP connection to the ngrok - // server. - // If set, takes precedence over the ProxyURL setting. - // If not set, defaults to a [net.Dialer]. - Dialer Dialer - - // The URL of a proxy to use when making the TCP connection to the ngrok - // server. - // Any proxy supported by [golang.org/x/net/proxy] may be used. - ProxyURL *url.URL - - // Opaque metadata string to be associated with the session. - // Viewable from the ngrok dashboard or API. - Metadata string - - // Child client types and versions used to identify specific applications - // using this library to the ngrok service. - ClientInfo []clientInfo - - // HeartbeatInterval determines how often we send application level - // heartbeats to the server go check connection liveness. - HeartbeatInterval time.Duration - // HeartbeatTolerance is the duration after which an unacknowledged - // heartbeat is determined to mean the connection is dead. - HeartbeatTolerance time.Duration - - ConnectHandler SessionConnectHandler - DisconnectHandler SessionDisconnectHandler - HeartbeatHandler SessionHeartbeatHandler - - StopHandler ServerCommandHandler - RestartHandler ServerCommandHandler - UpdateHandler ServerCommandHandler - - remoteStopErr *string - remoteRestartErr *string - remoteUpdateErr *string - - // The logger for the session to use. - Logger log.Logger -} - -// WithMetadata configures the opaque, machine-readable metadata string for this -// session. Metadata is made available to you in the ngrok dashboard and the -// Agents API resource. It is a useful way to allow you to uniquely identify -// sessions. We suggest encoding the value in a structured format like JSON. -// -// See the [metadata parameter in the ngrok docs] for additional details. -// -// [metadata parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#metadata -func WithMetadata(meta string) ConnectOption { - return func(cfg *connectConfig) { - cfg.Metadata = meta - } -} - -// WithClientInfo configures client type and version information for applications -// built on this library. This is a way for consumers of this library to identify -// themselves to the ngrok service. -// -// This will add a new entry to the `User-Agent` field in the "most significant" -// (first) position. -func WithClientInfo(clientType, version string, comments ...string) ConnectOption { - return func(cfg *connectConfig) { - cfg.ClientInfo = append([]clientInfo{{clientType, version, comments}}, cfg.ClientInfo...) - } -} - -// WithDialer configures the session to use the provided [Dialer] when -// establishing a connection to the ngrok service. This option will cause -// [WithProxyURL] to be ignored. -func WithDialer(dialer Dialer) ConnectOption { - return func(cfg *connectConfig) { - cfg.Dialer = dialer - } -} - -// WithProxyURL configures the session to connect to ngrok through an outbound -// HTTP or SOCKS5 proxy. This parameter is ignored if you override the dialer -// with [WithDialer]. -// -// See the [proxy url parameter in the ngrok docs] for additional details. -// -// [proxy url parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#proxy_url -func WithProxyURL(url *url.URL) ConnectOption { - return func(cfg *connectConfig) { - cfg.ProxyURL = url - } -} - -// WithAuthtoken configures the session to authenticate with the provided -// authtoken. You can [find your existing authtoken] or [create a new one] in the ngrok dashboard. -// -// See the [authtoken parameter in the ngrok docs] for additional details. -// -// [find your existing authtoken]: https://dashboard.ngrok.com/get-started/your-authtoken -// [create a new one]: https://dashboard.ngrok.com/tunnels/authtokens -// [authtoken parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#authtoken -func WithAuthtoken(token string) ConnectOption { - return func(cfg *connectConfig) { - cfg.Authtoken = proto.ObfuscatedString(token) - } -} - -// WithAuthtokenFromEnv is a shortcut for calling [WithAuthtoken] with the -// value of the NGROK_AUTHTOKEN environment variable. -func WithAuthtokenFromEnv() ConnectOption { - return WithAuthtoken(os.Getenv("NGROK_AUTHTOKEN")) -} - -// WithRegion configures the session to connect to a specific ngrok region. -// If unspecified, ngrok will connect to the fastest region, which is usually what you want. -// The [full list of ngrok regions] can be found in the ngrok documentation. -// -// See the [region parameter in the ngrok docs] for additional details. -// -// [full list of ngrok regions]: https://ngrok.com/docs/platform/pops -// [region parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#region -func WithRegion(region string) ConnectOption { - return func(cfg *connectConfig) { - if region != "" { - cfg.ServerAddr = fmt.Sprintf("connect.%s.ngrok-agent.com:443", region) - } - } -} - -// WithServer configures the network address to dial to connect to the ngrok -// service. Use this option only if you are connecting to a custom agent -// ingress. -// -// See the [server_addr parameter in the ngrok docs] for additional details. -// -// [server_addr parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#server_addr -func WithServer(addr string) ConnectOption { - return func(cfg *connectConfig) { - cfg.ServerAddr = addr - } -} - -// WithAdditionalServers configures the network address to dial to connect to the ngrok -// service on secondary legs. Use this option only if you are connecting to a custom agent -// ingress, and have enabled multi leg. -// -// See the [server_addr parameter in the ngrok docs] for additional details. -// -// [server_addr parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#server_addr -func WithAdditionalServers(addrs []string) ConnectOption { - return func(cfg *connectConfig) { - cfg.AdditionalServerAddrs = addrs - } -} - -// WithMultiLeg as true allows connecting to the ngrok service on secondary legs. -// -// See [WithAdditionalServers] if connecting to a custom agent ingress. -func WithMultiLeg(enable bool) ConnectOption { - return func(cfg *connectConfig) { - cfg.EnableMultiLeg = enable - } -} - -// WithTLSConfig allows customization of the TLS connection made from the agent -// to the ngrok service. Customization is applied after the [WithServer] and -// [WithCA] options are applied. -func WithTLSConfig(tlsCustomizer func(*tls.Config)) ConnectOption { - return func(cfg *connectConfig) { - cfg.TLSConfigCustomizer = tlsCustomizer - } -} - -// WithCA configures the CAs used to validate the TLS certificate returned by -// the ngrok service while establishing the session. Use this option only if -// you are connecting through a man-in-the-middle or deep packet inspection -// proxy. -// -// See the [root_cas parameter in the ngrok docs] for additional details. -// -// [root_cas parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#root_cas -func WithCA(pool *x509.CertPool) ConnectOption { - return func(cfg *connectConfig) { - cfg.CAPool = pool - } -} - -// WithHeartbeatTolerance configures the duration to wait for a response to a heartbeat -// before assuming the session connection is dead and attempting to reconnect. -// -// See the [heartbeat_tolerance parameter in the ngrok docs] for additional details. -// -// [heartbeat_tolerance parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#heartbeat_tolerance -func WithHeartbeatTolerance(tolerance time.Duration) ConnectOption { - return func(cfg *connectConfig) { - cfg.HeartbeatTolerance = tolerance - } -} - -// WithHeartbeatInterval configures how often the session will send heartbeat -// messages to the ngrok service to check session liveness. -// -// See the [heartbeat_interval parameter in the ngrok docs] for additional details. -// -// [heartbeat_interval parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#heartbeat_interval -func WithHeartbeatInterval(interval time.Duration) ConnectOption { - return func(cfg *connectConfig) { - cfg.HeartbeatInterval = interval - } -} - -// WithLogger configures a logger to receive log messages from the [Session]. The -// log subpackage contains adapters for both [logrus] and [zap]. -// -// [logrus]: https://pkg.go.dev/github.com/sirupsen/logrus -// [zap]: https://pkg.go.dev/go.uber.org/zap -func WithLogger(logger log.Logger) ConnectOption { - return func(cfg *connectConfig) { - cfg.Logger = logger - } -} - -// WithConnectHandler configures a function which is called each time the ngrok -// [Session] successfully connects to the ngrok service. Use this option to -// receive events when the [Session] successfully connects or reconnects after -// a disconnection due to network failure. -func WithConnectHandler(handler SessionConnectHandler) ConnectOption { - return func(cfg *connectConfig) { - cfg.ConnectHandler = handler - } -} - -// WithDisconnectHandler configures a function which is called each time the -// ngrok [Session] disconnects from the ngrok service. Use this option to detect -// when the ngrok session has gone temporarily offline. -// -// This handler will be called every time the [Session] encounters an error during -// or after connection. It may be called multiple times in a row; it may be -// called before any Connect handler is called and before [Connect] returns. -// -// If this function is called with a nil error, the [Session] has stopped and will -// not reconnect, usually due to [Session.Close] being called. -func WithDisconnectHandler(handler SessionDisconnectHandler) ConnectOption { - return func(cfg *connectConfig) { - cfg.DisconnectHandler = handler - } -} - -// WithHeartbeatHandler configures a function which is called each time the -// [Session] successfully heartbeats the ngrok service. The callback receives -// the latency of the round trip time from initiating the heartbeat to -// receiving an acknowledgement back from the ngrok service. -func WithHeartbeatHandler(handler SessionHeartbeatHandler) ConnectOption { - return func(cfg *connectConfig) { - cfg.HeartbeatHandler = handler - } -} - -// WithStopHandler configures a function which is called when the ngrok service -// requests that this [Session] stops. Your application may choose to interpret -// this callback as a request to terminate the [Session] or the entire process. -// -// Errors returned by this function will be visible to the ngrok dashboard or -// API as the response to the Stop operation. -// -// Do not block inside this callback. It will cause the Dashboard or API Stop -// operation to hang. Do not call [Session].Close or [os.Exit] inside this -// callback, it will also cause the operation to hang. -// -// Instead, either return an error or if you intend to Stop, spawn a goroutine -// to asynchronously call [Session].Close or [os.Exit]. -func WithStopHandler(handler ServerCommandHandler) ConnectOption { - return func(cfg *connectConfig) { - cfg.StopHandler = handler - } -} - -// WithRestartHandler configures a function which is called when the ngrok service -// requests that this [Session] restarts. Your application may choose to interpret -// this callback as a request to reconnect the [Session] or restart the entire process. -// -// Errors returned by this function will be visible to the ngrok dashboard or -// API as the response to the Restart operation. -// -// Do not block inside this callback. It will cause the Dashboard or API Restart -// operation to hang. Do not call [Session].Close or [os.Exit] inside this -// callback, it will also cause the operation to hang. -// -// Instead, either spawn a goroutine to asynchronously restart, or return an error. -func WithRestartHandler(handler ServerCommandHandler) ConnectOption { - return func(cfg *connectConfig) { - cfg.RestartHandler = handler - } -} - -// WithUpdateHandler configures a function which is called when the ngrok service -// requests that the application running this [Session] updates. Your application -// may use this callback to trigger a check for a newer version followed by an update -// and restart if one exists. -// -// Errors returned by this function will be visible to the ngrok dashboard or -// API as the response to the Update operation. -// -// Do not block inside this callback. It will cause the Dashboard or API Update -// operation to hang. Do not call [Session].Close or [os.Exit] inside this -// callback, it will also cause the operation to hang. -// -// Instead, spawn a goroutine to asynchronously handle the update process -// or return an error if there is no newer version to update to. -func WithUpdateHandler(handler ServerCommandHandler) ConnectOption { - return func(cfg *connectConfig) { - cfg.UpdateHandler = handler - } -} - -// WithStopCommandDisabled specifies a user-friendly error message to be reported -// by the ngrok dashboard or API when a user attempts to issue a Stop command for -// this [Session]. -// -// Set this error only if you wish to provide a more detailed reason for entirely -// disabling the Stop command for your application. If you wish to report an error -// while attempting to handle a Stop command, instead return that error from the -// handler function set by [WithStopHandler]. -func WithStopCommandDisabled(err string) ConnectOption { - return func(cfg *connectConfig) { - cfg.remoteStopErr = &err - } -} - -// WithRestartCommandDisabled specifies a user-friendly error message to be reported -// by the ngrok dashboard or API when a user attempts to issue a Restart command for -// this [Session]. -// -// Set this error only if you wish to provide a more detailed reason for entirely -// disabling the Restart command for your application. If you wish to report an error -// while attempting to handle a Restart command, instead return that error from the -// handler function set by [WithRestartHandler]. -func WithRestartCommandDisabled(err string) ConnectOption { - return func(cfg *connectConfig) { - cfg.remoteRestartErr = &err - } -} - -// WithUpdateCommandDisabled specifies a user-friendly error message to be reported -// by the ngrok dashboard or API when a user attempts to issue a Update command for -// this [Session]. -// -// Set this error only if you wish to provide a more detailed reason for entirely -// disabling the Update command for your application. If you wish to report an error -// while attempting to handle a Update command, instead return that error from the -// handler function set by [WithUpdateHandler]. -func WithUpdateCommandDisabled(err string) ConnectOption { - return func(cfg *connectConfig) { - cfg.remoteUpdateErr = &err - } -} - -// Connect begins a new ngrok [Session] by connecting to the ngrok service, -// retrying transient failures if they occur. -// -// Connect blocks until the session is successfully established or fails with -// an error that will not be retried. Customize session connection behavior -// with [ConnectOption] arguments. -func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) { - logger := log15.New() - logger.SetHandler(log15.DiscardHandler()) - - cfg := connectConfig{} - for _, o := range opts { - o(&cfg) - } - - if cfg.Logger != nil { - logger = toLog15(cfg.Logger) - } - - if cfg.CAPool == nil { - cfg.CAPool = x509.NewCertPool() - cfg.CAPool.AppendCertsFromPEM(defaultCACert) - } - - if cfg.ServerAddr == "" { - cfg.ServerAddr = defaultServer - } - - var dialer Dialer - - if cfg.Dialer != nil { - dialer = cfg.Dialer - } else { - netDialer := &net.Dialer{} - - if cfg.ProxyURL != nil { - proxied, err := proxy.FromURL(cfg.ProxyURL, netDialer) - if err != nil { - return nil, errProxyInit{cfg.ProxyURL, err} - } - dialer = proxied.(Dialer) - } else { - dialer = netDialer - } - } - - heartbeatConfig := muxado.NewHeartbeatConfig() - if cfg.HeartbeatTolerance != 0 { - heartbeatConfig.Tolerance = cfg.HeartbeatTolerance - } - if cfg.HeartbeatInterval != 0 { - heartbeatConfig.Interval = cfg.HeartbeatInterval - } - - session := new(sessionImpl) - - stateChanges := make(chan error, 32) - - callbackHandler := remoteCallbackHandler{ - Logger: logger, - sess: session, - stopHandler: cfg.StopHandler, - restartHandler: cfg.RestartHandler, - updateHandler: cfg.UpdateHandler, - } - - rawDialer := func(legNumber uint32) (tunnel_client.RawSession, error) { - serverAddr := cfg.ServerAddr - if legNumber > 0 && len(cfg.AdditionalServerAddrs) >= int(legNumber) { - serverAddr = cfg.AdditionalServerAddrs[legNumber-1] - } - tlsConfig := &tls.Config{ - RootCAs: cfg.CAPool, - ServerName: strings.Split(serverAddr, ":")[0], - MinVersion: tls.VersionTLS12, - } - if cfg.TLSConfigCustomizer != nil { - cfg.TLSConfigCustomizer(tlsConfig) - } - - conn, err := dialer.DialContext(ctx, "tcp", serverAddr) - if err != nil { - return nil, errSessionDial{serverAddr, err} - } - - conn = tls.Client(conn, tlsConfig) - - sess := muxado.Client(conn, &muxado.Config{}) - return tunnel_client.NewRawSession(logger, sess, heartbeatConfig, callbackHandler), nil - } - - empty := "" - notImplemented := "the agent has not defined a callback for this operation" - - if cfg.StopHandler != nil { - cfg.remoteStopErr = &empty - } - if cfg.RestartHandler != nil { - cfg.remoteRestartErr = &empty - } - if cfg.UpdateHandler != nil { - cfg.remoteUpdateErr = &empty - } - - if cfg.remoteStopErr == nil { - cfg.remoteStopErr = ¬Implemented - } - if cfg.remoteRestartErr == nil { - cfg.remoteRestartErr = ¬Implemented - } - if cfg.remoteUpdateErr == nil { - cfg.remoteUpdateErr = ¬Implemented - } - - cfg.ClientInfo = append( - cfg.ClientInfo, - clientInfo{Type: string(proto.LibraryOfficialGo), Version: strings.TrimSpace(libraryAgentVersion)}, - ) - - userAgent := generateUserAgent(cfg.ClientInfo) - - auth := proto.AuthExtra{ - Version: cfg.ClientInfo[0].Version, - ClientType: proto.ClientType(cfg.ClientInfo[0].Type), - UserAgent: userAgent, - Authtoken: proto.ObfuscatedString(cfg.Authtoken), - Metadata: cfg.Metadata, - OS: runtime.GOOS, - Arch: runtime.GOARCH, - HeartbeatInterval: int64(heartbeatConfig.Interval), - HeartbeatTolerance: int64(heartbeatConfig.Tolerance), - - RestartUnsupportedError: cfg.remoteRestartErr, - StopUnsupportedError: cfg.remoteStopErr, - UpdateUnsupportedError: cfg.remoteUpdateErr, - } - - reconnect := func(sess tunnel_client.Session, raw tunnel_client.RawSession, legNumber uint32) (int, error) { - auth.LegNumber = legNumber - resp, err := sess.Auth(auth) - if err != nil { - remote := false - if resp.Error != "" { - remote = true - } - return 0, errAuthFailed{remote, err} - } - - if resp.Extra.DeprecationWarning != nil { - warning := resp.Extra.DeprecationWarning - vars := make([]any, 0, 3) - if warning.NextMin != "" { - vars = append(vars, "min_version", warning.NextMin) - } - if !warning.NextDate.IsZero() { - vars = append(vars, "deadline", warning.NextDate) - } - if warning.Msg != "" { - vars = append(vars, "extra", warning.Msg) - } - logger.Warn(warning.Error(), vars...) - } - - sessionInner := &sessionInner{ - Session: sess, - Region: resp.Extra.Region, - ProtoVersion: resp.Version, - ServerVersion: resp.Extra.Version, - ClientID: resp.Extra.Region, - AccountName: resp.Extra.AccountName, - PlanName: resp.Extra.PlanName, - Banner: resp.Extra.Banner, - SessionDuration: resp.Extra.SessionDuration, - DeprecationWarning: resp.Extra.DeprecationWarning, - ConnectAddresses: resp.Extra.ConnectAddresses, - Logger: logger, - } - - if legNumber == 0 { - session.setInner(sessionInner) - } - - if cfg.HeartbeatHandler != nil { - // plumb a session with the proper region to the heartbeatHandler - heartbeatSession := new(sessionImpl) - heartbeatSession.setInner(sessionInner) - go func() { - // use the raw latency channel in case this is a multi-leg session - beats := raw.Latency() - for { - select { - case <-ctx.Done(): - return - case latency, ok := <-beats: - if !ok { - return - } - cfg.HeartbeatHandler(ctx, heartbeatSession, latency) - } - } - }() - } - - auth.Cookie = resp.Extra.Cookie - - // store any connect server addresses for use in subsequent legs - if cfg.EnableMultiLeg && legNumber == 0 && len(resp.Extra.ConnectAddresses) > 1 { - overrideAdditionalServers := len(cfg.AdditionalServerAddrs) == 0 - for i, ca := range resp.Extra.ConnectAddresses { - if i == 0 { - if leastLatencyServer.MatchString(cfg.ServerAddr) { - // lock in the leg 0 region - logger.Debug("first leg using region", "region", resp.Extra.Region, "server", ca.ServerAddr) - cfg.ServerAddr = ca.ServerAddr - } - } else if overrideAdditionalServers { - cfg.AdditionalServerAddrs = append(cfg.AdditionalServerAddrs, ca.ServerAddr) - } - } - } - - // if we are using multi-leg, we need to know how many legs to connect - desiredLegs := 1 - if cfg.EnableMultiLeg { - desiredLegs = 1 + len(cfg.AdditionalServerAddrs) - } - return desiredLegs, nil - } - - sess := tunnel_client.NewReconnectingSession(logger, rawDialer, stateChanges, reconnect) - // allow consumers to .Close() the session before a successful connect - session.setInner(&sessionInner{ - Session: sess, - }) - - // performs one "pump" of the session update channel - // returns true if there are more updates to handle - runSessionHandlers := func() (bool, error) { - select { - case <-ctx.Done(): - if cfg.DisconnectHandler != nil { - cfg.DisconnectHandler(ctx, session, ctx.Err()) - logger.Info("no more state changes") - cfg.DisconnectHandler(ctx, session, nil) - } - sess.Close() - return false, ctx.Err() - case err, ok := <-stateChanges: - switch { - case !ok: // session has given up on reconnecting - if cfg.DisconnectHandler != nil { - logger.Info("no more state changes") - cfg.DisconnectHandler(ctx, session, nil) - } - sess.Close() - return false, nil - case err != nil: // session encountered an error - if cfg.DisconnectHandler != nil { - cfg.DisconnectHandler(ctx, session, err) - } - return true, err - case err == nil: // session connected successfully - if cfg.ConnectHandler != nil { - cfg.ConnectHandler(ctx, session) - } - return true, nil - } - } - - panic("inexhaustive case match when handling session state change") - } - - var errs error - for again := true; again; { - var err error - again, err = runSessionHandlers() - switch { - case again && err == nil: // successfully connected, move to goroutine and return - again = false - case again && err != nil: // error on reconnect - errs = multierr.Append(errs, err) - case !again: // gave up trying to reconnect - errs = multierr.Append(errs, err) - return nil, errs - } - } - - go func() { - for again := true; again; again, _ = runSessionHandlers() { - } - }() - - return session, nil -} - -type sessionImpl struct { - raw atomic.Pointer[sessionInner] -} - -type sessionInner struct { - tunnel_client.Session - - Region string - ProtoVersion string - ServerVersion string - ClientID string - AccountName string - PlanName string - Banner string - SessionDuration int64 - DeprecationWarning *proto.AgentVersionDeprecated - ConnectAddresses []proto.ConnectAddress - - Logger log15.Logger -} - -func (s *sessionImpl) inner() *sessionInner { - return s.raw.Load() -} - -func (s *sessionImpl) setInner(raw *sessionInner) { - s.raw.Store(raw) -} - -func (s *sessionImpl) closeTunnel(clientID string, err error) error { - return s.inner().CloseTunnel(clientID, err) -} - -func (s *sessionImpl) Close() error { - return s.inner().Close() -} - -func (s *sessionImpl) Warnings() []error { - deprecated := s.inner().DeprecationWarning - if deprecated != nil { - return []error{(*AgentVersionDeprecated)(deprecated)} - } - return nil -} - -func (s *sessionImpl) Listen(ctx context.Context, cfg config.Tunnel) (Tunnel, error) { - var ( - tunnel tunnel_client.Tunnel - err error - ) - tunnelCfg, ok := cfg.(tunnelConfigPrivate) - if !ok { - return nil, errors.New("invalid tunnel config") - } - - extra := tunnelCfg.Extra() - if tunnelCfg.Proto() != "" { - tunnel, err = s.inner().Listen(tunnelCfg.Proto(), tunnelCfg.Opts(), extra, tunnelCfg.ForwardsTo(), tunnelCfg.ForwardsProto()) - } else { - tunnel, err = s.inner().ListenLabel(tunnelCfg.Labels(), extra.Metadata, tunnelCfg.ForwardsTo(), tunnelCfg.ForwardsProto()) - } - - impl := &tunnelImpl{ - Sess: s, - Tunnel: tunnel, - } - - // Legacy support for passing HTTP server via config options. - // TODO: Remove this after we feel HTTP options via config have been deprecated. - if serverCfg, ok := cfg.(interface{ HTTPServer() *http.Server }); ok { - server := serverCfg.HTTPServer() - if server != nil { - go func() { _ = server.Serve(impl) }() - impl.server = server - } - } - - if err == nil { - return impl, nil - } - return nil, errListen{err} -} - -func (s *sessionImpl) ListenAndForward(ctx context.Context, url *url.URL, cfg config.Tunnel) (Forwarder, error) { - tunnelCfg, ok := cfg.(tunnelConfigPrivate) - if !ok { - return nil, errors.New("invalid tunnel config") - } - - // Set 'Forwards To' - tunnelCfg.WithForwardsTo(url) - - tun, err := s.Listen(ctx, cfg) - if err != nil { - return nil, err - } - - return forwardTunnel(ctx, tun, url), nil -} - -func (s *sessionImpl) ListenAndServeHTTP(ctx context.Context, cfg config.Tunnel, server *http.Server) (Forwarder, error) { - tun, err := s.Listen(ctx, cfg) - if err != nil { - return nil, err - } - - mainGroup, _ := errgroup.WithContext(ctx) - if server != nil { - // Store server ref to close when tunnel closes - impl, _ := tun.(*tunnelImpl) - - // Check if tunnel is already serving an HTTP server - // TODO: Remove this once we feel HTTP options via config have been deprecated. - if impl.server == nil { - mainGroup.Go(func() error { return server.Serve(tun) }) - impl.server = server - } else { - // Inform end user that they're using a deprecated option. - s.inner().Logger.Warn("Tunnel is serving an HTTP server via HTTP options. This has been deprecated. Please use Session.ListenAndServeHTTP instead.") - } - } - - return &forwarder{ - Tunnel: tun, - mainGroup: mainGroup, - }, nil -} - -func (s *sessionImpl) ListenAndHandleHTTP(ctx context.Context, cfg config.Tunnel, handler *http.Handler) (Forwarder, error) { - return s.ListenAndServeHTTP(ctx, cfg, &http.Server{Handler: *handler}) -} - -// The rest of the `sessionImpl` methods are non-public, but can be -// interface-asserted if they're *really* needed. These are exempt from any -// stability guarantees and subject to change without notice. - -func (s *sessionImpl) ProtoVersion() string { - return s.inner().ProtoVersion -} -func (s *sessionImpl) ServerVersion() string { - return s.inner().ServerVersion -} -func (s *sessionImpl) ClientID() string { - return s.inner().ClientID -} -func (s *sessionImpl) AccountName() string { - return s.inner().AccountName -} -func (s *sessionImpl) PlanName() string { - return s.inner().PlanName -} -func (s *sessionImpl) Banner() string { - return s.inner().Banner -} -func (s *sessionImpl) SessionDuration() int64 { - return s.inner().SessionDuration -} -func (s *sessionImpl) Region() string { - return s.inner().Region -} -func (s *sessionImpl) Heartbeat() (time.Duration, error) { - return s.inner().Heartbeat() -} -func (s *sessionImpl) Latency() <-chan time.Duration { - return s.inner().Latency() -} -func (s *sessionImpl) ConnectAddresses() []struct{ Region, ServerAddr string } { - connectAddresses := make([]struct{ Region, ServerAddr string }, len(s.inner().ConnectAddresses)) - for i, addr := range s.inner().ConnectAddresses { - connectAddresses[i] = struct{ Region, ServerAddr string }{addr.Region, addr.ServerAddr} - } - return connectAddresses + // Agent returns the agent that started this session + Agent() Agent + // StartedAt returns the time that the AgentSession was connected + StartedAt() time.Time } -type remoteCallbackHandler struct { - log15.Logger - sess *sessionImpl - stopHandler ServerCommandHandler - restartHandler ServerCommandHandler - updateHandler ServerCommandHandler +// agentSession implements the AgentSession interface. +type agentSession struct { + id string + warnings []error + agent Agent + startedAt time.Time } -func (rc remoteCallbackHandler) OnStop(_ *proto.Stop, respond tunnel_client.HandlerRespFunc) { - if rc.stopHandler != nil { - resp := new(proto.StopResp) - close := true - if err := rc.stopHandler(context.TODO(), rc.sess); err != nil { - close = false - resp.Error = err.Error() - } - if err := respond(resp); err != nil { - rc.Warn("error responding to stop request", "error", err) - } - if close { - _ = rc.sess.Close() - } - } +func (s *agentSession) ID() string { + return s.id } -func (rc remoteCallbackHandler) OnRestart(_ *proto.Restart, respond tunnel_client.HandlerRespFunc) { - if rc.restartHandler != nil { - resp := new(proto.RestartResp) - close := true - if err := rc.restartHandler(context.TODO(), rc.sess); err != nil { - close = false - resp.Error = err.Error() - } - if err := respond(resp); err != nil { - rc.Warn("error responding to restart request", "error", err) - } - if close { - _ = rc.sess.Close() - } - } +func (s *agentSession) Warnings() []error { + return s.warnings } -func (rc remoteCallbackHandler) OnUpdate(_ *proto.Update, respond tunnel_client.HandlerRespFunc) { - if rc.updateHandler != nil { - resp := new(proto.UpdateResp) - if err := rc.updateHandler(context.TODO(), rc.sess); err != nil { - resp.Error = err.Error() - } - if err := respond(resp); err != nil { - rc.Warn("error responding to restart request", "error", err) - } - } +func (s *agentSession) Agent() Agent { + return s.agent } -func (rc remoteCallbackHandler) OnStopTunnel(stopTunnel *proto.StopTunnel, respond tunnel_client.HandlerRespFunc) { - ngrokErr := &ngrokError{Message: stopTunnel.Message, ErrCode: stopTunnel.ErrorCode} - // close the tunnel and maintain the session - err := rc.sess.closeTunnel(stopTunnel.ClientID, ngrokErr) - if err != nil { - rc.Warn("error closing tunnel", "error", err) - } +func (s *agentSession) StartedAt() time.Time { + return s.startedAt } diff --git a/upstream.go b/upstream.go new file mode 100644 index 00000000..a020891f --- /dev/null +++ b/upstream.go @@ -0,0 +1,73 @@ +package ngrok + +import ( + "crypto/tls" +) + +// Upstream represents configuration for forwarding to an upstream service. +type Upstream struct { + addr string + protocol string + proxyProto ProxyProtoVersion + tlsClientConfig *tls.Config + dialer Dialer +} + +// UpstreamOption configures an Upstream instance. +type UpstreamOption func(*Upstream) + +// WithUpstream creates an Upstream configuration with a required address. +// The address can be in various formats such as: +// - "80" (a port number for local services) +// - "example.com:8080" (a host:port combination) +// - "http://example.com" (a full URL) +func WithUpstream(addr string, opts ...UpstreamOption) *Upstream { + opt := &Upstream{addr: addr} + for _, o := range opts { + o(opt) + } + return opt +} + +// WithUpstreamProtocol sets the protocol to use when forwarding to the upstream. +// This is typically used to specify "http2" when communicating with an +// upstream HTTP/2 server. +func WithUpstreamProtocol(proto string) UpstreamOption { + return func(o *Upstream) { + o.protocol = proto + } +} + +// WithUpstreamTLSClientConfig sets the TLS client configuration to use when connecting +// to the upstream server over TLS. +func WithUpstreamTLSClientConfig(config *tls.Config) UpstreamOption { + return func(o *Upstream) { + o.tlsClientConfig = config + } +} + +// ProxyProtoVersion represents PROXY protocol version +type ProxyProtoVersion string + +const ( + ProxyProtoV1 ProxyProtoVersion = "v1" + ProxyProtoV2 ProxyProtoVersion = "v2" +) + +// WithUpstreamProxyProto sets the PROXY protocol version to use when connecting +// to the upstream server. Valid values are ProxyProtoV1 or ProxyProtoV2. +// +// See https://ngrok.com/docs/agent/config/v3/#upstreamproxy_protocol +func WithUpstreamProxyProto(proxyProto ProxyProtoVersion) UpstreamOption { + return func(o *Upstream) { + o.proxyProto = proxyProto + } +} + +// WithUpstreamDialer sets a custom dialer to use when connecting to the upstream server. +// This allows for custom network configurations or connection methods when reaching the upstream. +func WithUpstreamDialer(dialer Dialer) UpstreamOption { + return func(o *Upstream) { + o.dialer = dialer + } +}