diff --git a/pkg/agentgateway/jwks/config_map_syncer.go b/pkg/agentgateway/jwks/config_map_syncer.go index 02ddf6f970a..fe989dcf37e 100644 --- a/pkg/agentgateway/jwks/config_map_syncer.go +++ b/pkg/agentgateway/jwks/config_map_syncer.go @@ -11,52 +11,42 @@ import ( "istio.io/istio/pkg/kube/kclient" "istio.io/istio/pkg/kube/krt" - "istio.io/istio/pkg/ptr" corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" "github.com/kgateway-dev/kgateway/v2/pkg/apiclient" "github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/krtutil" ) -const jwksStorePrefix = "jwks-store" -const JwksStoreComponent = "app.kubernetes.io/component" +const configMapKey = "jwks-store" +const jwksStoreComponentLabel = "app.kubernetes.io/component" + +func JwksStoreLabelSelector(storePrefix string) string { + return jwksStoreComponentLabel + "=" + storePrefix +} + +func JwksStoreConfigMapLabel(storePrefix string) map[string]string { + return map[string]string{jwksStoreComponentLabel: storePrefix} +} // configMapSyncer is used for writing/reading jwks' to/from ConfigMaps. type configMapSyncer struct { - deploymentNamespace string storePrefix string - cmAccessor cmAccessor -} - -// this is an abstraction over ConfigMap access to facilitate testing -type cmAccessor interface { - Create(context.Context, *corev1.ConfigMap) error - Update(context.Context, *corev1.ConfigMap) error - Delete(context.Context, string) error - List() []*corev1.ConfigMap - Get(string) *corev1.ConfigMap - WaitForCacheSync(ctx context.Context) bool + deploymentNamespace string + cmCollection krt.Collection[*corev1.ConfigMap] } func NewConfigMapSyncer(client apiclient.Client, storePrefix, deploymentNamespace string, krtOptions krtutil.KrtOptions) *configMapSyncer { cmCollection := krt.NewFilteredInformer[*corev1.ConfigMap](client, kclient.Filter{ ObjectFilter: client.ObjectFilter(), - LabelSelector: JwksStoreComponent + "=" + storePrefix}, + LabelSelector: JwksStoreLabelSelector(storePrefix)}, krtOptions.ToOptions("config_map_syncer/ConfigMaps")...) toret := configMapSyncer{ deploymentNamespace: deploymentNamespace, storePrefix: storePrefix, - cmAccessor: &defaultCmAccessor{ - client: client, - deploymentNamespace: deploymentNamespace, - cmCollection: cmCollection, - }, + cmCollection: cmCollection, } return &toret @@ -65,7 +55,7 @@ func NewConfigMapSyncer(client apiclient.Client, storePrefix, deploymentNamespac // Load jwks from a ConfigMap. // Returns a map of jwks-uri -> jwks (currently one jwks-uri per ConfigMap). func JwksFromConfigMap(cm *corev1.ConfigMap) (map[string]string, error) { - jwksStore := cm.Data[jwksStorePrefix] + jwksStore := cm.Data[configMapKey] jwks := make(map[string]string) err := json.Unmarshal(([]byte)(jwksStore), &jwks) if err != nil { @@ -74,10 +64,6 @@ func JwksFromConfigMap(cm *corev1.ConfigMap) (map[string]string, error) { return jwks, nil } -func (cs *configMapSyncer) WaitForCacheSync(ctx context.Context) bool { - return cs.cmAccessor.WaitForCacheSync(ctx) -} - // Generates ConfigMap name based on jwks uri. Resulting name is a concatenation of "jwks-store-" prefix and an MD5 hash of the jwks uri. // The length of the name is a constant 32 chars (hash) + legth of the prefix. func JwksConfigMapName(storePrefix, jwksUri string) string { @@ -85,59 +71,20 @@ func JwksConfigMapName(storePrefix, jwksUri string) string { return fmt.Sprintf("%s-%s", storePrefix, hex.EncodeToString(hash[:])) } -// Write out jwks' in updates to ConfigMaps, one jwks uri per ConfigMap. updates contains a map of jwks-uri to serialized jwks. -// Each ConfigMap is labelled with "app.kubernetes.io/component":"jwks-store" to support bulk loading of jwks' handled by LoadJwksFromConfigMaps(). -func (cs *configMapSyncer) WriteJwksToConfigMaps(ctx context.Context, updates map[string]string) error { - log := log.FromContext(ctx) - errs := make([]error, 0) - - for uri, jwks := range updates { - switch jwks { - case "": // empty jwks == remove the underlying ConfigMap - err := cs.cmAccessor.Delete(ctx, JwksConfigMapName(cs.storePrefix, uri)) - if client.IgnoreNotFound(err) != nil { - log.Error(err, "error deleting jwks ConfigMap") - errs = append(errs, err) - } - default: - cmData, err := json.Marshal(map[string]string{uri: jwks}) - if err != nil { - log.Error(err, "error serialiazing jwks") - errs = append(errs, err) - continue - } - - existing := cs.cmAccessor.Get(JwksConfigMapName(cs.storePrefix, uri)) - if existing == nil { - cm := cs.newJwksStoreConfigMap(JwksConfigMapName(cs.storePrefix, uri)) - cm.Data[jwksStorePrefix] = string(cmData) - - err := cs.cmAccessor.Create(ctx, cm) - if err != nil { - log.Error(err, "error persisting jwks to ConfigMap") - errs = append(errs, err) - continue - } - } else { - existing.Data[jwksStorePrefix] = string(cmData) - err = cs.cmAccessor.Update(ctx, existing) - if err != nil { - log.Error(err, "error updating jwks ConfigMap") - errs = append(errs, err) - continue - } - } - } +func SetJwksInConfigMap(cm *corev1.ConfigMap, uri, jwks string) error { + b, err := json.Marshal(map[string]string{uri: jwks}) + if err != nil { + return err } - - return errors.Join(errs...) + cm.Data[configMapKey] = string(b) + return nil } // Loads all jwks persisted in ConfigMaps. The result is a map of jwks-uri to serialized jwks. func (cs *configMapSyncer) LoadJwksFromConfigMaps(ctx context.Context) (map[string]string, error) { log := log.FromContext(ctx) - allPersistedJwks := cs.cmAccessor.List() + allPersistedJwks := cs.cmCollection.List() if len(allPersistedJwks) == 0 { return nil, nil @@ -158,52 +105,3 @@ func (cs *configMapSyncer) LoadJwksFromConfigMaps(ctx context.Context) (map[stri return toret, errors.Join(errs...) } - -func (cs *configMapSyncer) newJwksStoreConfigMap(name string) *corev1.ConfigMap { - return &corev1.ConfigMap{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: cs.deploymentNamespace, - Labels: map[string]string{JwksStoreComponent: cs.storePrefix}, - }, - Data: make(map[string]string), - } -} - -type defaultCmAccessor struct { - cmCollection krt.Collection[*corev1.ConfigMap] - client apiclient.Client - deploymentNamespace string -} - -var _ cmAccessor = &defaultCmAccessor{} - -func (cma *defaultCmAccessor) Create(ctx context.Context, newCm *corev1.ConfigMap) error { - _, err := cma.client.Kube().CoreV1().ConfigMaps(cma.deploymentNamespace).Create(ctx, newCm, metav1.CreateOptions{}) - return err -} - -func (cma *defaultCmAccessor) Update(ctx context.Context, existingCm *corev1.ConfigMap) error { - _, err := cma.client.Kube().CoreV1().ConfigMaps(cma.deploymentNamespace).Update(ctx, existingCm, metav1.UpdateOptions{}) - return err -} - -func (cma *defaultCmAccessor) Delete(ctx context.Context, cmName string) error { - return cma.client.Kube().CoreV1().ConfigMaps(cma.deploymentNamespace).Delete(ctx, cmName, metav1.DeleteOptions{}) -} - -func (cma *defaultCmAccessor) Get(cmName string) *corev1.ConfigMap { - cmPtr := cma.cmCollection.GetKey(types.NamespacedName{Namespace: cma.deploymentNamespace, Name: cmName}.String()) - if cmPtr == nil { - return nil - } - return ptr.Flatten(cmPtr) -} - -func (cma *defaultCmAccessor) List() []*corev1.ConfigMap { - return cma.cmCollection.List() -} - -func (cma *defaultCmAccessor) WaitForCacheSync(ctx context.Context) bool { - return cma.client.Core().WaitForCacheSync("config_map_syncer/ConfigMaps", ctx.Done(), cma.cmCollection.HasSynced) -} diff --git a/pkg/agentgateway/jwks/jwks_cache.go b/pkg/agentgateway/jwks/jwks_cache.go index b6044c5ff99..252702da9d2 100644 --- a/pkg/agentgateway/jwks/jwks_cache.go +++ b/pkg/agentgateway/jwks/jwks_cache.go @@ -3,11 +3,13 @@ package jwks import ( "encoding/json" "errors" + "sync" "github.com/go-jose/go-jose/v4" ) type jwksCache struct { + l sync.Mutex jwks map[string]string // jwks uri -> jwks } @@ -29,26 +31,33 @@ func (c *jwksCache) LoadJwksFromStores(storedJwks map[string]string) error { errs = append(errs, err) continue } - newCache.compareAndAddJwks(uri, jwks) + newCache.addJwks(uri, jwks) } + c.l.Lock() c.jwks = newCache.jwks + c.l.Unlock() return errors.Join(errs...) } +func (c *jwksCache) GetJwks(uri string) (string, bool) { + c.l.Lock() + defer c.l.Unlock() + + jwks, ok := c.jwks[uri] + return jwks, ok +} + // Add a jwks to cache. If an exact same jwks is already present in the cache, the result is a nop. // TODO (dmitri-d) check for max size -func (c *jwksCache) compareAndAddJwks(uri string, jwks jose.JSONWebKeySet) (string, error) { +func (c *jwksCache) addJwks(uri string, jwks jose.JSONWebKeySet) (string, error) { serializedJwks, err := json.Marshal(jwks) if err != nil { return "", err } - if j, ok := c.jwks[uri]; ok { - if j == string(serializedJwks) { - return "", nil - } - } + c.l.Lock() + defer c.l.Unlock() c.jwks[uri] = string(serializedJwks) return c.jwks[uri], nil @@ -56,5 +65,7 @@ func (c *jwksCache) compareAndAddJwks(uri string, jwks jose.JSONWebKeySet) (stri // Remove jwks from cache. func (c *jwksCache) deleteJwks(uri string) { + c.l.Lock() delete(c.jwks, uri) + c.l.Unlock() } diff --git a/pkg/agentgateway/jwks/jwks_fetcher.go b/pkg/agentgateway/jwks/jwks_fetcher.go index 106a4573c6c..631894787e9 100644 --- a/pkg/agentgateway/jwks/jwks_fetcher.go +++ b/pkg/agentgateway/jwks/jwks_fetcher.go @@ -5,7 +5,6 @@ import ( "context" "crypto/tls" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -41,10 +40,13 @@ type JwksSource struct { Deleted bool } -type JwksSources []JwksSource +func (js JwksSource) ResourceName() string { + return js.JwksURL +} -func (js JwksSources) ResourceName() string { - return "jwkssources" +func (js JwksSource) Equals(other JwksSource) bool { + return js.JwksURL == other.JwksURL && + js.Ttl == other.Ttl && js.Deleted == other.Deleted } type fetchAt struct { @@ -111,7 +113,6 @@ func (f *JwksFetcher) Run(ctx context.Context) { } func (f *JwksFetcher) maybeFetchJwks(ctx context.Context) { - log := log.FromContext(ctx) updates := make(map[string]string) f.mu.Lock() @@ -128,9 +129,12 @@ func (f *JwksFetcher) maybeFetchJwks(ctx context.Context) { if fetch.keysetSource.Deleted { continue } + + logger.Debug("fetching remote jwks", "jwksUri", fetch.keysetSource.JwksURL) + jwks, err := f.jwksClient.FetchJwks(ctx, fetch.keysetSource.JwksURL) if err != nil { - log.Error(err, "error fetching jwks from ", fetch.keysetSource.JwksURL) + logger.Error("error fetching jwks", "jwksUri", fetch.keysetSource.JwksURL, "error", err) if fetch.retryAttempt < 5 { // backoff by 5s * retry attempt number heap.Push(&f.schedule, fetchAt{at: now.Add(time.Duration(5*(fetch.retryAttempt+1)) * time.Second), keysetSource: fetch.keysetSource, retryAttempt: fetch.retryAttempt + 1}) } else { @@ -140,18 +144,16 @@ func (f *JwksFetcher) maybeFetchJwks(ctx context.Context) { continue } - maybeUpdatedJwks, err := f.cache.compareAndAddJwks(fetch.keysetSource.JwksURL, jwks) + updatedJwks, err := f.cache.addJwks(fetch.keysetSource.JwksURL, jwks) // error serializing jwks, shouldn't happen, retry if err != nil { - log.Error(err, "error adding jwks", "uri", fetch.keysetSource.JwksURL) + logger.Error("error adding jwks", "jwksUri", fetch.keysetSource.JwksURL, "error", err) heap.Push(&f.schedule, fetchAt{at: now.Add(time.Duration(5*(fetch.retryAttempt+1)) * time.Second), keysetSource: fetch.keysetSource, retryAttempt: fetch.retryAttempt + 1}) continue } heap.Push(&f.schedule, fetchAt{at: now.Add(fetch.keysetSource.Ttl), keysetSource: fetch.keysetSource}) - if maybeUpdatedJwks != "" { - updates[fetch.keysetSource.JwksURL] = maybeUpdatedJwks - } + updates[fetch.keysetSource.JwksURL] = updatedJwks } if len(updates) > 0 { @@ -171,81 +173,39 @@ func (f *JwksFetcher) SubscribeToUpdates() chan map[string]string { return subscriber } -func (f *JwksFetcher) UpdateJwksSources(ctx context.Context, updates JwksSources) error { - var errs []error - maybeUpdates := make(map[string]JwksSource) - for _, s := range updates { - maybeUpdates[s.JwksURL] = s +func (f *JwksFetcher) AddOrUpdateKeyset(source JwksSource) error { + if _, err := url.Parse(source.JwksURL); err != nil { + return fmt.Errorf("error parsing jwks url %w", err) } f.mu.Lock() defer f.mu.Unlock() - todelete := make([]string, 0) - for s := range f.keysetSources { - if _, ok := maybeUpdates[s]; !ok { - todelete = append(todelete, s) - } + if existingKeysetSource, ok := f.keysetSources[source.JwksURL]; ok { + delete(f.keysetSources, source.JwksURL) + existingKeysetSource.Deleted = true } - for _, s := range updates { - if _, ok := f.keysetSources[s.JwksURL]; !ok { - if err := f.addKeyset(s.JwksURL, s.Ttl); err != nil { - errs = append(errs, err) - } - continue - } - if *f.keysetSources[s.JwksURL] != s { - if err := f.updateKeyset(s.JwksURL, s.Ttl); err != nil { - errs = append(errs, err) - } - } - } - - removals := make(map[string]string) - for _, jwksUri := range todelete { - if f.removeKeyset(jwksUri) { - removals[jwksUri] = "" - } - } - - if len(removals) > 0 { - for _, s := range f.subscribers { - s <- removals - } - } - - return errors.Join(errs...) -} - -func (f *JwksFetcher) addKeyset(jwksUrl string, ttl time.Duration) error { - if _, err := url.Parse(jwksUrl); err != nil { - return fmt.Errorf("error parsing jwks url %w", err) - } - - keysetSource := &JwksSource{JwksURL: jwksUrl, Ttl: ttl, Deleted: false} - f.keysetSources[jwksUrl] = keysetSource - heap.Push(&f.schedule, fetchAt{at: time.Now(), keysetSource: keysetSource}) // schedule an immediate fetch + addedKeysetSource := source + f.keysetSources[source.JwksURL] = &addedKeysetSource + heap.Push(&f.schedule, fetchAt{at: time.Now(), keysetSource: &addedKeysetSource}) // schedule an immediate fetch return nil } -func (f *JwksFetcher) removeKeyset(jwksUrl string) bool { - if keysetSource, ok := f.keysetSources[jwksUrl]; ok { - delete(f.keysetSources, jwksUrl) - f.cache.deleteJwks(jwksUrl) - keysetSource.Deleted = true - return true - } - return false -} +func (f *JwksFetcher) RemoveKeyset(source JwksSource) { + f.mu.Lock() + defer f.mu.Unlock() + + if beingDeleted, ok := f.keysetSources[source.JwksURL]; ok { + delete(f.keysetSources, source.JwksURL) + f.cache.deleteJwks(source.JwksURL) + beingDeleted.Deleted = true -func (f *JwksFetcher) updateKeyset(jwksUrl string, ttl time.Duration) error { - if keysetSource, ok := f.keysetSources[jwksUrl]; ok { - delete(f.keysetSources, jwksUrl) - keysetSource.Deleted = true + for _, s := range f.subscribers { + s <- map[string]string{source.JwksURL: ""} + } } - return f.addKeyset(jwksUrl, ttl) } func (c *jwksHttpClientImpl) FetchJwks(ctx context.Context, jwksURL string) (jose.JSONWebKeySet, error) { diff --git a/pkg/agentgateway/jwks/jwks_fetcher_test.go b/pkg/agentgateway/jwks/jwks_fetcher_test.go index 6933dd3b65a..e10bd4f41e4 100644 --- a/pkg/agentgateway/jwks/jwks_fetcher_test.go +++ b/pkg/agentgateway/jwks/jwks_fetcher_test.go @@ -14,10 +14,11 @@ import ( ) func TestAddKeysetToFetcher(t *testing.T) { + expectedKeysetSource := JwksSource{JwksURL: "https://test/jwks", Ttl: 5 * time.Minute, Deleted: false} + f := NewJwksFetcher(NewJwksCache()) - f.addKeyset("https://test/jwks", 5*time.Minute) + f.AddOrUpdateKeyset(expectedKeysetSource) - expectedKeysetSource := JwksSource{JwksURL: "https://test/jwks", Ttl: 5 * time.Minute, Deleted: false} fetch := f.schedule.Peek() assert.NotNil(t, fetch) assert.Equal(t, *fetch.keysetSource, expectedKeysetSource) @@ -28,12 +29,12 @@ func TestAddKeysetToFetcher(t *testing.T) { func TestRemoveKeysetFromFetcher(t *testing.T) { f := NewJwksFetcher(NewJwksCache()) - f.addKeyset("https://test/jwks", 5*time.Minute) + f.AddOrUpdateKeyset(JwksSource{JwksURL: "https://test/jwks", Ttl: 5 * time.Minute}) keysetSource := f.keysetSources["https://test/jwks"] assert.NotNil(t, keysetSource) f.cache.jwks["https://test/jwks"] = "jwks" - f.removeKeyset("https://test/jwks") + f.RemoveKeyset(JwksSource{JwksURL: "https://test/jwks"}) assert.NotContains(t, f.keysetSources, "https://test/jwks") assert.NotContains(t, f.cache.jwks, "https://test/jwks") assert.True(t, keysetSource.Deleted) @@ -64,7 +65,7 @@ func TestSuccessfulJwksFetch(t *testing.T) { jwksClient := mocks.NewMockJwksHttpClient(ctrl) f.jwksClient = jwksClient - f.addKeyset("https://test/jwks", 5*time.Minute) + f.AddOrUpdateKeyset(JwksSource{JwksURL: "https://test/jwks", Ttl: 5 * time.Minute}) updates := f.SubscribeToUpdates() expectedJwks := jose.JSONWebKeySet{} @@ -85,7 +86,7 @@ func TestSuccessfulJwksFetch(t *testing.T) { default: assert.Fail(c, "no updates") } - }, 1000*time.Second, 100*time.Millisecond) + }, 2*time.Second, 100*time.Millisecond) f.mu.Lock() defer f.mu.Unlock() @@ -95,8 +96,9 @@ func TestSuccessfulJwksFetch(t *testing.T) { assert.WithinDuration(t, time.Now().Add(5*time.Minute), fetch.at, 3*time.Second) } -// jwks were fetched, but there were no updates to keysets -func TestSuccessfulJwksFetchButNoUpdates(t *testing.T) { +// jwks were fetched, but there were no changes to keysets +// we still notify subscribers that a fetch happened (we always sync jwks to ConfigMaps) +func TestSuccessfulJwksFetchButNoChanges(t *testing.T) { ctx := t.Context() f := NewJwksFetcher(NewJwksCache()) @@ -104,7 +106,7 @@ func TestSuccessfulJwksFetchButNoUpdates(t *testing.T) { jwksClient := mocks.NewMockJwksHttpClient(ctrl) f.jwksClient = jwksClient - f.addKeyset("https://test/jwks", 5*time.Minute) + f.AddOrUpdateKeyset(JwksSource{JwksURL: "https://test/jwks", Ttl: 5 * time.Minute}) f.cache.jwks["https://test/jwks"] = jwks updates := f.SubscribeToUpdates() @@ -117,12 +119,14 @@ func TestSuccessfulJwksFetchButNoUpdates(t *testing.T) { Return(existingJwks, nil) go f.maybeFetchJwks(ctx) - assert.Never(t, func() bool { + assert.EventuallyWithT(t, func(c *assert.CollectT) { select { - case <-updates: - return true + case actual := <-updates: + cache := NewJwksCache() + assert.NoError(c, cache.LoadJwksFromStores(actual)) + assert.Equal(c, jwks, cache.jwks["https://test/jwks"]) default: - return false + assert.Fail(c, "no updates") } }, 2*time.Second, 100*time.Millisecond) @@ -142,7 +146,7 @@ func TestFetchJwksWithError(t *testing.T) { jwksClient := mocks.NewMockJwksHttpClient(ctrl) f.jwksClient = jwksClient - f.addKeyset("https://test/jwks", 5*time.Minute) + f.AddOrUpdateKeyset(JwksSource{JwksURL: "https://test/jwks", Ttl: 5 * time.Minute}) updates := f.SubscribeToUpdates() jwksClient.EXPECT(). diff --git a/pkg/agentgateway/jwks/jwks_store.go b/pkg/agentgateway/jwks/jwks_store.go index b602f90ef15..29dc1235ef1 100644 --- a/pkg/agentgateway/jwks/jwks_store.go +++ b/pkg/agentgateway/jwks/jwks_store.go @@ -2,15 +2,18 @@ package jwks import ( "context" + "sync" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/log" "github.com/kgateway-dev/kgateway/v2/pkg/apiclient" "github.com/kgateway-dev/kgateway/v2/pkg/common" + "github.com/kgateway-dev/kgateway/v2/pkg/logging" "github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/collections" ) +var logger = logging.New("jwks_store") + const DefaultJwksStorePrefix = "jwks-store" const RunnableName = "jwks-store" @@ -22,25 +25,27 @@ var JwksConfigMapNamespacedName = func(jwksUri string) *types.NamespacedName { // in ConfigMaps, a jwks per ConfigMap. The ConfigMaps are used to re-create internal // JwksStore state on startup and by traffic-plugins as source of remote jwks. type JwksStore struct { + storePrefix string jwksCache *jwksCache jwksFetcher *JwksFetcher configMapSyncer *configMapSyncer - updates <-chan map[string]string - latestJwks <-chan JwksSources + jwksChanges <-chan JwksSource + cmNameToJwks map[string]string + l sync.Mutex } -func BuildJwksStore(ctx context.Context, cli apiclient.Client, commonCols *collections.CommonCollections, jwksQueue <-chan JwksSources, storePrefix, deploymentNamespace string) *JwksStore { - log := log.Log.WithName("jwks store setup") - log.Info("creating jwks store", "prefix", storePrefix) +func BuildJwksStore(ctx context.Context, cli apiclient.Client, commonCols *collections.CommonCollections, jwksChanges <-chan JwksSource, storePrefix, deploymentNamespace string) *JwksStore { + logger.Info("creating jwks store") jwksCache := NewJwksCache() jwksStore := &JwksStore{ + storePrefix: storePrefix, jwksCache: jwksCache, - latestJwks: jwksQueue, + jwksChanges: jwksChanges, jwksFetcher: NewJwksFetcher(jwksCache), configMapSyncer: NewConfigMapSyncer(cli, storePrefix, deploymentNamespace, commonCols.KrtOpts), + cmNameToJwks: make(map[string]string), } - jwksStore.updates = jwksStore.jwksFetcher.SubscribeToUpdates() BuildJwksConfigMapNamespacedNameFunc(storePrefix, deploymentNamespace) return jwksStore } @@ -52,20 +57,17 @@ func BuildJwksConfigMapNamespacedNameFunc(storePrefix, deploymentNamespace strin } func (s *JwksStore) Start(ctx context.Context) error { - log := log.FromContext(ctx) - - s.configMapSyncer.WaitForCacheSync(ctx) + logger.Info("starting jwks store") storedJwks, err := s.configMapSyncer.LoadJwksFromConfigMaps(ctx) if err != nil { - log.Error(err, "error loading jwks store from a ConfigMap") + logger.Error("error loading jwks store from a ConfigMap", "error", err) } err = s.jwksCache.LoadJwksFromStores(storedJwks) if err != nil { - log.Error(err, "error loading jwks store state") + logger.Error("error loading jwks store state", "error", err) } - go s.syncToConfigMaps(ctx) go s.jwksFetcher.Run(ctx) go s.updateJwksSources(ctx) @@ -73,30 +75,49 @@ func (s *JwksStore) Start(ctx context.Context) error { return nil } -func (s *JwksStore) updateJwksSources(ctx context.Context) { - for { - select { - case jwks := <-s.latestJwks: - s.jwksFetcher.UpdateJwksSources(ctx, jwks) - case <-ctx.Done(): - return - } - } +func (s *JwksStore) SubscribeToUpdates() chan map[string]string { + return s.jwksFetcher.SubscribeToUpdates() } -func (s *JwksStore) syncToConfigMaps(ctx context.Context) { - log := log.FromContext(ctx) +func (s *JwksStore) JwksByConfigMapName(cmName string) (string, string, bool) { + s.l.Lock() + defer s.l.Unlock() + uri, ok := s.cmNameToJwks[cmName] + if !ok { + return "", "", false + } + + jwks, ok := s.jwksCache.GetJwks(uri) + if !ok { + return "", "", false + } + + return uri, jwks, true +} + +func (s *JwksStore) updateJwksSources(ctx context.Context) { for { select { + case jwksUpdate := <-s.jwksChanges: + if jwksUpdate.Deleted { + logger.Debug("deleting keyset", "jwksUri", jwksUpdate.JwksURL, "ConfigMap", JwksConfigMapName(s.storePrefix, jwksUpdate.JwksURL)) + s.jwksFetcher.RemoveKeyset(jwksUpdate) + s.l.Lock() + delete(s.cmNameToJwks, JwksConfigMapName(s.storePrefix, jwksUpdate.JwksURL)) + s.l.Unlock() + } else { + logger.Debug("updating keyset", "jwksUri", jwksUpdate.JwksURL, "ConfigMap", JwksConfigMapName(s.storePrefix, jwksUpdate.JwksURL)) + err := s.jwksFetcher.AddOrUpdateKeyset(jwksUpdate) + if err != nil { + logger.Error("error adding/updating a jwks keyset", "error", err, "uri", jwksUpdate.JwksURL) + } + s.l.Lock() + s.cmNameToJwks[JwksConfigMapName(s.storePrefix, jwksUpdate.JwksURL)] = jwksUpdate.JwksURL + s.l.Unlock() + } case <-ctx.Done(): return - case update := <-s.updates: - log.Info("received an update") - err := s.configMapSyncer.WriteJwksToConfigMaps(ctx, update) - if err != nil { - log.Error(err, "error(s) syncing jwks cache to ConfigMaps") - } } } } diff --git a/pkg/agentgateway/jwksstore/cm_controller.go b/pkg/agentgateway/jwksstore/cm_controller.go new file mode 100644 index 00000000000..554b0ac2816 --- /dev/null +++ b/pkg/agentgateway/jwksstore/cm_controller.go @@ -0,0 +1,156 @@ +package agentjwksstore + +import ( + "context" + "math" + "time" + + "golang.org/x/time/rate" + "istio.io/istio/pkg/kube/controllers" + "istio.io/istio/pkg/kube/kclient" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kgateway-dev/kgateway/v2/pkg/agentgateway/jwks" + "github.com/kgateway-dev/kgateway/v2/pkg/apiclient" + "github.com/kgateway-dev/kgateway/v2/pkg/logging" +) + +var cmLogger = logging.New("jwks_store_config_map_controller") + +type JwksStoreConfigMapsController struct { + apiClient apiclient.Client + cmClient kclient.Client[*corev1.ConfigMap] + eventQueue controllers.Queue + jwksUpdates chan map[string]string + jwksStore *jwks.JwksStore + deploymentNamespace string + storePrefix string + waitForSync []cache.InformerSynced +} + +var ( + rateLimiter = workqueue.NewTypedMaxOfRateLimiter( + workqueue.NewTypedItemExponentialFailureRateLimiter[any](500*time.Millisecond, 10*time.Second), + // 10 qps, 100 bucket size. This is only for retry speed and its only the overall factor (not per item) + &workqueue.TypedBucketRateLimiter[any]{Limiter: rate.NewLimiter(rate.Limit(10), 100)}, + ) +) + +func NewJWKSStoreConfigMapsController(apiClient apiclient.Client, storePrefix, deploymentNamespace string, jwksStore *jwks.JwksStore) *JwksStoreConfigMapsController { + cmLogger.Info("creating jwks store ConfigMap controller") + return &JwksStoreConfigMapsController{ + apiClient: apiClient, + deploymentNamespace: deploymentNamespace, + storePrefix: storePrefix, + jwksStore: jwksStore, + } +} + +func (jcm *JwksStoreConfigMapsController) Init(ctx context.Context) { + jcm.cmClient = kclient.NewFiltered[*corev1.ConfigMap](jcm.apiClient, + kclient.Filter{ + ObjectFilter: jcm.apiClient.ObjectFilter(), + Namespace: jcm.deploymentNamespace, + LabelSelector: jwks.JwksStoreLabelSelector(jcm.storePrefix)}) + + jcm.waitForSync = []cache.InformerSynced{ + jcm.cmClient.HasSynced, + } + + jcm.jwksUpdates = jcm.jwksStore.SubscribeToUpdates() + jcm.eventQueue = controllers.NewQueue("JwksStoreConfigMapController", controllers.WithReconciler(jcm.Reconcile), controllers.WithMaxAttempts(math.MaxInt), controllers.WithRateLimiter(rateLimiter)) +} + +func (jcm *JwksStoreConfigMapsController) Start(ctx context.Context) error { + cmLogger.Info("waiting for cache to sync") + jcm.apiClient.Core().WaitForCacheSync( + "kube jwks store ConfigMap syncer", + ctx.Done(), + jcm.waitForSync..., + ) + + cmLogger.Info("starting jwks store ConfigMap controller") + jcm.cmClient.AddEventHandler( + controllers.FromEventHandler( + func(o controllers.Event) { + jcm.eventQueue.AddObject(o.Latest()) + })) + + go func() { + for { + select { + case u := <-jcm.jwksUpdates: + for uri := range u { + jcm.eventQueue.AddObject(jcm.newJwksStoreConfigMap(jwks.JwksConfigMapName(jcm.storePrefix, uri))) + } + case <-ctx.Done(): + return + } + } + }() + go jcm.eventQueue.Run(ctx.Done()) + + <-ctx.Done() + return nil +} + +func (jcm *JwksStoreConfigMapsController) Reconcile(req types.NamespacedName) error { + cmLogger.Debug("syncing jwks store to ConfigMap(s)") + ctx := context.Background() + + uri, storedJwks, ok := jcm.jwksStore.JwksByConfigMapName(req.Name) + if !ok { + cmLogger.Debug("deleting ConfigMap", "name", req.Name) + return client.IgnoreNotFound(jcm.apiClient.Kube().CoreV1().ConfigMaps(req.Namespace).Delete(ctx, req.Name, metav1.DeleteOptions{})) + } + + existingCm := jcm.cmClient.Get(req.Name, req.Namespace) + if existingCm == nil { + cmLogger.Debug("creating ConfigMap", "name", req.Name) + newCm := jcm.newJwksStoreConfigMap(jwks.JwksConfigMapName(jcm.storePrefix, uri)) + if err := jwks.SetJwksInConfigMap(newCm, uri, storedJwks); err != nil { + cmLogger.Error("error updating ConfigMap", "error", err) + return err // should we skip retries as json serialization error won't go away? + } + + _, err := jcm.apiClient.Kube().CoreV1().ConfigMaps(req.Namespace).Create(ctx, newCm, metav1.CreateOptions{}) + if err != nil { + cmLogger.Error("error creating ConfigMap", "error", err) + return err + } + } else { + cmLogger.Debug("updating ConfigMap", "name", req.Name) + if err := jwks.SetJwksInConfigMap(existingCm, uri, storedJwks); err != nil { + cmLogger.Error("error updating ConfigMap", "error", err) + return err // should we skip retries as json serialization error won't go away? + } + _, err := jcm.apiClient.Kube().CoreV1().ConfigMaps(req.Namespace).Update(ctx, existingCm, metav1.UpdateOptions{}) + if err != nil { + cmLogger.Error("error updating jwks ConfigMap", "error", err) + return err + } + } + + return nil +} + +// runs on the leader only +func (jcm *JwksStoreConfigMapsController) NeedLeaderElection() bool { + return true +} + +func (jcm *JwksStoreConfigMapsController) newJwksStoreConfigMap(name string) *corev1.ConfigMap { + return &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: jcm.deploymentNamespace, + Labels: jwks.JwksStoreConfigMapLabel(jcm.storePrefix), + }, + Data: make(map[string]string), + } +} diff --git a/pkg/agentgateway/jwksstore/jwks_store_controller.go b/pkg/agentgateway/jwksstore/jwks_store_controller.go deleted file mode 100644 index c2aab903dd7..00000000000 --- a/pkg/agentgateway/jwksstore/jwks_store_controller.go +++ /dev/null @@ -1,134 +0,0 @@ -package agentjwksstore - -import ( - "context" - "time" - - "istio.io/istio/pkg/kube/kclient" - "istio.io/istio/pkg/kube/krt" - "k8s.io/client-go/tools/cache" - - "github.com/kgateway-dev/kgateway/v2/api/v1alpha1/agentgateway" - "github.com/kgateway-dev/kgateway/v2/pkg/agentgateway/jwks" - "github.com/kgateway-dev/kgateway/v2/pkg/agentgateway/plugins" - "github.com/kgateway-dev/kgateway/v2/pkg/apiclient" - "github.com/kgateway-dev/kgateway/v2/pkg/kgateway/wellknown" - "github.com/kgateway-dev/kgateway/v2/pkg/logging" -) - -const JwksStoreConfigMapName = "jwks-store" - -type JwksStoreController struct { - agw *plugins.AgwCollections - apiClient apiclient.Client - jwks krt.Singleton[jwks.JwksSources] - jwksQueue chan jwks.JwksSources - waitForSync []cache.InformerSynced -} - -var logger = logging.New("jwks_store") - -func NewJWKSStoreController(apiClient apiclient.Client, agw *plugins.AgwCollections) *JwksStoreController { - return &JwksStoreController{ - agw: agw, - apiClient: apiClient, - jwksQueue: make(chan jwks.JwksSources), - } -} - -func (j *JwksStoreController) Init(ctx context.Context) { - backendCol := krt.WrapClient(kclient.NewFilteredDelayed[*agentgateway.AgentgatewayBackend]( - j.apiClient, - wellknown.AgentgatewayBackendGVR, - kclient.Filter{ObjectFilter: j.agw.Client.ObjectFilter()}, - ), j.agw.KrtOpts.ToOptions("AgentgatewayBackend")...) - policyCol := krt.WrapClient(kclient.NewFilteredDelayed[*agentgateway.AgentgatewayPolicy]( - j.apiClient, - wellknown.AgentgatewayPolicyGVR, - kclient.Filter{ObjectFilter: j.agw.Client.ObjectFilter()}, - ), j.agw.KrtOpts.ToOptions("AgentgatewayPolicy")...) - j.jwks = krt.NewSingleton(func(kctx krt.HandlerContext) *jwks.JwksSources { - pols := krt.Fetch(kctx, policyCol) - toret := make(jwks.JwksSources, 0, len(pols)) - for _, p := range pols { - // enqueue Traffic JWT providers (if present) - if p.Spec.Traffic != nil && p.Spec.Traffic.JWTAuthentication != nil { - for _, provider := range p.Spec.Traffic.JWTAuthentication.Providers { - if provider.JWKS.Remote == nil { - continue - } - toret = append(toret, jwks.JwksSource{ - JwksURL: provider.JWKS.Remote.JwksUri, - Ttl: provider.JWKS.Remote.CacheDuration.Duration, - }) - } - } - - // enqueue Backend MCP authentication JWKS (if present) - if p.Spec.Backend != nil && p.Spec.Backend.MCP != nil && p.Spec.Backend.MCP.Authentication != nil { - ttl := time.Duration(0) - if p.Spec.Backend.MCP.Authentication.JWKS.CacheDuration != nil { - ttl = p.Spec.Backend.MCP.Authentication.JWKS.CacheDuration.Duration - } - if p.Spec.Backend.MCP.Authentication.JWKS.JwksUri != "" { - toret = append(toret, jwks.JwksSource{ - JwksURL: p.Spec.Backend.MCP.Authentication.JWKS.JwksUri, - Ttl: ttl, - }) - } - } - } - - backends := krt.Fetch(kctx, backendCol) - for _, b := range backends { - if b.Spec.MCP == nil { - // ignore non-mcp backend types - continue - } - if b.Spec.Policies != nil && b.Spec.Policies.MCP != nil && b.Spec.Policies.MCP.Authentication != nil { - ttl := time.Duration(0) - if b.Spec.Policies.MCP.Authentication.JWKS.CacheDuration != nil { - ttl = b.Spec.Policies.MCP.Authentication.JWKS.CacheDuration.Duration - } - if b.Spec.Policies.MCP.Authentication.JWKS.JwksUri != "" { - toret = append(toret, jwks.JwksSource{ - JwksURL: b.Spec.Policies.MCP.Authentication.JWKS.JwksUri, - Ttl: ttl, - }) - } - } - } - - return &toret - }, j.agw.KrtOpts.ToOptions("JwksSources")...) - - j.waitForSync = []cache.InformerSynced{ - policyCol.HasSynced, - backendCol.HasSynced, - } -} - -func (j *JwksStoreController) Start(ctx context.Context) error { - logger.Info("waiting for cache to sync") - j.apiClient.Core().WaitForCacheSync( - "kube AgentgatewayPolicy syncer", - ctx.Done(), - j.waitForSync..., - ) - - j.jwks.Register(func(o krt.Event[jwks.JwksSources]) { - j.jwksQueue <- o.Latest() - }) - - <-ctx.Done() - return nil -} - -// runs on the leader only -func (j *JwksStoreController) NeedLeaderElection() bool { - return true -} - -func (j *JwksStoreController) JwksQueue() <-chan jwks.JwksSources { - return j.jwksQueue -} diff --git a/pkg/agentgateway/jwksstore/policy_controller.go b/pkg/agentgateway/jwksstore/policy_controller.go new file mode 100644 index 00000000000..578257999c0 --- /dev/null +++ b/pkg/agentgateway/jwksstore/policy_controller.go @@ -0,0 +1,140 @@ +package agentjwksstore + +import ( + "context" + "time" + + "istio.io/istio/pkg/kube/controllers" + "istio.io/istio/pkg/kube/kclient" + "istio.io/istio/pkg/kube/krt" + "k8s.io/client-go/tools/cache" + + "github.com/kgateway-dev/kgateway/v2/api/v1alpha1/agentgateway" + "github.com/kgateway-dev/kgateway/v2/pkg/agentgateway/jwks" + "github.com/kgateway-dev/kgateway/v2/pkg/agentgateway/plugins" + "github.com/kgateway-dev/kgateway/v2/pkg/apiclient" + "github.com/kgateway-dev/kgateway/v2/pkg/kgateway/wellknown" + "github.com/kgateway-dev/kgateway/v2/pkg/logging" +) + +type JwksStorePolicyController struct { + agw *plugins.AgwCollections + apiClient apiclient.Client + jwks krt.Collection[jwks.JwksSource] + jwksChanges chan jwks.JwksSource + waitForSync []cache.InformerSynced +} + +var polLogger = logging.New("jwks_store_policy_controller") + +func NewJWKSStorePolicyController(apiClient apiclient.Client, agw *plugins.AgwCollections) *JwksStorePolicyController { + polLogger.Info("creating jwks store policy controller") + return &JwksStorePolicyController{ + agw: agw, + apiClient: apiClient, + jwksChanges: make(chan jwks.JwksSource), + } +} + +func (j *JwksStorePolicyController) Init(ctx context.Context) { + backendCol := krt.WrapClient(kclient.NewFilteredDelayed[*agentgateway.AgentgatewayBackend]( + j.apiClient, + wellknown.AgentgatewayBackendGVR, + kclient.Filter{ObjectFilter: j.agw.Client.ObjectFilter()}, + ), j.agw.KrtOpts.ToOptions("AgentgatewayBackend")...) + policyCol := krt.WrapClient(kclient.NewFilteredDelayed[*agentgateway.AgentgatewayPolicy]( + j.apiClient, + wellknown.AgentgatewayPolicyGVR, + kclient.Filter{ObjectFilter: j.agw.Client.ObjectFilter()}, + ), j.agw.KrtOpts.ToOptions("AgentgatewayPolicy")...) + j.jwks = krt.NewManyCollection(policyCol, func(kctx krt.HandlerContext, p *agentgateway.AgentgatewayPolicy) []jwks.JwksSource { + toret := make([]jwks.JwksSource, 0) + + // enqueue Traffic JWT providers (if present) + if p.Spec.Traffic != nil && p.Spec.Traffic.JWTAuthentication != nil { + for _, provider := range p.Spec.Traffic.JWTAuthentication.Providers { + if provider.JWKS.Remote == nil { + continue + } + toret = append(toret, jwks.JwksSource{ + JwksURL: provider.JWKS.Remote.JwksUri, + Ttl: provider.JWKS.Remote.CacheDuration.Duration, + }) + } + } + + // enqueue Backend MCP authentication JWKS (if present) + if p.Spec.Backend != nil && p.Spec.Backend.MCP != nil && p.Spec.Backend.MCP.Authentication != nil { + ttl := 5 * time.Minute + if p.Spec.Backend.MCP.Authentication.JWKS.CacheDuration != nil { + ttl = p.Spec.Backend.MCP.Authentication.JWKS.CacheDuration.Duration + } + if p.Spec.Backend.MCP.Authentication.JWKS.JwksUri != "" { + toret = append(toret, jwks.JwksSource{ + JwksURL: p.Spec.Backend.MCP.Authentication.JWKS.JwksUri, + Ttl: ttl, + }) + } + } + + backends := krt.Fetch(kctx, backendCol) + for _, b := range backends { + if b.Spec.MCP == nil { + // ignore non-mcp backend types + continue + } + if b.Spec.Policies != nil && b.Spec.Policies.MCP != nil && b.Spec.Policies.MCP.Authentication != nil { + ttl := 5 * time.Minute + if b.Spec.Policies.MCP.Authentication.JWKS.CacheDuration != nil { + ttl = b.Spec.Policies.MCP.Authentication.JWKS.CacheDuration.Duration + } + if b.Spec.Policies.MCP.Authentication.JWKS.JwksUri != "" { + toret = append(toret, jwks.JwksSource{ + JwksURL: b.Spec.Policies.MCP.Authentication.JWKS.JwksUri, + Ttl: ttl, + }) + } + } + } + + return toret + }, j.agw.KrtOpts.ToOptions("JwksSources")...) + + j.waitForSync = []cache.InformerSynced{ + policyCol.HasSynced, + backendCol.HasSynced, + } +} + +func (j *JwksStorePolicyController) Start(ctx context.Context) error { + polLogger.Info("waiting for cache to sync") + j.apiClient.Core().WaitForCacheSync( + "kube AgentgatewayPolicy syncer", + ctx.Done(), + j.waitForSync..., + ) + + polLogger.Info("starting jwks store policy controller") + j.jwks.Register(func(o krt.Event[jwks.JwksSource]) { + switch o.Event { + case controllers.EventAdd, controllers.EventUpdate: + j.jwksChanges <- *o.New + case controllers.EventDelete: + deleted := *o.Old + deleted.Deleted = true + j.jwksChanges <- deleted + } + }) + + <-ctx.Done() + return nil +} + +// runs on the leader only +func (j *JwksStorePolicyController) NeedLeaderElection() bool { + return true +} + +func (j *JwksStorePolicyController) JwksChanges() chan jwks.JwksSource { + return j.jwksChanges +} diff --git a/pkg/kgateway/setup/setup.go b/pkg/kgateway/setup/setup.go index 09a19c4eefa..191fdce7010 100644 --- a/pkg/kgateway/setup/setup.go +++ b/pkg/kgateway/setup/setup.go @@ -557,15 +557,22 @@ func SetupLogging(levelStr string) { } func buildJwksStore(ctx context.Context, mgr manager.Manager, apiClient apiclient.Client, commonCollections *collections.CommonCollections, agwCollections *agwplugins.AgwCollections) error { - jwksStoreCtrl := agentjwksstore.NewJWKSStoreController(apiClient, agwCollections) - if err := mgr.Add(jwksStoreCtrl); err != nil { + jwksStorePolicyCtrl := agentjwksstore.NewJWKSStorePolicyController(apiClient, agwCollections) + if err := mgr.Add(jwksStorePolicyCtrl); err != nil { return err } - jwksStoreCtrl.Init(ctx) - jwksStore := jwks.BuildJwksStore(ctx, apiClient, commonCollections, jwksStoreCtrl.JwksQueue(), jwks.DefaultJwksStorePrefix, namespaces.GetPodNamespace()) + jwksStorePolicyCtrl.Init(ctx) + + jwksStore := jwks.BuildJwksStore(ctx, apiClient, commonCollections, jwksStorePolicyCtrl.JwksChanges(), jwks.DefaultJwksStorePrefix, namespaces.GetPodNamespace()) if err := mgr.Add(jwksStore); err != nil { return err } + jwksStoreCMCtrl := agentjwksstore.NewJWKSStoreConfigMapsController(apiClient, jwks.DefaultJwksStorePrefix, namespaces.GetPodNamespace(), jwksStore) + jwksStoreCMCtrl.Init(ctx) + if err := mgr.Add(jwksStoreCMCtrl); err != nil { + return err + } + return nil }