Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 28 additions & 49 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,18 @@ import (
"golang.ngrok.com/ngrok/v2/rpc"
)

// Agent is the main interface for interacting with the ngrok service.
type Agent interface {
// Connect begins a new Session by connecting and authenticating to the ngrok cloud service.
Connect(context.Context) error

// Disconnect terminates the current Session which disconnects it from the ngrok cloud service.
Disconnect() error

// Session returns an object describing the connection of the Agent to the ngrok cloud service.
Session() (AgentSession, error)

// Endpoints returns the list of endpoints created by this Agent from calls to either Listen or Forward.
Endpoints() []Endpoint

// Listen creates an Endpoint which returns received connections to the caller via an EndpointListener.
Listen(context.Context, ...EndpointOption) (EndpointListener, error)

// Forward creates an Endpoint which forwards received connections to a target upstream URL.
Forward(context.Context, *Upstream, ...EndpointOption) (EndpointForwarder, error)
}

// Dialer is an interface that is satisfied by net.Dialer or you can specify your
// own implementation.
type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

// agent implements the Agent interface.
type agent struct {
// Agent is the main interface for interacting with the ngrok service.
type Agent struct {
mu sync.RWMutex
sess legacy.Session
agentSession *agentSession
agentSession *AgentSession
opts *agentOpts
endpoints []Endpoint
// Event handlers registered with this agent
Expand All @@ -58,13 +37,13 @@ type agent struct {
}

// NewAgent creates a new Agent object.
func NewAgent(agentOpts ...AgentOption) (Agent, error) {
func NewAgent(agentOpts ...AgentOption) (*Agent, error) {
opts := defaultAgentOpts()
for _, opt := range agentOpts {
opt(opts)
}

return &agent{
return &Agent{
opts: opts,
endpoints: make([]Endpoint, 0),
eventHandlers: opts.eventHandlers,
Expand All @@ -73,7 +52,7 @@ func NewAgent(agentOpts ...AgentOption) (Agent, error) {

// Connect begins a new Session by connecting and authenticating to the ngrok
// cloud service.
func (a *agent) Connect(ctx context.Context) error {
func (a *Agent) Connect(ctx context.Context) error {
a.mu.Lock()
defer a.mu.Unlock()

Expand Down Expand Up @@ -118,9 +97,9 @@ func (a *agent) Connect(ctx context.Context) error {
}

// Create our AgentSession wrapper early so we can capture it in closures
agentSession := &agentSession{
agent: a,
startedAt: time.Now(),
agentSession := &AgentSession{
Agent: a,
StartedAt: time.Now(),
}

// Hook up connect event
Expand Down Expand Up @@ -155,8 +134,8 @@ func (a *agent) Connect(ctx context.Context) error {
}

// Complete the AgentSession wrapper with session-specific data
agentSession.id = sess.AgentSessionID()
agentSession.warnings = sess.Warnings()
agentSession.ID = sess.AgentSessionID()
agentSession.Warnings = sess.Warnings()

// Store in agent
a.sess = sess
Expand All @@ -167,7 +146,7 @@ func (a *agent) Connect(ctx context.Context) error {

// Disconnect terminates the current Session which disconnects it from the ngrok
// cloud service.
func (a *agent) Disconnect() error {
func (a *Agent) Disconnect() error {
// Get what we need under lock
a.mu.Lock()
sess := a.sess
Expand Down Expand Up @@ -196,19 +175,18 @@ func (a *agent) Disconnect() error {

// Session returns an object describing the connection of the Agent to the ngrok
// cloud service.
func (a *agent) Session() (AgentSession, error) {
func (a *Agent) Session() (*AgentSession, error) {
a.mu.RLock()
defer a.mu.RUnlock()

if a.sess == nil || a.agentSession == nil {
return nil, errors.New("agent not connected")
}

return a.agentSession, nil
return a.agentSession.clone(), nil
}

// Endpoints returns the list of endpoints created by this Agent.
func (a *agent) Endpoints() []Endpoint {
// Endpoints returns the list of endpoints created by this Agent
// from calls to either [*Agent.Listen] or [*Agent.Forward].
func (a *Agent) Endpoints() []Endpoint {
a.mu.RLock()
defer a.mu.RUnlock()

Expand All @@ -217,7 +195,7 @@ func (a *agent) Endpoints() []Endpoint {
}

// createListener creates an endpointListener for internal use
func (a *agent) createListener(ctx context.Context, endpointOpts *endpointOpts) (*endpointListener, error) {
func (a *Agent) createListener(ctx context.Context, endpointOpts *endpointOpts) (*EndpointListener, error) {
// Get the session
a.mu.RLock()
sess := a.sess
Expand Down Expand Up @@ -255,7 +233,7 @@ func (a *agent) createListener(ctx context.Context, endpointOpts *endpointOpts)
now := time.Now()

// Create endpoint listener
endpoint := &endpointListener{
endpoint := &EndpointListener{
baseEndpoint: baseEndpoint{
agent: a,
id: tunnel.ID(),
Expand Down Expand Up @@ -285,8 +263,9 @@ func (a *agent) createListener(ctx context.Context, endpointOpts *endpointOpts)
return endpoint, nil
}

// Listen creates an EndpointListener.
func (a *agent) Listen(ctx context.Context, opts ...EndpointOption) (EndpointListener, error) {
// Listen creates an endpoint which returns received connections to the caller
// via an [*EndpointListener].
func (a *Agent) Listen(ctx context.Context, opts ...EndpointOption) (*EndpointListener, error) {
// Apply all options
endpointOpts := defaultEndpointOpts()
for _, opt := range opts {
Expand All @@ -308,7 +287,7 @@ func (a *agent) Listen(ctx context.Context, opts ...EndpointOption) (EndpointLis
}

// ensureConnected handles automatic connection and verifies connection state
func (a *agent) ensureConnected(ctx context.Context) error {
func (a *Agent) ensureConnected(ctx context.Context) error {
// First check if we're already connected (with a read lock)
a.mu.RLock()
sessionExists := a.sess != nil
Expand All @@ -333,7 +312,7 @@ func (a *agent) ensureConnected(ctx context.Context) error {
}

// removeEndpoint removes an endpoint from the agent's list
func (a *agent) removeEndpoint(endpoint Endpoint) {
func (a *Agent) removeEndpoint(endpoint Endpoint) {
// Remove the endpoint from our list under lock
a.mu.Lock()
for i, e := range a.endpoints {
Expand All @@ -346,7 +325,7 @@ func (a *agent) removeEndpoint(endpoint Endpoint) {
}

// emitEvent sends an event to all registered handlers
func (a *agent) emitEvent(evt Event) {
func (a *Agent) emitEvent(evt Event) {
a.eventMutex.RLock()
handlers := make([]EventHandler, len(a.eventHandlers))
copy(handlers, a.eventHandlers)
Expand All @@ -361,7 +340,7 @@ func (a *agent) emitEvent(evt Event) {

// createCommandHandler returns a legacy.ServerCommandHandler that delegates to the RPCHandler
// for the specified RPC method.
func (a *agent) createCommandHandler(method string) legacy.ServerCommandHandler {
func (a *Agent) createCommandHandler(method string) legacy.ServerCommandHandler {
return func(ctx context.Context, sess legacy.Session) error {
if a.opts.rpcHandler == nil {
return nil
Expand Down Expand Up @@ -389,7 +368,7 @@ func (a *agent) createCommandHandler(method string) legacy.ServerCommandHandler
// Forward creates an EndpointForwarder that forwards traffic to the specified upstream.
// The upstream parameter is required and must be created using WithUpstream().
// Additional endpoint options can be provided to configure the endpoint.
func (a *agent) Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (EndpointForwarder, error) {
func (a *Agent) Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (*EndpointForwarder, error) {
// Apply all base options first
endpointOpts := defaultEndpointOpts()

Expand Down Expand Up @@ -432,7 +411,7 @@ func (a *agent) Forward(ctx context.Context, upstream *Upstream, opts ...Endpoin
upstreamURL, _ := url.Parse(endpointOpts.upstreamURL)

// Create the forwarder
endpoint := &endpointForwarder{
endpoint := &EndpointForwarder{
baseEndpoint: listener.baseEndpoint, // reuse the baseEndpoint from listener
listener: listener,
upstreamURL: *upstreamURL,
Expand Down
4 changes: 2 additions & 2 deletions defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ var DefaultAgent, _ = NewAgent(
)

// Listen is equivalent to DefaultAgent.Listen().
func Listen(ctx context.Context, opts ...EndpointOption) (EndpointListener, error) {
func Listen(ctx context.Context, opts ...EndpointOption) (*EndpointListener, error) {
return DefaultAgent.Listen(ctx, opts...)
}

// Forward is sugar for DefaultAgent.Forward().
func Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (EndpointForwarder, error) {
func Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (*EndpointForwarder, error) {
return DefaultAgent.Forward(ctx, upstream, opts...)
}
31 changes: 10 additions & 21 deletions diagnose.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,15 @@ func IsMuxadoDiagnoseFailure(err error) bool {
return errors.As(err, &de) && de.Step == "muxado"
}

// Diagnoser is implemented by Agent types that support pre-connection
// diagnostic probing. Use a type assertion to access it:
// Diagnose tests connectivity to addr by probing TCP, TLS, and the Muxado
// tunnel protocol. It uses the Agent's configured TLS settings, CA roots,
// and proxy/dialer settings.
//
// d, ok := agent.(ngrok.Diagnoser)
type Diagnoser interface {
Agent

// Diagnose tests connectivity to addr by probing TCP, TLS, and the Muxado
// tunnel protocol. It uses the Agent's configured TLS settings, CA roots,
// and proxy/dialer settings.
//
// If addr is empty, the configured server address is probed.
//
// This method does NOT establish a persistent session or call Auth. It is
// safe to call without affecting any existing connection.
Diagnose(ctx context.Context, addr string) (DiagnoseResult, error)
}

// Diagnose implements Diagnoser.
func (a *agent) Diagnose(ctx context.Context, addr string) (DiagnoseResult, error) {
// If addr is empty, the configured server address is probed.
//
// This method does NOT establish a persistent session or call Auth. It is
// safe to call without affecting any existing connection.
func (a *Agent) Diagnose(ctx context.Context, addr string) (DiagnoseResult, error) {
connectAddr := cmp.Or(a.opts.connectURL, "connect.ngrok-agent.com:443")
if addr == "" {
addr = connectAddr
Expand All @@ -110,7 +99,7 @@ func (a *agent) Diagnose(ctx context.Context, addr string) (DiagnoseResult, erro

// buildDiagnosticDialer returns the effective dialer for probes, applying
// proxy configuration without mutating agent state.
func (a *agent) buildDiagnosticDialer() (Dialer, error) {
func (a *Agent) buildDiagnosticDialer() (Dialer, error) {
baseDialer := cmp.Or(a.opts.dialer, Dialer(&net.Dialer{}))
if a.opts.proxyURL == "" {
return baseDialer, nil
Expand All @@ -132,7 +121,7 @@ func (a *agent) buildDiagnosticDialer() (Dialer, error) {

// probeAddr runs TCP → TLS → Muxado → SrvInfo for addr and returns a
// DiagnoseResult on success, or a *DiagnoseError indicating which step failed.
func (a *agent) probeAddr(ctx context.Context, logger *slog.Logger, dialer Dialer, serverName, addr string) (DiagnoseResult, error) {
func (a *Agent) probeAddr(ctx context.Context, logger *slog.Logger, dialer Dialer, serverName, addr string) (DiagnoseResult, error) {
result := DiagnoseResult{Addr: addr}

// TCP
Expand Down
18 changes: 4 additions & 14 deletions diagnose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ func TestDiagnoseTCPFailure(t *testing.T) {
a, err := NewAgent()
require.NoError(t, err)

d, ok := a.(Diagnoser)
require.True(t, ok, "agent should implement Diagnoser")

result, err := d.Diagnose(testcontext.ForTB(t), addr)
result, err := a.Diagnose(testcontext.ForTB(t), addr)
require.Error(t, err)
assert.True(t, IsTCPDiagnoseFailure(err))
assert.Equal(t, addr, result.Addr)
Expand All @@ -57,9 +54,7 @@ func TestDiagnoseTLSFailure(t *testing.T) {
a, err := NewAgent(WithAgentConnectURL(l.Addr().String()))
require.NoError(t, err)

d := a.(Diagnoser)

result, err := d.Diagnose(testcontext.ForTB(t), l.Addr().String())
result, err := a.Diagnose(testcontext.ForTB(t), l.Addr().String())
require.Error(t, err)
assert.True(t, IsTLSDiagnoseFailure(err))
assert.Equal(t, l.Addr().String(), result.Addr)
Expand Down Expand Up @@ -118,9 +113,7 @@ func TestDiagnoseMuxadoSuccess(t *testing.T) {
)
require.NoError(t, err)

d := a.(Diagnoser)

result, err := d.Diagnose(testcontext.ForTB(t), l.Addr().String())
result, err := a.Diagnose(testcontext.ForTB(t), l.Addr().String())
require.NoError(t, err)
assert.Equal(t, l.Addr().String(), result.Addr)
assert.Equal(t, testRegion, result.Region)
Expand Down Expand Up @@ -149,12 +142,9 @@ func TestDiagnoseOnline(t *testing.T) {
a, err := NewAgent(agentOpts...)
require.NoError(t, err)

d, ok := a.(Diagnoser)
require.True(t, ok)

ctx := testcontext.ForTB(t)

result, err := d.Diagnose(ctx, serverAddr)
result, err := a.Diagnose(ctx, serverAddr)
require.NoError(t, err)
t.Logf("addr=%s region=%s latency=%s", result.Addr, result.Region, result.Latency)
assert.NotEmpty(t, result.Region)
Expand Down
Loading
Loading