diff --git a/edge-apis/api_session.go b/edge-apis/api_session.go new file mode 100644 index 00000000..31831b1b --- /dev/null +++ b/edge-apis/api_session.go @@ -0,0 +1,429 @@ +package edge_apis + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/strfmt" + "github.com/golang-jwt/jwt/v5" + "github.com/openziti/edge-api/rest_model" + "github.com/openziti/foundation/v2/stringz" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" +) + +var _ json.Marshaler = (*ApiSessionJsonWrapper)(nil) +var _ json.Unmarshaler = (*ApiSessionJsonWrapper)(nil) + +// ApiSessionJsonWrapper provides JSON marshaling and unmarshaling capabilities for ApiSession +// interface types. It allows polymorphic ApiSession implementations (ApiSessionLegacy and +// ApiSessionOidc) to be correctly serialized and deserialized by delegating to the underlying +// ApiSession's JSON methods. +// +// This wrapper enables ApiSession instances to be embedded in structs and marshaled to/from +// JSON. +type ApiSessionJsonWrapper struct { + ApiSession ApiSession +} + +func (a *ApiSessionJsonWrapper) UnmarshalJSON(bytes []byte) error { + var err error + a.ApiSession, err = UnmarshalApiSession(bytes) + + return err +} + +func (a *ApiSessionJsonWrapper) MarshalJSON() ([]byte, error) { + return a.ApiSession.MarshalJSON() +} + +type ApiSession interface { + //GetAccessHeader returns the HTTP header name and value that should be used to represent this ApiSession + GetAccessHeader() (string, string) + + //AuthenticateRequest fulfills the interface defined by the OpenAPI libraries to authenticate client HTTP requests + AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error + + //GetToken returns the ApiSessions' token bytes + GetToken() []byte + + //GetExpiresAt returns the time when the ApiSession will expire. + GetExpiresAt() *time.Time + + //GetAuthQueries returns a list of authentication queries the ApiSession is subjected to + GetAuthQueries() rest_model.AuthQueryList + + //GetIdentityName returns the name of the authenticating identity + GetIdentityName() string + + //GetIdentityId returns the id of the authenticating identity + GetIdentityId() string + + //GetId returns the id of the ApiSession + GetId() string + + //RequiresRouterTokenUpdate returns true if the token is a bearer token requires updating on edge router connections. + RequiresRouterTokenUpdate() bool + + GetRequestHeaders() http.Header + + // GetType returns the authentication method used to establish this session, enabling + // callers to determine whether legacy or OIDC-based authentication is in use. + GetType() ApiSessionType + + json.Marshaler + json.Unmarshaler +} + +type ApiSessionJson struct { + Type string `json:"type"` + ZtSessionToken string `json:"ztSessionToken,omitempty"` + OidcAccessToken string `json:"oidcAccessToken,omitempty"` + OidcRefreshToken string `json:"oidcRefreshToken,omitempty"` +} + +// ApiSessionType identifies the authentication mechanism used to establish an API session. +type ApiSessionType string + +const ( + // ApiSessionTypeLegacy indicates a session created using the original Ziti authentication + // with session tokens passed in the zt-session header. + ApiSessionTypeLegacy ApiSessionType = "legacy" + + // ApiSessionTypeOidc indicates a session created using OpenID Connect authentication + // with JWT bearer tokens. + ApiSessionTypeOidc ApiSessionType = "oidc" +) + +func UnmarshalApiSession(data []byte) (ApiSession, error) { + apiSessionJson := &ApiSessionJson{} + + err := json.Unmarshal(data, apiSessionJson) + + if err != nil { + return nil, err + } + + switch apiSessionJson.Type { + case string(ApiSessionTypeLegacy): + result := &ApiSessionLegacy{} + err := result.setFromJson(apiSessionJson) + if err != nil { + return nil, err + } + return result, nil + case string(ApiSessionTypeOidc): + result := &ApiSessionOidc{} + err := result.setFromJson(apiSessionJson) + if err != nil { + return nil, err + } + return result, nil + } + + return nil, fmt.Errorf("unsupported api session type %s", apiSessionJson.Type) +} + +var _ ApiSession = (*ApiSessionLegacy)(nil) +var _ ApiSession = (*ApiSessionOidc)(nil) + +// ApiSessionLegacy represents OpenZiti's original authentication API Session Detail, supplied in the `zt-session` header. +// It has been supplanted by OIDC authentication represented by ApiSessionOidc. +type ApiSessionLegacy struct { + Detail *rest_model.CurrentAPISessionDetail + RequestHeaders http.Header +} + +func NewApiSessionLegacy(token string) *ApiSessionLegacy { + return &ApiSessionLegacy{ + Detail: &rest_model.CurrentAPISessionDetail{ + APISessionDetail: rest_model.APISessionDetail{ + Token: &token, + }, + }, + } +} + +func (a *ApiSessionLegacy) NewApiSessionLegacy(token string) *ApiSessionLegacy { + return &ApiSessionLegacy{ + Detail: &rest_model.CurrentAPISessionDetail{ + APISessionDetail: rest_model.APISessionDetail{ + Token: &token, + }, + }, + } +} + +func (a *ApiSessionLegacy) GetType() ApiSessionType { + return ApiSessionTypeLegacy +} + +func (a *ApiSessionLegacy) GetRequestHeaders() http.Header { + return a.RequestHeaders +} + +func (a *ApiSessionLegacy) RequiresRouterTokenUpdate() bool { + return false +} + +func (a *ApiSessionLegacy) GetId() string { + return stringz.OrEmpty(a.Detail.ID) +} + +func (a *ApiSessionLegacy) GetIdentityName() string { + return a.Detail.Identity.Name +} + +func (a *ApiSessionLegacy) GetIdentityId() string { + return stringz.OrEmpty(a.Detail.IdentityID) +} + +// GetAccessHeader returns the header and header token value should be used for authentication requests +func (a *ApiSessionLegacy) GetAccessHeader() (string, string) { + if a.Detail != nil && a.Detail.Token != nil { + return "zt-session", *a.Detail.Token + } + + return "", "" +} + +func (a *ApiSessionLegacy) AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error { + if a == nil { + return errors.New("api session is nil") + } + + for h, v := range a.RequestHeaders { + err := request.SetHeaderParam(h, v...) + if err != nil { + return err + } + } + + //legacy does not support multiple zt-session headers, so we can it sfely + header, val := a.GetAccessHeader() + err := request.SetHeaderParam(header, val) + if err != nil { + return err + } + + return nil +} + +func (a *ApiSessionLegacy) GetToken() []byte { + if a.Detail != nil && a.Detail.Token != nil { + return []byte(*a.Detail.Token) + } + + return nil +} + +func (a *ApiSessionLegacy) GetAuthQueries() rest_model.AuthQueryList { + return a.Detail.AuthQueries +} + +func (a *ApiSessionLegacy) GetExpiresAt() *time.Time { + if a.Detail != nil { + return (*time.Time)(a.Detail.ExpiresAt) + } + + return nil +} + +func (a *ApiSessionLegacy) MarshalJSON() ([]byte, error) { + apiSessionJson := ApiSessionJson{ + Type: string(a.GetType()), + ZtSessionToken: string(a.GetToken()), + } + + return json.Marshal(apiSessionJson) +} + +func (a *ApiSessionLegacy) UnmarshalJSON(bytes []byte) error { + apiSessionJson := ApiSessionJson{} + err := json.Unmarshal(bytes, &apiSessionJson) + if err != nil { + return err + } + + return a.setFromJson(&apiSessionJson) +} + +func (a *ApiSessionLegacy) setFromJson(apiSessionJson *ApiSessionJson) error { + if apiSessionJson.Type != string(ApiSessionTypeLegacy) { + return fmt.Errorf("unsupported api session type %s", apiSessionJson.Type) + } + + a.Detail = &rest_model.CurrentAPISessionDetail{ + APISessionDetail: rest_model.APISessionDetail{ + Token: &apiSessionJson.ZtSessionToken, + }, + } + + return nil +} + +// ApiSessionOidc represents an authenticated session backed by OIDC tokens. +type ApiSessionOidc struct { + OidcTokens *oidc.Tokens[*oidc.IDTokenClaims] + RequestHeaders http.Header +} + +func NewApiSessionOidc(accessToken, refreshToken string) *ApiSessionOidc { + return &ApiSessionOidc{ + OidcTokens: &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + } +} + +func (a *ApiSessionOidc) GetType() ApiSessionType { + return ApiSessionTypeOidc +} + +func (a *ApiSessionOidc) GetRequestHeaders() http.Header { + return a.RequestHeaders +} + +func (a *ApiSessionOidc) RequiresRouterTokenUpdate() bool { + return true +} + +func (a *ApiSessionOidc) GetAccessClaims() (*ApiAccessClaims, error) { + claims := &ApiAccessClaims{} + + parser := jwt.NewParser() + _, _, err := parser.ParseUnverified(a.OidcTokens.AccessToken, claims) + + if err != nil { + return nil, err + } + + return claims, nil +} + +func (a *ApiSessionOidc) GetId() string { + claims, err := a.GetAccessClaims() + + if err != nil { + return "" + } + + return claims.ApiSessionId +} + +func (a *ApiSessionOidc) GetIdentityName() string { + return a.OidcTokens.IDTokenClaims.Name +} + +func (a *ApiSessionOidc) GetIdentityId() string { + return a.OidcTokens.IDTokenClaims.Subject +} + +// GetAccessHeader returns the header and header token value should be used for authentication requests +func (a *ApiSessionOidc) GetAccessHeader() (string, string) { + if a.OidcTokens != nil { + return "authorization", "Bearer " + a.OidcTokens.AccessToken + } + + return "", "" +} + +func (a *ApiSessionOidc) AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error { + if a == nil { + return errors.New("api session is nil") + } + + if a.RequestHeaders == nil { + a.RequestHeaders = http.Header{} + } + + //multiple Authorization headers are allowed, obtain all auth header candidates + primaryAuthHeader, primaryAuthValue := a.GetAccessHeader() + altAuthValues := a.RequestHeaders.Get(primaryAuthHeader) + + authValues := []string{primaryAuthValue} + + if len(altAuthValues) > 0 { + authValues = append(authValues, altAuthValues) + } + + //set request headers + for h, v := range a.RequestHeaders { + err := request.SetHeaderParam(h, v...) + if err != nil { + return err + } + } + + //restore auth headers + err := request.SetHeaderParam(primaryAuthHeader, authValues...) + + if err != nil { + return err + } + + return nil +} + +func (a *ApiSessionOidc) GetToken() []byte { + if a.OidcTokens != nil && a.OidcTokens.AccessToken != "" { + return []byte(a.OidcTokens.AccessToken) + } + + return nil +} + +func (a *ApiSessionOidc) GetAuthQueries() rest_model.AuthQueryList { + //todo convert JWT auth queries to rest_model.AuthQueryList + return nil +} + +func (a *ApiSessionOidc) GetExpiresAt() *time.Time { + if a.OidcTokens != nil { + return &a.OidcTokens.Expiry + } + return nil +} + +func (a *ApiSessionOidc) MarshalJSON() ([]byte, error) { + apiSessionJson := &ApiSessionJson{ + Type: string(a.GetType()), + OidcAccessToken: a.OidcTokens.AccessToken, + OidcRefreshToken: a.OidcTokens.RefreshToken, + } + + return json.Marshal(apiSessionJson) +} + +func (a *ApiSessionOidc) UnmarshalJSON(bytes []byte) error { + apiSessionJson := &ApiSessionJson{} + + err := json.Unmarshal(bytes, &apiSessionJson) + if err != nil { + return err + } + + if apiSessionJson.Type != string(ApiSessionTypeOidc) { + return fmt.Errorf("unsupported api session type %s", apiSessionJson.Type) + } + + return a.setFromJson(apiSessionJson) +} + +func (a *ApiSessionOidc) setFromJson(apiSessionJson *ApiSessionJson) error { + a.OidcTokens = &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: apiSessionJson.OidcAccessToken, + RefreshToken: apiSessionJson.OidcRefreshToken, + }, + } + + return nil +} diff --git a/edge-apis/api_session_test.go b/edge-apis/api_session_test.go new file mode 100644 index 00000000..a3a9293a --- /dev/null +++ b/edge-apis/api_session_test.go @@ -0,0 +1,33 @@ +package edge_apis + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_ApiSessionMarshalling(t *testing.T) { + + t.Run("test marshalling", func(t *testing.T) { + req := require.New(t) + type testStruct struct { + ApiSession ApiSessionJsonWrapper `json:"apiSession"` + } + + test := &testStruct{ + ApiSession: ApiSessionJsonWrapper{ + ApiSession: NewApiSessionOidc("access", "refresh"), + }, + } + + testJson, err := json.Marshal(test) + req.NoError(err) + + testUnmarhsal := &testStruct{} + + err = json.Unmarshal(testJson, testUnmarhsal) + req.NoError(err) + req.Equal(test.ApiSession.ApiSession.GetToken(), testUnmarhsal.ApiSession.ApiSession.GetToken()) + }) +} diff --git a/edge-apis/authwrapper.go b/edge-apis/authwrapper.go deleted file mode 100644 index 7afe51f7..00000000 --- a/edge-apis/authwrapper.go +++ /dev/null @@ -1,910 +0,0 @@ -// Package edge_apis_2 edge_apis_2 provides a wrapper around the generated Edge Client and Management APIs improve ease -// of use. -package edge_apis - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "sync" - "time" - - "github.com/go-openapi/runtime" - "github.com/go-openapi/strfmt" - "github.com/go-resty/resty/v2" - "github.com/golang-jwt/jwt/v5" - "github.com/openziti/edge-api/rest_client_api_client" - clientAuth "github.com/openziti/edge-api/rest_client_api_client/authentication" - clientControllers "github.com/openziti/edge-api/rest_client_api_client/controllers" - clientApiSession "github.com/openziti/edge-api/rest_client_api_client/current_api_session" - clientInfo "github.com/openziti/edge-api/rest_client_api_client/informational" - "github.com/openziti/edge-api/rest_management_api_client" - manAuth "github.com/openziti/edge-api/rest_management_api_client/authentication" - manControllers "github.com/openziti/edge-api/rest_management_api_client/controllers" - manCurApiSession "github.com/openziti/edge-api/rest_management_api_client/current_api_session" - manInfo "github.com/openziti/edge-api/rest_management_api_client/informational" - "github.com/openziti/edge-api/rest_model" - "github.com/openziti/edge-api/rest_util" - "github.com/openziti/foundation/v2/errorz" - "github.com/openziti/foundation/v2/stringz" - "github.com/pkg/errors" - "github.com/zitadel/oidc/v3/pkg/client/tokenexchange" - "github.com/zitadel/oidc/v3/pkg/oidc" - "golang.org/x/oauth2" -) - -const ( - AuthRequestIdHeader = "auth-request-id" - TotpRequiredHeader = "totp-required" -) - -// AuthEnabledApi is used as a sentinel interface to detect APIs that support authentication and to work around a golang -// limitation dealing with accessing field of generically typed fields. -type AuthEnabledApi interface { - //Authenticate will attempt to issue an authentication request using the provided credentials and http client. - //These functions act as abstraction around the underlying go-swagger generated client and will use the default - //http client if not provided. - Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) - SetUseOidc(bool) - ListControllers() (*rest_model.ControllersList, error) - GetClientTransportPool() ClientTransportPool - SetClientTransportPool(ClientTransportPool) -} - -type ApiSession interface { - //GetAccessHeader returns the HTTP header name and value that should be used to represent this ApiSession - GetAccessHeader() (string, string) - - //AuthenticateRequest fulfills the interface defined by the OpenAPI libraries to authenticate client HTTP requests - AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error - - //GetToken returns the ApiSessions' token bytes - GetToken() []byte - - //GetExpiresAt returns the time when the ApiSession will expire. - GetExpiresAt() *time.Time - - //GetAuthQueries returns a list of authentication queries the ApiSession is subjected to - GetAuthQueries() rest_model.AuthQueryList - - //GetIdentityName returns the name of the authenticating identity - GetIdentityName() string - - //GetIdentityId returns the id of the authenticating identity - GetIdentityId() string - - //GetId returns the id of the ApiSession - GetId() string - - //RequiresRouterTokenUpdate returns true if the token is a bearer token requires updating on edge router connections. - RequiresRouterTokenUpdate() bool - - GetRequestHeaders() http.Header - - // GetType returns the authentication method used to establish this session, enabling - // callers to determine whether legacy or OIDC-based authentication is in use. - GetType() ApiSessionType -} - -// ApiSessionType identifies the authentication mechanism used to establish an API session. -type ApiSessionType string - -const ( - // ApiSessionTypeLegacy indicates a session created using the original Ziti authentication - // with session tokens passed in the zt-session header. - ApiSessionTypeLegacy ApiSessionType = "legacy" - - // ApiSessionTypeOidc indicates a session created using OpenID Connect authentication - // with JWT bearer tokens. - ApiSessionTypeOidc ApiSessionType = "oidc" -) - -var _ ApiSession = (*ApiSessionLegacy)(nil) -var _ ApiSession = (*ApiSessionOidc)(nil) - -// ApiSessionLegacy represents OpenZiti's original authentication API Session Detail, supplied in the `zt-session` header. -// It has been supplanted by OIDC authentication represented by ApiSessionOidc. -type ApiSessionLegacy struct { - Detail *rest_model.CurrentAPISessionDetail - RequestHeaders http.Header -} - -func (a *ApiSessionLegacy) GetType() ApiSessionType { - return ApiSessionTypeLegacy -} - -func (a *ApiSessionLegacy) GetRequestHeaders() http.Header { - return a.RequestHeaders -} - -func (a *ApiSessionLegacy) RequiresRouterTokenUpdate() bool { - return false -} - -func (a *ApiSessionLegacy) GetId() string { - return stringz.OrEmpty(a.Detail.ID) -} - -func (a *ApiSessionLegacy) GetIdentityName() string { - return a.Detail.Identity.Name -} - -func (a *ApiSessionLegacy) GetIdentityId() string { - return stringz.OrEmpty(a.Detail.IdentityID) -} - -// GetAccessHeader returns the header and header token value should be used for authentication requests -func (a *ApiSessionLegacy) GetAccessHeader() (string, string) { - if a.Detail != nil && a.Detail.Token != nil { - return "zt-session", *a.Detail.Token - } - - return "", "" -} - -func (a *ApiSessionLegacy) AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error { - if a == nil { - return errors.New("api session is nil") - } - - for h, v := range a.RequestHeaders { - err := request.SetHeaderParam(h, v...) - if err != nil { - return err - } - } - - //legacy does not support multiple zt-session headers, so we can it sfely - header, val := a.GetAccessHeader() - err := request.SetHeaderParam(header, val) - if err != nil { - return err - } - - return nil -} - -func (a *ApiSessionLegacy) GetToken() []byte { - if a.Detail != nil && a.Detail.Token != nil { - return []byte(*a.Detail.Token) - } - - return nil -} - -func (a *ApiSessionLegacy) GetAuthQueries() rest_model.AuthQueryList { - return a.Detail.AuthQueries -} - -func (a *ApiSessionLegacy) GetExpiresAt() *time.Time { - if a.Detail != nil { - return (*time.Time)(a.Detail.ExpiresAt) - } - - return nil -} - -// ApiSessionOidc represents an authenticated session backed by OIDC tokens. -type ApiSessionOidc struct { - OidcTokens *oidc.Tokens[*oidc.IDTokenClaims] - RequestHeaders http.Header -} - -func (a *ApiSessionOidc) GetType() ApiSessionType { - return ApiSessionTypeOidc -} - -func (a *ApiSessionOidc) GetRequestHeaders() http.Header { - return a.RequestHeaders -} - -func (a *ApiSessionOidc) RequiresRouterTokenUpdate() bool { - return true -} - -func (a *ApiSessionOidc) GetAccessClaims() (*ApiAccessClaims, error) { - claims := &ApiAccessClaims{} - - parser := jwt.NewParser() - _, _, err := parser.ParseUnverified(a.OidcTokens.AccessToken, claims) - - if err != nil { - return nil, err - } - - return claims, nil -} - -func (a *ApiSessionOidc) GetId() string { - claims, err := a.GetAccessClaims() - - if err != nil { - return "" - } - - return claims.ApiSessionId -} - -func (a *ApiSessionOidc) GetIdentityName() string { - return a.OidcTokens.IDTokenClaims.Name -} - -func (a *ApiSessionOidc) GetIdentityId() string { - return a.OidcTokens.IDTokenClaims.Subject -} - -// GetAccessHeader returns the header and header token value should be used for authentication requests -func (a *ApiSessionOidc) GetAccessHeader() (string, string) { - if a.OidcTokens != nil { - return "authorization", "Bearer " + a.OidcTokens.AccessToken - } - - return "", "" -} - -func (a *ApiSessionOidc) AuthenticateRequest(request runtime.ClientRequest, _ strfmt.Registry) error { - if a == nil { - return errors.New("api session is nil") - } - - if a.RequestHeaders == nil { - a.RequestHeaders = http.Header{} - } - - //multiple Authorization headers are allowed, obtain all auth header candidates - primaryAuthHeader, primaryAuthValue := a.GetAccessHeader() - altAuthValues := a.RequestHeaders.Get(primaryAuthHeader) - - authValues := []string{primaryAuthValue} - - if len(altAuthValues) > 0 { - authValues = append(authValues, altAuthValues) - } - - //set request headers - for h, v := range a.RequestHeaders { - err := request.SetHeaderParam(h, v...) - if err != nil { - return err - } - } - - //restore auth headers - err := request.SetHeaderParam(primaryAuthHeader, authValues...) - - if err != nil { - return err - } - - return nil -} - -func (a *ApiSessionOidc) GetToken() []byte { - if a.OidcTokens != nil && a.OidcTokens.AccessToken != "" { - return []byte(a.OidcTokens.AccessToken) - } - - return nil -} - -func (a *ApiSessionOidc) GetAuthQueries() rest_model.AuthQueryList { - //todo convert JWT auth queries to rest_model.AuthQueryList - return nil -} - -func (a *ApiSessionOidc) GetExpiresAt() *time.Time { - if a.OidcTokens != nil { - return &a.OidcTokens.Expiry - } - return nil -} - -var _ AuthEnabledApi = (*ZitiEdgeManagement)(nil) - -// ZitiEdgeManagement is an alias of the go-swagger generated client that allows this package to add additional -// functionality to the alias type to implement the AuthEnabledApi interface. -type ZitiEdgeManagement struct { - *rest_management_api_client.ZitiEdgeManagement - // useOidc tracks if OIDC auth should be used - useOidc bool - - // useOidcExplicitlySet signals if useOidc was set from an external caller and should be used as is - useOidcExplicitlySet bool - - // oidcDynamicallyEnabled will cause the client to check the controller for OIDC support and use if possible as long as useOidc was not explicitly set - oidcDynamicallyEnabled bool //currently defaults false till HA release - - versionOnce sync.Once - versionInfo *rest_model.Version - - TotpCallback func(chan string) - ClientTransportPool ClientTransportPool -} - -func (self *ZitiEdgeManagement) SetClientTransportPool(transportPool ClientTransportPool) { - self.ClientTransportPool = transportPool -} - -func (self *ZitiEdgeManagement) GetClientTransportPool() ClientTransportPool { - return self.ClientTransportPool -} - -func (self *ZitiEdgeManagement) ListControllers() (*rest_model.ControllersList, error) { - params := manControllers.NewListControllersParams() - resp, err := self.Controllers.ListControllers(params, nil) - if err != nil { - return nil, err - } - - return &resp.GetPayload().Data, nil -} - -func (self *ZitiEdgeManagement) Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { - self.versionOnce.Do(func() { - if self.useOidcExplicitlySet { - return - } - - if self.oidcDynamicallyEnabled { - versionParams := manInfo.NewListVersionParams() - - versionResp, _ := self.Informational.ListVersion(versionParams) - - if versionResp != nil { - self.versionInfo = versionResp.Payload.Data - self.useOidc = stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH)) - } - } else { - self.useOidc = false - } - }) - - if self.useOidc { - return self.oidcAuth(credentials, configTypes, httpClient) - } - - return self.legacyAuth(credentials, configTypes, httpClient) -} - -func (self *ZitiEdgeManagement) legacyAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { - params := manAuth.NewAuthenticateParams() - params.Auth = credentials.Payload() - params.Method = credentials.Method() - params.Auth.ConfigTypes = append(params.Auth.ConfigTypes, configTypes...) - - certs := credentials.TlsCerts() - if len(certs) != 0 { - if transport, ok := httpClient.Transport.(*http.Transport); ok { - transport.TLSClientConfig.Certificates = certs - transport.CloseIdleConnections() - } - } - - resp, err := self.Authentication.Authenticate(params, getClientAuthInfoOp(credentials, httpClient)) - - if err != nil { - return nil, err - } - - return &ApiSessionLegacy{ - Detail: resp.GetPayload().Data, - RequestHeaders: credentials.GetRequestHeaders()}, err -} - -func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) { - return oidcAuth(self.ClientTransportPool, credentials, configTypeOverrides, httpClient, self.TotpCallback) -} - -func (self *ZitiEdgeManagement) SetUseOidc(use bool) { - self.useOidcExplicitlySet = true - self.useOidc = use -} - -func (self *ZitiEdgeManagement) SetAllowOidcDynamicallyEnabled(allow bool) { - self.oidcDynamicallyEnabled = allow -} - -func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) { - switch s := apiSession.(type) { - case *ApiSessionLegacy: - params := manCurApiSession.NewGetCurrentAPISessionParams() - _, err := self.CurrentAPISession.GetCurrentAPISession(params, s) - - if err != nil { - return nil, rest_util.WrapErr(err) - } - - return s, nil - case *ApiSessionOidc: - tokens, err := self.ExchangeTokens(s.OidcTokens, httpClient) - - if err != nil { - return nil, err - } - - return &ApiSessionOidc{ - OidcTokens: tokens, - RequestHeaders: apiSession.GetRequestHeaders(), - }, nil - } - - return nil, errors.New("api session does not have any tokens") -} - -func (self *ZitiEdgeManagement) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { - return exchangeTokens(self.ClientTransportPool, curTokens, httpClient) -} - -var _ AuthEnabledApi = (*ZitiEdgeClient)(nil) - -// ZitiEdgeClient is an alias of the go-swagger generated client that allows this package to add additional -// functionality to the alias type to implement the AuthEnabledApi interface. -type ZitiEdgeClient struct { - *rest_client_api_client.ZitiEdgeClient - // useOidc tracks if OIDC auth should be used - useOidc bool - - // useOidcExplicitlySet signals if useOidc was set from an external caller and should be used as is - useOidcExplicitlySet bool - - // oidcDynamicallyEnabled will cause the client to check the controller for OIDC support and use if possible as long as useOidc was not explicitly set. - oidcDynamicallyEnabled bool //currently defaults false till HA release - - versionInfo *rest_model.Version - versionOnce sync.Once - - TotpCallback func(chan string) - ClientTransportPool ClientTransportPool -} - -func (self *ZitiEdgeClient) GetClientTransportPool() ClientTransportPool { - return self.ClientTransportPool -} - -func (self *ZitiEdgeClient) SetClientTransportPool(transportPool ClientTransportPool) { - self.ClientTransportPool = transportPool -} - -func (self *ZitiEdgeClient) ListControllers() (*rest_model.ControllersList, error) { - params := clientControllers.NewListControllersParams() - resp, err := self.Controllers.ListControllers(params, nil) - if err != nil { - return nil, err - } - - return &resp.GetPayload().Data, nil -} - -func (self *ZitiEdgeClient) Authenticate(credentials Credentials, configTypesOverrides []string, httpClient *http.Client) (ApiSession, error) { - self.versionOnce.Do(func() { - if self.useOidcExplicitlySet { - return - } - - if self.oidcDynamicallyEnabled { - versionParams := clientInfo.NewListVersionParams() - - versionResp, _ := self.Informational.ListVersion(versionParams) - - if versionResp != nil { - self.versionInfo = versionResp.Payload.Data - self.useOidc = stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH)) - } - } else { - self.useOidc = false - } - }) - - if self.useOidc { - return self.oidcAuth(credentials, configTypesOverrides, httpClient) - } - - return self.legacyAuth(credentials, configTypesOverrides, httpClient) -} - -func (self *ZitiEdgeClient) legacyAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { - params := clientAuth.NewAuthenticateParams() - params.Auth = credentials.Payload() - params.Method = credentials.Method() - params.Auth.ConfigTypes = append(params.Auth.ConfigTypes, configTypes...) - - certs := credentials.TlsCerts() - if len(certs) != 0 { - if transport, ok := httpClient.Transport.(*http.Transport); ok { - transport.TLSClientConfig.Certificates = certs - transport.CloseIdleConnections() - } - } - - resp, err := self.Authentication.Authenticate(params, getClientAuthInfoOp(credentials, httpClient)) - - if err != nil { - return nil, err - } - - return &ApiSessionLegacy{Detail: resp.GetPayload().Data, RequestHeaders: credentials.GetRequestHeaders()}, err -} - -func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) { - return oidcAuth(self.ClientTransportPool, credentials, configTypeOverrides, httpClient, self.TotpCallback) -} - -func (self *ZitiEdgeClient) SetUseOidc(use bool) { - self.useOidcExplicitlySet = true - self.useOidc = use -} - -func (self *ZitiEdgeClient) SetAllowOidcDynamicallyEnabled(allow bool) { - self.oidcDynamicallyEnabled = allow -} - -func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) { - switch s := apiSession.(type) { - case *ApiSessionLegacy: - params := clientApiSession.NewGetCurrentAPISessionParams() - newApiSessionDetail, err := self.CurrentAPISession.GetCurrentAPISession(params, s) - - if err != nil { - return nil, rest_util.WrapErr(err) - } - - newApiSession := &ApiSessionLegacy{ - Detail: newApiSessionDetail.Payload.Data, - RequestHeaders: apiSession.GetRequestHeaders(), - } - - return newApiSession, nil - case *ApiSessionOidc: - tokens, err := self.ExchangeTokens(s.OidcTokens, httpClient) - - if err != nil { - return nil, err - } - - return &ApiSessionOidc{ - OidcTokens: tokens, - RequestHeaders: apiSession.GetRequestHeaders(), - }, nil - } - - return nil, errors.New("api session does not have any tokens") -} - -func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { - return exchangeTokens(self.ClientTransportPool, curTokens, httpClient) -} - -func exchangeTokens(clientTransportPool ClientTransportPool, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { - subjectToken := curTokens.RefreshToken - subjectTokenType := oidc.RefreshTokenType - - // if subjectToken is "", then we don't have a refresh token, attempt to exchange a non-expired access token - if subjectToken == "" { - if curTokens.Expiry.Before(time.Now()) { - return nil, errors.New("cannot exchange token: refresh token not found, access token expired") - } - - if curTokens.AccessToken == "" { - return nil, errors.New("cannot exchange token: refresh token not found, access token not found") - } - subjectToken = curTokens.AccessToken - subjectTokenType = oidc.AccessTokenType - } - - var outTokens *oidc.Tokens[*oidc.IDTokenClaims] - - _, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { - timeoutCtx, cancelF := context.WithTimeout(context.Background(), 30*time.Second) - defer cancelF() - - apiHost := transport.ApiUrl.Host - issuer := "https://" + apiHost + "/oidc" - tokenEndpoint := "https://" + apiHost + "/oidc/oauth/token" - - te, err := tokenexchange.NewTokenExchangerClientCredentials(timeoutCtx, issuer, "native", "", tokenexchange.WithHTTPClient(client), tokenexchange.WithStaticTokenEndpoint(issuer, tokenEndpoint)) - - if err != nil { - return nil, err - } - - var tokenResponse *oidc.TokenExchangeResponse - - now := time.Now() - - switch subjectTokenType { - case oidc.RefreshTokenType: - tokenResponse, err = tokenexchange.ExchangeToken(timeoutCtx, te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType) - case oidc.AccessTokenType: - tokenResponse, err = tokenexchange.ExchangeToken(timeoutCtx, te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.AccessTokenType) - } - - if err != nil { - return nil, err - } - - idResp, err := tokenexchange.ExchangeToken(timeoutCtx, te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.IDTokenType) - - if err != nil { - return nil, err - } - - idClaims := &IdClaims{} - - //access token is used to hold id token per zitadel comments - _, _, err = jwt.NewParser().ParseUnverified(idResp.AccessToken, idClaims) - - if err != nil { - return nil, err - } - - outTokens = &oidc.Tokens[*oidc.IDTokenClaims]{ - Token: &oauth2.Token{ - AccessToken: tokenResponse.AccessToken, - TokenType: tokenResponse.TokenType, - RefreshToken: tokenResponse.RefreshToken, - Expiry: now.Add(time.Second * time.Duration(tokenResponse.ExpiresIn)), - }, - IDTokenClaims: &idClaims.IDTokenClaims, - IDToken: idResp.AccessToken, //access token field is used to hold id token per zitadel comments - } - - return outTokens, nil - }) - - if err != nil { - return nil, err - } - - return outTokens, nil -} - -type authPayload struct { - *rest_model.Authenticate - AuthRequestId string `json:"id"` -} - -type totpCodePayload struct { - rest_model.MfaCode - AuthRequestId string `json:"id"` -} - -func (a *authPayload) toValues() url.Values { - result := url.Values{ - "id": []string{a.AuthRequestId}, - "password": []string{string(a.Password)}, - "username": []string{string(a.Username)}, - "configTypes": a.ConfigTypes, - "envArch": []string{a.EnvInfo.Arch}, - "envOs": []string{a.EnvInfo.Os}, - "envOsRelease": []string{a.EnvInfo.OsRelease}, - "envOsVersion": []string{a.EnvInfo.OsVersion}, - "sdkAppID": []string{a.SdkInfo.AppID}, - "sdkAppVersion": []string{a.SdkInfo.AppVersion}, - "sdkBranch": []string{a.SdkInfo.Branch}, - "sdkRevision": []string{a.SdkInfo.Revision}, - "sdkType": []string{a.SdkInfo.Type}, - "sdkVersion": []string{a.SdkInfo.Version}, - } - - return result -} - -func oidcAuth(clientTransportPool ClientTransportPool, credentials Credentials, configTypeOverrides []string, httpClient *http.Client, totpCallback func(chan string)) (ApiSession, error) { - payload := &authPayload{ - Authenticate: credentials.Payload(), - } - method := credentials.Method() - - if configTypeOverrides != nil { - payload.ConfigTypes = configTypeOverrides - } - - certs := credentials.TlsCerts() - - if len(certs) != 0 { - if transport, ok := httpClient.Transport.(*http.Transport); ok { - transport.TLSClientConfig.Certificates = certs - transport.CloseIdleConnections() - } - } - - var outTokens *oidc.Tokens[*oidc.IDTokenClaims] - - _, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { - rpServer, err := newLocalRpServer(transport.ApiUrl.Host, method) - - if err != nil { - return nil, err - } - - rpServer.Start() - defer rpServer.Stop() - - client := resty.NewWithClient(httpClient) - apiHost := transport.ApiUrl.Hostname() - - client.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("127.0.0.1", "localhost", apiHost)) - resp, err := client.R().Get(rpServer.LoginUri) - - if err != nil { - return nil, err - } - - if resp.StatusCode() != http.StatusOK { - return nil, fmt.Errorf("local rp login response is expected to be HTTP status %d got %d with body: %s", http.StatusOK, resp.StatusCode(), resp.Body()) - } - payload.AuthRequestId = resp.Header().Get(AuthRequestIdHeader) - - if payload.AuthRequestId == "" { - return nil, errors.New("could not find auth request id header") - } - - opLoginUri := "https://" + resp.RawResponse.Request.URL.Host + "/oidc/login/" + method - totpUri := "https://" + resp.RawResponse.Request.URL.Host + "/oidc/login/totp" - - formData := payload.toValues() - - req := client.R() - clientRequest := asClientRequest(req, client) - - err = credentials.AuthenticateRequest(clientRequest, strfmt.Default) - - if err != nil { - return nil, err - } - - resp, err = req.SetFormDataFromValues(formData).Post(opLoginUri) - - if err != nil { - return nil, err - } - - if resp.StatusCode() != http.StatusOK { - return nil, fmt.Errorf("remote op login response is expected to be HTTP status %d got %d with body: %s", http.StatusOK, resp.StatusCode(), resp.Body()) - } - - authRequestId := payload.AuthRequestId - totpRequiredHeader := resp.Header().Get(TotpRequiredHeader) - totpRequired := totpRequiredHeader != "" - totpCode := "" - - if totpRequired { - - if totpCallback == nil { - return nil, errors.New("totp is required but not totp callback was defined") - } - codeChan := make(chan string) - go totpCallback(codeChan) - - select { - case code := <-codeChan: - totpCode = code - case <-time.After(30 * time.Minute): - return nil, fmt.Errorf("timedout waiting for totpT callback") - } - - resp, err = client.R().SetBody(&totpCodePayload{ - MfaCode: rest_model.MfaCode{ - Code: &totpCode, - }, - AuthRequestId: authRequestId, - }).Post(totpUri) - - if err != nil { - return nil, err - } - - if resp.StatusCode() != http.StatusOK { - apiErr := &errorz.ApiError{} - err = json.Unmarshal(resp.Body(), apiErr) - - if err != nil { - return nil, fmt.Errorf("could not verify TOTP MFA code recieved %d - could not parse body: %s", resp.StatusCode(), string(resp.Body())) - } - - return nil, apiErr - } - } - - var tokens *oidc.Tokens[*oidc.IDTokenClaims] - select { - case tokens = <-rpServer.TokenChan: - case <-time.After(30 * time.Minute): - } - - if tokens == nil { - return nil, errors.New("authentication did not complete, received nil tokens") - } - outTokens = tokens - - return nil, nil - }) - - if err != nil { - return nil, err - } - - return &ApiSessionOidc{ - OidcTokens: outTokens, - RequestHeaders: credentials.GetRequestHeaders(), - }, nil -} - -// restyClientRequest is meant to mimic open api's client request which is a combination -// of resty's request and client. -type restyClientRequest struct { - restyRequest *resty.Request - restyClient *resty.Client -} - -func (r *restyClientRequest) SetHeaderParam(s string, s2 ...string) error { - r.restyRequest.Header[s] = s2 - return nil -} - -func (r *restyClientRequest) GetHeaderParams() http.Header { - return r.restyRequest.Header -} - -func (r *restyClientRequest) SetQueryParam(s string, s2 ...string) error { - r.restyRequest.QueryParam[s] = s2 - return nil -} - -func (r *restyClientRequest) SetFormParam(s string, s2 ...string) error { - r.restyRequest.FormData[s] = s2 - return nil -} - -func (r *restyClientRequest) SetPathParam(s string, s2 string) error { - r.restyRequest.PathParams[s] = s2 - return nil -} - -func (r *restyClientRequest) GetQueryParams() url.Values { - return r.restyRequest.QueryParam -} - -func (r *restyClientRequest) SetFileParam(s string, closer ...runtime.NamedReadCloser) error { - for _, curCloser := range closer { - r.restyRequest.SetFileReader(s, curCloser.Name(), curCloser) - } - - return nil -} - -func (r *restyClientRequest) SetBodyParam(i interface{}) error { - r.restyRequest.SetBody(i) - return nil -} - -func (r *restyClientRequest) SetTimeout(duration time.Duration) error { - r.restyClient.SetTimeout(duration) - return nil -} - -func (r *restyClientRequest) GetMethod() string { - return r.restyRequest.Method -} - -func (r *restyClientRequest) GetPath() string { - return r.restyRequest.URL -} - -func (r *restyClientRequest) GetBody() []byte { - return r.restyRequest.Body.([]byte) -} - -func (r *restyClientRequest) GetBodyParam() interface{} { - return r.restyRequest.Body -} - -func (r *restyClientRequest) GetFileParam() map[string][]runtime.NamedReadCloser { - return nil -} - -func asClientRequest(request *resty.Request, client *resty.Client) runtime.ClientRequest { - return &restyClientRequest{request, client} -} diff --git a/edge-apis/client_base.go b/edge-apis/client_base.go new file mode 100644 index 00000000..27f90bae --- /dev/null +++ b/edge-apis/client_base.go @@ -0,0 +1,245 @@ +package edge_apis + +import ( + "net/http" + "net/url" + "strings" + "sync/atomic" + + "github.com/go-openapi/runtime" + openapiclient "github.com/go-openapi/runtime/client" + "github.com/go-openapi/strfmt" + "github.com/michaelquigley/pfxlog" + "github.com/openziti/edge-api/rest_model" +) + +const ( + AuthRequestIdHeader = "auth-request-id" + TotpRequiredHeader = "totp-required" +) + +// AuthEnabledApi is a sentinel interface that detects APIs supporting authentication. +// It provides methods for authenticating, managing sessions, and discovering controllers for high-availability. +type AuthEnabledApi interface { + // Authenticate authenticates using the provided credentials and returns an ApiSession for subsequent authenticated requests. + Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) + // SetUseOidc forces OIDC mode (true) or legacy mode (false). + SetUseOidc(bool) + // ListControllers returns the list of available controllers for HA failover. + ListControllers() (*rest_model.ControllersList, error) + // GetClientTransportPool returns the transport pool managing multiple controller endpoints. + GetClientTransportPool() ClientTransportPool + // SetClientTransportPool sets the transport pool. + SetClientTransportPool(ClientTransportPool) + // RefreshApiSession refreshes an existing session. + RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) +} + +// BaseClient provides shared authentication and session management for OpenZiti API clients. +// It handles credential-based authentication, TLS configuration, session storage, and controller failover. +type BaseClient[A ApiType] struct { + API *A + AuthEnabledApi AuthEnabledApi + Components + AuthInfoWriter runtime.ClientAuthInfoWriter + ApiSession atomic.Pointer[ApiSession] + Credentials Credentials + ApiUrls []*url.URL + ApiBinding string + ApiVersion string + Schemes []string + onControllerListeners []func([]*url.URL) +} + +// Url returns the URL of the currently active controller endpoint. +func (self *BaseClient[A]) Url() url.URL { + return *self.AuthEnabledApi.GetClientTransportPool().GetActiveTransport().ApiUrl +} + +// AddOnControllerUpdateListeners registers a callback that is invoked when the list of +// available controller endpoints changes. +func (self *BaseClient[A]) AddOnControllerUpdateListeners(listener func([]*url.URL)) { + self.onControllerListeners = append(self.onControllerListeners, listener) +} + +// GetCurrentApiSession returns the ApiSession that is being used to authenticate requests. +func (self *BaseClient[A]) GetCurrentApiSession() ApiSession { + ptr := self.ApiSession.Load() + if ptr == nil { + return nil + } + + return *ptr +} + +// SetUseOidc forces the API client to operate in OIDC mode when true, or legacy mode when false. +func (self *BaseClient[A]) SetUseOidc(use bool) { + v := any(self.API) + apiType := v.(OidcEnabledApi) + apiType.SetUseOidc(use) +} + +// SetAllowOidcDynamicallyEnabled configures whether the client checks the controller for +// OIDC support and switches modes accordingly. +func (self *BaseClient[A]) SetAllowOidcDynamicallyEnabled(allow bool) { + v := any(self.API) + apiType := v.(OidcEnabledApi) + apiType.SetAllowOidcDynamicallyEnabled(allow) +} + +func (self *BaseClient[A]) SetOidcRedirectUri(redirectUri string) { + v := any(self.API) + apiType := v.(OidcEnabledApi) + apiType.SetOidcRedirectUri(redirectUri) +} + +// Authenticate authenticates using provided credentials, updating the TLS configuration based on the credential's CA pool. +// On success, stores the session and processes controller endpoints for HA failover. +// On failure, clears the session and credentials. +func (self *BaseClient[A]) Authenticate(credentials Credentials, configTypesOverride []string) (ApiSession, error) { + self.Credentials = nil + self.ApiSession.Store(nil) + + tlsClientConfig := self.TlsAwareTransport.GetTlsClientConfig() + + if credCaPool := credentials.GetCaPool(); credCaPool != nil { + tlsClientConfig.RootCAs = credCaPool + } else { + tlsClientConfig.RootCAs = self.CaPool + } + + apiSession, err := self.AuthEnabledApi.Authenticate(credentials, configTypesOverride, self.HttpClient) + + if err != nil { + return nil, err + } + + self.Credentials = credentials + self.ApiSession.Store(&apiSession) + + self.ProcessControllers(self.AuthEnabledApi) + + return apiSession, nil +} + +func (self *BaseClient[A]) AuthenticateWithPreviousSession(credentials Credentials, prevApiSession ApiSession) (ApiSession, error) { + self.Credentials = nil + self.ApiSession.Store(nil) + + tlsClientConfig := self.TlsAwareTransport.GetTlsClientConfig() + if credCaPool := credentials.GetCaPool(); credCaPool != nil { + tlsClientConfig.RootCAs = credCaPool + } else { + tlsClientConfig.RootCAs = self.CaPool + } + + refreshedSession, refreshErr := self.AuthEnabledApi.RefreshApiSession(prevApiSession, self.HttpClient) + + if refreshErr != nil { + return nil, refreshErr + } + + self.Credentials = credentials + self.ApiSession.Store(&refreshedSession) + + self.ProcessControllers(self.AuthEnabledApi) + + return refreshedSession, nil +} + +// initializeComponents assembles HTTP client infrastructure, either using provided Components or creating new ones. +// If Components are provided with nil transport/client, they are initialized with warnings logged. +func (self *BaseClient[A]) initializeComponents(config *ApiClientConfig) { + //have a config and either the client or transport are set, verify them, else an empty components was supplied + // then initialize them with defaults + if config.Components != nil && (config.Components.HttpClient != nil || config.Components.TlsAwareTransport != nil) { + config.Components.assertComponents(config) + self.Components = *config.Components + + if config.Proxy != nil { + pfxlog.Logger().Warn("components were provided along with a proxy function on the ApiClientConfig, it is being ignored, if needed properly set on components") + } + return + } + + components := NewComponentsWithConfig(&ComponentsConfig{ + Proxy: config.Proxy, + }) + + tlsClientConfig := components.TlsAwareTransport.GetTlsClientConfig() + tlsClientConfig.RootCAs = config.CaPool + components.CaPool = config.CaPool + + self.Components = *components +} + +// NewRuntime creates an OpenAPI runtime for communicating with a controller endpoint. Used for HA failover to add multiple controller endpoints. +func NewRuntime(apiUrl *url.URL, schemes []string, httpClient *http.Client) *openapiclient.Runtime { + return openapiclient.NewWithClient(apiUrl.Host, apiUrl.Path, schemes, httpClient) +} + +// AuthenticateRequest authenticates outgoing API requests using the current session or credentials. +// It implements the openapi runtime.ClientAuthInfoWriter interface. +func (self *BaseClient[A]) AuthenticateRequest(request runtime.ClientRequest, registry strfmt.Registry) error { + if self.AuthInfoWriter != nil { + return self.AuthInfoWriter.AuthenticateRequest(request, registry) + } + + // do not add auth to authenticating endpoints + if strings.Contains(request.GetPath(), "/oidc/auth") || strings.Contains(request.GetPath(), "/authenticate") { + return nil + } + + currentSessionPtr := self.ApiSession.Load() + if currentSessionPtr != nil { + currentSession := *currentSessionPtr + + if currentSession != nil && currentSession.GetToken() != nil { + if err := currentSession.AuthenticateRequest(request, registry); err != nil { + return err + } + } + } + + if self.Credentials != nil { + if err := self.Credentials.AuthenticateRequest(request, registry); err != nil { + return err + } + } + + return nil +} + +// ProcessControllers discovers peer controllers and registers them for HA failover. Called after successful authentication. +func (self *BaseClient[A]) ProcessControllers(authEnabledApi AuthEnabledApi) { + list, err := authEnabledApi.ListControllers() + + if err != nil { + pfxlog.Logger().WithError(err).Debug("error listing controllers, continuing with 1 default configured controller") + return + } + + if list == nil || len(*list) <= 1 { + pfxlog.Logger().Debug("no additional controllers reported, continuing with 1 default configured controller") + return + } + + //look for matching api binding and versions + for _, controller := range *list { + apis := controller.APIAddresses[self.ApiBinding] + + for _, apiAddr := range apis { + if apiAddr.Version == self.ApiVersion { + apiUrl, parseErr := url.Parse(apiAddr.URL) + if parseErr == nil { + self.AuthEnabledApi.GetClientTransportPool().Add(apiUrl, NewRuntime(apiUrl, self.Schemes, self.HttpClient)) + } + } + } + } + + apis := self.AuthEnabledApi.GetClientTransportPool().GetApiUrls() + for _, listener := range self.onControllerListeners { + listener(apis) + } +} diff --git a/edge-apis/client_edge_client.go b/edge-apis/client_edge_client.go new file mode 100644 index 00000000..ed80c5c3 --- /dev/null +++ b/edge-apis/client_edge_client.go @@ -0,0 +1,273 @@ +package edge_apis + +import ( + "crypto/x509" + "errors" + "fmt" + "net/http" + "net/url" + "sync" + + "github.com/openziti/edge-api/rest_client_api_client" + clientAuth "github.com/openziti/edge-api/rest_client_api_client/authentication" + clientControllers "github.com/openziti/edge-api/rest_client_api_client/controllers" + clientApiSession "github.com/openziti/edge-api/rest_client_api_client/current_api_session" + clientInfo "github.com/openziti/edge-api/rest_client_api_client/informational" + "github.com/openziti/edge-api/rest_model" + "github.com/openziti/edge-api/rest_util" + "github.com/openziti/foundation/v2/stringz" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +var _ OidcEnabledApi = (*ZitiEdgeClient)(nil) +var _ AuthEnabledApi = (*ZitiEdgeClient)(nil) + +// ClientApiClient provides access to the Ziti Edge Client API for identity operations. +type ClientApiClient struct { + BaseClient[ZitiEdgeClient] +} + +// NewClientApiClient will assemble a ClientApiClient. The apiUrl should be the full URL +// to the Edge Client API (e.g. `https://example.com/edge/client/v1`). +// +// The `caPool` argument should be a list of trusted root CAs. If provided as `nil` here unauthenticated requests +// will use the system certificate pool. If authentication occurs, and a certificate pool is set on the Credentials +// the certificate pool from the Credentials will be used from that point forward. Credentials implementations +// based on an identity.Identity are likely to provide a certificate pool. +// +// For OpenZiti instances not using publicly signed certificates, `ziti.GetControllerWellKnownCaPool()` can be used +// to obtain and verify the target controllers CAs. Tools should allow users to verify and accept new controllers +// that have not been verified from an outside secret (such as an enrollment token). +func NewClientApiClient(apiUrls []*url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ClientApiClient { + return NewClientApiClientWithConfig(&ApiClientConfig{ + ApiUrls: apiUrls, + CaPool: caPool, + TotpCodeProvider: NewTotpCodeProviderFromChStringFunc(totpCallback), + Proxy: http.ProxyFromEnvironment, + }) +} + +// NewClientApiClientWithConfig creates a Client API client using the provided configuration. +func NewClientApiClientWithConfig(config *ApiClientConfig) *ClientApiClient { + ret := &ClientApiClient{} + ret.ApiBinding = "edge-client" + ret.ApiVersion = "v1" + ret.Schemes = rest_client_api_client.DefaultSchemes + ret.ApiUrls = config.ApiUrls + + ret.initializeComponents(config) + + transportPool := NewClientTransportPoolRandom() + + for _, apiUrl := range config.ApiUrls { + newRuntime := NewRuntime(apiUrl, ret.Schemes, ret.HttpClient) + newRuntime.DefaultAuthentication = ret + transportPool.Add(apiUrl, newRuntime) + } + + newApi := rest_client_api_client.New(transportPool, nil) + api := ZitiEdgeClient{ + ZitiEdgeClient: newApi, + TotpCodeProvider: config.TotpCodeProvider, + ClientTransportPool: transportPool, + } + ret.API = &api + ret.AuthEnabledApi = &api + + api.doOnceCacheVersionInfo() + + return ret +} + +var _ AuthEnabledApi = (*ZitiEdgeClient)(nil) + +// ZitiEdgeClient is an alias of the go-swagger generated client that allows this package to add additional +// functionality to the alias type to implement the AuthEnabledApi interface. +type ZitiEdgeClient struct { + *rest_client_api_client.ZitiEdgeClient + // useOidc tracks if OIDC auth should be used + useOidc bool + + // useOidcExplicitlySet signals if useOidc was set from an external caller and should be used as is + useOidcExplicitlySet bool + + // oidcDynamicallyEnabled will cause the client to check the controller for OIDC support and use if possible as long as useOidc was not explicitly set. + oidcDynamicallyEnabled bool //currently defaults false till HA release + + versionInfo *rest_model.Version + versionOnce sync.Once + + TotpCodeProvider TotpCodeProvider + ClientTransportPool ClientTransportPool + OidcRedirectUri string +} + +func (self *ZitiEdgeClient) SetOidcRedirectUri(redirectUri string) { + self.OidcRedirectUri = redirectUri +} + +// GetClientTransportPool returns the transport pool managing multiple controller endpoints for failover. +func (self *ZitiEdgeClient) GetClientTransportPool() ClientTransportPool { + return self.ClientTransportPool +} + +// SetClientTransportPool sets the transport pool. +func (self *ZitiEdgeClient) SetClientTransportPool(transportPool ClientTransportPool) { + self.ClientTransportPool = transportPool +} + +// ListControllers returns the list of available controllers for high-availability failover. +func (self *ZitiEdgeClient) ListControllers() (*rest_model.ControllersList, error) { + params := clientControllers.NewListControllersParams() + resp, err := self.Controllers.ListControllers(params, nil) + if err != nil { + return nil, err + } + + return &resp.GetPayload().Data, nil +} + +func (self *ZitiEdgeClient) Authenticate(credentials Credentials, configTypesOverrides []string, httpClient *http.Client) (ApiSession, error) { + self.doOnceCacheVersionInfo() + useOidc := false + + if self.useOidcExplicitlySet { + useOidc = self.useOidc + } else if self.oidcDynamicallyEnabled { + useOidc = self.ControllerSupportsOidc() + } + + if useOidc { + return self.oidcAuth(credentials, configTypesOverrides, httpClient) + } + + return self.legacyAuth(credentials, configTypesOverrides, httpClient) +} + +// legacyAuth performs zt-session token based authentication. +func (self *ZitiEdgeClient) legacyAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { + params := clientAuth.NewAuthenticateParams() + params.Auth = credentials.Payload() + params.Method = string(credentials.Method()) + params.Auth.ConfigTypes = append(params.Auth.ConfigTypes, configTypes...) + + if credentials.Method() == AuthMethodEmpty { + return nil, fmt.Errorf("auth method %s cannot be used for authentication, please provide alternate credentials", AuthMethodEmpty) + } + + certs := credentials.TlsCerts() + if len(certs) != 0 { + if transport, ok := httpClient.Transport.(TlsAwareTransport); ok { + tlsClientConf := transport.GetTlsClientConfig() + tlsClientConf.Certificates = certs + transport.CloseIdleConnections() + } + } + + resp, err := self.Authentication.Authenticate(params, getClientAuthInfoOp(credentials, httpClient)) + + if err != nil { + return nil, err + } + + return &ApiSessionLegacy{Detail: resp.GetPayload().Data, RequestHeaders: credentials.GetRequestHeaders()}, err +} + +// oidcAuth performs OIDC OAuth flow based authentication. +func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) { + config := &EdgeOidcAuthConfig{ + ClientTransportPool: self.ClientTransportPool, + Credentials: credentials, + ConfigTypeOverrides: configTypeOverrides, + HttpClient: httpClient, + TotpCodeProvider: self.TotpCodeProvider, + RedirectUri: self.OidcRedirectUri, + } + + return oidcAuth(config) +} + +// SetUseOidc forces OIDC mode (true) or legacy mode (false), overriding automatic detection. +func (self *ZitiEdgeClient) SetUseOidc(use bool) { + self.useOidcExplicitlySet = true + self.useOidc = use +} + +// SetAllowOidcDynamicallyEnabled enables automatic OIDC capability detection on the controller. +func (self *ZitiEdgeClient) SetAllowOidcDynamicallyEnabled(allow bool) { + self.oidcDynamicallyEnabled = allow +} + +// RefreshApiSession refreshes an existing API session (both legacy and OIDC types). +func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) { + switch s := apiSession.(type) { + case *ApiSessionLegacy: + params := clientApiSession.NewGetCurrentAPISessionParams() + newApiSessionDetail, err := self.CurrentAPISession.GetCurrentAPISession(params, s) + + if err != nil { + return nil, rest_util.WrapErr(err) + } + + newApiSession := &ApiSessionLegacy{ + Detail: newApiSessionDetail.Payload.Data, + RequestHeaders: apiSession.GetRequestHeaders(), + } + + return newApiSession, nil + case *ApiSessionOidc: + tokens, err := self.ExchangeTokens(s.OidcTokens, httpClient) + + if err != nil { + return nil, err + } + + return &ApiSessionOidc{ + OidcTokens: tokens, + RequestHeaders: apiSession.GetRequestHeaders(), + }, nil + } + + return nil, errors.New("api session is an unknown type") +} + +// ExchangeTokens exchanges OIDC tokens for refreshed tokens. +func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + return exchangeTokens(self.ClientTransportPool, curTokens, httpClient) +} + +// ControllerSupportsHa checks if the controller supports high-availability by inspecting its capabilities. +func (self *ZitiEdgeClient) ControllerSupportsHa() bool { + self.doOnceCacheVersionInfo() + + if self.versionInfo != nil && self.versionInfo.Capabilities != nil { + return stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesHACONTROLLER)) + } + + return false +} + +// ControllerSupportsOidc checks if the controller supports OIDC authentication by inspecting its capabilities. +func (self *ZitiEdgeClient) ControllerSupportsOidc() bool { + self.doOnceCacheVersionInfo() + + if self.versionInfo != nil && self.versionInfo.Capabilities != nil { + return stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH)) + } + + return false +} + +// doOnceCacheVersionInfo caches the controller version information including capabilities on first call. +// Subsequent calls are no-ops due to sync.Once synchronization. +func (self *ZitiEdgeClient) doOnceCacheVersionInfo() { + self.versionOnce.Do(func() { + versionParams := clientInfo.NewListVersionParams() + + versionResp, _ := self.Informational.ListVersion(versionParams) + + if versionResp != nil { + self.versionInfo = versionResp.Payload.Data + } + }) +} diff --git a/edge-apis/client_edge_management.go b/edge-apis/client_edge_management.go new file mode 100644 index 00000000..08439057 --- /dev/null +++ b/edge-apis/client_edge_management.go @@ -0,0 +1,273 @@ +package edge_apis + +import ( + "crypto/x509" + "errors" + "fmt" + "net/http" + "net/url" + "sync" + + "github.com/openziti/edge-api/rest_management_api_client" + manAuth "github.com/openziti/edge-api/rest_management_api_client/authentication" + manControllers "github.com/openziti/edge-api/rest_management_api_client/controllers" + manCurApiSession "github.com/openziti/edge-api/rest_management_api_client/current_api_session" + manInfo "github.com/openziti/edge-api/rest_management_api_client/informational" + "github.com/openziti/edge-api/rest_model" + "github.com/openziti/edge-api/rest_util" + "github.com/openziti/foundation/v2/stringz" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// ManagementApiClient provides the ability to authenticate and interact with the Edge Management API. +type ManagementApiClient struct { + BaseClient[ZitiEdgeManagement] +} + +// NewManagementApiClient will assemble an ManagementApiClient. The apiUrl should be the full URL +// to the Edge Management API (e.g. `https://example.com/edge/management/v1`). +// +// The `caPool` argument should be a list of trusted root CAs. If provided as `nil` here unauthenticated requests +// will use the system certificate pool. If authentication occurs, and a certificate pool is set on the Credentials +// the certificate pool from the Credentials will be used from that point forward. Credentials implementations +// based on an identity.Identity are likely to provide a certificate pool. +// +// For OpenZiti instances not using publicly signed certificates, `ziti.GetControllerWellKnownCaPool()` can be used +// to obtain and verify the target controllers CAs. Tools should allow users to verify and accept new controllers +// that have not been verified from an outside secret (such as an enrollment token). +func NewManagementApiClient(apiUrls []*url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ManagementApiClient { + return NewManagementApiClientWithConfig(&ApiClientConfig{ + ApiUrls: apiUrls, + CaPool: caPool, + TotpCodeProvider: NewTotpCodeProviderFromChStringFunc(totpCallback), + Proxy: http.ProxyFromEnvironment, + }) +} + +// NewManagementApiClientWithConfig creates a Management API client using the provided configuration. +func NewManagementApiClientWithConfig(config *ApiClientConfig) *ManagementApiClient { + ret := &ManagementApiClient{} + ret.Schemes = rest_management_api_client.DefaultSchemes + ret.ApiBinding = "edge-management" + ret.ApiVersion = "v1" + ret.ApiUrls = config.ApiUrls + + ret.initializeComponents(config) + + transportPool := NewClientTransportPoolRandom() + + for _, apiUrl := range config.ApiUrls { + newRuntime := NewRuntime(apiUrl, ret.Schemes, ret.HttpClient) + newRuntime.DefaultAuthentication = ret + transportPool.Add(apiUrl, newRuntime) + } + + newApi := rest_management_api_client.New(transportPool, nil) + api := ZitiEdgeManagement{ + ZitiEdgeManagement: newApi, + TotpCodeProvider: config.TotpCodeProvider, + ClientTransportPool: transportPool, + } + + ret.API = &api + ret.AuthEnabledApi = &api + + api.doOnceCacheVersionInfo() + + return ret +} + +var _ AuthEnabledApi = (*ZitiEdgeManagement)(nil) +var _ OidcEnabledApi = (*ZitiEdgeManagement)(nil) + +// ZitiEdgeManagement is an alias of the go-swagger generated client that allows this package to add additional +// functionality to the alias type to implement the AuthEnabledApi interface. +type ZitiEdgeManagement struct { + *rest_management_api_client.ZitiEdgeManagement + // useOidc tracks if OIDC auth should be used + useOidc bool + + // useOidcExplicitlySet signals if useOidc was set from an external caller and should be used as is + useOidcExplicitlySet bool + + // oidcDynamicallyEnabled will cause the client to check the controller for OIDC support and use if possible as long as useOidc was not explicitly set + oidcDynamicallyEnabled bool //currently defaults false till HA release + + versionOnce sync.Once + versionInfo *rest_model.Version + + TotpCodeProvider TotpCodeProvider + ClientTransportPool ClientTransportPool + OidcRedirectUri string +} + +func (self *ZitiEdgeManagement) SetOidcRedirectUri(redirectUri string) { + self.OidcRedirectUri = redirectUri +} + +// SetClientTransportPool sets the transport pool. +func (self *ZitiEdgeManagement) SetClientTransportPool(transportPool ClientTransportPool) { + self.ClientTransportPool = transportPool +} + +// GetClientTransportPool returns the transport pool managing multiple controller endpoints for failover. +func (self *ZitiEdgeManagement) GetClientTransportPool() ClientTransportPool { + return self.ClientTransportPool +} + +// ListControllers returns the list of available controllers for high-availability failover. +func (self *ZitiEdgeManagement) ListControllers() (*rest_model.ControllersList, error) { + params := manControllers.NewListControllersParams() + resp, err := self.Controllers.ListControllers(params, nil) + if err != nil { + return nil, err + } + + return &resp.GetPayload().Data, nil +} + +func (self *ZitiEdgeManagement) Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { + self.doOnceCacheVersionInfo() + useOidc := false + + if self.useOidcExplicitlySet { + useOidc = self.useOidc + } else if self.oidcDynamicallyEnabled { + useOidc = self.ControllerSupportsOidc() + } + + if useOidc { + return self.oidcAuth(credentials, configTypes, httpClient) + } + + return self.legacyAuth(credentials, configTypes, httpClient) +} + +// legacyAuth performs zt-session token based authentication. +func (self *ZitiEdgeManagement) legacyAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { + params := manAuth.NewAuthenticateParams() + params.Auth = credentials.Payload() + params.Method = string(credentials.Method()) + params.Auth.ConfigTypes = append(params.Auth.ConfigTypes, configTypes...) + + if credentials.Method() == AuthMethodEmpty { + return nil, fmt.Errorf("auth method %s cannot be used for authentication, please provide alternate credentials", AuthMethodEmpty) + } + + certs := credentials.TlsCerts() + if len(certs) != 0 { + if transport, ok := httpClient.Transport.(TlsAwareTransport); ok { + tlsClientConf := transport.GetTlsClientConfig() + tlsClientConf.Certificates = certs + transport.CloseIdleConnections() + } + } + + resp, err := self.Authentication.Authenticate(params, getClientAuthInfoOp(credentials, httpClient)) + + if err != nil { + return nil, err + } + + return &ApiSessionLegacy{ + Detail: resp.GetPayload().Data, + RequestHeaders: credentials.GetRequestHeaders()}, err +} + +// oidcAuth performs OIDC OAuth flow based authentication. +func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) { + config := &EdgeOidcAuthConfig{ + ClientTransportPool: self.ClientTransportPool, + Credentials: credentials, + ConfigTypeOverrides: configTypeOverrides, + HttpClient: httpClient, + TotpCodeProvider: self.TotpCodeProvider, + RedirectUri: self.OidcRedirectUri, + } + return oidcAuth(config) +} + +// SetUseOidc forces OIDC mode (true) or legacy mode (false), overriding automatic detection. +func (self *ZitiEdgeManagement) SetUseOidc(use bool) { + self.useOidcExplicitlySet = true + self.useOidc = use +} + +// SetAllowOidcDynamicallyEnabled enables automatic OIDC capability detection on the controller. +func (self *ZitiEdgeManagement) SetAllowOidcDynamicallyEnabled(allow bool) { + self.oidcDynamicallyEnabled = allow +} + +// RefreshApiSession refreshes an existing API session (both legacy and OIDC types). +func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) { + switch s := apiSession.(type) { + case *ApiSessionLegacy: + params := manCurApiSession.NewGetCurrentAPISessionParams() + newApiSessionDetail, err := self.CurrentAPISession.GetCurrentAPISession(params, s) + + if err != nil { + return nil, rest_util.WrapErr(err) + } + + newApiSession := &ApiSessionLegacy{ + Detail: newApiSessionDetail.Payload.Data, + RequestHeaders: apiSession.GetRequestHeaders(), + } + + return newApiSession, nil + case *ApiSessionOidc: + tokens, err := self.ExchangeTokens(s.OidcTokens, httpClient) + + if err != nil { + return nil, err + } + + return &ApiSessionOidc{ + OidcTokens: tokens, + RequestHeaders: apiSession.GetRequestHeaders(), + }, nil + } + + return nil, errors.New("api session is an unknown type") +} + +// ExchangeTokens exchanges OIDC tokens for refreshed tokens. +func (self *ZitiEdgeManagement) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + return exchangeTokens(self.ClientTransportPool, curTokens, httpClient) +} + +// ControllerSupportsHa checks if the controller supports high-availability by inspecting its capabilities. +func (self *ZitiEdgeManagement) ControllerSupportsHa() bool { + self.doOnceCacheVersionInfo() + + if self.versionInfo != nil && self.versionInfo.Capabilities != nil { + return stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesHACONTROLLER)) + } + + return false +} + +// ControllerSupportsOidc checks if the controller supports OIDC authentication by inspecting its capabilities. +func (self *ZitiEdgeManagement) ControllerSupportsOidc() bool { + self.doOnceCacheVersionInfo() + + if self.versionInfo != nil && self.versionInfo.Capabilities != nil { + return stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH)) + } + + return false +} + +// doOnceCacheVersionInfo caches the controller version information including capabilities on first call. +// Subsequent calls are no-ops due to sync.Once synchronization. +func (self *ZitiEdgeManagement) doOnceCacheVersionInfo() { + self.versionOnce.Do(func() { + versionParams := manInfo.NewListVersionParams() + + versionResp, _ := self.Informational.ListVersion(versionParams) + + if versionResp != nil { + self.versionInfo = versionResp.Payload.Data + } + }) +} diff --git a/edge-apis/clients.go b/edge-apis/clients.go deleted file mode 100644 index 385f6c6f..00000000 --- a/edge-apis/clients.go +++ /dev/null @@ -1,331 +0,0 @@ -/* - Copyright 2019 NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package edge_apis - -import ( - "crypto/x509" - "net/http" - "net/url" - "strings" - "sync/atomic" - - "github.com/go-openapi/runtime" - openapiclient "github.com/go-openapi/runtime/client" - "github.com/go-openapi/strfmt" - "github.com/michaelquigley/pfxlog" - "github.com/openziti/edge-api/rest_client_api_client" - "github.com/openziti/edge-api/rest_management_api_client" -) - -// ApiType is an interface constraint for generics. The underlying go-swagger types only have fields, which are -// insufficient to attempt to make a generic type from. Instead, this constraint is used that points at the -// aliased types. -type ApiType interface { - ZitiEdgeManagement | ZitiEdgeClient -} - -type OidcEnabledApi interface { - // SetUseOidc forces an API Client to operate in OIDC mode (true) or legacy mode (false). The state of the controller - // is ignored and dynamic enable/disable of OIDC support is suspended. - SetUseOidc(use bool) - - // SetAllowOidcDynamicallyEnabled sets whether clients will check the controller for OIDC support or not. If supported - // OIDC is favored over legacy authentication. - SetAllowOidcDynamicallyEnabled(allow bool) -} - -// BaseClient implements the Client interface specifically for the types specified in the ApiType constraint. It -// provides shared functionality that all ApiType types require. -type BaseClient[A ApiType] struct { - API *A - AuthEnabledApi AuthEnabledApi - Components - AuthInfoWriter runtime.ClientAuthInfoWriter - ApiSession atomic.Pointer[ApiSession] - Credentials Credentials - ApiUrls []*url.URL - ApiBinding string - ApiVersion string - Schemes []string - onControllerListeners []func([]*url.URL) -} - -// Url returns the URL of the currently active controller endpoint. -func (self *BaseClient[A]) Url() url.URL { - return *self.AuthEnabledApi.GetClientTransportPool().GetActiveTransport().ApiUrl -} - -// AddOnControllerUpdateListeners registers a callback that is invoked when the list of -// available controller endpoints changes. -func (self *BaseClient[A]) AddOnControllerUpdateListeners(listener func([]*url.URL)) { - self.onControllerListeners = append(self.onControllerListeners, listener) -} - -// GetCurrentApiSession returns the ApiSession that is being used to authenticate requests. -func (self *BaseClient[A]) GetCurrentApiSession() ApiSession { - ptr := self.ApiSession.Load() - if ptr == nil { - return nil - } - - return *ptr -} - -// SetUseOidc forces the API client to operate in OIDC mode when true, or legacy mode when false. -func (self *BaseClient[A]) SetUseOidc(use bool) { - v := any(self.API) - apiType := v.(OidcEnabledApi) - apiType.SetUseOidc(use) -} - -// SetAllowOidcDynamicallyEnabled configures whether the client checks the controller for -// OIDC support and switches modes accordingly. -func (self *BaseClient[A]) SetAllowOidcDynamicallyEnabled(allow bool) { - v := any(self.API) - apiType := v.(OidcEnabledApi) - apiType.SetAllowOidcDynamicallyEnabled(allow) -} - -// Authenticate will attempt to use the provided credentials to authenticate via the underlying ApiType. On success -// the API Session details will be returned and the current client will make authenticated requests on future -// calls. On an error the API Session in use will be cleared and subsequent requests will become/continue to be -// made in an unauthenticated fashion. -func (self *BaseClient[A]) Authenticate(credentials Credentials, configTypesOverride []string) (ApiSession, error) { - - self.Credentials = nil - self.ApiSession.Store(nil) - - if credCaPool := credentials.GetCaPool(); credCaPool != nil { - self.HttpTransport.TLSClientConfig.RootCAs = credCaPool - } else { - self.HttpTransport.TLSClientConfig.RootCAs = self.CaPool - } - - apiSession, err := self.AuthEnabledApi.Authenticate(credentials, configTypesOverride, self.HttpClient) - - if err != nil { - return nil, err - } - - self.Credentials = credentials - self.ApiSession.Store(&apiSession) - - self.ProcessControllers(self.AuthEnabledApi) - - return apiSession, nil -} - -// initializeComponents assembles the lower level components necessary for the go-swagger/openapi facilities. -func (self *BaseClient[A]) initializeComponents(config *ApiClientConfig) { - components := NewComponentsWithConfig(&ComponentsConfig{ - Proxy: config.Proxy, - }) - components.HttpTransport.TLSClientConfig.RootCAs = config.CaPool - components.CaPool = config.CaPool - - self.Components = *components -} - -// NewRuntime creates an OpenAPI runtime configured for the specified API endpoint. -func NewRuntime(apiUrl *url.URL, schemes []string, httpClient *http.Client) *openapiclient.Runtime { - return openapiclient.NewWithClient(apiUrl.Host, apiUrl.Path, schemes, httpClient) -} - -// AuthenticateRequest implements the openapi runtime.ClientAuthInfoWriter interface from the OpenAPI libraries. It is used -// to authenticate outgoing requests. -func (self *BaseClient[A]) AuthenticateRequest(request runtime.ClientRequest, registry strfmt.Registry) error { - if self.AuthInfoWriter != nil { - return self.AuthInfoWriter.AuthenticateRequest(request, registry) - } - - // do not add auth to authenticating endpoints - if strings.Contains(request.GetPath(), "/oidc/auth") || strings.Contains(request.GetPath(), "/authenticate") { - return nil - } - - currentSessionPtr := self.ApiSession.Load() - if currentSessionPtr != nil { - currentSession := *currentSessionPtr - - if currentSession != nil && currentSession.GetToken() != nil { - if err := currentSession.AuthenticateRequest(request, registry); err != nil { - return err - } - } - } - - if self.Credentials != nil { - if err := self.Credentials.AuthenticateRequest(request, registry); err != nil { - return err - } - } - - return nil -} - -// ProcessControllers queries the authenticated controller for its list of peer controllers -// and registers them for high-availability failover. -func (self *BaseClient[A]) ProcessControllers(authEnabledApi AuthEnabledApi) { - list, err := authEnabledApi.ListControllers() - - if err != nil { - pfxlog.Logger().WithError(err).Debug("error listing controllers, continuing with 1 default configured controller") - return - } - - if list == nil || len(*list) <= 1 { - pfxlog.Logger().Debug("no additional controllers reported, continuing with 1 default configured controller") - return - } - - //look for matching api binding and versions - for _, controller := range *list { - apis := controller.APIAddresses[self.ApiBinding] - - for _, apiAddr := range apis { - if apiAddr.Version == self.ApiVersion { - apiUrl, parseErr := url.Parse(apiAddr.URL) - if parseErr == nil { - self.AuthEnabledApi.GetClientTransportPool().Add(apiUrl, NewRuntime(apiUrl, self.Schemes, self.HttpClient)) - } - } - } - } - - apis := self.AuthEnabledApi.GetClientTransportPool().GetApiUrls() - for _, listener := range self.onControllerListeners { - listener(apis) - } -} - -// ManagementApiClient provides the ability to authenticate and interact with the Edge Management API. -type ManagementApiClient struct { - BaseClient[ZitiEdgeManagement] -} - -// ApiClientConfig contains configuration options for creating API clients. -type ApiClientConfig struct { - ApiUrls []*url.URL - CaPool *x509.CertPool - TotpCallback func(chan string) - Proxy func(r *http.Request) (*url.URL, error) -} - -// NewManagementApiClient will assemble an ManagementApiClient. The apiUrl should be the full URL -// to the Edge Management API (e.g. `https://example.com/edge/management/v1`). -// -// The `caPool` argument should be a list of trusted root CAs. If provided as `nil` here unauthenticated requests -// will use the system certificate pool. If authentication occurs, and a certificate pool is set on the Credentials -// the certificate pool from the Credentials will be used from that point forward. Credentials implementations -// based on an identity.Identity are likely to provide a certificate pool. -// -// For OpenZiti instances not using publicly signed certificates, `ziti.GetControllerWellKnownCaPool()` can be used -// to obtain and verify the target controllers CAs. Tools should allow users to verify and accept new controllers -// that have not been verified from an outside secret (such as an enrollment token). -func NewManagementApiClient(apiUrls []*url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ManagementApiClient { - return NewManagementApiClientWithConfig(&ApiClientConfig{ - ApiUrls: apiUrls, - CaPool: caPool, - TotpCallback: totpCallback, - Proxy: http.ProxyFromEnvironment, - }) -} - -// NewManagementApiClientWithConfig creates a Management API client using the provided configuration. -func NewManagementApiClientWithConfig(config *ApiClientConfig) *ManagementApiClient { - ret := &ManagementApiClient{} - ret.Schemes = rest_management_api_client.DefaultSchemes - ret.ApiBinding = "edge-management" - ret.ApiVersion = "v1" - ret.ApiUrls = config.ApiUrls - ret.initializeComponents(config) - - transportPool := NewClientTransportPoolRandom() - - for _, apiUrl := range config.ApiUrls { - newRuntime := NewRuntime(apiUrl, ret.Schemes, ret.HttpClient) - newRuntime.DefaultAuthentication = ret - transportPool.Add(apiUrl, newRuntime) - } - - newApi := rest_management_api_client.New(transportPool, nil) - api := ZitiEdgeManagement{ - ZitiEdgeManagement: newApi, - TotpCallback: config.TotpCallback, - ClientTransportPool: transportPool, - } - - ret.API = &api - ret.AuthEnabledApi = &api - - return ret -} - -// ClientApiClient provides access to the Ziti Edge Client API for identity operations. -type ClientApiClient struct { - BaseClient[ZitiEdgeClient] -} - -// NewClientApiClient will assemble a ClientApiClient. The apiUrl should be the full URL -// to the Edge Client API (e.g. `https://example.com/edge/client/v1`). -// -// The `caPool` argument should be a list of trusted root CAs. If provided as `nil` here unauthenticated requests -// will use the system certificate pool. If authentication occurs, and a certificate pool is set on the Credentials -// the certificate pool from the Credentials will be used from that point forward. Credentials implementations -// based on an identity.Identity are likely to provide a certificate pool. -// -// For OpenZiti instances not using publicly signed certificates, `ziti.GetControllerWellKnownCaPool()` can be used -// to obtain and verify the target controllers CAs. Tools should allow users to verify and accept new controllers -// that have not been verified from an outside secret (such as an enrollment token). -func NewClientApiClient(apiUrls []*url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ClientApiClient { - return NewClientApiClientWithConfig(&ApiClientConfig{ - ApiUrls: apiUrls, - CaPool: caPool, - TotpCallback: totpCallback, - Proxy: http.ProxyFromEnvironment, - }) -} - -// NewClientApiClientWithConfig creates a Client API client using the provided configuration. -func NewClientApiClientWithConfig(config *ApiClientConfig) *ClientApiClient { - ret := &ClientApiClient{} - ret.ApiBinding = "edge-client" - ret.ApiVersion = "v1" - ret.Schemes = rest_client_api_client.DefaultSchemes - ret.ApiUrls = config.ApiUrls - - ret.initializeComponents(config) - - transportPool := NewClientTransportPoolRandom() - - for _, apiUrl := range config.ApiUrls { - newRuntime := NewRuntime(apiUrl, ret.Schemes, ret.HttpClient) - newRuntime.DefaultAuthentication = ret - transportPool.Add(apiUrl, newRuntime) - } - - newApi := rest_client_api_client.New(transportPool, nil) - api := ZitiEdgeClient{ - ZitiEdgeClient: newApi, - TotpCallback: config.TotpCallback, - ClientTransportPool: transportPool, - } - ret.API = &api - ret.AuthEnabledApi = &api - - return ret -} diff --git a/edge-apis/clients_shared.go b/edge-apis/clients_shared.go new file mode 100644 index 00000000..ef8844cb --- /dev/null +++ b/edge-apis/clients_shared.go @@ -0,0 +1,728 @@ +/* + Copyright 2019 NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package edge_apis + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/strfmt" + "github.com/go-resty/resty/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/openziti/edge-api/rest_model" + "github.com/zitadel/oidc/v3/pkg/client/tokenexchange" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" +) + +// DefaultOidcRedirectUri is the default redirect URI for the OIDC PKCE flow that satisfies the default OIDC redirects +// for the Ziti Edge OIDC API. It is not an actual server, rather an intercepted redirect URI that is used to extract +// the resulting OIDC tokens. +const DefaultOidcRedirectUri = "http://localhost:8080/auth/callback" + +// ApiType is an interface constraint for generics. The underlying go-swagger types only have fields, which are +// insufficient to attempt to make a generic type from. Instead, this constraint is used that points at the +// aliased types. +type ApiType interface { + ZitiEdgeManagement | ZitiEdgeClient +} + +type OidcEnabledApi interface { + // SetUseOidc forces an API Client to operate in OIDC mode (true) or legacy mode (false). The state of the controller + // is ignored and dynamic enable/disable of OIDC support is suspended. + SetUseOidc(use bool) + + // SetAllowOidcDynamicallyEnabled sets whether clients will check the controller for OIDC support or not. If supported + // OIDC is favored over legacy authentication. + SetAllowOidcDynamicallyEnabled(allow bool) + + // SetOidcRedirectUri sets the redirect URI for the OIDC PKCE flow. The default value is used if not set. + // Should only be necessary to call for custom redirect controller configurations. + SetOidcRedirectUri(redirectUri string) +} + +// EdgeOidcAuthConfig represents the options necessary to complete an OAuth 2.0 PKCE authentication flow against an +// OpenZiti controller. +type EdgeOidcAuthConfig struct { + ClientTransportPool ClientTransportPool + Credentials Credentials + ConfigTypeOverrides []string + HttpClient *http.Client + TotpCodeProvider TotpCodeProvider + RedirectUri string + ApiHost string +} + +// ApiClientConfig contains configuration options for creating API clients. +type ApiClientConfig struct { + ApiUrls []*url.URL + CaPool *x509.CertPool + TotpCodeProvider TotpCodeProvider + Components *Components + Proxy func(r *http.Request) (*url.URL, error) +} + +// exchangeTokens exchanges OIDC tokens for refreshed tokens. It uses refresh tokens preferentially, +// falling back to non-expired access tokens if refresh is unavailable. +func exchangeTokens(clientTransportPool ClientTransportPool, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + subjectToken := "" + var subjectTokenType oidc.TokenType + + if curTokens.RefreshToken != "" { + subjectToken = curTokens.RefreshToken + subjectTokenType = oidc.RefreshTokenType + } else if curTokens.AccessToken != "" { + // if subjectToken is "", then we don't have a refresh token, attempt to exchange a non-expired access token + expired, err := isAccessTokenExpired(curTokens) + + if err != nil { + return nil, err + } + + if expired { + return nil, errors.New("cannot exchange token: refresh token not found, access token expired") + } + + if curTokens.AccessToken == "" { + return nil, errors.New("cannot exchange token: refresh token not found, access token not found") + } + subjectToken = curTokens.AccessToken + subjectTokenType = oidc.AccessTokenType + } + + if subjectToken == "" { + return nil, errors.New("cannot exchange token: refresh token not found, access token not found or expired") + } + + var outTokens *oidc.Tokens[*oidc.IDTokenClaims] + + _, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { + timeoutCtx, cancelF := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelF() + + apiHost := transport.ApiUrl.Host + issuer := "https://" + apiHost + "/oidc" + tokenEndpoint := "https://" + apiHost + "/oidc/oauth/token" + + te, err := tokenexchange.NewTokenExchangerClientCredentials(timeoutCtx, issuer, "native", "", tokenexchange.WithHTTPClient(client), tokenexchange.WithStaticTokenEndpoint(issuer, tokenEndpoint)) + + if err != nil { + return nil, err + } + + var tokenResponse *oidc.TokenExchangeResponse + + now := time.Now() + + switch subjectTokenType { + case oidc.RefreshTokenType: + tokenResponse, err = tokenexchange.ExchangeToken(timeoutCtx, te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType) + case oidc.AccessTokenType: + tokenResponse, err = tokenexchange.ExchangeToken(timeoutCtx, te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.AccessTokenType) + } + + if err != nil { + return nil, err + } + + idResp, err := tokenexchange.ExchangeToken(timeoutCtx, te, subjectToken, subjectTokenType, "", "", nil, nil, nil, oidc.IDTokenType) + + if err != nil { + return nil, err + } + + idClaims := &IdClaims{} + + //access token is used to hold id token per zitadel comments + _, _, err = jwt.NewParser().ParseUnverified(idResp.AccessToken, idClaims) + + if err != nil { + return nil, err + } + + outTokens = &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: tokenResponse.AccessToken, + TokenType: tokenResponse.TokenType, + RefreshToken: tokenResponse.RefreshToken, + Expiry: now.Add(time.Second * time.Duration(tokenResponse.ExpiresIn)), + }, + IDTokenClaims: &idClaims.IDTokenClaims, + IDToken: idResp.AccessToken, //access token field is used to hold id token per zitadel comments + } + + return outTokens, nil + }) + + if err != nil { + return nil, err + } + + return outTokens, nil +} + +// isAccessTokenExpired checks if an access token is expired. If token metadata is unavailable, +// it parses the JWT claims to determine expiration. +func isAccessTokenExpired(tokens *oidc.Tokens[*oidc.IDTokenClaims]) (bool, error) { + if tokens.Expiry.IsZero() { + //meta data isn't set, we need to parse the token + idClaims := &IdClaims{} + _, _, err := jwt.NewParser().ParseUnverified(tokens.AccessToken, idClaims) + + if err != nil { + return true, fmt.Errorf("token meta data is empty, could not parse token to determine token validity: %w", err) + } + + //failed to parse out a required exp field for oAuth2, we have no idea of this token is good + if idClaims.GetExpiration().IsZero() { + return true, errors.New("token meta data is empty, parsed token does not have an expiration value") + } + + return idClaims.GetExpiration().Before(time.Now()), nil + } + + return tokens.Expiry.Before(time.Now()), nil +} + +type authPayload struct { + *rest_model.Authenticate + AuthRequestId string `json:"id"` +} + +type totpCodePayload struct { + rest_model.MfaCode + AuthRequestId string `json:"id"` +} + +func (a *authPayload) toValues() url.Values { + result := url.Values{ + "id": []string{a.AuthRequestId}, + "password": []string{string(a.Password)}, + "username": []string{string(a.Username)}, + "configTypes": a.ConfigTypes, + "envArch": []string{a.EnvInfo.Arch}, + "envOs": []string{a.EnvInfo.Os}, + "envOsRelease": []string{a.EnvInfo.OsRelease}, + "envOsVersion": []string{a.EnvInfo.OsVersion}, + "sdkAppID": []string{a.SdkInfo.AppID}, + "sdkAppVersion": []string{a.SdkInfo.AppVersion}, + "sdkBranch": []string{a.SdkInfo.Branch}, + "sdkRevision": []string{a.SdkInfo.Revision}, + "sdkType": []string{a.SdkInfo.Type}, + "sdkVersion": []string{a.SdkInfo.Version}, + } + + return result +} + +// oidcAuth performs OIDC authentication using OAuth flow with PKCE. +// It handles TOTP if required and returns an OIDC session with tokens. +func oidcAuth(config *EdgeOidcAuthConfig) (ApiSession, error) { + if config.Credentials.Method() == AuthMethodEmpty { + return nil, fmt.Errorf("auth method %s cannot be used for authentication, please provide alternate credentials", AuthMethodEmpty) + } + + certificates := config.Credentials.TlsCerts() + + if len(certificates) != 0 { + if transport, ok := config.HttpClient.Transport.(TlsAwareTransport); ok { + tlsClientConf := transport.GetTlsClientConfig() + tlsClientConf.Certificates = certificates + transport.CloseIdleConnections() + } + } + + var outTokens *oidc.Tokens[*oidc.IDTokenClaims] + + _, err := config.ClientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { + config.ApiHost = transport.ApiUrl.Host + edgeOidcAuth := NewEdgeOidcAuthenticator(config) + + var err error + outTokens, err = edgeOidcAuth.Authenticate() + + if err != nil { + return nil, err + } + + return outTokens, nil + }) + + if err != nil { + return nil, err + } + + return &ApiSessionOidc{ + OidcTokens: outTokens, + RequestHeaders: config.Credentials.GetRequestHeaders(), + }, nil +} + +// EdgeOidcAuthenticator handles the OAuth 2.0 PKCE authentication flow for the Ziti Edge API. +// It submits user credentials to the authorization endpoint, handles optional TOTP verification, +// and exchanges the authorization code for OIDC tokens. The HTTP client follows redirects +// during the authorization flow and extracts the authorization code from the final redirect. +type EdgeOidcAuthenticator struct { + *EdgeOidcAuthConfig + restyClient *resty.Client +} + +// NewEdgeOidcAuthenticator creates a new EdgeOidcAuthenticator configured for PKCE authentication. +// It sets up an HTTP client with a custom redirect policy that follows redirects during the +// authorization flow but stops when the callback redirect URI is reached, allowing code extraction +// from the redirect URL. The redirectUri parameter defines where the authorization server will +// redirect with the authorization code in the query parameters. +func NewEdgeOidcAuthenticator(config *EdgeOidcAuthConfig) *EdgeOidcAuthenticator { + client := resty.NewWithClient(config.HttpClient) + + if config.RedirectUri == "" { + config.RedirectUri = DefaultOidcRedirectUri + } + + // allows resty to follow redirects for us during the OAuth flow, but not for the end PKCE callback + // there is no server running for that redirect to hit, as it is this code + client.SetRedirectPolicy(RedirectUntilUrlPrefix(DefaultOidcRedirectUri)) + + return &EdgeOidcAuthenticator{ + EdgeOidcAuthConfig: config, + restyClient: client, + } +} + +// SetRedirectUri sets the redirect URI for the authorization server. The default value is +// included in the default Edge OIDC controller configuration, but if it has been set to custom +// values, this function can be used to reflect that configuration. +func (e *EdgeOidcAuthenticator) SetRedirectUri(redirectUri string) { + e.RedirectUri = redirectUri +} + +// Authenticate performs the complete OAuth 2.0 PKCE authentication flow. It initiates authorization +// with PKCE parameters, submits credentials and handles optional TOTP verification, then exchanges +// the resulting authorization code for OIDC tokens. +func (e *EdgeOidcAuthenticator) Authenticate() (*oidc.Tokens[*oidc.IDTokenClaims], error) { + pkceParams, err := newPkceParameters() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE parameters: %w", err) + } + + verificationParams, err := e.initOAuthFlow(pkceParams) + + if err != nil { + return nil, fmt.Errorf("failed to initiate authorization flow: %w", err) + } + + redirectResp, err := e.handlePrimaryAndSecondaryAuth(verificationParams) + if err != nil { + return nil, err + } + + tokens, err := e.finishOAuthFlow(redirectResp, verificationParams, pkceParams) + if err != nil { + return nil, err + } + + return tokens, nil +} + +// finishOAuthFlow extracts the authorization code from the callback redirect and exchanges it for tokens. +// The authorization server returns the code as a query parameter in the Location header of the redirect response. +// The code is then used with the PKCE verifier to obtain OIDC tokens via the token endpoint. +func (e *EdgeOidcAuthenticator) finishOAuthFlow(redirectResp *resty.Response, verificationParams *verificationParameters, pkceParams *pkceParameters) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + if redirectResp.StatusCode() != http.StatusFound { + return nil, fmt.Errorf("authentication failed, expected a 302, got %d", redirectResp.StatusCode()) + } + + redirectStr := redirectResp.Header().Get("Location") + redirectUrl, err := url.Parse(redirectStr) + if err != nil { + return nil, fmt.Errorf("authentication failed, could not parse redirect url [%s]: %w", redirectStr, err) + } + + state := redirectUrl.Query().Get("state") + + if state == "" { + return nil, errors.New("authentication failed, no state found in redirect url") + } + + if state != verificationParams.State { + return nil, errors.New("authentication failed, state mismatch") + } + + code := redirectUrl.Query().Get("code") + if code == "" { + return nil, errors.New("authentication failed, no code found in redirect url") + } + + tokens, err := e.exchangeAuthorizationCodeForTokens(code, pkceParams) + if err != nil { + return nil, fmt.Errorf("failed to exchange authorization code: %w", err) + } + + if tokens.IDTokenClaims.Nonce != verificationParams.Nonce { + return nil, errors.New("authentication failed, nonce mismatch") + } + + return tokens, nil +} + +// handlePrimaryAndSecondaryAuth submits credentials to the authorization endpoint and handles optional TOTP. +func (e *EdgeOidcAuthenticator) handlePrimaryAndSecondaryAuth(verificationParams *verificationParameters) (*resty.Response, error) { + loginUri := "https://" + e.ApiHost + "/oidc/login/" + string(e.Credentials.Method()) + totpUri := "https://" + e.ApiHost + "/oidc/login/totp" + + payload := &authPayload{ + Authenticate: e.Credentials.Payload(), + AuthRequestId: verificationParams.AuthRequestId, + } + + if e.ConfigTypeOverrides != nil { + payload.ConfigTypes = e.ConfigTypeOverrides + } + + formData := payload.toValues() + req := e.restyClient.R() + clientRequest := asClientRequest(req, e.restyClient) + + err := e.Credentials.AuthenticateRequest(clientRequest, strfmt.Default) + if err != nil { + return nil, err + } + + resp, err := req.SetFormDataFromValues(formData).Post(loginUri) + if err != nil { + return nil, err + } + + // no additional secondary authentication required + if resp.StatusCode() == http.StatusFound { + return resp, nil + } + + // something went wrong + if resp.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("credential submission failed with status %d", resp.StatusCode()) + } + + totpRequiredHeader := resp.Header().Get(TotpRequiredHeader) + if totpRequiredHeader == "" { + return nil, errors.New("response was not a redirect and TOTP is not required, unknown additional authentication steps are required but unsupported") + } + + if e.TotpCodeProvider == nil { + return nil, errors.New("totp is required but no totp callback was defined") + } + + totpCodeResultCh := e.TotpCodeProvider.GetTotpCode() + var totpCode string + + select { + case totpCodeResult := <-totpCodeResultCh: + if totpCodeResult.Err != nil { + return nil, fmt.Errorf("error getting totp code: %w", totpCodeResult.Err) + } + totpCode = totpCodeResult.Code + case <-time.After(30 * time.Minute): + return nil, fmt.Errorf("timeout waiting for totp code provider") + } + + resp, err = e.restyClient.R().SetBody(&totpCodePayload{ + MfaCode: rest_model.MfaCode{ + Code: &totpCode, + }, + AuthRequestId: payload.AuthRequestId, + }).Post(totpUri) + + if err != nil { + return nil, err + } + + switch resp.StatusCode() { + case http.StatusOK: + return nil, errors.New("totp code verified, but additional authentication is required that is not supported or not configured, cannot authenticate") + case http.StatusFound: + return resp, nil + case http.StatusBadRequest: + return nil, errors.New("totp code did not verify") + default: + return nil, fmt.Errorf("unexpected response code %d from TOTP verification", resp.StatusCode()) + } +} + +// initOAuthFlow initiates the OAuth authorization request with PKCE parameters and returns the authorization request ID. +func (e *EdgeOidcAuthenticator) initOAuthFlow(pkceParams *pkceParameters) (*verificationParameters, error) { + verificationParams := &verificationParameters{ + State: generateRandomState(), + Nonce: generateNonce(), + } + + authUrl := "https://" + e.ApiHost + "/oidc/authorize?" + url.Values{ + "client_id": []string{"native"}, + "response_type": []string{"code"}, + "scope": []string{"openid offline_access"}, + "state": []string{verificationParams.State}, + "code_challenge": []string{pkceParams.Challenge}, + "code_challenge_method": []string{pkceParams.Method}, + "redirect_uri": []string{e.RedirectUri}, + "nonce": []string{verificationParams.Nonce}, + }.Encode() + + resp, err := e.restyClient.R().SetDoNotParseResponse(true).Get(authUrl) + if err != nil { + return nil, err + } + defer func() { _ = resp.RawResponse.Body.Close() }() + + if resp.StatusCode() != http.StatusOK { + body, _ := io.ReadAll(resp.RawResponse.Body) + + if len(body) == 0 { + body = []byte("") + } + + return nil, fmt.Errorf("authentication request start failed with status %d, either a misconfigured request was sent or the expected redirect URL (%s) is not allowed: %s", resp.StatusCode(), e.RedirectUri, body) + } + + verificationParams.AuthRequestId = resp.Header().Get(AuthRequestIdHeader) + if verificationParams.AuthRequestId == "" { + return nil, errors.New("could not find auth request id header from authorize endpoint") + } + + return verificationParams, nil +} + +// RedirectUntilUrlPrefix returns a redirect policy that follows redirects until the request URL +// matches one of the provided URL prefixes. Once a matching prefix is encountered, the redirect +// is not followed, allowing the caller to inspect the redirect response. +func RedirectUntilUrlPrefix(urlPrefixToStopAt ...string) resty.RedirectPolicy { + return resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { + reqUrl := req.URL.String() + for _, urlToStopAt := range urlPrefixToStopAt { + if strings.HasPrefix(reqUrl, urlToStopAt) { + return http.ErrUseLastResponse + } + } + return nil + }) +} + +// exchangeAuthorizationCodeForTokens exchanges an authorization code and PKCE verifier for OIDC tokens. +func (e *EdgeOidcAuthenticator) exchangeAuthorizationCodeForTokens(code string, pkceParams *pkceParameters) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + tokenEndpoint := "https://" + e.ApiHost + "/oidc/oauth/token" + + tokenResp, err := e.restyClient.R().SetFormData(map[string]string{ + "grant_type": "authorization_code", + "client_id": "native", + "code_verifier": pkceParams.Verifier, + "code": code, + "redirect_uri": DefaultOidcRedirectUri, + }).Post(tokenEndpoint) + + if err != nil { + return nil, fmt.Errorf("failed to exchange authorization code for tokens: %w", err) + } + + if tokenResp.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("token exchange failed with status %d: %s", tokenResp.StatusCode(), string(tokenResp.Body())) + } + + // Parse token response + var tokenData map[string]interface{} + err = json.Unmarshal(tokenResp.Body(), &tokenData) + if err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + accessToken, ok := tokenData["access_token"].(string) + if !ok { + return nil, errors.New("access_token not found in token response") + } + + refreshToken, _ := tokenData["refresh_token"].(string) + expiresIn, _ := tokenData["expires_in"].(float64) + + // Parse ID token + idToken, _ := tokenData["id_token"].(string) + idClaims := &IdClaims{} + + if idToken != "" { + _, _, err = jwt.NewParser().ParseUnverified(idToken, idClaims) + if err != nil { + // Log but don't fail if ID token parsing fails + return nil, fmt.Errorf("failed to parse ID token: %w", err) + } + } + + tokens := &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + RefreshToken: refreshToken, + Expiry: time.Now().Add(time.Duration(expiresIn) * time.Second), + }, + IDTokenClaims: &idClaims.IDTokenClaims, + IDToken: idToken, + } + + return tokens, nil +} + +// pkceParameters holds the PKCE parameters used for OAuth 2.0 Proof Key for Public Clients flow. +type pkceParameters struct { + Verifier string + Challenge string + Method string +} + +type verificationParameters struct { + State string + AuthRequestId string + Nonce string +} + +// newPkceParameters generates PKCE parameters for OAuth 2.0 PKCE flow. +// It creates a random code verifier and derives the code challenge by applying SHA256 hashing. +func newPkceParameters() (*pkceParameters, error) { + var err error + params := &pkceParameters{ + Method: "S256", + } + + b := make([]byte, 32) + _, err = rand.Read(b) + if err != nil { + return nil, fmt.Errorf("failed to generate random bytes: %w", err) + } + params.Verifier = base64URLEncodeNoPadding(b) + + hash := sha256.Sum256([]byte(params.Verifier)) + params.Challenge = base64URLEncodeNoPadding(hash[:]) + + return params, nil +} + +// generateRandomState generates a random state string for CSRF protection. +func generateRandomState() string { + b := make([]byte, 16) + _, _ = rand.Read(b) + return base64URLEncodeNoPadding(b) +} + +// generateNonce generates a random nonce for binding the authorization request to the ID token. +func generateNonce() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(b) +} + +// base64URLEncodeNoPadding encodes data to base64URL format without padding. +// Padding is removed because base64URL is designed to work in URLs and query strings where +// the '=' character may have special meaning. +func base64URLEncodeNoPadding(data []byte) string { + encoded := base64.URLEncoding.EncodeToString(data) + return strings.TrimRight(encoded, "=") +} + +// restyClientRequest is meant to mimic open api's client request which is a combination +// of resty's request and client. +type restyClientRequest struct { + restyRequest *resty.Request + restyClient *resty.Client +} + +func (r *restyClientRequest) SetHeaderParam(s string, s2 ...string) error { + r.restyRequest.Header[s] = s2 + return nil +} + +func (r *restyClientRequest) GetHeaderParams() http.Header { + return r.restyRequest.Header +} + +func (r *restyClientRequest) SetQueryParam(s string, s2 ...string) error { + r.restyRequest.QueryParam[s] = s2 + return nil +} + +func (r *restyClientRequest) SetFormParam(s string, s2 ...string) error { + r.restyRequest.FormData[s] = s2 + return nil +} + +func (r *restyClientRequest) SetPathParam(s string, s2 string) error { + r.restyRequest.PathParams[s] = s2 + return nil +} + +func (r *restyClientRequest) GetQueryParams() url.Values { + return r.restyRequest.QueryParam +} + +func (r *restyClientRequest) SetFileParam(s string, closer ...runtime.NamedReadCloser) error { + for _, curCloser := range closer { + r.restyRequest.SetFileReader(s, curCloser.Name(), curCloser) + } + + return nil +} + +func (r *restyClientRequest) SetBodyParam(i interface{}) error { + r.restyRequest.SetBody(i) + return nil +} + +func (r *restyClientRequest) SetTimeout(duration time.Duration) error { + r.restyClient.SetTimeout(duration) + return nil +} + +func (r *restyClientRequest) GetMethod() string { + return r.restyRequest.Method +} + +func (r *restyClientRequest) GetPath() string { + return r.restyRequest.URL +} + +func (r *restyClientRequest) GetBody() []byte { + return r.restyRequest.Body.([]byte) +} + +func (r *restyClientRequest) GetBodyParam() interface{} { + return r.restyRequest.Body +} + +func (r *restyClientRequest) GetFileParam() map[string][]runtime.NamedReadCloser { + return nil +} + +func asClientRequest(request *resty.Request, client *resty.Client) runtime.ClientRequest { + return &restyClientRequest{request, client} +} diff --git a/edge-apis/component.go b/edge-apis/component.go deleted file mode 100644 index 4bc78999..00000000 --- a/edge-apis/component.go +++ /dev/null @@ -1,62 +0,0 @@ -package edge_apis - -import ( - "crypto/x509" - "github.com/openziti/edge-api/rest_util" - "net/http" - "net/http/cookiejar" - "net/url" - "time" -) - -// Components provides the foundational HTTP client infrastructure for OpenAPI clients, -// bundling the HTTP client, transport, and certificate pool as a cohesive unit. -type Components struct { - HttpClient *http.Client - HttpTransport *http.Transport - CaPool *x509.CertPool -} - -// ComponentsConfig contains configuration options for creating Components. -type ComponentsConfig struct { - Proxy func(*http.Request) (*url.URL, error) -} - -// NewComponents assembles a new set of components with reasonable production defaults. -func NewComponents() *Components { - return NewComponentsWithConfig(&ComponentsConfig{ - Proxy: http.ProxyFromEnvironment, - }) -} - -// NewComponentsWithConfig assembles a new set of components using the provided configuration. -func NewComponentsWithConfig(cfg *ComponentsConfig) *Components { - tlsClientConfig, _ := rest_util.NewTlsConfig() - - httpTransport := &http.Transport{ - TLSClientConfig: tlsClientConfig, - ForceAttemptHTTP2: true, - MaxIdleConns: 10, - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - - if cfg != nil && cfg.Proxy != nil { - httpTransport.Proxy = cfg.Proxy - } - - jar, _ := cookiejar.New(nil) - - httpClient := &http.Client{ - Transport: httpTransport, - CheckRedirect: nil, - Jar: jar, - Timeout: 10 * time.Second, - } - - return &Components{ - HttpClient: httpClient, - HttpTransport: httpTransport, - } -} diff --git a/edge-apis/credentials.go b/edge-apis/credentials.go index f0dcda91..1e6201bd 100644 --- a/edge-apis/credentials.go +++ b/edge-apis/credentials.go @@ -1,16 +1,28 @@ package edge_apis import ( + "bytes" "crypto" "crypto/tls" "crypto/x509" + "net/http" + "github.com/go-openapi/runtime" "github.com/go-openapi/strfmt" + "github.com/michaelquigley/pfxlog" "github.com/openziti/edge-api/rest_model" "github.com/openziti/identity" "github.com/openziti/sdk-golang/ziti/edge/network" "github.com/openziti/sdk-golang/ziti/sdkinfo" - "net/http" +) + +type AuthMethod string + +const ( + AuthMethodCert AuthMethod = "cert" + AuthMethodUpdb AuthMethod = "password" + AuthMethodEmpty AuthMethod = "empty" + AuthMethodJwtExt AuthMethod = "ext-jwt" ) // Credentials represents the minimal information needed across all authentication mechanisms to authenticate an identity @@ -26,7 +38,7 @@ type Credentials interface { GetCaPool() *x509.CertPool // Method returns the authentication necessary to complete an authentication request. - Method() string + Method() AuthMethod // AddAuthHeader adds a header for all authentication requests. AddAuthHeader(key, value string) @@ -218,6 +230,24 @@ type CertCredentials struct { // be provided and the certificate at index zero is assumed to be the leaf client certificate that pairs with the // provided private key. All other certificates are assumed to support the leaf client certificate as a chain. func NewCertCredentials(certs []*x509.Certificate, key crypto.PrivateKey) *CertCredentials { + + leaf := certs[0] + + leafPub := leaf.PublicKey + keySigner, ok := key.(crypto.Signer) + + if ok { + keyPub := keySigner.Public() + + leafPubBytes, _ := x509.MarshalPKIXPublicKey(leafPub) + keyPubBytes, _ := x509.MarshalPKIXPublicKey(keyPub) + if !bytes.Equal(leafPubBytes, keyPubBytes) { + pfxlog.Logger().Warn("key and leaf certificates do not match for NewCertCredentials, cannot verify certificate/key match") + } + } else { + pfxlog.Logger().Warn("key is not a crypto.Signer, cannot verify certificate/key match") + } + return &CertCredentials{ BaseCredentials: BaseCredentials{}, Certs: certs, @@ -225,8 +255,8 @@ func NewCertCredentials(certs []*x509.Certificate, key crypto.PrivateKey) *CertC } } -func (c *CertCredentials) Method() string { - return "cert" +func (c *CertCredentials) Method() AuthMethod { + return AuthMethodCert } func (c *CertCredentials) TlsCerts() []tls.Certificate { @@ -264,8 +294,8 @@ func (c *IdentityCredentials) GetIdentity() identity.Identity { return c.Identity } -func (c *IdentityCredentials) Method() string { - return "cert" +func (c *IdentityCredentials) Method() AuthMethod { + return AuthMethodCert } func (c *IdentityCredentials) GetCaPool() *x509.CertPool { @@ -301,8 +331,8 @@ func NewJwtCredentials(jwt string) *JwtCredentials { } } -func (c *JwtCredentials) Method() string { - return "ext-jwt" +func (c *JwtCredentials) Method() AuthMethod { + return AuthMethodJwtExt } func (c *JwtCredentials) AuthenticateRequest(request runtime.ClientRequest, reg strfmt.Registry) error { @@ -330,8 +360,8 @@ type UpdbCredentials struct { Password string } -func (c *UpdbCredentials) Method() string { - return "password" +func (c *UpdbCredentials) Method() AuthMethod { + return AuthMethodUpdb } // NewUpdbCredentials creates a Credentials instance based on a username/passwords combination. @@ -354,3 +384,13 @@ func (c *UpdbCredentials) Payload() *rest_model.Authenticate { func (c *UpdbCredentials) AuthenticateRequest(request runtime.ClientRequest, reg strfmt.Registry) error { return c.BaseCredentials.AuthenticateRequest(request, reg) } + +var _ Credentials = (*EmptyCredentials)(nil) + +type EmptyCredentials struct { + BaseCredentials +} + +func (e EmptyCredentials) Method() AuthMethod { + return AuthMethodEmpty +} diff --git a/edge-apis/http_components.go b/edge-apis/http_components.go new file mode 100644 index 00000000..64f6f3e5 --- /dev/null +++ b/edge-apis/http_components.go @@ -0,0 +1,147 @@ +package edge_apis + +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/cookiejar" + "net/url" + "time" + + "github.com/michaelquigley/pfxlog" + "github.com/openziti/edge-api/rest_util" +) + +// Components provides the foundational HTTP client infrastructure for OpenAPI clients, +// bundling the HTTP client, transport, and certificate pool as a cohesive unit. +type Components struct { + HttpClient *http.Client + TlsAwareTransport TlsAwareTransport + CaPool *x509.CertPool +} + +// assertComponents ensures that the components are initialized properly. +func (c Components) assertComponents(config *ApiClientConfig) { + if config.Components.HttpClient == nil { + pfxlog.Logger().Warn("components were provided but the http client was nil, it is being initialized") + + if config.Components.TlsAwareTransport == nil { + config.Components.TlsAwareTransport = NewTlsAwareHttpTransport(nil) + pfxlog.Logger().Warn("components were provided but the client and transport are nil, they are being initialized with a default") + } + + config.Components.HttpClient = NewHttpClient(config.Components.TlsAwareTransport) + } + + if config.Components.TlsAwareTransport == nil { + if tlsAwareTransport, ok := config.Components.HttpClient.Transport.(TlsAwareTransport); ok { + config.Components.TlsAwareTransport = tlsAwareTransport + pfxlog.Logger().Warn("components were provided but the transport was nil, it is being initialized with the transport from the http client") + } else { + pfxlog.Logger().Warn("components were provided but the transport was nil and the client did not have a suitable transport, it is being initialized with a default") + config.Components.TlsAwareTransport = NewTlsAwareHttpTransport(nil) + config.Components.HttpClient.Transport = config.Components.TlsAwareTransport + } + } + + if config.Components.HttpClient.Transport != config.Components.TlsAwareTransport { + pfxlog.Logger().Warn("components were provided but the http client transport was not the same as the transport in components, it is being initialized") + config.Components.HttpClient.Transport = config.Components.TlsAwareTransport + } +} + +// ComponentsConfig contains configuration options for creating Components. +type ComponentsConfig struct { + Proxy func(*http.Request) (*url.URL, error) +} + +// NewComponentsWithConfig assembles a new set of components using the provided configuration. +func NewComponentsWithConfig(cfg *ComponentsConfig) *Components { + tlsAwareHttpTransport := NewTlsAwareHttpTransport(cfg) + httpClient := NewHttpClient(tlsAwareHttpTransport) + + return &Components{ + HttpClient: httpClient, + TlsAwareTransport: tlsAwareHttpTransport, + } +} + +// NewHttpClient creates an HTTP client with the given transport. +func NewHttpClient(tlsAwareHttpTransport TlsAwareTransport) *http.Client { + jar, _ := cookiejar.New(nil) + return &http.Client{ + Transport: tlsAwareHttpTransport, + CheckRedirect: nil, + Jar: jar, + Timeout: 10 * time.Second, + } +} + +// TlsAwareTransport abstracts HTTP transport to allow API implementations to dynamically +// configure TLS settings during authentication (e.g., adding client certificates) and manage +// proxy configuration. +type TlsAwareTransport interface { + http.RoundTripper + + // GetTlsClientConfig returns the current TLS configuration. + GetTlsClientConfig() *tls.Config + // SetTlsClientConfig updates the TLS configuration. + SetTlsClientConfig(*tls.Config) + + // SetProxy sets the proxy function for HTTP requests. + SetProxy(func(*http.Request) (*url.URL, error)) + // GetProxy returns the current proxy function. + GetProxy() func(*http.Request) (*url.URL, error) + + // CloseIdleConnections closes all idle HTTP connections. + CloseIdleConnections() +} + +var _ TlsAwareTransport = (*TlsAwareHttpTransport)(nil) + +// TlsAwareHttpTransport is a concrete implementation of TlsAwareTransport that wraps http.Transport. +type TlsAwareHttpTransport struct { + *http.Transport +} + +// NewTlsAwareHttpTransport creates a TlsAwareHttpTransport with default HTTP/2 and TLS settings. +func NewTlsAwareHttpTransport(cfg *ComponentsConfig) *TlsAwareHttpTransport { + tlsClientConfig, _ := rest_util.NewTlsConfig() + + authAwareTransport := &TlsAwareHttpTransport{ + &http.Transport{ + TLSClientConfig: tlsClientConfig, + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + IdleConnTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + + if cfg != nil && cfg.Proxy != nil { + authAwareTransport.Proxy = cfg.Proxy + } + + return authAwareTransport +} + +// GetProxy returns the proxy function currently set on the transport. +func (a *TlsAwareHttpTransport) GetProxy() func(*http.Request) (*url.URL, error) { + return a.Proxy +} + +// SetProxy sets the proxy function for the transport. +func (a *TlsAwareHttpTransport) SetProxy(proxyFunc func(*http.Request) (*url.URL, error)) { + a.Proxy = proxyFunc +} + +// GetTlsClientConfig returns the TLS configuration from the underlying transport. +func (a *TlsAwareHttpTransport) GetTlsClientConfig() *tls.Config { + return a.TLSClientConfig +} + +// SetTlsClientConfig updates the TLS configuration on the underlying transport. +func (a *TlsAwareHttpTransport) SetTlsClientConfig(config *tls.Config) { + a.TLSClientConfig = config +} diff --git a/edge-apis/oidc.go b/edge-apis/oidc.go index 0af71bc7..e5e85ddc 100644 --- a/edge-apis/oidc.go +++ b/edge-apis/oidc.go @@ -1,20 +1,8 @@ package edge_apis import ( - "context" - "crypto/rand" - "crypto/tls" - "fmt" "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/michaelquigley/pfxlog" - "github.com/zitadel/oidc/v3/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" - "net" - "net/http" - "net/http/cookiejar" - "time" ) // JwtTokenPrefix is the standard prefix for JWT tokens, representing the first two characters @@ -76,142 +64,3 @@ func (r *IdClaims) GetSubject() (string, error) { func (r *IdClaims) GetAudience() (jwt.ClaimStrings, error) { return jwt.ClaimStrings(r.Audience), nil } - -// localRpServer manages a local HTTP server for OpenID Connect relying party operations, -// handling OAuth callbacks and token exchanges during authentication flows. -type localRpServer struct { - Server *http.Server - Port string - Listener net.Listener - TokenChan chan *oidc.Tokens[*oidc.IDTokenClaims] - CallbackPath string - CallbackUri string - LoginUri string -} - -// Stop shuts down the local server and closes the token channel. -func (t *localRpServer) Stop() { - _ = t.Server.Shutdown(context.Background()) - close(t.TokenChan) -} - -// Start launches the local server and waits for it to become available. -func (t *localRpServer) Start() { - go func() { - _ = t.Server.Serve(t.Listener) - }() - - started := make(chan struct{}) - - go func() { - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - end := time.Now().Add(11 * time.Second) - for time.Now().Before(end) { - time.Sleep(100 * time.Millisecond) - - _, err := client.Get(t.LoginUri) - - if err == nil { - break - } - } - close(started) - }() - select { - case <-started: - case <-time.After(10 * time.Second): - pfxlog.Logger().Warn("local relying party server did not start within 10s") - } -} - -// newLocalRpServer creates and configures a local HTTP server for handling OpenID Connect -// authentication flows, including callback processing and token exchange. -func newLocalRpServer(apiHost string, authMethod string) (*localRpServer, error) { - tokenOutChan := make(chan *oidc.Tokens[*oidc.IDTokenClaims], 1) - result := &localRpServer{ - CallbackPath: "/auth/callback", - TokenChan: tokenOutChan, - } - var err error - - result.Listener, err = net.Listen("tcp", ":0") - - if err != nil { - return nil, fmt.Errorf("could not listen on a random port: %w", err) - } - - _, result.Port, _ = net.SplitHostPort(result.Listener.Addr().String()) - - result.LoginUri = "http://127.0.0.1:" + result.Port + "/login" - - key := make([]byte, 32) - _, err = rand.Read(key) - if err != nil { - return nil, fmt.Errorf("could not generate secure cookie key: %w", err) - } - - urlBase := "https://" + apiHost - issuer := urlBase + "/oidc" - clientID := "native" - clientSecret := "" - scopes := []string{"openid", "offline_access"} - result.CallbackUri = "http://127.0.0.1:" + result.Port + result.CallbackPath - - cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) - jar, _ := cookiejar.New(&cookiejar.Options{}) - httpClient := &http.Client{ - - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Proxy: http.ProxyFromEnvironment, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, - CheckRedirect: nil, - Jar: jar, - Timeout: 10 * time.Second, - } - - options := []rp.Option{ - rp.WithHTTPClient(httpClient), - rp.WithPKCE(cookieHandler), - } - - provider, err := rp.NewRelyingPartyOIDC(context.Background(), issuer, clientID, clientSecret, result.CallbackUri, scopes, options...) - - if err != nil { - return nil, fmt.Errorf("could not create rp OIDC: %w", err) - } - - state := func() string { - return uuid.New().String() - } - serverMux := http.NewServeMux() - - authHandler := rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!"), rp.WithURLParam("method", authMethod)) - loginHandler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - authHandler.ServeHTTP(writer, request) - }) - - serverMux.Handle("/login", loginHandler) - - marshalToken := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, relyingParty rp.RelyingParty) { - tokenOutChan <- tokens - _, _ = w.Write([]byte("done!")) - } - - serverMux.Handle(result.CallbackPath, rp.CodeExchangeHandler(marshalToken, provider)) - - result.Server = &http.Server{Handler: serverMux} - - return result, nil -} diff --git a/ziti/edge/posture/totp.go b/edge-apis/totp.go similarity index 56% rename from ziti/edge/posture/totp.go rename to edge-apis/totp.go index 5903c115..42a3bb3a 100644 --- a/ziti/edge/posture/totp.go +++ b/edge-apis/totp.go @@ -1,4 +1,4 @@ -package posture +package edge_apis import ( "errors" @@ -23,34 +23,63 @@ type TotpTokenResult struct { Err error } -// TotpCodeProvider defines the interface for obtaining TOTP codes, typically implemented -// by user interaction handlers that prompt for authenticator app codes. +// TotpCodeProvider supplies TOTP codes for multi-factor authentication. Implementations typically +// prompt users to enter codes from authenticator apps. type TotpCodeProvider interface { + // GetTotpCode returns a channel that delivers the TOTP code result. GetTotpCode() <-chan TotpCodeResult } -// TotpTokenRequestor defines the interface for exchanging TOTP codes with the authentication -// service to obtain session tokens. +// TotpCodeProviderFunc is a function adapter that implements TotpCodeProvider. +type TotpCodeProviderFunc func() <-chan TotpCodeResult + +// NewTotpCodeProviderFromChStringFunc adapts legacy func(chan string) callbacks to the TotpCodeProvider interface. +// This enables backward compatibility while allowing a smoother migration path to the new interface. +func NewTotpCodeProviderFromChStringFunc(stringFunc func(ch chan string)) TotpCodeProvider { + return TotpCodeProviderFunc(func() <-chan TotpCodeResult { + resultCh := make(chan TotpCodeResult) + + go func() { + stringCh := make(chan string) + go stringFunc(stringCh) + + code := <-stringCh + + resultCh <- TotpCodeResult{ + Code: code, + } + }() + + return resultCh + }) +} + +func (f TotpCodeProviderFunc) GetTotpCode() <-chan TotpCodeResult { + return f() +} + +// TotpTokenRequestor exchanges TOTP codes with the authentication service for session tokens. type TotpTokenRequestor interface { + // RequestTotpToken exchanges a TOTP code for a session token. RequestTotpToken(code string) <-chan TotpTokenResult } -// TotpTokenProvider abstracts the complete TOTP authentication flow, handling both code -// acquisition and token exchange. +// TotpTokenProvider coordinates the complete TOTP authentication flow, obtaining codes and exchanging them for tokens. type TotpTokenProvider interface { + // Request initiates a TOTP token request, returning a channel with the result. Request() <-chan TotpTokenResult } -// TotpTokenProviderFunc is a function adapter that implements TotpTokenProvider, allowing -// simple functions to satisfy the interface. +// TotpTokenProviderFunc is a function adapter that implements TotpTokenProvider. type TotpTokenProviderFunc func() <-chan TotpTokenResult +// Request implements TotpTokenProvider. func (f TotpTokenProviderFunc) Request() <-chan TotpTokenResult { return f() } -// SingularTokenRequestor ensures only one TOTP token request is active at a time, -// preventing duplicate authentication attempts when multiple operations require TOTP. +// SingularTokenRequestor serializes TOTP token requests, ensuring only one is active at a time. +// This prevents duplicate authentication attempts when multiple operations require TOTP. type SingularTokenRequestor struct { isRequesting sync.Mutex codeProvider TotpCodeProvider @@ -60,8 +89,8 @@ type SingularTokenRequestor struct { const totpCodeProviderTimeout = 5 * time.Minute const totpTokenRequestorTimeout = 30 * time.Second -// NewSingularTokenRequestor creates a requestor that coordinates TOTP code collection -// and token exchange while preventing concurrent requests. +// NewSingularTokenRequestor creates a token requestor that coordinates code collection and token exchange. +// Only one request can be active at a time; subsequent requests return nil if one is already in progress. func NewSingularTokenRequestor(codeProvider TotpCodeProvider, tokenRequestor TotpTokenRequestor) *SingularTokenRequestor { return &SingularTokenRequestor{ codeProvider: codeProvider, @@ -69,9 +98,8 @@ func NewSingularTokenRequestor(codeProvider TotpCodeProvider, tokenRequestor Tot } } -// Request initiates a TOTP token request if none is in progress, returning nil if a request -// is already active. The returned channel delivers the token result once the code is -// collected and exchanged, or an error if the process times out or fails. +// Request initiates a TOTP token request, returning nil if a request is already in progress. +// The returned channel delivers the token result once the code is collected and exchanged. func (r *SingularTokenRequestor) Request() <-chan TotpTokenResult { if lockObtained := r.isRequesting.TryLock(); !lockObtained { //outstanding request don't do anything diff --git a/edge-apis/pool.go b/edge-apis/transport_pool.go similarity index 99% rename from edge-apis/pool.go rename to edge-apis/transport_pool.go index 422de29f..b7ebedcd 100644 --- a/edge-apis/pool.go +++ b/edge-apis/transport_pool.go @@ -17,6 +17,7 @@ package edge_apis import ( + "errors" "math/rand/v2" "net" "net/url" @@ -27,7 +28,6 @@ import ( "github.com/go-openapi/runtime" "github.com/michaelquigley/pfxlog" cmap "github.com/orcaman/concurrent-map/v2" - errors "github.com/pkg/errors" ) // ApiClientTransport wraps a runtime.ClientTransport with its associated API URL, diff --git a/ziti/client.go b/ziti/client.go index 2a992f09..c333f8fa 100644 --- a/ziti/client.go +++ b/ziti/client.go @@ -24,6 +24,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "errors" "fmt" "strings" "sync/atomic" @@ -49,7 +50,6 @@ import ( apis "github.com/openziti/sdk-golang/edge-apis" "github.com/openziti/sdk-golang/ziti/edge/posture" "github.com/openziti/transport/v2" - "github.com/pkg/errors" ) // CtrlClient is a stateful version of ZitiEdgeClient that simplifies operations @@ -71,20 +71,21 @@ type CtrlClient struct { capabilitiesLoaded atomic.Bool } -func (self *CtrlClient) RequestTotpToken(code string) <-chan posture.TotpTokenResult { - totpTokenResultChan := make(chan posture.TotpTokenResult) +// RequestTotpToken implements TotpTokenRequestor, exchanging a TOTP code for a TOTP token. +func (self *CtrlClient) RequestTotpToken(code string) <-chan apis.TotpTokenResult { + totpTokenResultChan := make(chan apis.TotpTokenResult) go func() { totpToken, err := self.CreateTotpToken(code) if err != nil { - totpTokenResultChan <- posture.TotpTokenResult{ + totpTokenResultChan <- apis.TotpTokenResult{ Err: fmt.Errorf("could not request totp token: %v", err), } return } - totpTokenResultChan <- posture.TotpTokenResult{ + totpTokenResultChan <- apis.TotpTokenResult{ Token: *totpToken.Token, IssuedAt: time.Time(*totpToken.IssuedAt), Err: nil, @@ -94,6 +95,7 @@ func (self *CtrlClient) RequestTotpToken(code string) <-chan posture.TotpTokenRe return totpTokenResultChan } +// CreateTotpToken submits a TOTP code to the controller and returns a TOTP token. func (self *CtrlClient) CreateTotpToken(code string) (*rest_model.TotpToken, error) { params := current_api_session.NewCreateTotpTokenParams() params.MfaValidation = &rest_model.MfaCode{ @@ -208,7 +210,7 @@ func (self *CtrlClient) SendPostureResponseBulk(responses []rest_model.PostureRe if len(responses) == 0 { return nil } - + params := posture_checks.NewCreatePostureResponseBulkParams() params.PostureResponse = responses _, err := self.API.PostureChecks.CreatePostureResponseBulk(params, self.GetCurrentApiSession()) @@ -306,7 +308,8 @@ func (self *CtrlClient) GetIdentity() (identity.Identity, error) { } } - return identity.NewClientTokenIdentityWithPool([]*x509.Certificate{self.ApiSessionCertificate}, self.ApiSessionPrivateKey, self.HttpTransport.TLSClientConfig.RootCAs), nil + rootCaPool := self.TlsAwareTransport.GetTlsClientConfig().RootCAs + return identity.NewClientTokenIdentityWithPool([]*x509.Certificate{self.ApiSessionCertificate}, self.ApiSessionPrivateKey, rootCaPool), nil } // EnsureApiSessionCertificate will create an ApiSessionCertificate if one does not already exist. diff --git a/ziti/contexts.go b/ziti/contexts.go index 31548d4a..283602f0 100644 --- a/ziti/contexts.go +++ b/ziti/contexts.go @@ -26,6 +26,8 @@ package ziti import ( + "errors" + "fmt" "net/http" "net/url" "strconv" @@ -37,7 +39,6 @@ import ( "github.com/openziti/sdk-golang/ziti/edge" "github.com/openziti/sdk-golang/ziti/edge/posture" cmap "github.com/orcaman/concurrent-map/v2" - "github.com/pkg/errors" ) var idCount = 0 @@ -120,7 +121,7 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { apiUrl, err := url.Parse(cfg.ZtAPI) if err != nil { - return nil, errors.Wrapf(err, "could not parse ZtAPI from configuration as URI: %s", apiStr) + return nil, fmt.Errorf("could not parse ZtAPI from configuration as URI: %s: %w", apiStr, err) } apiUrls = append(apiUrls, apiUrl) @@ -129,7 +130,7 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { apiClientConfig := &edge_apis.ApiClientConfig{ ApiUrls: apiUrls, CaPool: cfg.Credentials.GetCaPool(), - TotpCallback: func(codeCh chan string) { + TotpCodeProvider: edge_apis.NewTotpCodeProviderFromChStringFunc(func(codeCh chan string) { provider := rest_model.MfaProvidersZiti authQuery := &rest_model.AuthQueryDetail{ @@ -153,7 +154,7 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { } } - }, + }), Proxy: cfg.CtrlProxy, } @@ -170,7 +171,7 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { newContext.CtrlClt.SetAllowOidcDynamicallyEnabled(true) multiSubmitter := posture.NewMultiSubmitter(newContext.CtrlClt, newContext.CtrlClt, newContext) - totpTokenProvider := posture.NewSingularTokenRequestor(newContext, newContext.CtrlClt) + totpTokenProvider := edge_apis.NewSingularTokenRequestor(newContext, newContext.CtrlClt) newContext.CtrlClt.PostureCache = posture.NewCache(newContext, multiSubmitter, totpTokenProvider, newContext.closeNotify) newContext.CtrlClt.AddOnControllerUpdateListeners(func(urls []*url.URL) { diff --git a/ziti/edge/posture/cache.go b/ziti/edge/posture/cache.go index e569350f..05ea3a0e 100644 --- a/ziti/edge/posture/cache.go +++ b/ziti/edge/posture/cache.go @@ -26,6 +26,7 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/edge-api/rest_model" "github.com/openziti/foundation/v2/stringz" + "github.com/openziti/sdk-golang/edge-apis" "github.com/openziti/sdk-golang/ziti/edge" cmap "github.com/orcaman/concurrent-map/v2" ) @@ -47,7 +48,7 @@ type CacheData struct { MacAddresses []string Os OsInfo Domain string - TotpToken TotpTokenResult + TotpToken edge_apis.TotpTokenResult OnWake WakeEvent OnUnlock UnlockEvent Index uint64 @@ -104,7 +105,7 @@ type Cache struct { MacProvider MacProvider OsProvider OsProvider ProcessProvider ProcessProvider - TotpTokenProvider TotpTokenProvider + TotpTokenProvider edge_apis.TotpTokenProvider lock sync.Mutex totpTimeout int64 @@ -114,7 +115,7 @@ type Cache struct { // NewCache creates a posture cache that monitors device state and coordinates posture response // submission. The cache uses the provided service provider to determine which posture checks // are active, the submitter to send responses, and the token provider for TOTP authentication. -func NewCache(activeServiceProvider ActiveServiceProvider, submitter Submitter, totpTokenProvider TotpTokenProvider, closeNotify <-chan struct{}) *Cache { +func NewCache(activeServiceProvider ActiveServiceProvider, submitter Submitter, totpTokenProvider edge_apis.TotpTokenProvider, closeNotify <-chan struct{}) *Cache { cache := &Cache{ currentData: NewCacheData(), previousData: NewCacheData(), @@ -479,14 +480,14 @@ func (cache *Cache) SimulateUnlock() { } func (cache *Cache) SetTotpToken(token *rest_model.TotpToken) { - cache.currentData.TotpToken = TotpTokenResult{ + cache.currentData.TotpToken = edge_apis.TotpTokenResult{ Token: *token.Token, IssuedAt: time.Time(*token.IssuedAt), } } -func (cache *Cache) SetTotpProviderFunc(f func() <-chan TotpTokenResult) { - p := TotpTokenProviderFunc(f) +func (cache *Cache) SetTotpProviderFunc(f func() <-chan edge_apis.TotpTokenResult) { + p := edge_apis.TotpTokenProviderFunc(f) cache.TotpTokenProvider = &p } diff --git a/ziti/sdkinfo/build_info.go b/ziti/sdkinfo/build_info.go index 1346e248..f5373663 100644 --- a/ziti/sdkinfo/build_info.go +++ b/ziti/sdkinfo/build_info.go @@ -20,5 +20,5 @@ package sdkinfo const ( - Version = "v1.2.10" + Version = "v1.3.0" ) diff --git a/ziti/ziti.go b/ziti/ziti.go index 822fdc69..c22edc0e 100644 --- a/ziti/ziti.go +++ b/ziti/ziti.go @@ -45,7 +45,6 @@ import ( "github.com/openziti/foundation/v2/stringz" apis "github.com/openziti/sdk-golang/edge-apis" "github.com/openziti/sdk-golang/xgress" - "github.com/openziti/sdk-golang/ziti/edge/posture" "github.com/openziti/secretstream/kx" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -282,11 +281,11 @@ func (context *ContextImpl) addActiveBindService(svc *rest_model.ServiceDetail) context.CtrlClt.PostureCache.Evaluate() } -func (context *ContextImpl) GetTotpCode() <-chan posture.TotpCodeResult { - totpCodeResultChan := make(chan posture.TotpCodeResult) +func (context *ContextImpl) GetTotpCode() <-chan apis.TotpCodeResult { + totpCodeResultChan := make(chan apis.TotpCodeResult) if context.ListenerCount(EventMfaTotpCode) == 0 { - totpCodeResultChan <- posture.TotpCodeResult{ + totpCodeResultChan <- apis.TotpCodeResult{ Code: "", Err: errors.New("no MFA TOTP code providers have been added via zitiContext.Events().AddMfaTotpCodeListener()"), } @@ -306,7 +305,7 @@ func (context *ContextImpl) GetTotpCode() <-chan posture.TotpCodeResult { } context.Emit(EventMfaTotpCode, authQuery, MfaCodeResponse(func(code string) error { - totpCodeResultChan <- posture.TotpCodeResult{ + totpCodeResultChan <- apis.TotpCodeResult{ Code: code, } return nil