-
Notifications
You must be signed in to change notification settings - Fork 101
Expand file tree
/
Copy pathagent.go
More file actions
452 lines (379 loc) · 12.6 KB
/
agent.go
File metadata and controls
452 lines (379 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
package ngrok
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"slices"
"sync"
"time"
"golang.org/x/net/proxy"
"golang.ngrok.com/ngrok/v2/internal/legacy"
"golang.ngrok.com/ngrok/v2/internal/legacy/config"
"golang.ngrok.com/ngrok/v2/rpc"
)
// Agent is the main interface for interacting with the ngrok service.
type Agent interface {
// Connect begins a new Session by connecting and authenticating to the ngrok cloud service.
Connect(context.Context) error
// Disconnect terminates the current Session which disconnects it from the ngrok cloud service.
Disconnect() error
// Session returns an object describing the connection of the Agent to the ngrok cloud service.
Session() (AgentSession, error)
// Endpoints returns the list of endpoints created by this Agent from calls to either Listen or Forward.
Endpoints() []Endpoint
// Listen creates an Endpoint which returns received connections to the caller via an EndpointListener.
Listen(context.Context, ...EndpointOption) (EndpointListener, error)
// Forward creates an Endpoint which forwards received connections to a target upstream URL.
Forward(context.Context, *Upstream, ...EndpointOption) (EndpointForwarder, error)
}
// Dialer is an interface that is satisfied by net.Dialer or you can specify your
// own implementation.
type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// agent implements the Agent interface.
type agent struct {
mu sync.RWMutex
sess legacy.Session
agentSession *agentSession
opts *agentOpts
endpoints []Endpoint
// Event handlers registered with this agent
eventHandlers []EventHandler
eventMutex sync.RWMutex // Protects eventHandlers
}
// NewAgent creates a new Agent object.
func NewAgent(agentOpts ...AgentOption) (Agent, error) {
opts := defaultAgentOpts()
for _, opt := range agentOpts {
opt(opts)
}
return &agent{
opts: opts,
endpoints: make([]Endpoint, 0),
eventHandlers: opts.eventHandlers,
}, nil
}
// Connect begins a new Session by connecting and authenticating to the ngrok
// cloud service.
func (a *agent) Connect(ctx context.Context) error {
a.mu.Lock()
defer a.mu.Unlock()
// If we're already connected, return an error
if a.sess != nil && a.agentSession != nil {
return errors.New("agent already connected")
}
// Add legacy connect handlers for events
legacyOpts := append([]legacy.ConnectOption{}, a.opts.sessionOpts...)
// Process proxy URL if provided
if a.opts.proxyURL != "" {
parsedURL, err := url.Parse(a.opts.proxyURL)
if err != nil {
return fmt.Errorf("invalid proxy URL: %w", err)
}
// Determine the base dialer to use for connecting to the proxy
baseDialer := a.opts.dialer
if baseDialer == nil {
// If no custom dialer is provided, use a standard net.Dialer
baseDialer = &net.Dialer{}
}
// Create a proxy dialer using the base dialer
proxyDialer, err := proxy.FromURL(parsedURL, baseDialer)
if err != nil {
return fmt.Errorf("failed to initialize proxy: %w", err)
}
// We know FromURL returns a Dialer-compatible type
dialer, ok := proxyDialer.(Dialer)
if !ok {
return fmt.Errorf("proxy dialer is not compatible with ngrok Dialer interface")
}
// Set the dialer in our options
a.opts.dialer = dialer
// Pass it to the legacy package
legacyOpts = append(legacyOpts, legacy.WithDialer(dialer))
}
// Create our AgentSession wrapper early so we can capture it in closures
agentSession := &agentSession{
agent: a,
startedAt: time.Now(),
}
// Hook up connect event
legacyOpts = append(legacyOpts, legacy.WithConnectHandler(func(_ context.Context, sess legacy.Session) {
a.emitEvent(newAgentConnectSucceeded(a, agentSession))
}))
// Hook up disconnect event
legacyOpts = append(legacyOpts, legacy.WithDisconnectHandler(func(_ context.Context, sess legacy.Session, err error) {
a.emitEvent(newAgentDisconnected(a, agentSession, err))
}))
// Hook up heartbeat event
legacyOpts = append(legacyOpts, legacy.WithHeartbeatHandler(func(_ context.Context, sess legacy.Session, latency time.Duration) {
a.emitEvent(newAgentHeartbeatReceived(a, agentSession, latency))
}))
// If an RPC handler is registered, hook up the command handlers
if a.opts.rpcHandler != nil {
// Register the command handlers that delegate to the RPC handler
legacyOpts = append(legacyOpts,
legacy.WithStopHandler(a.createCommandHandler(rpc.StopAgentMethod)),
legacy.WithRestartHandler(a.createCommandHandler(rpc.RestartAgentMethod)),
legacy.WithUpdateHandler(a.createCommandHandler(rpc.UpdateAgentMethod)),
)
}
// Create a new ngrok session
sess, err := legacy.Connect(ctx, legacyOpts...)
if err != nil {
return wrapError(err)
}
// Complete the AgentSession wrapper with session-specific data
agentSession.id = sess.AgentSessionID()
agentSession.warnings = sess.Warnings()
// Store in agent
a.sess = sess
a.agentSession = agentSession
return nil
}
// Disconnect terminates the current Session which disconnects it from the ngrok
// cloud service.
func (a *agent) Disconnect() error {
// Get what we need under lock
a.mu.Lock()
sess := a.sess
endpoints := a.endpoints
a.sess = nil
a.agentSession = nil
a.endpoints = make([]Endpoint, 0)
a.mu.Unlock()
if sess == nil {
return nil
}
// Signal done for all endpoints (not holding the lock)
for _, endpoint := range endpoints {
// Only signal done, don't remove (already cleared the list)
if e, ok := endpoint.(interface{ signalDone() }); ok {
e.signalDone()
}
}
// Close session (not holding the lock)
err := sess.Close()
return wrapError(err)
}
// Session returns an object describing the connection of the Agent to the ngrok
// cloud service.
func (a *agent) Session() (AgentSession, error) {
a.mu.RLock()
defer a.mu.RUnlock()
if a.sess == nil || a.agentSession == nil {
return nil, errors.New("agent not connected")
}
return a.agentSession, nil
}
// Endpoints returns the list of endpoints created by this Agent.
func (a *agent) Endpoints() []Endpoint {
a.mu.RLock()
defer a.mu.RUnlock()
// Return a copy to avoid race conditions
return slices.Clone(a.endpoints)
}
// createListener creates an endpointListener for internal use
func (a *agent) createListener(ctx context.Context, endpointOpts *endpointOpts) (*endpointListener, error) {
// Get the session
a.mu.RLock()
sess := a.sess
a.mu.RUnlock()
// Determine URL scheme and configure endpoint
scheme, err := determineURLScheme(endpointOpts.url)
if err != nil {
return nil, err
}
tunnelConfig, err := configureEndpoint(scheme, endpointOpts)
if err != nil {
return nil, err
}
// Create tunnel and parse URL
tunnel, err := sess.Listen(ctx, tunnelConfig)
if err != nil {
return nil, wrapError(err)
}
tunnelURL, err := url.Parse(tunnel.URL())
if err != nil {
return nil, fmt.Errorf("failed to parse tunnel URL: %w", err)
}
// Validate upstream URL format if provided
if endpointOpts.upstreamURL != "" {
_, err = url.Parse(endpointOpts.upstreamURL)
if err != nil {
return nil, fmt.Errorf("invalid upstream URL: %w", err)
}
}
// Create endpoint listener
endpoint := &endpointListener{
baseEndpoint: baseEndpoint{
agent: a,
id: tunnel.ID(),
name: tunnel.Name(),
poolingEnabled: endpointOpts.poolingEnabled,
bindings: endpointOpts.bindings,
description: endpointOpts.description,
metadata: endpointOpts.metadata,
agentTLSConfig: endpointOpts.agentTLSConfig,
trafficPolicy: endpointOpts.trafficPolicy,
endpointURL: *tunnelURL,
doneChannel: make(chan struct{}),
doneOnce: &sync.Once{},
region: tunnel.Region(),
createdAt: tunnel.CreatedAt(),
updatedAt: tunnel.UpdatedAt(),
tunnelSessionID: tunnel.TunnelSessionID(),
tunnelID: tunnel.TunnelID(),
},
tunnel: tunnel,
}
// Add the endpoint to our list
a.mu.Lock()
a.endpoints = append(a.endpoints, endpoint)
a.mu.Unlock()
return endpoint, nil
}
// Listen creates an EndpointListener.
func (a *agent) Listen(ctx context.Context, opts ...EndpointOption) (EndpointListener, error) {
// Apply all options
endpointOpts := defaultEndpointOpts()
for _, opt := range opts {
opt(endpointOpts)
}
// Ensure we're connected
if err := a.ensureConnected(ctx); err != nil {
return nil, err
}
// Create the listener using the helper method
listener, err := a.createListener(ctx, endpointOpts)
if err != nil {
return nil, err
}
return listener, nil
}
// ensureConnected handles automatic connection and verifies connection state
func (a *agent) ensureConnected(ctx context.Context) error {
// First check if we're already connected (with a read lock)
a.mu.RLock()
sessionExists := a.sess != nil
a.mu.RUnlock()
// Only try to connect if needed and auto-connect is enabled
if !sessionExists && a.opts.autoConnect {
if err := a.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
}
// Final verification that we're connected
a.mu.RLock()
defer a.mu.RUnlock()
if a.sess == nil {
return errors.New("agent not connected, call Connect() first")
}
return nil
}
// removeEndpoint removes an endpoint from the agent's list
func (a *agent) removeEndpoint(endpoint Endpoint) {
// Remove the endpoint from our list under lock
a.mu.Lock()
for i, e := range a.endpoints {
if e == endpoint {
a.endpoints = append(a.endpoints[:i], a.endpoints[i+1:]...)
break
}
}
a.mu.Unlock()
}
// emitEvent sends an event to all registered handlers
func (a *agent) emitEvent(evt Event) {
a.eventMutex.RLock()
handlers := make([]EventHandler, len(a.eventHandlers))
copy(handlers, a.eventHandlers)
a.eventMutex.RUnlock()
for _, handler := range handlers {
// Call the handler directly
// Note: The handler is responsible for not blocking
handler(evt)
}
}
// createCommandHandler returns a legacy.ServerCommandHandler that delegates to the RPCHandler
// for the specified RPC method.
func (a *agent) createCommandHandler(method string) legacy.ServerCommandHandler {
return func(ctx context.Context, sess legacy.Session) error {
if a.opts.rpcHandler == nil {
return nil
}
// Get the current agent session
agentSession, err := a.Session()
if err != nil {
return err
}
// Create request object with the specified method
req := &rpcRequest{
method: method,
payload: nil, // No payload for now
}
// Call the RPC handler
_, err = a.opts.rpcHandler(ctx, agentSession, req)
// Ignore response payload for now
return err
}
}
// Forward creates an EndpointForwarder that forwards traffic to the specified upstream.
// The upstream parameter is required and must be created using WithUpstream().
// Additional endpoint options can be provided to configure the endpoint.
func (a *agent) Forward(ctx context.Context, upstream *Upstream, opts ...EndpointOption) (EndpointForwarder, error) {
// Apply all base options first
endpointOpts := defaultEndpointOpts()
// Set upstream values directly from the Upstream object
endpointOpts.upstreamURL = upstream.addr
endpointOpts.upstreamProtocol = upstream.protocol
endpointOpts.upstreamTLSClientConfig = upstream.tlsClientConfig
// Convert the proxy protocol to config.ProxyProtoVersion
if upstream.proxyProto != "" {
var proxyVersion config.ProxyProtoVersion
switch upstream.proxyProto {
case ProxyProtoV1:
proxyVersion = config.ProxyProtoV1
case ProxyProtoV2:
proxyVersion = config.ProxyProtoV2
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", upstream.proxyProto)
}
endpointOpts.proxyProtoVersion = proxyVersion
}
// Apply additional options
for _, opt := range opts {
opt(endpointOpts)
}
// Ensure we're connected
if err := a.ensureConnected(ctx); err != nil {
return nil, err
}
// Create the listener using the helper method
listener, err := a.createListener(ctx, endpointOpts)
if err != nil {
return nil, err
}
// Parse upstream URL - we know it exists and is valid from createListener
upstreamURL, _ := url.Parse(endpointOpts.upstreamURL)
// Create the forwarder
endpoint := &endpointForwarder{
baseEndpoint: listener.baseEndpoint, // reuse the baseEndpoint from listener
listener: listener,
upstreamURL: *upstreamURL,
upstreamProtocol: endpointOpts.upstreamProtocol,
upstreamTLSClientConfig: endpointOpts.upstreamTLSClientConfig,
proxyProtocol: upstream.proxyProto,
upstreamDialer: upstream.dialer,
}
// Start the forwarding process
endpoint.start(ctx)
// Add the endpoint to our list
a.mu.Lock()
a.endpoints = append(a.endpoints, endpoint)
a.mu.Unlock()
return endpoint, nil
}