Skip to content

Untie prefix from Broker communication paths #2108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions broker/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,35 @@ func getHelloWorldHandler(t *testing.T) func(resp http.ResponseWriter, req *http
}
}

// getFreePort returns an available port.
// A free port is needed beforehand so that it can be used for adding
// an origin entry in the registry database.
func getFreePort() (int, error) {
l, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer l.Close()

addr := l.Addr().(*net.TCPAddr)
return addr.Port, nil
}

func Setup(t *testing.T, ctx context.Context, egrp *errgroup.Group) {
dirpath := t.TempDir()

port, err := getFreePort()
require.NoError(t, err)
server_utils.ResetTestState()
viper.Set("Logging.Level", "Debug")
viper.Set("ConfigDir", filepath.Join(dirpath, "config"))
config.InitConfig()
viper.Set("Server.WebPort", "0")
viper.Set("Server.WebPort", port)
viper.Set("Registry.DbLocation", filepath.Join(dirpath, "ns-registry.sqlite"))
viper.Set("Origin.FederationPrefix", "/foo")

err := config.InitServer(ctx, server_structs.BrokerType)
modules := server_structs.ServerType(0)
modules.Set(server_structs.RegistryType)
modules.Set(server_structs.BrokerType)
err = config.InitServer(ctx, modules)
require.NoError(t, err)

err = registry.InitializeDB()
Expand All @@ -102,9 +119,11 @@ func Setup(t *testing.T, ctx context.Context, egrp *errgroup.Group) {
Identity: "test_data",
})
require.NoError(t, err)

