Skip to content

remove zta endpoint #2232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions ee/localserver/dt4a.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@ import (
"context"
"log/slog"
"net/http"
"strings"
"time"

"github.com/kolide/launcher/ee/observability"
)

var (
// legacyDt4aInfoKey is the key that was used to store dt4a info that was not tied to dt4a IDs
// can be removed when unauthed /zta endpoint is removed
legacyDt4aInfoKey = []byte("localserver_zta_info")
)

const (
accelerateInterval = 5 * time.Second
accelerateDuration = 5 * time.Minute
Expand All @@ -29,52 +22,11 @@ func (ls *localServer) requestDt4aInfoHandlerFunc(w http.ResponseWriter, r *http
r, span := observability.StartHttpRequestSpan(r, "path", r.URL.Path)
defer span.End()

// This check is superfluous with the check in the middleware -- but we still have one
// unauthenticated endpoint that points directly to this handler, so we're leaving the check
// in both places for now. We can remove it once /zta is removed.
requestOrigin := r.Header.Get("Origin")
if requestOrigin != "" {
if _, ok := allowlistedDt4aOriginsLookup[requestOrigin]; !ok && !strings.HasPrefix(requestOrigin, safariWebExtensionScheme) {
escapedOrigin := strings.ReplaceAll(strings.ReplaceAll(requestOrigin, "\n", ""), "\r", "") // remove any newlines
ls.slogger.Log(r.Context(), slog.LevelInfo,
"received dt4a request with origin not in allowlist",
"req_origin", escapedOrigin,
)
w.WriteHeader(http.StatusForbidden)
return
}
}

// We only allow acceleration via this endpoint if this testing flag is set.
if ls.knapsack.AllowOverlyBroadDt4aAcceleration() {
ls.accelerate(r.Context())
}

// this should be removed when we drop unauthed endpoint
if r.Header.Get(dt4aAccountUuidHeaderKey) == "" {
// This is a legacy request to the unauthed endpoint that does not include the dt4a account uuid header.
// We will return the dt4a info stored under the legacy key.
dt4aInfo, err := ls.knapsack.Dt4aInfoStore().Get(legacyDt4aInfoKey)
if err != nil {
ls.slogger.Log(r.Context(), slog.LevelWarn,
"could not retrieve dt4a info from store using legacy dt4a key",
"err", err,
)

w.WriteHeader(http.StatusInternalServerError)
return
}

if len(dt4aInfo) == 0 {
w.WriteHeader(http.StatusNoContent)
return
}

w.Header().Set("Content-Type", "application/json")
w.Write(dt4aInfo)
return
}

// dt4aAccountUuid is set, so we will try to get the dt4a info using the account uuid
dt4aInfo, err := ls.knapsack.Dt4aInfoStore().Get([]byte(r.Header.Get(dt4aAccountUuidHeaderKey)))
if err != nil {
Expand Down
168 changes: 148 additions & 20 deletions ee/localserver/dt4a_auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ func Test_Dt4aAuthMiddleware(t *testing.T) {
t.Run("handles valid chain", func(t *testing.T) {
t.Parallel()

// create a valid chain of keys
validKeys := make([]*ecdsa.PrivateKey, 4)
validKeys[0] = rootTrustedEcKey

Expand All @@ -148,33 +149,145 @@ func Test_Dt4aAuthMiddleware(t *testing.T) {

b64 := base64.URLEncoding.EncodeToString(chainMarshalled)

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, fmt.Sprintf("/?payload=%s", b64), nil))
require.Equal(t, http.StatusOK, rr.Code,
"should return ok when chain is valid",
)
t.Run("allows missing origin", func(t *testing.T) {
t.Parallel()

bodyBytes, err := io.ReadAll(rr.Body)
require.NoError(t, err)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, fmt.Sprintf("/?payload=%s", b64), nil))
require.Equal(t, http.StatusOK, rr.Code,
"should return ok when chain is valid",
)

var z dt4aResponse
require.NoError(t, json.Unmarshal(bodyBytes, &z))
bodyBytes, err := io.ReadAll(rr.Body)
require.NoError(t, err)

dataDecoded, err := base64.URLEncoding.DecodeString(z.Data)
require.NoError(t, err)
var z dt4aResponse
require.NoError(t, json.Unmarshal(bodyBytes, &z))

x25519Decoded, err := base64.URLEncoding.DecodeString(z.PubKey)
require.NoError(t, err)
dataDecoded, err := base64.URLEncoding.DecodeString(z.Data)
require.NoError(t, err)

x25519 := new([32]byte)
copy(x25519[:], x25519Decoded)
x25519Decoded, err := base64.URLEncoding.DecodeString(z.PubKey)
require.NoError(t, err)

opened, err := echelper.OpenNaCl(dataDecoded, x25519, callerPrivKey)
require.NoError(t, err)
x25519 := new([32]byte)
copy(x25519[:], x25519Decoded)

require.Equal(t, returnData, opened,
"should be able to open NaCl box and get data",
)
opened, err := echelper.OpenNaCl(dataDecoded, x25519, callerPrivKey)
require.NoError(t, err)

require.Equal(t, returnData, opened,
"should be able to open NaCl box and get data",
)
})

t.Run("allows safari web extensions origin", func(t *testing.T) {
t.Parallel()

rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/?payload=%s", b64), nil)

req.Header.Set("origin", fmt.Sprintf("%sexample.com", safariWebExtensionScheme))

handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code,
"should return ok when chain is valid",
)

bodyBytes, err := io.ReadAll(rr.Body)
require.NoError(t, err)

var z dt4aResponse
require.NoError(t, json.Unmarshal(bodyBytes, &z))

dataDecoded, err := base64.URLEncoding.DecodeString(z.Data)
require.NoError(t, err)

x25519Decoded, err := base64.URLEncoding.DecodeString(z.PubKey)
require.NoError(t, err)

x25519 := new([32]byte)
copy(x25519[:], x25519Decoded)

opened, err := echelper.OpenNaCl(dataDecoded, x25519, callerPrivKey)
require.NoError(t, err)

require.Equal(t, returnData, opened,
"should be able to open NaCl box and get data",
)
})

t.Run("allows empty origin", func(t *testing.T) {
t.Parallel()

rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/?payload=%s", b64), nil)

req.Header.Set("origin", "")

handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code,
"should return ok when chain is valid",
)

bodyBytes, err := io.ReadAll(rr.Body)
require.NoError(t, err)

var z dt4aResponse
require.NoError(t, json.Unmarshal(bodyBytes, &z))

dataDecoded, err := base64.URLEncoding.DecodeString(z.Data)
require.NoError(t, err)

x25519Decoded, err := base64.URLEncoding.DecodeString(z.PubKey)
require.NoError(t, err)

x25519 := new([32]byte)
copy(x25519[:], x25519Decoded)

opened, err := echelper.OpenNaCl(dataDecoded, x25519, callerPrivKey)
require.NoError(t, err)

require.Equal(t, returnData, opened,
"should be able to open NaCl box and get data",
)
})

t.Run("allows valid origin", func(t *testing.T) {
t.Parallel()

rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/?payload=%s", b64), nil)

req.Header.Set("origin", acceptableOrigin(t))

handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code,
"should return ok when chain is valid",
)

bodyBytes, err := io.ReadAll(rr.Body)
require.NoError(t, err)

var z dt4aResponse
require.NoError(t, json.Unmarshal(bodyBytes, &z))

dataDecoded, err := base64.URLEncoding.DecodeString(z.Data)
require.NoError(t, err)

x25519Decoded, err := base64.URLEncoding.DecodeString(z.PubKey)
require.NoError(t, err)

x25519 := new([32]byte)
copy(x25519[:], x25519Decoded)

opened, err := echelper.OpenNaCl(dataDecoded, x25519, callerPrivKey)
require.NoError(t, err)

require.Equal(t, returnData, opened,
"should be able to open NaCl box and get data",
)
})
})
}

Expand Down Expand Up @@ -429,3 +542,18 @@ func toJWK(key any, kid string) (*jwk, error) {
return nil, errors.New("unsupported key type")
}
}

func acceptableOrigin(t *testing.T) string {
// Just grab the first origin available in our allowlist
acceptableOrigin := ""
for k := range allowlistedDt4aOriginsLookup {
acceptableOrigin = k
break
}
if acceptableOrigin == "" {
t.Error("no acceptable origins found")
t.FailNow()
}

return acceptableOrigin
}
Loading
Loading