Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions agent-manager-service/controllers/agent_token_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"net/http"

"github.com/wso2/agent-manager/agent-manager-service/middleware/jwtassertion"
"github.com/wso2/agent-manager/agent-manager-service/middleware/logger"
"github.com/wso2/agent-manager/agent-manager-service/services"
"github.com/wso2/agent-manager/agent-manager-service/spec"
Expand Down Expand Up @@ -62,6 +63,14 @@ func (c *agentTokenController) GenerateToken(w http.ResponseWriter, r *http.Requ
"agentName", agentName,
)

// Extract OrgId from the caller's JWT claims
callerClaims := jwtassertion.GetTokenClaims(ctx)
if callerClaims == nil || callerClaims.OuId == "" {
log.Error("GenerateToken: missing organization identity in caller token")
utils.WriteErrorResponse(w, http.StatusForbidden, "missing organization identity")
return
}

// Parse optional query parameters
environment := r.URL.Query().Get("environment")

Expand All @@ -87,6 +96,7 @@ func (c *agentTokenController) GenerateToken(w http.ResponseWriter, r *http.Requ
AgentName: agentName,
Environment: environment,
ExpiresIn: expiresIn,
OrgId: callerClaims.OuId,
}

// Generate token
Expand Down
4 changes: 3 additions & 1 deletion agent-manager-service/middleware/jwtassertion/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ func NewMockMiddleware(t *testing.T) Middleware {
t.Helper()

tokenClaims := &TokenClaims{
Scope: "scopes",
Scope: "scopes",
OuId: "mock-org-id",
OuHandle: "mock-org",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
Expand Down
8 changes: 8 additions & 0 deletions agent-manager-service/services/agent_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/wso2/agent-manager/agent-manager-service/clients/openchoreosvc/gen"
"github.com/wso2/agent-manager/agent-manager-service/clients/secretmanagersvc"
"github.com/wso2/agent-manager/agent-manager-service/config"
"github.com/wso2/agent-manager/agent-manager-service/middleware/jwtassertion"
"github.com/wso2/agent-manager/agent-manager-service/models"
"github.com/wso2/agent-manager/agent-manager-service/repositories"
"github.com/wso2/agent-manager/agent-manager-service/spec"
Expand Down Expand Up @@ -429,13 +430,20 @@ func (s *agentManagerService) generateAgentAPIKey(ctx context.Context, orgName,
}
firstEnvName := findLowestEnvironment(pipeline.PromotionPaths)

// Extract OrgId from the caller's JWT claims
callerClaims := jwtassertion.GetTokenClaims(ctx)
if callerClaims == nil || callerClaims.OuId == "" {
s.logger.Error("GenerateToken: missing organization identity in caller token")
return "", utils.ErrForbidden
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// Generate agent API key using token manager service with 1 year expiry
tokenReq := GenerateTokenRequest{
OrgName: orgName,
ProjectName: projectName,
AgentName: agentName,
Environment: firstEnvName,
ExpiresIn: "8760h", // 1 year (365 days * 24 hours)
OrgId: callerClaims.OuId,
}
tokenResp, err := s.tokenManagerService.GenerateToken(ctx, tokenReq)
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions agent-manager-service/services/agent_token_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type GenerateTokenRequest struct {
AgentName string
Environment string // Optional, defaults to config default if not provided
ExpiresIn string // Optional, Go duration format (e.g., "720h")
OrgId string // Organization UID from the caller's JWT claims
}

// AgentTokenClaims represents the custom claims for agent tokens
Expand All @@ -63,6 +64,7 @@ type AgentTokenClaims struct {
ComponentUid string `json:"component_uid"`
EnvironmentUid string `json:"environment_uid"`
ProjectUid string `json:"project_uid,omitempty"`
OrgId string `json:"org_id"`
}

// KeyPair holds a private/public RSA key pair with its metadata
Expand Down Expand Up @@ -246,6 +248,10 @@ func (s *agentTokenManagerService) GenerateToken(ctx context.Context, req Genera
"projectName", req.ProjectName,
)

if req.OrgId == "" {
return nil, fmt.Errorf("org id is required: %w", utils.ErrInvalidInput)
}

// Fetch component UID from OpenChoreo
component, err := s.ocClient.GetComponent(ctx, req.OrgName, req.ProjectName, req.AgentName)
if err != nil {
Expand Down Expand Up @@ -294,6 +300,7 @@ func (s *agentTokenManagerService) GenerateToken(ctx context.Context, req Genera
ComponentUid: component.UUID,
EnvironmentUid: environment.UUID,
ProjectUid: project.UUID,
OrgId: req.OrgId,
}

// Get the active signing key
Expand Down
2 changes: 2 additions & 0 deletions agent-manager-service/tests/agent_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ func TestGenerateAgentToken(t *testing.T) {
// Validate claims exist
require.Contains(t, claims, "component_uid")
require.Contains(t, claims, "environment_uid")
require.Contains(t, claims, "org_id")
require.Contains(t, claims, "iss")
require.Contains(t, claims, "exp")
require.Contains(t, claims, "iat")

// Validate the component_uid matches what we expect
require.Equal(t, tokenComponentUid, claims["component_uid"])
require.Equal(t, tokenEnvUid, claims["environment_uid"])
require.Equal(t, "mock-org-id", claims["org_id"])
})

t.Run("Invalid expiry duration - malformed string", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions console/workspaces/libs/auth/src/no-auth/hooks/authHooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const demoUserInfo : UserInfo = {
orgId: 'default',
orgName: 'Default',
sessionState: '',
sub: 'default',
sub: '8f307351-25c5-4fc6-85e0-f51c2d458f06',
allowedScopes: "openid email profile",
};

Expand All @@ -39,6 +39,6 @@ export const useAuthHooks = () => {
login: () => Promise.resolve(),
logout: () => Promise.resolve(),
trySignInSilently: () => Promise.resolve(),
getToken: () => Promise.resolve('eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJBZ2VudCBNYW5hZ2VtZW50IFBsYXRmb3JtIExvY2FsIiwiaWF0IjoxNzYxNzI3NDY5LCJleHAiOjE3OTMyNjM0NjksImF1ZCI6ImxvY2FsaG9zdCIsInN1YiI6IjhmMzA3MzUxLTI1YzUtNGZjNi04NWUwLWY1MWMyZDQ1OGYwNiJ9.etSp2_pwhdaWnFlK8IYWCptWV1MiZd32Ou6Ri6rBIvE'),
getToken: () => Promise.resolve('eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJBZ2VudCBNYW5hZ2VtZW50IFBsYXRmb3JtIExvY2FsIiwiaWF0IjoxNzYxNzI3NDY5LCJleHAiOjE3OTMyNjM0NjksImF1ZCI6ImxvY2FsaG9zdCIsInN1YiI6IjhmMzA3MzUxLTI1YzUtNGZjNi04NWUwLWY1MWMyZDQ1OGYwNiIsIm91SWQiOiJmYWExODNmMS0zOTgzLTRmNjMtYmIzNS04NmZhNzIzZmQ1ZWYifQ.t3ioFvAOrXYrHrNeMi5BSvu3oWQ8jv5cJgmgCTOiXfY'),
};
};
79 changes: 70 additions & 9 deletions traces-observer-service/observer/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@ type AuthProvider struct {
mu sync.RWMutex
accessToken string
expiresAt time.Time
usePostAuth bool
}

type tokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
}

type tokenErrorResponse struct {
Error string `json:"error"`
}

// NewAuthProvider creates a new AuthProvider with the given credentials.
func NewAuthProvider(tokenURL, clientID, clientSecret string) *AuthProvider {
return &AuthProvider{
Expand Down Expand Up @@ -114,43 +119,99 @@ func (p *AuthProvider) isTokenValid() bool {
}

func (p *AuthProvider) fetchToken(ctx context.Context) (string, int64, error) {
if p.usePostAuth {
return p.doTokenRequest(ctx, true)
}

token, expiresIn, statusCode, body, err := p.executeTokenRequest(ctx, false)
if err != nil {
return "", 0, err
}

if statusCode == http.StatusOK {
return token, expiresIn, nil
}

if isUnauthorizedClientError(statusCode, body) {
slog.Info("observer auth: client_secret_basic rejected, falling back to client_secret_post",
"status_code", statusCode)
tok, exp, retryErr := p.doTokenRequest(ctx, true)
if retryErr != nil {
return "", 0, retryErr
}
p.usePostAuth = true
return tok, exp, nil
}

return "", 0, fmt.Errorf("token endpoint returned %d: %s", statusCode, string(body))
}

func (p *AuthProvider) doTokenRequest(ctx context.Context, postAuth bool) (string, int64, error) {
token, expiresIn, statusCode, body, err := p.executeTokenRequest(ctx, postAuth)
if err != nil {
return "", 0, err
}
if statusCode != http.StatusOK {
return "", 0, fmt.Errorf("token endpoint returned %d: %s", statusCode, string(body))
}
return token, expiresIn, nil
}

func (p *AuthProvider) executeTokenRequest(ctx context.Context, postAuth bool) (string, int64, int, []byte, error) {
form := url.Values{
"grant_type": {"client_credentials"},
}
if postAuth {
form.Set("client_id", p.clientID)
form.Set("client_secret", p.clientSecret)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.tokenURL,
strings.NewReader(form.Encode()))
if err != nil {
return "", 0, fmt.Errorf("build request: %w", err)
return "", 0, 0, nil, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(p.clientID, p.clientSecret)
if !postAuth {
req.SetBasicAuth(p.clientID, p.clientSecret)
}

resp, err := p.httpClient.Do(req)
if err != nil {
return "", 0, fmt.Errorf("execute request: %w", err)
return "", 0, 0, nil, fmt.Errorf("execute request: %w", err)
}
defer func() { _ = resp.Body.Close() }()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", 0, fmt.Errorf("read response body: %w", err)
return "", 0, 0, nil, fmt.Errorf("read response body: %w", err)
}

if resp.StatusCode != http.StatusOK {
return "", 0, fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body))
return "", 0, resp.StatusCode, body, nil
}

var tr tokenResponse
if err := json.Unmarshal(body, &tr); err != nil {
return "", 0, fmt.Errorf("decode token response: %w", err)
return "", 0, 0, nil, fmt.Errorf("decode token response: %w", err)
}
if tr.AccessToken == "" {
return "", 0, fmt.Errorf("empty access token in response")
return "", 0, 0, nil, fmt.Errorf("empty access token in response")
}
if tr.ExpiresIn <= 0 {
return "", 0, fmt.Errorf("invalid expires_in value: %d", tr.ExpiresIn)
return "", 0, 0, nil, fmt.Errorf("invalid expires_in value: %d", tr.ExpiresIn)
}

return tr.AccessToken, tr.ExpiresIn, nil
return tr.AccessToken, tr.ExpiresIn, resp.StatusCode, body, nil
}

func isUnauthorizedClientError(statusCode int, body []byte) bool {
if statusCode != http.StatusBadRequest {
return false
}
var errResp tokenErrorResponse
if err := json.Unmarshal(body, &errResp); err != nil {
return false
}
return errResp.Error == "unauthorized_client"
}
Loading