originUrl, _ := url.Parse(param.Server_ExternalWebUrl.GetString())
err = registry.AddNamespace(&server_structs.Namespace{
ID: 2,
Prefix: "/foo",
Prefix: "/origins/" + originUrl.Host,
Pubkey: string(keysetBytes),
Identity: "test_data",
})
Expand All @@ -119,10 +138,10 @@ func doRetrieveRequest(t *testing.T, ctx context.Context, dur time.Duration) (*h
brokerEndpoint := param.Server_ExternalWebUrl.GetString() + "/api/v1.0/broker/retrieve"
originUrl, err := url.Parse(param.Server_ExternalWebUrl.GetString())
require.NoError(t, err)

serverNs := "/origins/" + originUrl.Host
oReq := originRequest{
Origin: originUrl.Hostname(),
Prefix: param.Origin_FederationPrefix.GetString(),
OriginNs: serverNs,
ServerNs: serverNs,
}
reqBytes, err := json.Marshal(&oReq)
require.NoError(t, err)
Expand All @@ -135,7 +154,7 @@ func doRetrieveRequest(t *testing.T, ctx context.Context, dur time.Duration) (*h
require.NoError(t, err)
brokerAud.Path = ""

token, err := createToken(param.Origin_FederationPrefix.GetString(), param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
token, err := createToken(serverNs, param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
require.NoError(t, err)

req, err := http.NewRequestWithContext(ctx, "POST", brokerEndpoint, reqReader)
Expand Down Expand Up @@ -191,6 +210,7 @@ func TestBroker(t *testing.T) {
// Launch the origin-side monitoring of requests.
viper.Set("Federation.BrokerURL", param.Server_ExternalWebUrl.GetString())
viper.Set("Federation.RegistryUrl", param.Server_ExternalWebUrl.GetString())

listenerChan := make(chan any)
ctxQuick, deadlineCancel := context.WithTimeout(ctx, 5*time.Second) // Have shorter timeout for this handshake
err = LaunchRequestMonitor(ctxQuick, egrp, listenerChan)
Expand All @@ -202,10 +222,11 @@ func TestBroker(t *testing.T) {

brokerUrl.Path = "/api/v1.0/broker/reverse"
query := brokerUrl.Query()
query.Set("origin", param.Server_Hostname.GetString())
query.Set("prefix", "/foo")
originUrl, err := url.Parse(param.Server_ExternalWebUrl.GetString())
require.NoError(t, err)
query.Set("origin", originUrl.Host)
brokerUrl.RawQuery = query.Encode()
clientConn, err := ConnectToOrigin(ctxQuick, brokerUrl.String(), "/foo", param.Server_Hostname.GetString())
clientConn, err := ConnectToOrigin(ctxQuick, brokerUrl.String(), originUrl.Host)
require.NoError(t, err)
log.Debugf("Cache got reversed client connection with cache side %s and origin side %s", clientConn.LocalAddr(), clientConn.RemoteAddr())

Expand Down
37 changes: 24 additions & 13 deletions broker/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ type (

// Struct holding pending requests waiting on an origin callback
pendingReversals struct {
channel chan http.ResponseWriter
prefix string
channel chan http.ResponseWriter
originNs string
}
)

Expand Down Expand Up @@ -157,7 +157,7 @@ func generateRequestId() string {
}

// Given an origin's broker URL, return a connected socket to the origin
func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string) (conn net.Conn, err error) {
func ConnectToOrigin(ctx context.Context, brokerUrl, originHost string) (conn net.Conn, err error) {

// Ensure we have a local CA for signing an origin host certificate.
if err = config.GenerateCACert(); err != nil {
Expand All @@ -181,8 +181,7 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string)
RequestId: generateRequestId(),
PrivateKey: keyContents,
CallbackUrl: param.Server_ExternalWebUrl.GetString() + "/api/v1.0/broker/callback",
OriginName: originName,
Prefix: prefix,
OriginHost: originHost,
}
reqBytes, err := json.Marshal(&reqC)
if err != nil {
Expand All @@ -194,7 +193,8 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string)
responseChannel := make(chan http.ResponseWriter)
defer close(responseChannel)
responseMapLock.Lock()
response[reqC.RequestId] = pendingReversals{channel: responseChannel, prefix: prefix}
originNs := "/origins/" + originHost
response[reqC.RequestId] = pendingReversals{channel: responseChannel, originNs: originNs}
responseMapLock.Unlock()
defer func() {
responseMapLock.Lock()
Expand Down Expand Up @@ -257,20 +257,24 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string)
if err != nil {
return
}
originHostname, _, err := net.SplitHostPort(originHost)
if err != nil {
return
}
notBefore := time.Now()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Pelican"},
CommonName: originName,
CommonName: originHostname,
},
NotBefore: notBefore,
NotAfter: notBefore.Add(10 * time.Minute),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
template.DNSNames = []string{originName}
template.DNSNames = []string{originHostname}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, caCert, &privKey.PublicKey, caPrivateKey)
if err != nil {
return
Expand All @@ -287,7 +291,7 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string)
// will write to the channel we originally posted.
tck := time.NewTicker(20 * time.Second)
defer tck.Stop()
log.Debugf("Cache waiting for up to 20 seconds for the origin %s to callback", originName)
log.Debugf("Cache waiting for up to 20 seconds for the origin %s to callback", originHost)
select {
case <-ctx.Done():
log.Debug("Context has been cancelled while waiting for callback")
Expand Down Expand Up @@ -415,7 +419,13 @@ func doCallback(ctx context.Context, brokerResp reversalRequest) (listener net.L
}
cacheAud.Path = ""

token, err := createToken(param.Origin_FederationPrefix.GetString(), param.Server_Hostname.GetString(), cacheAud.String(), token_scopes.Broker_Callback)
originUrl, err := url.Parse(param.Server_ExternalWebUrl.GetString())
if err != nil {
return
}
serverNs := "/origins/" + originUrl.Host

token, err := createToken(serverNs, param.Server_Hostname.GetString(), cacheAud.String(), token_scopes.Broker_Callback)
if err != nil {
err = errors.Wrap(err, "failure when constructing the cache callback token")
return
Expand Down Expand Up @@ -565,9 +575,10 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan
if err != nil {
return
}
serverNs := "/origins/" + originUrl.Host
oReq := originRequest{
Origin: originUrl.Hostname(),
Prefix: param.Origin_FederationPrefix.GetString(),
OriginNs: serverNs,
ServerNs: serverNs,
}
req, err := json.Marshal(&oReq)
if err != nil {
Expand Down Expand Up @@ -602,7 +613,7 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan
}
brokerAud.Path = ""

token, err := createToken(param.Origin_FederationPrefix.GetString(), param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
token, err := createToken(serverNs, param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
if err != nil {
log.Errorln("Failure when constructing the broker retrieve token:", err)
break
Expand Down
22 changes: 10 additions & 12 deletions broker/request_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,15 @@ type (
CallbackUrl string `json:"callback_url,omitempty"`
PrivateKey string `json:"private_key,omitempty"`
RequestId string `json:"request_id,omitempty"`
Prefix string `json:"prefix,omitempty"`
OriginName string `json:"origin,omitempty"`
OriginHost string `json:"origin,omitempty"`
}

requestInfo struct {
channel chan reversalRequest
prefix string
}

requestKey struct {
origin string
prefix string
originNs string
}
)

Expand All @@ -53,22 +50,23 @@ var (
requests map[requestKey]requestInfo = make(map[requestKey]requestInfo)
)

func getOriginQueue(prefix, origin string) chan reversalRequest {
func getOriginQueue(originNs string) chan reversalRequest {
requestsLock.Lock()
defer requestsLock.Unlock()
if req, ok := requests[requestKey{origin: origin, prefix: prefix}]; ok {
if req, ok := requests[requestKey{originNs: originNs}]; ok {
return req.channel
} else {
newChan := make(chan reversalRequest)
requests[requestKey{origin: origin, prefix: prefix}] = requestInfo{channel: newChan, prefix: prefix}
requests[requestKey{originNs: originNs}] = requestInfo{channel: newChan}
return newChan
}
}

// Send a request to a given origin's queue.
// Return a requestTimeout error if no origin retrieved the request before the context timed out.
func handleRequest(ctx context.Context, origin string, req reversalRequest, timeout time.Duration) (err error) {
queue := getOriginQueue(req.Prefix, origin)
func handleRequest(ctx context.Context, originHost string, req reversalRequest, timeout time.Duration) (err error) {
originNs := "/origins/" + originHost
queue := getOriginQueue(originNs)
maxTime := timeout - 500*time.Millisecond - time.Duration(rand.Intn(500))*time.Millisecond
if maxTime <= 0 {
maxTime = time.Millisecond
Expand All @@ -90,7 +88,7 @@ func handleRequest(ctx context.Context, origin string, req reversalRequest, time
}

// Handle the origin's request to retrieve any pending reversals.
func handleRetrieve(appCtx context.Context, ginCtx context.Context, prefix, origin string, timeout time.Duration) (req reversalRequest, err error) {
func handleRetrieve(appCtx context.Context, ginCtx context.Context, originNs string, timeout time.Duration) (req reversalRequest, err error) {
// Return randomly short of the timeout.
maxTime := timeout - 500*time.Millisecond - time.Duration(rand.Intn(500))*time.Millisecond
if maxTime <= 0 {
Expand All @@ -99,7 +97,7 @@ func handleRetrieve(appCtx context.Context, ginCtx context.Context, prefix, orig
tick := time.NewTicker(maxTime)
defer tick.Stop()
select {
case req = <-getOriginQueue(prefix, origin):
case req = <-getOriginQueue(originNs):
break
case <-tick.C:
err = errRetrieveTimeout
Expand Down
18 changes: 7 additions & 11 deletions broker/server_apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ type (

// Structure for an origin's POST to the broker
originRequest struct {
Origin string `json:"origin"`
Prefix string `json:"prefix"`
OriginNs string `json:"origin_ns"`
ServerNs string `json:"server_ns"`
}

// Response for a successful retrieval
Expand Down Expand Up @@ -98,7 +98,7 @@ func retrieveRequest(ctx context.Context, ginCtx *gin.Context) {
return
}

ok, err := verifyToken(ctx, token, originReq.Prefix, param.Server_ExternalWebUrl.GetString(), token_scopes.Broker_Retrieve)
ok, err := verifyToken(ctx, token, originReq.ServerNs, param.Server_ExternalWebUrl.GetString(), token_scopes.Broker_Retrieve)
if err != nil {
log.Errorln("Failed to verify token for reverse request:", err)
ginCtx.AbortWithStatusJSON(http.StatusBadRequest, newBrokerRespFail("Failed to verify provided token"))
Expand All @@ -108,7 +108,7 @@ func retrieveRequest(ctx context.Context, ginCtx *gin.Context) {
ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, newBrokerRespFail("Authorization denied"))
}

req, err := handleRetrieve(ctx, ginCtx, originReq.Prefix, originReq.Origin, timeoutVal)
req, err := handleRetrieve(ctx, ginCtx, originReq.OriginNs, timeoutVal)
if errors.Is(err, errRetrieveTimeout) {
ginCtx.JSON(http.StatusOK, newBrokerRespTimeout())
return
Expand Down Expand Up @@ -160,16 +160,12 @@ func reverseRequest(ctx context.Context, ginCtx *gin.Context) {
ginCtx.AbortWithStatusJSON(http.StatusBadRequest, newBrokerRespFail("Failed to parse the cache's reversal request"))
return
}
if reversalReq.OriginName == "" {
if reversalReq.OriginHost == "" {
ginCtx.AbortWithStatusJSON(http.StatusBadRequest, newBrokerRespFail("Missing 'origin' parameter in request"))
return
}
if reversalReq.Prefix == "" {
ginCtx.AbortWithStatusJSON(http.StatusBadRequest, newBrokerRespFail("Missing 'prefix' parameter in request"))
return
}

if err = handleRequest(ctx, reversalReq.OriginName, reversalReq, timeoutVal); errors.Is(err, errRequestTimeout) {
if err = handleRequest(ctx, reversalReq.OriginHost, reversalReq, timeoutVal); errors.Is(err, errRequestTimeout) {
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newBrokerRespFail("Timeout when waiting for origin callback"))
return
} else if err != nil {
Expand Down Expand Up @@ -215,7 +211,7 @@ func handleCallback(ctx context.Context, ginCtx *gin.Context) {
return
}

ok, err := verifyToken(ctx, token, pendingRev.prefix, param.Server_ExternalWebUrl.GetString(), token_scopes.Broker_Callback)
ok, err := verifyToken(ctx, token, pendingRev.originNs, param.Server_ExternalWebUrl.GetString(), token_scopes.Broker_Callback)
if err != nil {
log.Errorln("Failed to verify token for cache callback:", err)
ginCtx.AbortWithStatusJSON(http.StatusBadRequest, newBrokerRespFail("Failed to verify provided token"))
Expand Down
5 changes: 2 additions & 3 deletions cache/broker_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ import (
type (
xrootdBrokerRequest struct {
BrokerURL string `json:"broker_url"`
OriginName string `json:"origin"`
Prefix string `json:"prefix"`
OriginHost string `json:"origin"`
err error
}

Expand Down Expand Up @@ -121,7 +120,7 @@ func handleRequest(ctx context.Context, xrdConn net.Conn) {
sendXrootdError(xrdConn, errStr)
return
}
newConn, err := broker.ConnectToOrigin(ctx, xrdReq.BrokerURL, xrdReq.Prefix, xrdReq.OriginName)
newConn, err := broker.ConnectToOrigin(ctx, xrdReq.BrokerURL, xrdReq.OriginHost)
if err != nil {
errStr := "Failure when getting connection reversal from origin: " + err.Error()
log.Warning(errStr)
Expand Down
3 changes: 3 additions & 0 deletions launchers/cache_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ func CacheServe(ctx context.Context, engine *gin.Engine, egrp *errgroup.Group, m

broker.RegisterBrokerCallback(ctx, engine.Group("/", web_ui.ServerHeaderMiddleware))
broker.LaunchNamespaceKeyMaintenance(ctx, egrp)
if err = cache.LaunchRequestListener(ctx, egrp); err != nil {
return nil, err
}
configPath, err := xrootd.ConfigXrootd(ctx, false)
if err != nil {
return nil, err
Expand Down
Loading
Loading