Skip to content

Commit 32c9722

Browse files
feat(auth): Add JWT authentication for heartbeat endpoints (#228)
* Add JWT-protected heartbeat endpoint with organization validation * Fix JWT audience to match token-exchange service * Fix JWT claims extraction for ESP v1 (GAE Flex) * Use portable JWT parsing with go-jose library * Update heartbeat client to use JWT authentication * feat(auth): add JWT token refresh for heartbeat connections Implements automatic JWT token renewal before expiry to handle GAE's 1-hour connection limits. Includes comprehensive test coverage. * update comment * feat(auth): update handler to read ESPv1's custom header * add debug logging * fix(auth): parse claims as a JSON string * use base64.StdEncoding * restore expiry parsing from JWT * Make tests table-driven
1 parent 981da68 commit 32c9722

File tree

8 files changed

+603
-21
lines changed

8 files changed

+603
-21
lines changed

cmd/heartbeat/main.go

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package main
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/json"
57
"flag"
68
"fmt"
79
"log"
@@ -29,7 +31,7 @@ import (
2931

3032
var (
3133
heartbeatURL string
32-
hostname flagx.StringFile
34+
hostname flagx.StringFile
3335
experiment string
3436
pod string
3537
node string
@@ -41,13 +43,60 @@ var (
4143
heartbeatPeriod = static.HeartbeatPeriod
4244
mainCtx, mainCancel = context.WithCancel(context.Background())
4345
lbPath = "/metadata/loadbalanced"
46+
47+
// JWT authentication parameters
48+
apiKey string
49+
tokenExchangeURL string
4450
)
4551

4652
// Checker generates a health score for the heartbeat instance (0, 1).
4753
type Checker interface {
4854
GetHealth(ctx context.Context) float64 // Health score.
4955
}
5056

57+
// TokenResponse represents the response from the token exchange service
58+
type TokenResponse struct {
59+
Token string `json:"token"`
60+
}
61+
62+
// getJWTTokenFunc is a variable that holds the JWT token function, allowing for test overrides
63+
var getJWTTokenFunc = getJWTToken
64+
65+
// getJWTToken exchanges an API key for a JWT token
66+
func getJWTToken(apiKey, tokenExchangeURL string) (string, error) {
67+
// Prepare the request payload
68+
payload := map[string]string{
69+
"api_key": apiKey,
70+
}
71+
payloadBytes, err := json.Marshal(payload)
72+
if err != nil {
73+
return "", fmt.Errorf("failed to marshal token request: %w", err)
74+
}
75+
76+
// Make the HTTP request to the token exchange service
77+
resp, err := http.Post(tokenExchangeURL, "application/json", bytes.NewBuffer(payloadBytes))
78+
if err != nil {
79+
return "", fmt.Errorf("failed to request JWT token: %w", err)
80+
}
81+
defer resp.Body.Close()
82+
83+
if resp.StatusCode != http.StatusOK {
84+
return "", fmt.Errorf("token exchange failed with status %d", resp.StatusCode)
85+
}
86+
87+
// Parse the response
88+
var tokenResp TokenResponse
89+
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
90+
return "", fmt.Errorf("failed to parse token response: %w", err)
91+
}
92+
93+
if tokenResp.Token == "" {
94+
return "", fmt.Errorf("empty token received from exchange service")
95+
}
96+
97+
return tokenResp.Token, nil
98+
}
99+
51100
func init() {
52101
flag.StringVar(&heartbeatURL, "heartbeat-url", "ws://localhost:8080/v2/platform/heartbeat",
53102
"URL for locate service")
@@ -59,12 +108,23 @@ func init() {
59108
flag.Var(&kubernetesURL, "kubernetes-url", "URL for Kubernetes API")
60109
flag.Var(&registrationURL, "registration-url", "URL for site registration")
61110
flag.Var(&services, "services", "Maps experiment target names to their set of services")
111+
flag.StringVar(&apiKey, "api-key", "", "API key for JWT token exchange (required)")
112+
flag.StringVar(&tokenExchangeURL, "token-exchange-url", "https://auth.mlab-sandbox.measurementlab.net/v0/token/autojoin",
113+
"URL for token exchange service")
62114
}
63115

64116
func main() {
65117
flag.Parse()
66118
rtx.Must(flagx.ArgsFromEnvWithLog(flag.CommandLine, false), "failed to read args from env")
67119

120+
// Validate JWT authentication parameters
121+
if apiKey == "" {
122+
log.Fatal("API key is required for JWT authentication (-api-key flag)")
123+
}
124+
if tokenExchangeURL == "" {
125+
log.Fatal("Token exchange URL is required (-token-exchange-url flag)")
126+
}
127+
68128
// Start metrics server.
69129
prom := prometheusx.MustServeMetrics()
70130
defer prom.Close()
@@ -82,9 +142,26 @@ func main() {
82142
rtx.Must(err, "could not load registration data")
83143
hbm := v2.HeartbeatMessage{Registration: r}
84144

85-
// Establish a connection.
145+
// Get JWT token for authentication
146+
log.Printf("Exchanging API key for JWT token...")
147+
jwtToken, err := getJWTTokenFunc(apiKey, tokenExchangeURL)
148+
rtx.Must(err, "failed to get JWT token")
149+
log.Printf("Successfully obtained JWT token")
150+
151+
// Prepare headers with JWT authentication
152+
headers := http.Header{}
153+
headers.Set("Authorization", "Bearer "+jwtToken)
154+
155+
// Establish a connection with JWT authentication.
86156
conn := connection.NewConn()
87-
err = conn.Dial(heartbeatURL, http.Header{}, hbm)
157+
158+
// Set up JWT token refresh for automatic token renewal
159+
conn.SetTokenRefresher(func() (string, error) {
160+
log.Printf("Refreshing JWT token...")
161+
return getJWTTokenFunc(apiKey, tokenExchangeURL)
162+
})
163+
164+
err = conn.Dial(heartbeatURL, headers, hbm)
88165
rtx.Must(err, "failed to establish a websocket connection with %s", heartbeatURL)
89166

90167
probe := health.NewPortProbe(svcs)

cmd/heartbeat/main_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ import (
1717
)
1818

1919
func Test_main(t *testing.T) {
20+
// Override the JWT token function for testing
21+
getJWTTokenFunc = func(apiKey, tokenExchangeURL string) (string, error) {
22+
return "fake-jwt-token", nil
23+
}
24+
defer func() {
25+
getJWTTokenFunc = getJWTToken // restore original function
26+
}()
27+
2028
mainCtx, mainCancel = context.WithCancel(context.Background())
2129
fh := testdata.FakeHandler{}
2230
s := testdata.FakeServer(fh.Upgrade)
@@ -36,6 +44,8 @@ func Test_main(t *testing.T) {
3644
flag.Set("namespace", "default")
3745
flag.Set("registration-url", "file:./registration/testdata/registration.json")
3846
flag.Set("services", "ndt/ndt7=ws://:"+u.Port()+"/ndt/v7/download")
47+
flag.Set("api-key", "test-api-key")
48+
flag.Set("token-exchange-url", "http://fake-token-exchange.example.com/token")
3949

4050
heartbeatPeriod = 2 * time.Second
4151
timer := time.NewTimer(2 * heartbeatPeriod)

connection/connection.go

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ import (
77
"log"
88
"net/http"
99
"net/url"
10-
"sync"
10+
"strings"
1111
"time"
1212

1313
"github.com/cenkalti/backoff/v4"
1414
"github.com/gorilla/websocket"
1515
"github.com/m-lab/locate/metrics"
1616
"github.com/m-lab/locate/static"
17+
"gopkg.in/square/go-jose.v2/jwt"
1718
)
1819

1920
var (
@@ -23,8 +24,14 @@ var (
2324
// retryErrors contains the list of errors that may become successful
2425
// if the request is retried.
2526
retryErrors = map[int]bool{408: true, 425: true, 500: true, 502: true, 503: true, 504: true}
27+
28+
// JWT refresh buffer: refresh token if it expires within this duration
29+
jwtRefreshBuffer = 5 * time.Minute
2630
)
2731

32+
// TokenRefresher is a function type that can refresh JWT tokens.
33+
type TokenRefresher func() (string, error)
34+
2835
// Conn contains the state needed to connect, reconnect, and send
2936
// messages.
3037
// Default values must be updated before calling `Dial`.
@@ -52,9 +59,11 @@ type Conn struct {
5259
url url.URL
5360
header http.Header
5461
ticker time.Ticker
55-
mu sync.Mutex
5662
isDialed bool
5763
isConnected bool
64+
65+
// JWT token refresh functionality
66+
tokenRefresher TokenRefresher
5867
}
5968

6069
// NewConn creates a new Conn with default values.
@@ -69,6 +78,11 @@ func NewConn() *Conn {
6978
return c
7079
}
7180

81+
// SetTokenRefresher sets the function used to refresh JWT tokens when they expire.
82+
func (c *Conn) SetTokenRefresher(refresher TokenRefresher) {
83+
c.tokenRefresher = refresher
84+
}
85+
7286
// Dial creates a new persistent client connection and sets
7387
// the necessary state for future reconnections. It also
7488
// starts a goroutine to reset the number of reconnections.
@@ -167,6 +181,12 @@ func (c *Conn) close() error {
167181
// In case of failure, it uses an exponential backoff to
168182
// increase the duration of retry attempts.
169183
func (c *Conn) connect() error {
184+
// Check if JWT token needs refreshing before attempting connection
185+
if err := c.refreshJWTIfNeeded(); err != nil {
186+
log.Printf("failed to refresh JWT token: %v", err)
187+
// Continue with existing token - it might still work
188+
}
189+
170190
b := c.getBackoff()
171191
ticker := backoff.NewTicker(b)
172192

@@ -234,3 +254,67 @@ func (c *Conn) getBackoff() *backoff.ExponentialBackOff {
234254
b.MaxElapsedTime = c.MaxElapsedTime
235255
return b
236256
}
257+
258+
// parseJWTExpiry extracts the expiry time from a JWT token without verification
259+
func (c *Conn) parseJWTExpiry(token string) time.Time {
260+
parsed, err := jwt.ParseSigned(token)
261+
if err != nil {
262+
log.Printf("failed to parse JWT for expiry check: %v", err)
263+
return time.Time{} // Return zero time on error
264+
}
265+
266+
var claims jwt.Claims
267+
if err := parsed.UnsafeClaimsWithoutVerification(&claims); err != nil {
268+
log.Printf("failed to extract JWT claims for expiry check: %v", err)
269+
return time.Time{} // Return zero time on error
270+
}
271+
272+
if claims.Expiry == nil {
273+
return time.Time{} // Return zero time if no expiry
274+
}
275+
276+
return claims.Expiry.Time()
277+
}
278+
279+
// SetToken sets a JWT token
280+
func (c *Conn) SetToken(token string) {
281+
c.header.Set("Authorization", "Bearer "+token)
282+
}
283+
284+
func (c *Conn) refreshJWTIfNeeded() error {
285+
// Only refresh if we have a token refresher
286+
if c.tokenRefresher == nil {
287+
return nil
288+
}
289+
290+
// Extract current token from headers
291+
authHeader := c.header.Get("Authorization")
292+
if !strings.HasPrefix(authHeader, "Bearer ") {
293+
return nil // No token present
294+
}
295+
296+
currentToken := strings.TrimPrefix(authHeader, "Bearer ")
297+
298+
// Parse expiry on-demand
299+
expiry := c.parseJWTExpiry(currentToken)
300+
if expiry.IsZero() {
301+
return nil // Token has no expiry or couldn't parse
302+
}
303+
304+
// Check if refresh needed
305+
if time.Until(expiry) > jwtRefreshBuffer {
306+
return nil // Still fresh
307+
}
308+
309+
// Token is expired or close to expiring, refresh it
310+
log.Printf("JWT token expires soon, refreshing...")
311+
newToken, err := c.tokenRefresher()
312+
if err != nil {
313+
return err
314+
}
315+
316+
c.header.Set("Authorization", "Bearer "+newToken)
317+
log.Printf("JWT token refreshed successfully")
318+
319+
return nil
320+
}

connection/connection_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,79 @@ func TestWriteMessage_ClosedServer(t *testing.T) {
200200
}
201201
}
202202

203+
func TestConn_RefreshJWTIfNeeded(t *testing.T) {
204+
tests := []struct {
205+
name string
206+
token string
207+
hasRefresher bool
208+
refreshedToken string
209+
wantRefreshCount int
210+
wantHeaderToken string
211+
wantErr bool
212+
}{
213+
{
214+
name: "expired_token_with_refresher",
215+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDA5MzUzMDB9.fake-signature", // Expired in 2020
216+
hasRefresher: true,
217+
refreshedToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjk5OTk5OTk5OTl9.fake-signature",
218+
wantRefreshCount: 1,
219+
wantHeaderToken: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjk5OTk5OTk5OTl9.fake-signature",
220+
},
221+
{
222+
name: "expired_token_no_refresher",
223+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDA5MzUzMDB9.fake-signature", // Expired in 2020
224+
hasRefresher: false,
225+
wantRefreshCount: 0,
226+
wantHeaderToken: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDA5MzUzMDB9.fake-signature",
227+
},
228+
{
229+
name: "valid_token_with_refresher",
230+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjk5OTk5OTk5OTl9.fake-signature", // Expires far in future
231+
hasRefresher: true,
232+
refreshedToken: "new-token",
233+
wantRefreshCount: 0,
234+
wantHeaderToken: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjk5OTk5OTk5OTl9.fake-signature",
235+
},
236+
}
237+
238+
for _, tt := range tests {
239+
t.Run(tt.name, func(t *testing.T) {
240+
c := NewConn()
241+
242+
refreshCount := 0
243+
if tt.hasRefresher {
244+
c.SetTokenRefresher(func() (string, error) {
245+
refreshCount++
246+
return tt.refreshedToken, nil
247+
})
248+
}
249+
250+
headers := http.Header{}
251+
headers.Set("Authorization", "Bearer "+tt.token)
252+
c.header = headers
253+
254+
fh := testdata.FakeHandler{}
255+
s := testdata.FakeServer(fh.Upgrade)
256+
defer close(c, s)
257+
258+
err := c.refreshJWTIfNeeded()
259+
if (err != nil) != tt.wantErr {
260+
t.Errorf("refreshJWTIfNeeded() error = %v, wantErr %v", err, tt.wantErr)
261+
return
262+
}
263+
264+
if refreshCount != tt.wantRefreshCount {
265+
t.Errorf("Expected %d token refresh calls, got %d", tt.wantRefreshCount, refreshCount)
266+
}
267+
268+
newAuth := c.header.Get("Authorization")
269+
if newAuth != tt.wantHeaderToken {
270+
t.Errorf("Authorization header = %s, want %s", newAuth, tt.wantHeaderToken)
271+
}
272+
})
273+
}
274+
}
275+
203276
func close(c *Conn, s *httptest.Server) {
204277
c.Close()
205278
s.Close()

0 commit comments

Comments
 (0)