Skip to content

Commit c91e81d

Browse files
Fix workspace connection tracker and heartbeats
1 parent 05469d0 commit c91e81d

File tree

6 files changed

+69
-35
lines changed

6 files changed

+69
-35
lines changed
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,36 @@
11
package network
22

3-
import "sync"
3+
import (
4+
"sync"
5+
6+
"github.com/loft-sh/log"
7+
)
48

59
// ConnTracker is a simple connection counter used by several services.
610
type ConnTracker struct {
711
mu sync.Mutex
812
count int
13+
14+
logger log.Logger
915
}
1016

11-
func (c *ConnTracker) Add() {
17+
func (c *ConnTracker) Add(serviceName string) {
1218
c.mu.Lock()
1319
defer c.mu.Unlock()
1420
c.count++
21+
c.logger.Debugf("%s: Added new connection, connection count %d\n", serviceName, c.count)
1522
}
1623

17-
func (c *ConnTracker) Remove() {
24+
func (c *ConnTracker) Remove(serviceName string) {
1825
c.mu.Lock()
1926
defer c.mu.Unlock()
2027
c.count--
28+
c.logger.Debugf("%s: Removed connection, connection count %d\n", serviceName, c.count)
2129
}
2230

23-
func (c *ConnTracker) Count() int {
31+
func (c *ConnTracker) Count(serviceName string) int {
2432
c.mu.Lock()
2533
defer c.mu.Unlock()
34+
c.logger.Debugf("%s: Get connection count %d\n", serviceName, c.count)
2635
return c.count
2736
}

pkg/daemon/workspace/network/heartbeat.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,24 @@ func NewHeartbeatService(config *WorkspaceServerConfig, tsServer *tsnet.Server,
3737

3838
// Start begins the heartbeat loop.
3939
func (s *HeartbeatService) Start(ctx context.Context) {
40+
s.log.Info("HeartbeatService: Start")
4041
transport := &http.Transport{DialContext: s.tsServer.Dial}
4142
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
4243
ticker := time.NewTicker(10 * time.Second)
4344
defer ticker.Stop()
4445
for {
4546
select {
4647
case <-ctx.Done():
48+
s.log.Info("HeartbeatService: Exit")
4749
return
4850
case <-ticker.C:
49-
if s.tracker.Count() > 0 {
51+
s.log.Debugf("HeartbeatService: checking connection count")
52+
if s.tracker.Count("HeartbeatService") > 0 {
5053
if err := s.sendHeartbeat(ctx, client); err != nil {
5154
s.log.Errorf("HeartbeatService: failed to send heartbeat: %v", err)
5255
}
56+
} else {
57+
s.log.Debugf("HeartbeatService: No active connections, skipping heartbeat.") // Added for clarity
5358
}
5459
}
5560
}
@@ -58,25 +63,33 @@ func (s *HeartbeatService) Start(ctx context.Context) {
5863
func (s *HeartbeatService) sendHeartbeat(ctx context.Context, client *http.Client) error {
5964
hbCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
6065
defer cancel()
66+
6167
discoveredRunner, err := discoverRunner(hbCtx, s.lc, s.log)
6268
if err != nil {
69+
s.log.Errorf("HeartbeatService: failed to discover runner: %v", err)
6370
return fmt.Errorf("failed to discover runner: %w", err)
6471
}
72+
6573
heartbeatURL := fmt.Sprintf("http://%s.ts.loft/devpod/%s/%s/heartbeat", discoveredRunner, s.projectName, s.workspaceName)
66-
s.log.Infof("HeartbeatService: sending heartbeat to %s, active connections: %d", heartbeatURL, s.tracker.Count())
74+
s.log.Infof("HeartbeatService: sending heartbeat to %s, active connections: %d", heartbeatURL, s.tracker.Count("HeartbeatService"))
6775
req, err := http.NewRequestWithContext(hbCtx, "GET", heartbeatURL, nil)
6876
if err != nil {
77+
s.log.Errorf("HeartbeatService: failed to create request for %s: %v", heartbeatURL, err)
6978
return fmt.Errorf("failed to create request for %s: %w", heartbeatURL, err)
7079
}
7180
req.Header.Set("Authorization", "Bearer "+s.config.AccessKey)
7281
resp, err := client.Do(req)
7382
if err != nil {
83+
s.log.Errorf("HeartbeatService: request to %s failed: %v", heartbeatURL, err)
7484
return fmt.Errorf("request to %s failed: %w", heartbeatURL, err)
7585
}
7686
defer resp.Body.Close()
87+
7788
if resp.StatusCode != http.StatusOK {
89+
s.log.Errorf("HeartbeatService: received non-OK response from %s - Status: %d", heartbeatURL, resp.StatusCode)
7890
return fmt.Errorf("received response from %s - Status: %d", heartbeatURL, resp.StatusCode)
7991
}
80-
s.log.Infof("HeartbeatService: received response from %s - Status: %d", heartbeatURL, resp.StatusCode)
92+
93+
s.log.Debugf("HeartbeatService: received response from %s - Status: %d", heartbeatURL, resp.StatusCode)
8194
return nil
8295
}

pkg/daemon/workspace/network/netmap.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,19 @@ func NewNetmapWatcherService(rootDir string, lc *tailscale.LocalClient, log log.
3131

3232
// Start begins watching the netmap.
3333
func (s *NetmapWatcherService) Start(ctx context.Context) {
34-
go func() {
35-
lastUpdate := time.Now()
36-
if err := ts.WatchNetmap(ctx, s.lc, func(netMap *netmap.NetworkMap) {
37-
if time.Since(lastUpdate) < netMapCooldown {
38-
return
39-
}
40-
lastUpdate = time.Now()
41-
nm, err := json.Marshal(netMap)
42-
if err != nil {
43-
s.log.Errorf("NetmapWatcherService: failed to marshal netmap: %v", err)
44-
} else {
45-
_ = os.WriteFile(filepath.Join(s.rootDir, "netmap.json"), nm, 0644)
46-
}
47-
}); err != nil {
48-
s.log.Errorf("NetmapWatcherService: failed to watch netmap: %v", err)
34+
lastUpdate := time.Now()
35+
if err := ts.WatchNetmap(ctx, s.lc, func(netMap *netmap.NetworkMap) {
36+
if time.Since(lastUpdate) < netMapCooldown {
37+
return
4938
}
50-
}()
39+
lastUpdate = time.Now()
40+
nm, err := json.Marshal(netMap)
41+
if err != nil {
42+
s.log.Errorf("NetmapWatcherService: failed to marshal netmap: %v", err)
43+
} else {
44+
_ = os.WriteFile(filepath.Join(s.rootDir, "netmap.json"), nm, 0644)
45+
}
46+
}); err != nil {
47+
s.log.Errorf("NetmapWatcherService: failed to watch netmap: %v", err)
48+
}
5149
}

pkg/daemon/workspace/network/port_forward.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ func (s *HTTPPortForwardService) Start(ctx context.Context) {
5858
}
5959

6060
func (s *HTTPPortForwardService) portForwardHandler(w http.ResponseWriter, r *http.Request) {
61-
s.tracker.Add()
62-
defer s.tracker.Remove()
61+
s.tracker.Add("PortForward")
62+
defer s.tracker.Remove("PortForward")
6363
s.log.Debugf("HTTPPortForwardService: received request")
6464

6565
targetPort := r.Header.Get("X-Loft-Forward-Port")

pkg/daemon/workspace/network/server.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ type WorkspaceServer struct {
4646
// NewWorkspaceServer creates a new WorkspaceServer.
4747
func NewWorkspaceServer(config *WorkspaceServerConfig, logger log.Logger) *WorkspaceServer {
4848
return &WorkspaceServer{
49-
config: config,
50-
log: logger,
51-
connTracker: &ConnTracker{},
49+
config: config,
50+
log: logger,
51+
connTracker: &ConnTracker{
52+
logger: logger,
53+
},
5254
}
5355
}
5456

@@ -72,7 +74,6 @@ func (s *WorkspaceServer) Start(ctx context.Context) error {
7274
}
7375
s.sshSvc.Start(ctx)
7476

75-
// Create and start the HTTP port forward service.
7677
s.httpProxySvc, err = NewHTTPPortForwardService(s.network, s.connTracker, s.log)
7778
if err != nil {
7879
return err
@@ -92,15 +93,15 @@ func (s *WorkspaceServer) Start(ctx context.Context) error {
9293
if err != nil {
9394
return err
9495
}
95-
s.netProxySvc.Start(ctx)
96+
go s.netProxySvc.Start(ctx)
9697

9798
// Start the heartbeat service.
9899
s.heartbeatSvc = NewHeartbeatService(s.config, s.network, lc, projectName, workspaceName, s.connTracker, s.log)
99100
go s.heartbeatSvc.Start(ctx)
100101

101102
// Start netmap watcher.
102103
s.netmapWatcher = NewNetmapWatcherService(s.config.RootDir, lc, s.log)
103-
s.netmapWatcher.Start(ctx)
104+
go s.netmapWatcher.Start(ctx)
104105

105106
// Wait until the context is canceled.
106107
<-ctx.Done()

pkg/daemon/workspace/network/ssh.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ func (s *SSHService) acceptLoop(ctx context.Context) {
5959
}
6060

6161
func (s *SSHService) handleConnection(conn net.Conn) {
62-
s.tracker.Add()
63-
defer s.tracker.Remove()
62+
s.tracker.Add("SSHService")
63+
defer s.tracker.Remove("SSHService")
6464
defer conn.Close()
6565

6666
localAddr := fmt.Sprintf("127.0.0.1:%d", sshServer.DefaultUserPort)
@@ -71,10 +71,23 @@ func (s *SSHService) handleConnection(conn net.Conn) {
7171
}
7272
defer backendConn.Close()
7373

74+
// We need to wait for copying to finish before the function returns and Remove is called.
75+
errChan := make(chan error, 2)
76+
7477
go func() {
75-
_, _ = io.Copy(backendConn, conn)
78+
_, err := io.Copy(backendConn, conn)
79+
errChan <- err
7680
}()
77-
_, _ = io.Copy(conn, backendConn)
81+
82+
go func() {
83+
_, err := io.Copy(conn, backendConn)
84+
errChan <- err
85+
}()
86+
87+
// Wait for one side of the connection to close or error
88+
<-errChan
89+
// Optionally wait for the second one too, or just proceed to cleanup
90+
// <-errChan
7891
}
7992

8093
// Stop stops the SSH server by closing its listener.

0 commit comments

Comments
 (0)