Skip to content

Commit a66b186

Browse files
Close heartbeat connection on write error (#123)
* Close heartbeat connection on write error * Decrease metric * Review comment
1 parent 9429683 commit a66b186

File tree

3 files changed

+98
-15
lines changed

3 files changed

+98
-15
lines changed

handler/heartbeat.go

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ import (
1414

1515
var readDeadline = static.WebsocketReadDeadline
1616

17+
type conn interface {
18+
ReadMessage() (int, []byte, error)
19+
SetReadDeadline(time.Time) error
20+
Close() error
21+
}
22+
1723
// Heartbeat implements /v2/heartbeat requests.
1824
// It starts a new persistent connection and a new goroutine
1925
// to read incoming messages.
@@ -34,7 +40,7 @@ func (c *Client) Heartbeat(rw http.ResponseWriter, req *http.Request) {
3440
}
3541

3642
// handleHeartbeats handles incoming messages from the connection.
37-
func (c *Client) handleHeartbeats(ws *websocket.Conn) {
43+
func (c *Client) handleHeartbeats(ws conn) error {
3844
defer ws.Close()
3945
setReadDeadline(ws)
4046

@@ -43,11 +49,8 @@ func (c *Client) handleHeartbeats(ws *websocket.Conn) {
4349
for {
4450
_, message, err := ws.ReadMessage()
4551
if err != nil {
46-
log.Errorf("read error: %v", err)
47-
if experiment != "" {
48-
metrics.CurrentHeartbeatConnections.WithLabelValues(experiment).Dec()
49-
}
50-
return
52+
closeConnection(experiment, err)
53+
return err
5154
}
5255
if message != nil {
5356
setReadDeadline(ws)
@@ -60,19 +63,32 @@ func (c *Client) handleHeartbeats(ws *websocket.Conn) {
6063

6164
switch {
6265
case hbm.Registration != nil:
66+
if err := c.RegisterInstance(*hbm.Registration); err != nil {
67+
closeConnection(experiment, err)
68+
return err
69+
}
6370
hostname = hbm.Registration.Hostname
64-
c.RegisterInstance(*hbm.Registration)
6571
experiment = hbm.Registration.Experiment
6672
metrics.CurrentHeartbeatConnections.WithLabelValues(experiment).Inc()
6773
case hbm.Health != nil:
68-
c.UpdateHealth(hostname, *hbm.Health)
74+
if err := c.UpdateHealth(hostname, *hbm.Health); err != nil {
75+
closeConnection(experiment, err)
76+
return err
77+
}
6978
}
7079
}
7180
}
7281
}
7382

7483
// setReadDeadline sets/resets the read deadline for the connection.
75-
func setReadDeadline(ws *websocket.Conn) {
84+
func setReadDeadline(ws conn) {
7685
deadline := time.Now().Add(readDeadline)
7786
ws.SetReadDeadline(deadline)
7887
}
88+
89+
func closeConnection(experiment string, err error) {
90+
if experiment != "" {
91+
metrics.CurrentHeartbeatConnections.WithLabelValues(experiment).Dec()
92+
}
93+
log.Errorf("closing connection, err: %v", err)
94+
}

handler/heartbeat_test.go

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
package handler
22

33
import (
4+
"encoding/json"
5+
"errors"
46
"net/http"
57
"net/http/httptest"
68
"testing"
9+
"time"
710

811
"github.com/m-lab/locate/clientgeo"
12+
"github.com/m-lab/locate/connection/testdata"
13+
"github.com/m-lab/locate/heartbeat"
14+
"github.com/m-lab/locate/heartbeat/heartbeattest"
915
prom "github.com/prometheus/client_golang/api/prometheus/v1"
1016
)
1117

@@ -14,15 +20,76 @@ func TestClient_Heartbeat_Error(t *testing.T) {
1420
// The header from this request will not contain the
1521
// necessary "upgrade" tokens.
1622
req := httptest.NewRequest(http.MethodGet, "/v2/heartbeat", nil)
17-
c := fakeClient()
23+
c := fakeClient(nil)
1824
c.Heartbeat(rw, req)
1925

2026
if rw.Code != http.StatusBadRequest {
2127
t.Errorf("Heartbeat() wrong status code; got %d, want %d", rw.Code, http.StatusBadRequest)
2228
}
2329
}
2430

25-
func fakeClient() *Client {
26-
return NewClient("mlab-sandbox", &fakeSigner{}, &fakeLocator{}, &fakeLocatorV2{},
31+
func TestClient_handleHeartbeats(t *testing.T) {
32+
wantErr := errors.New("connection error")
33+
tests := []struct {
34+
name string
35+
ws conn
36+
tracker heartbeat.StatusTracker
37+
}{
38+
{
39+
name: "read-err",
40+
ws: &fakeConn{
41+
err: wantErr,
42+
},
43+
},
44+
{
45+
name: "registration-err",
46+
ws: &fakeConn{
47+
msg: testdata.FakeRegistration,
48+
},
49+
tracker: &heartbeattest.FakeStatusTracker{Err: wantErr},
50+
},
51+
{
52+
name: "health-err",
53+
ws: &fakeConn{
54+
msg: testdata.FakeHealth,
55+
},
56+
tracker: &heartbeattest.FakeStatusTracker{Err: wantErr},
57+
},
58+
}
59+
for _, tt := range tests {
60+
t.Run(tt.name, func(t *testing.T) {
61+
c := fakeClient(tt.tracker)
62+
err := c.handleHeartbeats(tt.ws)
63+
if !errors.Is(err, wantErr) {
64+
t.Errorf("Client.handleHeartbeats() error = %v, wantErr %v", err, wantErr)
65+
}
66+
})
67+
}
68+
}
69+
70+
func fakeClient(t heartbeat.StatusTracker) *Client {
71+
locatorv2 := fakeLocatorV2{StatusTracker: t}
72+
return NewClient("mlab-sandbox", &fakeSigner{}, &fakeLocator{}, &locatorv2,
2773
clientgeo.NewAppEngineLocator(), prom.NewAPI(nil))
2874
}
75+
76+
type fakeConn struct {
77+
msg any
78+
err error
79+
}
80+
81+
// ReadMessage returns 0, the JSON encoding of a fake message, and an error.
82+
func (c *fakeConn) ReadMessage() (int, []byte, error) {
83+
jsonMsg, _ := json.Marshal(c.msg)
84+
return 0, jsonMsg, c.err
85+
}
86+
87+
// SetReadDeadline returns nil.
88+
func (c *fakeConn) SetReadDeadline(time.Time) error {
89+
return nil
90+
}
91+
92+
// Close returns nil.
93+
func (c *fakeConn) Close() error {
94+
return nil
95+
}

heartbeat/heartbeat.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func NewHeartbeatStatusTracker(client MemorystoreClient[v2.HeartbeatMessage]) *h
6767
func (h *heartbeatStatusTracker) RegisterInstance(rm v2.Registration) error {
6868
hostname := rm.Hostname
6969
if err := h.Put(hostname, "Registration", &rm, true); err != nil {
70-
return err
70+
return fmt.Errorf("%w: failed to write Registration message to Memorystore", err)
7171
}
7272

7373
h.registerInstance(hostname, rm)
@@ -78,7 +78,7 @@ func (h *heartbeatStatusTracker) RegisterInstance(rm v2.Registration) error {
7878
// updates it locally.
7979
func (h *heartbeatStatusTracker) UpdateHealth(hostname string, hm v2.Health) error {
8080
if err := h.Put(hostname, "Health", &hm, true); err != nil {
81-
return err
81+
return fmt.Errorf("%w: failed to write Health message to Memorystore", err)
8282
}
8383
return h.updateHealth(hostname, hm)
8484
}
@@ -158,7 +158,7 @@ func (h *heartbeatStatusTracker) updatePrometheusMessage(instance v2.HeartbeatMe
158158
// Update in Memorystore.
159159
err := h.Put(hostname, "Prometheus", pm, false)
160160
if err != nil {
161-
return err
161+
return fmt.Errorf("%w: failed to write Prometheus message to Memorystore", err)
162162
}
163163

164164
// Update locally.

0 commit comments

Comments
 (0)