From f0fb6d797594d7e1a640988712b6c32fd0c143e5 Mon Sep 17 00:00:00 2001 From: Vadim Berezniker Date: Mon, 20 Apr 2026 15:59:25 -0700 Subject: [PATCH] Add locality support to the xds control plane. Kubernetes services may be annotated with `xds.lmwn.com/locality-preference` to indicate how traffic to that service should be load balanced with the following options: zone: clients are matched to backends using only the zone (`topology.kubernetes.io/zone`) label from the parent node. Zone local endpoints are returned with a higher priority. If no zone-local endpoints are available the client will fallback to endpoints in other zones. sub_zone: clients are matched to backends using both sub_zone (`topology.kubernetes.io/rack` by default) and zone (`topology.kubernetes.io/zone`). Endpoints are returned with 3 priority levels, rack-local endpoints with the highest priority, zone local endpoints with the next highest priority, and then all other endpoints. Services that are not annotated load balance across all backends. endpoints.go is the service that processes incoming k8s endpoint information and translates it into xDS resources for use by clients. Prior to this change, it would do a simple conversion and hand it off to the go-control-plane code to do the rest. With the introduction of locality, different clients may receive different assignments. To that end, endpointsCache (in endpointscache.go) becomes the new "entry point" for clients subscribing to endpoint information. endpoints.go groups endpoints by locality (depending on how the service is configured) and stores the data in the endpointsCache. endpointsCache then takes the stored data and generates the endpoints assignments tailored to the locality information requested by the client. nodeLocalityStore (nodes.go) subscribes to node information from k8s and stores zone and rack information for each node. If node locality metadata changes, we re-process the endpoints via endpoints.go in case the endpoint locality has been affected. --- internal/di/k8sxds.go | 4 +- internal/di/wire.go | 5 +- internal/di/wire_gen.go | 9 +- main.go | 5 +- snapshot/endpoints.go | 99 ++++++++++----- snapshot/endpointscache.go | 213 ++++++++++++++++++++++++++++++++ snapshot/endpointscache_test.go | 150 ++++++++++++++++++++++ snapshot/locality.go | 46 +++++++ snapshot/nodes.go | 112 +++++++++++++++++ snapshot/services.go | 6 + snapshot/snapshotter.go | 58 ++++++++- test/integration_test.go | 202 +++++++++++++++++++++++++++++- test/xds_test.go | 2 +- 13 files changed, 863 insertions(+), 48 deletions(-) create mode 100644 snapshot/endpointscache.go create mode 100644 snapshot/endpointscache_test.go create mode 100644 snapshot/locality.go create mode 100644 snapshot/nodes.go diff --git a/internal/di/k8sxds.go b/internal/di/k8sxds.go index 3241000..e7f6eb5 100644 --- a/internal/di/k8sxds.go +++ b/internal/di/k8sxds.go @@ -22,9 +22,9 @@ var K8sXdsSet = wire.NewSet( ProvideLRSServer, ) -func ProvideSnapshotter(ctx context.Context, k8sClient kubernetes.Interface) (*snapshot.Snapshotter, func()) { +func ProvideSnapshotter(ctx context.Context, k8sClient kubernetes.Interface, subZoneLabel snapshot.SubZoneLabel) (*snapshot.Snapshotter, func()) { stopCtx, stop := context.WithCancel(ctx) - snapshotter := snapshot.New(k8sClient) + snapshotter := snapshot.New(k8sClient, subZoneLabel) go func() { err := snapshotter.Start(stopCtx) diff --git a/internal/di/wire.go b/internal/di/wire.go index 257b8ab..3fa1ec9 100644 --- a/internal/di/wire.go +++ b/internal/di/wire.go @@ -6,6 +6,7 @@ import ( "context" "github.com/google/wire" "github.com/wongnai/xds/debug" + "github.com/wongnai/xds/snapshot" "google.golang.org/grpc" "k8s.io/client-go/kubernetes" ) @@ -25,7 +26,7 @@ type DevServer struct { GrpcServer *grpc.Server } -func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds) (Servers, func(), error) { +func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (Servers, func(), error) { wire.Build( KubernetesSet, GrpcSet, @@ -38,7 +39,7 @@ func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSec return Servers{}, nil, nil } -func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds) (TestServer, func(), error) { +func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (TestServer, func(), error) { wire.Build( GrpcSet, K8sXdsSet, diff --git a/internal/di/wire_gen.go b/internal/di/wire_gen.go index 7ef58be..2700167 100644 --- a/internal/di/wire_gen.go +++ b/internal/di/wire_gen.go @@ -9,13 +9,14 @@ package di import ( "context" "github.com/wongnai/xds/debug" + "github.com/wongnai/xds/snapshot" "google.golang.org/grpc" "k8s.io/client-go/kubernetes" ) // Injectors from wire.go: -func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds) (Servers, func(), error) { +func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (Servers, func(), error) { v := ProvideOtelGrpcServerOptions() server, cleanup := ProvideGrpcServer(v) config, err := ProvideClientConfig() @@ -34,7 +35,7 @@ func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSec cleanup() return Servers{}, nil, err } - snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubernetesInterface) + snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubernetesInterface, subZoneLabel) callbackFuncs := ProvideXdsLogger() serverServer, cleanup3 := ProvideXdsServer(ctx, snapshotter, callbackFuncs) sideEffectADSRegistered := ProvideSideEffectADSRegistered(server, serverServer) @@ -76,10 +77,10 @@ func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSec }, nil } -func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds) (TestServer, func(), error) { +func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (TestServer, func(), error) { v := ProvideGrpcTestOption() server, cleanup := ProvideGrpcServer(v) - snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubeClient) + snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubeClient, subZoneLabel) callbackFuncs := ProvideXdsLogger() serverServer, cleanup3 := ProvideXdsServer(ctx, snapshotter, callbackFuncs) sideEffectADSRegistered := ProvideSideEffectADSRegistered(server, serverServer) diff --git a/main.go b/main.go index 823e9b6..ca3ac1e 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "github.com/wongnai/xds/internal/di" "github.com/wongnai/xds/meter" + "github.com/wongnai/xds/snapshot" _ "k8s.io/client-go/plugin/pkg/client/auth/oidc" "k8s.io/klog/v2" ) @@ -19,13 +20,15 @@ func main() { var statsIntervalInSeconds int64 flag.CommandLine.Int64Var(&statsIntervalInSeconds, "statsinterval", 300, "stats update interval in seconds") + subZoneLabel := flag.String("sub-zone-label", snapshot.DefaultSubZoneLabel, + "Kubernetes node label read as sub-zone when a Service requests sub_zone locality") flag.Parse() ctx := context.Background() meter.InstallPromExporter() - servers, stop, err := di.InitializeServer(context.Background(), statsIntervalInSeconds) + servers, stop, err := di.InitializeServer(context.Background(), statsIntervalInSeconds, snapshot.SubZoneLabel(*subZoneLabel)) if err != nil { klog.Fatal(err) } diff --git a/snapshot/endpoints.go b/snapshot/endpoints.go index c26638b..dd31203 100644 --- a/snapshot/endpoints.go +++ b/snapshot/endpoints.go @@ -3,13 +3,14 @@ package snapshot import ( "context" "fmt" + "maps" + "slices" "sort" "github.com/ccoveille/go-safecast/v2" corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" "github.com/envoyproxy/go-control-plane/pkg/cache/types" - "github.com/envoyproxy/go-control-plane/pkg/cache/v3" "github.com/wongnai/xds/meter" "go.opentelemetry.io/otel/metric" "google.golang.org/protobuf/types/known/wrapperspb" @@ -22,15 +23,15 @@ import ( ) type endpointCacheItem struct { - version string - resources []types.Resource + version string + nodeVersion uint64 + mode LocalityMode + resources []types.Resource } func (s *Snapshotter) startEndpoints(ctx context.Context) error { - emit := func() {} - store := k8scache.NewUndeltaStore(func(v []interface{}) { - emit() + s.triggerEndpointsEmit() }, k8scache.DeletionHandlingMetaNamespaceKeyFunc) reflector := k8scache.NewReflector(&k8scache.ListWatch{ @@ -44,7 +45,7 @@ func (s *Snapshotter) startEndpoints(ctx context.Context) error { var lastSnapshotHash uint64 - emit = func() { + emit := func() { version := reflector.LastSyncResourceVersion() s.kubeEventCounter.Add(ctx, 1, metric.WithAttributes(meter.ResourceAttrKey.String("endpoints"))) @@ -53,7 +54,7 @@ func (s *Snapshotter) startEndpoints(ctx context.Context) error { hash, err := resourcesHash(endpointsResources) if err == nil { if hash == lastSnapshotHash { - klog.V(5).Info("new snapshot is equivalent to the previous one") + klog.V(5).Info("new endpoints snapshot is equivalent to the previous one") return } lastSnapshotHash = hash @@ -64,13 +65,11 @@ func (s *Snapshotter) startEndpoints(ctx context.Context) error { resourcesByType := resourcesToMap(endpointsResources) s.setEndpointResourcesByType(resourcesByType) - snapshot, err := cache.NewSnapshot(version, resourcesByType) - if err != nil { - panic(err) - } - - s.endpointsCache.SetSnapshot(ctx, "", snapshot) + s.endpointsCache.setResources(ctx, version, resourcesByType) } + s.emitMu.Lock() + s.emitEndpointsFn = emit + s.emitMu.Unlock() reflector.Run(ctx.Done()) return nil @@ -101,7 +100,14 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso klog.Errorf("fail to get object key: %s", err) return nil } - if val, ok := s.endpointResourceCache[name]; ok && val.version == ep.ResourceVersion { + + mode := s.localityModeFor(ep.Namespace, ep.Name) + nodeVersion := s.nodeLocality.getVersion() + + if val, ok := s.endpointResourceCache[name]; ok && + val.version == ep.ResourceVersion && + val.nodeVersion == nodeVersion && + val.mode == mode { return val.resources } @@ -109,32 +115,26 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso for _, subset := range ep.Subsets { for _, port := range subset.Ports { - var portName string + var clusterName string if port.Name == "" { - portName = fmt.Sprintf("%s.%s:%d", ep.Name, ep.Namespace, port.Port) + clusterName = fmt.Sprintf("%s.%s:%d", ep.Name, ep.Namespace, port.Port) } else { - portName = fmt.Sprintf("%s.%s:%s", ep.Name, ep.Namespace, port.Name) + clusterName = fmt.Sprintf("%s.%s:%s", ep.Name, ep.Namespace, port.Name) } cla := &endpointv3.ClusterLoadAssignment{ - ClusterName: portName, - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LoadBalancingWeight: wrapperspb.UInt32(1), - Locality: &corev3.Locality{}, - LbEndpoints: []*endpointv3.LbEndpoint{}, - }, - }, + ClusterName: clusterName, } out = append(out, cla) sortedAddresses := subset.Addresses sort.SliceStable(sortedAddresses, func(i, j int) bool { - l := sortedAddresses[i].IP - r := sortedAddresses[j].IP - return l < r + return sortedAddresses[i].IP < sortedAddresses[j].IP }) + portU32 := safecast.MustConvert[uint32](port.Port) + groups := map[string]*endpointv3.LocalityLbEndpoints{} + for _, addr := range sortedAddresses { hostname := addr.Hostname if hostname == "" && addr.TargetRef != nil { @@ -143,9 +143,18 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso if hostname == "" && addr.NodeName != nil { hostname = *addr.NodeName } - portU32 := safecast.MustConvert[uint32](port.Port) - cla.Endpoints[0].LbEndpoints = append(cla.Endpoints[0].LbEndpoints, &endpointv3.LbEndpoint{ + loc := s.localityForAddress(addr, mode) + key := loc.GetZone() + localityKeySep + loc.GetSubZone() + g, ok := groups[key] + if !ok { + g = &endpointv3.LocalityLbEndpoints{ + Locality: loc, + LoadBalancingWeight: wrapperspb.UInt32(1), + } + groups[key] = g + } + g.LbEndpoints = append(g.LbEndpoints, &endpointv3.LbEndpoint{ HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ Endpoint: &endpointv3.Endpoint{ Address: &corev3.Address{ @@ -164,13 +173,37 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso }, }) } + + // Emit in sorted locality-key order so hash is stable. + for _, k := range slices.Sorted(maps.Keys(groups)) { + cla.Endpoints = append(cla.Endpoints, groups[k]) + } } } s.endpointResourceCache[name] = endpointCacheItem{ - version: ep.ResourceVersion, - resources: out, + version: ep.ResourceVersion, + nodeVersion: nodeVersion, + mode: mode, + resources: out, } return out } + +// localityForAddress builds the Locality for a single endpoint +// address according to the service's locality mode. +func (s *Snapshotter) localityForAddress(addr corev1.EndpointAddress, mode LocalityMode) *corev3.Locality { + if mode == LocalityNone || addr.NodeName == nil { + return &corev3.Locality{} + } + info := s.nodeLocality.get(*addr.NodeName) + switch mode { + case LocalityZone: + return &corev3.Locality{Zone: info.zone} + case LocalitySubZone: + return &corev3.Locality{Zone: info.zone, SubZone: info.subZone} + default: + return &corev3.Locality{} + } +} diff --git a/snapshot/endpointscache.go b/snapshot/endpointscache.go new file mode 100644 index 0000000..4690d5c --- /dev/null +++ b/snapshot/endpointscache.go @@ -0,0 +1,213 @@ +package snapshot + +import ( + "context" + "slices" + "strings" + "sync" + + "github.com/envoyproxy/go-control-plane/pkg/cache/types" + "github.com/envoyproxy/go-control-plane/pkg/cache/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "google.golang.org/protobuf/proto" + "k8s.io/klog/v2" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" +) + +const localityKeySep = "|" + +// localityNodeHash partitions xDS clients by (zone, sub_zone). +type localityNodeHash struct{} + +func (localityNodeHash) ID(node *corev3.Node) string { + zone, subZone := nodeLocalityFromXds(node) + return zone + localityKeySep + subZone +} + +func nodeLocalityFromXds(node *corev3.Node) (zone, subZone string) { + if node != nil && node.GetLocality() != nil { + zone = node.GetLocality().GetZone() + subZone = node.GetLocality().GetSubZone() + } + return +} + +func splitLocalityKey(k string) (zone, subZone string) { + if idx := strings.Index(k, localityKeySep); idx >= 0 { + return k[:idx], k[idx+1:] + } + return k, "" +} + +// endpointsCache wraps a SnapshotCache to produce per-client-locality +// endpoint responses. It stores the current endpoint resources and, when +// a client issues a watch from a previously-unseen locality, synthesizes +// a snapshot for that locality on demand. +type endpointsCache struct { + inner cache.SnapshotCache + + mu sync.RWMutex + version string + resourcesByType map[string][]types.Resource +} + +func newEndpointsCache() *endpointsCache { + return &endpointsCache{ + inner: cache.NewSnapshotCache(false, localityNodeHash{}, Logger), + } +} + +// setResources replaces the cached endpoint resources and pushes a new +// snapshot for every locality that already has an active watch, plus the +// default (no-locality) key. +func (c *endpointsCache) setResources(ctx context.Context, version string, resourcesByType map[string][]types.Resource) { + c.mu.Lock() + c.version = version + c.resourcesByType = resourcesByType + c.mu.Unlock() + + keys := c.inner.GetStatusKeys() + if !slices.Contains(keys, "") { + keys = append(keys, "") + } + for _, k := range keys { + c.setSnapshotForKey(ctx, k, version, resourcesByType) + } +} + +func (c *endpointsCache) setSnapshotForKey(ctx context.Context, key, version string, resourcesByType map[string][]types.Resource) { + zone, subZone := splitLocalityKey(key) + filtered := resourcesForLocality(zone, subZone, resourcesByType) + snap, err := cache.NewSnapshot(version, filtered) + if err != nil { + klog.Errorf("fail to create endpoint snapshot for locality %q: %s", key, err) + return + } + if err := c.inner.SetSnapshot(ctx, key, snap); err != nil { + klog.Errorf("fail to set endpoint snapshot for locality %q: %s", key, err) + } +} + +// ensureSnapshotForNode makes sure a snapshot exists for the node's locality +// before the inner cache's CreateWatch runs. +func (c *endpointsCache) ensureSnapshotForNode(ctx context.Context, node *corev3.Node) { + key := localityNodeHash{}.ID(node) + if _, err := c.inner.GetSnapshot(key); err == nil { + return + } + c.mu.RLock() + version := c.version + resourcesByType := c.resourcesByType + c.mu.RUnlock() + if resourcesByType == nil { + return + } + c.setSnapshotForKey(ctx, key, version, resourcesByType) +} + +// resourcesForLocality returns a copy of resourcesByType where each +// ClusterLoadAssignment's LocalityLbEndpoints are re-prioritized relative +// to the client at (clientZone, clientSubZone). Non-endpoint types pass +// through untouched. Priorities are compacted so they always start at 0 +// and are contiguous (gRPC's priority LB requires that). +func resourcesForLocality(clientZone, clientSubZone string, resourcesByType map[string][]types.Resource) map[string][]types.Resource { + eps, ok := resourcesByType[resource.EndpointType] + if !ok { + return resourcesByType + } + out := make(map[string][]types.Resource, len(resourcesByType)) + for k, v := range resourcesByType { + if k != resource.EndpointType { + out[k] = v + } + } + newEps := make([]types.Resource, 0, len(eps)) + for _, r := range eps { + cla, ok := r.(*endpointv3.ClusterLoadAssignment) + if !ok { + newEps = append(newEps, r) + continue + } + newEps = append(newEps, assignmentWithPriorities(cla, clientZone, clientSubZone)) + } + out[resource.EndpointType] = newEps + return out +} + +// assignmentWithPriorities clones the input and assigns a priority to each +// LocalityLbEndpoints group based on how closely it matches the client's +// locality: +// +// score 2: group.zone == client.zone && group.sub_zone == client.sub_zone +// score 1: group.zone == client.zone +// score 0: no match (or client has no locality info) +// +// Higher score -> lower (more preferred) priority. Empty buckets are +// skipped so priorities stay contiguous. +func assignmentWithPriorities(in *endpointv3.ClusterLoadAssignment, clientZone, clientSubZone string) *endpointv3.ClusterLoadAssignment { + out := proto.Clone(in).(*endpointv3.ClusterLoadAssignment) + if len(out.Endpoints) == 0 { + return out + } + // score -> groups at that score + bySort := map[int][]*endpointv3.LocalityLbEndpoints{} + scores := make([]int, 0, 3) + for _, g := range out.Endpoints { + s := matchScore(g.GetLocality(), clientZone, clientSubZone) + if _, seen := bySort[s]; !seen { + scores = append(scores, s) + } + bySort[s] = append(bySort[s], g) + } + // Sort the scores in reverse order. + slices.Sort(scores) + slices.Reverse(scores) + + // xDS priorities, on the other hand, start from 0 and go up. + // Clients will first try to fill backends with priority 0 before + // spilling to backends with a higher priority value. + priority := uint32(0) + regrouped := make([]*endpointv3.LocalityLbEndpoints, 0, len(out.Endpoints)) + for _, s := range scores { + for _, g := range bySort[s] { + g.Priority = priority + regrouped = append(regrouped, g) + } + priority++ + } + out.Endpoints = regrouped + return out +} + +func matchScore(loc *corev3.Locality, clientZone, clientSubZone string) int { + if loc == nil || clientZone == "" { + return 0 + } + if loc.GetZone() == "" || loc.GetZone() != clientZone { + return 0 + } + if loc.GetSubZone() != "" && clientSubZone != "" && loc.GetSubZone() == clientSubZone { + return 2 + } + return 1 +} + +// CreateWatch implements cache.ConfigWatcher. +func (c *endpointsCache) CreateWatch(request *cache.Request, sub cache.Subscription, value chan cache.Response) (func(), error) { + c.ensureSnapshotForNode(context.Background(), request.GetNode()) + return c.inner.CreateWatch(request, sub, value) +} + +// CreateDeltaWatch implements cache.ConfigWatcher. +func (c *endpointsCache) CreateDeltaWatch(request *cache.DeltaRequest, sub cache.Subscription, value chan cache.DeltaResponse) (func(), error) { + c.ensureSnapshotForNode(context.Background(), request.GetNode()) + return c.inner.CreateDeltaWatch(request, sub, value) +} + +// Fetch implements cache.ConfigFetcher. +func (c *endpointsCache) Fetch(ctx context.Context, request *cache.Request) (cache.Response, error) { + c.ensureSnapshotForNode(ctx, request.GetNode()) + return c.inner.Fetch(ctx, request) +} diff --git a/snapshot/endpointscache_test.go b/snapshot/endpointscache_test.go new file mode 100644 index 0000000..9891c13 --- /dev/null +++ b/snapshot/endpointscache_test.go @@ -0,0 +1,150 @@ +package snapshot + +import ( + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "github.com/envoyproxy/go-control-plane/pkg/cache/types" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" +) + +func TestClaWithPrioritiesZone(t *testing.T) { + cla := &endpointv3.ClusterLoadAssignment{ + ClusterName: "svc", + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "zone-a"}}, + {Locality: &corev3.Locality{Zone: "zone-b"}}, + {Locality: &corev3.Locality{Zone: "zone-c"}}, + }, + } + + out := assignmentWithPriorities(cla, "zone-b", "") + + if got := len(out.Endpoints); got != 3 { + t.Fatalf("expected 3 groups, got %d", got) + } + // zone-b should be priority 0; everyone else priority 1, and must be + // contiguous. + seenPri := map[uint32]string{} + for _, g := range out.Endpoints { + seenPri[g.Priority] = g.GetLocality().GetZone() + } + if seenPri[0] != "zone-b" { + t.Errorf("priority 0 = %q, want zone-b", seenPri[0]) + } + foundP1 := 0 + for _, g := range out.Endpoints { + if g.Priority == 1 { + foundP1++ + } + } + if foundP1 != 2 { + t.Errorf("priority 1 count = %d, want 2", foundP1) + } + // Must be contiguous. No priority 2. + for _, g := range out.Endpoints { + if g.Priority > 1 { + t.Errorf("unexpected priority %d (priorities must be contiguous)", g.Priority) + } + } +} + +func TestClaWithPrioritiesSubZone(t *testing.T) { + cla := &endpointv3.ClusterLoadAssignment{ + ClusterName: "svc", + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "z1", SubZone: "r1"}}, // exact + {Locality: &corev3.Locality{Zone: "z1", SubZone: "r2"}}, // same zone + {Locality: &corev3.Locality{Zone: "z2", SubZone: "r1"}}, // diff zone + }, + } + + out := assignmentWithPriorities(cla, "z1", "r1") + + pri := map[string]uint32{} + for _, g := range out.Endpoints { + loc := g.GetLocality() + pri[loc.Zone+"/"+loc.SubZone] = g.Priority + } + if pri["z1/r1"] != 0 { + t.Errorf("exact match priority = %d, want 0", pri["z1/r1"]) + } + if pri["z1/r2"] != 1 { + t.Errorf("same-zone priority = %d, want 1", pri["z1/r2"]) + } + if pri["z2/r1"] != 2 { + t.Errorf("diff-zone priority = %d, want 2", pri["z2/r1"]) + } +} + +func TestClaWithPrioritiesCompaction(t *testing.T) { + // Only same-zone and diff-zone exist (no exact sub_zone match). The + // output must start at priority 0 and stay contiguous. + cla := &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "z1", SubZone: "r2"}}, + {Locality: &corev3.Locality{Zone: "z2", SubZone: "r1"}}, + }, + } + + out := assignmentWithPriorities(cla, "z1", "r1") + + maxPri := uint32(0) + sawZero := false + for _, g := range out.Endpoints { + if g.Priority == 0 { + sawZero = true + } + if g.Priority > maxPri { + maxPri = g.Priority + } + } + if !sawZero { + t.Error("missing priority 0") + } + if maxPri != 1 { + t.Errorf("max priority = %d, want 1 (contiguous)", maxPri) + } +} + +func TestClaWithPrioritiesNoClientLocality(t *testing.T) { + // Client has no locality info. All groups get the same priority (0) + // so no preference is expressed. + cla := &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "z1"}}, + {Locality: &corev3.Locality{Zone: "z2"}}, + }, + } + + out := assignmentWithPriorities(cla, "", "") + + for _, g := range out.Endpoints { + if g.Priority != 0 { + t.Errorf("got priority %d with empty client locality; all should be 0", g.Priority) + } + } +} + +func TestResourcesForLocalityPassesThroughNonEndpoints(t *testing.T) { + in := map[string][]types.Resource{ + resource.ClusterType: {&endpointv3.ClusterLoadAssignment{ClusterName: "x"}}, // wrong type on purpose + } + out := resourcesForLocality("z", "", in) + if _, ok := out[resource.EndpointType]; ok { + t.Error("unexpected endpoint resources added") + } + if len(out[resource.ClusterType]) != 1 { + t.Error("non-endpoint resources should pass through unchanged") + } +} + +func TestLocalityNodeHashIncludesSubZone(t *testing.T) { + h := localityNodeHash{} + a := h.ID(&corev3.Node{Locality: &corev3.Locality{Zone: "z1", SubZone: "r1"}}) + b := h.ID(&corev3.Node{Locality: &corev3.Locality{Zone: "z1", SubZone: "r2"}}) + if a == b { + t.Errorf("expected distinct hashes for different sub_zone; both are %q", a) + } +} diff --git a/snapshot/locality.go b/snapshot/locality.go new file mode 100644 index 0000000..8f85299 --- /dev/null +++ b/snapshot/locality.go @@ -0,0 +1,46 @@ +package snapshot + +import ( + corev1 "k8s.io/api/core/v1" +) + +const ( + // AnnotationLocalityPreference on a Service selects how endpoints are + // split into localities. Values: + // "zone" - split by topology.kubernetes.io/zone + // "sub_zone" - split by zone + the sub-zone label + // "" - no split; a single empty-locality group (default) + AnnotationLocalityPreference = "xds.lmwn.com/locality-preference" + + // LabelZone is the well-known Kubernetes node zone label. + LabelZone = "topology.kubernetes.io/zone" + + // DefaultSubZoneLabel is the default node label we read as sub-zone when + // no override is provided. There is no Kubernetes standard for sub-zone, + // so we default to the common "rack" convention that mirrors the + // topology.kubernetes.io/zone shape. + DefaultSubZoneLabel = "topology.kubernetes.io/rack" +) + +// LocalityMode describes the locality split for a single service. +type LocalityMode int + +const ( + LocalityNone LocalityMode = iota + LocalityZone + LocalitySubZone +) + +func localityModeFromService(svc *corev1.Service) LocalityMode { + if svc == nil { + return LocalityNone + } + switch svc.GetAnnotations()[AnnotationLocalityPreference] { + case "zone": + return LocalityZone + case "sub_zone": + return LocalitySubZone + default: + return LocalityNone + } +} diff --git a/snapshot/nodes.go b/snapshot/nodes.go new file mode 100644 index 0000000..54e6f03 --- /dev/null +++ b/snapshot/nodes.go @@ -0,0 +1,112 @@ +package snapshot + +import ( + "context" + "sync" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/klog/v2" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8scache "k8s.io/client-go/tools/cache" +) + +type nodeLocality struct { + zone string + subZone string +} + +// nodeLocalityStore holds a snapshot of node to {zone, sub-zone} mapping. +// The endpoints builder consults this store when assigning locality to each +// LbEndpoint. +// Version is bumped on every observed change so callers can invalidate +// per-endpoint caches. +type nodeLocalityStore struct { + subZoneLabel string + + mu sync.RWMutex + nodes map[string]nodeLocality + version uint64 +} + +func newNodeLocalityStore(subZoneLabel string) *nodeLocalityStore { + if subZoneLabel == "" { + subZoneLabel = DefaultSubZoneLabel + } + return &nodeLocalityStore{ + subZoneLabel: subZoneLabel, + nodes: map[string]nodeLocality{}, + } +} + +func (s *nodeLocalityStore) get(name string) nodeLocality { + s.mu.RLock() + defer s.mu.RUnlock() + return s.nodes[name] +} + +func (s *nodeLocalityStore) getVersion() uint64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.version +} + +// apply replaces the store with the nodes. Returns true if anything changed. +func (s *nodeLocalityStore) apply(nodes []*corev1.Node) bool { + next := make(map[string]nodeLocality, len(nodes)) + for _, n := range nodes { + next[n.Name] = nodeLocality{ + zone: n.Labels[LabelZone], + subZone: n.Labels[s.subZoneLabel], + } + } + s.mu.Lock() + defer s.mu.Unlock() + if nodeMapsEqual(s.nodes, next) { + return false + } + s.nodes = next + s.version++ + return true +} + +func nodeMapsEqual(a, b map[string]nodeLocality) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } + } + return true +} + +func (s *Snapshotter) startNodes(ctx context.Context) error { + store := k8scache.NewUndeltaStore(func(v []interface{}) { + nodes := make([]*corev1.Node, 0, len(v)) + for _, obj := range v { + if n, ok := obj.(*corev1.Node); ok { + nodes = append(nodes, n) + } + } + if s.nodeLocality.apply(nodes) { + klog.V(2).Infof("node locality updated to version %d (%d nodes)", s.nodeLocality.getVersion(), len(nodes)) + s.triggerEndpointsEmit() + } + }, k8scache.DeletionHandlingMetaNamespaceKeyFunc) + + reflector := k8scache.NewReflector(&k8scache.ListWatch{ + ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) { + return s.client.CoreV1().Nodes().List(ctx, options) + }, + WatchFuncWithContext: func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) { + return s.client.CoreV1().Nodes().Watch(ctx, options) + }, + }, &corev1.Node{}, store, s.ResyncPeriod) + + reflector.Run(ctx.Done()) + return nil +} diff --git a/snapshot/services.go b/snapshot/services.go index 24b1a84..4561e05 100644 --- a/snapshot/services.go +++ b/snapshot/services.go @@ -35,7 +35,13 @@ func (s *Snapshotter) startServices(ctx context.Context) error { store := k8scache.NewUndeltaStore(func(v []interface{}) { emit() + // A service's locality-preference annotation affects how its + // endpoints are split into LocalityLbEndpoints, so any service + // change requires re-emitting endpoints. The endpoint emit path + // dedups via its own hash. + s.triggerEndpointsEmit() }, k8scache.DeletionHandlingMetaNamespaceKeyFunc) + s.serviceStore = store reflector := k8scache.NewReflector(&k8scache.ListWatch{ ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) { diff --git a/snapshot/snapshotter.go b/snapshot/snapshotter.go index 9332c51..6aa5ffb 100644 --- a/snapshot/snapshotter.go +++ b/snapshot/snapshotter.go @@ -12,7 +12,9 @@ import ( "github.com/wongnai/xds/meter" "go.opentelemetry.io/otel/metric" "golang.org/x/sync/errgroup" + corev1 "k8s.io/api/core/v1" "k8s.io/client-go/kubernetes" + k8scache "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" ) @@ -47,20 +49,34 @@ type Snapshotter struct { client kubernetes.Interface servicesCache cache.SnapshotCache - endpointsCache cache.SnapshotCache + endpointsCache *endpointsCache muxCache cache.MuxCache - endpointResourceCache map[string]endpointCacheItem + nodeLocality *nodeLocalityStore + + endpointResourceCache map[string]endpointCacheItem + resourcesByTypeLock sync.RWMutex serviceResourcesByType map[string][]types.Resource endpointResourcesByType map[string][]types.Resource apiGatewayStats map[string]int kubeEventCounter metric.Int64Counter + + // serviceStore is captured at informer start so other code paths + // (e.g. per-endpoint locality lookups) can consult Service objects. + serviceStore k8scache.Store + + emitMu sync.Mutex + emitEndpointsFn func() } -func New(client kubernetes.Interface) *Snapshotter { +// SubZoneLabel selects the Kubernetes node label used as sub-zone for the +// sub_zone-preference locality mode. Empty = use DefaultSubZoneLabel. +type SubZoneLabel string + +func New(client kubernetes.Interface, subZoneLabel SubZoneLabel) *Snapshotter { servicesCache := cache.NewSnapshotCache(false, EmptyNodeID{}, Logger) - endpointsCache := cache.NewSnapshotCache(false, EmptyNodeID{}, Logger) + endpointsCache := newEndpointsCache() muxCache := cache.MuxCache{ Classify: func(r *cache.Request) string { return mapTypeURL(r.TypeUrl) @@ -82,6 +98,8 @@ func New(client kubernetes.Interface) *Snapshotter { endpointsCache: endpointsCache, muxCache: muxCache, + nodeLocality: newNodeLocalityStore(string(subZoneLabel)), + endpointResourceCache: map[string]endpointCacheItem{}, } @@ -105,9 +123,41 @@ func (s *Snapshotter) Start(stopCtx context.Context) error { group.Go(func() error { return s.startEndpoints(groupCtx) }) + group.Go(func() error { + return s.startNodes(groupCtx) + }) return group.Wait() } +// triggerEndpointsEmit runs the endpoints emit pipeline under a mutex so +// concurrent triggers from multiple informers don't race on the shared +// endpointResourceCache and the emit closure's snapshot-hash state. +func (s *Snapshotter) triggerEndpointsEmit() { + s.emitMu.Lock() + defer s.emitMu.Unlock() + if s.emitEndpointsFn != nil { + s.emitEndpointsFn() + } +} + +// localityModeFor reads the Service object behind a namespace/name and +// returns its configured locality mode. Returns LocalityNone if the +// service is not yet known. +func (s *Snapshotter) localityModeFor(namespace, name string) LocalityMode { + if s.serviceStore == nil { + return LocalityNone + } + obj, exists, err := s.serviceStore.GetByKey(namespace + "/" + name) + if err != nil || !exists { + return LocalityNone + } + svc, ok := obj.(*corev1.Service) + if !ok { + return LocalityNone + } + return localityModeFromService(svc) +} + func (s *Snapshotter) snapshotResourceGaugeCallback(_ context.Context, result metric.Int64Observer) error { for k, r := range s.getServiceResourcesByType() { result.Observe(int64(len(r)), metric.WithAttributes(meter.TypeURLAttrKey.String(k))) diff --git a/test/integration_test.go b/test/integration_test.go index 33cee65..f6597ec 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -1,15 +1,22 @@ package test_test import ( + "context" "fmt" "net" "testing" + "time" + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + discoveryv3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/wongnai/xds/internal/di" + "github.com/wongnai/xds/snapshot" "github.com/wongnai/xds/snapshot/apigateway" "github.com/wongnai/xds/test" "google.golang.org/grpc" @@ -247,6 +254,199 @@ func (s *XdsIntegrationTestSuite) TestApiGateway() { require.NoError(s.T(), err) } +// localityEndpoint describes one backend to register behind a Service, along +// with which Node it runs on. +type localityEndpoint struct { + IP string + NodeName string +} + +func (s *XdsIntegrationTestSuite) createKubeNode(name, zone, subZone string) { + node := &corev1.Node{ + TypeMeta: metav1.TypeMeta{APIVersion: "v1", Kind: "Node"}, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: map[string]string{ + snapshot.LabelZone: zone, + snapshot.DefaultSubZoneLabel: subZone, + }, + }, + } + err := s.kube.Tracker().Add(node) + s.Require().NoError(err) +} + +func (s *XdsIntegrationTestSuite) createLocalityService(name, namespace string, port int32, mode string) { + svc := &test.K8SService{ + Name: name, + Namespace: namespace, + Annotations: map[string]string{ + snapshot.AnnotationLocalityPreference: mode, + }, + Ports: []corev1.ServicePort{{ + Name: "grpc", + Port: port, + Protocol: corev1.ProtocolTCP, + }}, + } + err := s.kube.Tracker().Add(svc.AsK8S()) + s.Require().NoError(err) +} + +func (s *XdsIntegrationTestSuite) createKubeEndpointWithNodes(name, namespace string, addrs []localityEndpoint, port int32) { + addresses := make([]corev1.EndpointAddress, 0, len(addrs)) //nolint:staticcheck // legacy Endpoints API + for _, a := range addrs { + nodeName := a.NodeName + addresses = append(addresses, corev1.EndpointAddress{ //nolint:staticcheck // See above + IP: a.IP, + NodeName: &nodeName, + }) + } + endpoint := &corev1.Endpoints{ //nolint:staticcheck // See above + TypeMeta: metav1.TypeMeta{APIVersion: "v1", Kind: "Endpoints"}, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Subsets: []corev1.EndpointSubset{{ //nolint:staticcheck // See above + Addresses: addresses, + Ports: []corev1.EndpointPort{{Name: "grpc", Port: port}}, //nolint:staticcheck // See above + }}, + } + err := s.kube.Tracker().Add(endpoint) + s.Require().NoError(err) +} + +// fetchEDS opens a one-shot ADS stream with the given client locality and +// returns the first matching ClusterLoadAssignment. The stream is +// re-established each call so snapshot-version state is fresh. Uses Eventually +// because the K8s reflector is asynchronous. +func (s *XdsIntegrationTestSuite) fetchEDS(clientZone, clientSubZone, resourceName string, expectedGroups int) *endpointv3.ClusterLoadAssignment { + s.T().Helper() + var cla *endpointv3.ClusterLoadAssignment + s.Require().Eventually(func() bool { + conn, err := grpc.NewClient(s.listener.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return false + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(s.T().Context(), 2*time.Second) + defer cancel() + + stream, err := discoveryv3.NewAggregatedDiscoveryServiceClient(conn).StreamAggregatedResources(ctx) + if err != nil { + return false + } + err = stream.Send(&discoveryv3.DiscoveryRequest{ + Node: &corev3.Node{ + Id: "test-" + clientZone + "-" + clientSubZone, + Locality: &corev3.Locality{Zone: clientZone, SubZone: clientSubZone}, + }, + ResourceNames: []string{resourceName}, + TypeUrl: resource.EndpointType, + }) + if err != nil { + return false + } + resp, err := stream.Recv() + if err != nil || len(resp.Resources) == 0 { + return false + } + candidate := &endpointv3.ClusterLoadAssignment{} + if err := resp.Resources[0].UnmarshalTo(candidate); err != nil { + return false + } + if len(candidate.Endpoints) != expectedGroups { + return false + } + cla = candidate + return true + }, 5*time.Second, 100*time.Millisecond, "timed out waiting for CLA with %d locality groups", expectedGroups) + return cla +} + +func (s *XdsIntegrationTestSuite) TestLocalityZonePriorities() { + s.createKubeNode("node-a", "zone-a", "") + s.createKubeNode("node-b", "zone-b", "") + s.createKubeNode("node-c", "zone-c", "") + + s.createLocalityService("zsvc", "default", 50100, "zone") + s.createKubeEndpointWithNodes("zsvc", "default", []localityEndpoint{ + {IP: "10.200.0.1", NodeName: "node-a"}, + {IP: "10.200.0.2", NodeName: "node-b"}, + {IP: "10.200.0.3", NodeName: "node-c"}, + }, 50100) + + // Client in zone-b: zone-b should be priority 0, others contiguous at 1. + cla := s.fetchEDS("zone-b", "", "zsvc.default:grpc", 3) + priByZone := map[string]uint32{} + for _, g := range cla.Endpoints { + priByZone[g.GetLocality().GetZone()] = g.GetPriority() + } + s.Equal(uint32(0), priByZone["zone-b"], "zone-b should be preferred for zone-b client") + s.Equal(uint32(1), priByZone["zone-a"]) + s.Equal(uint32(1), priByZone["zone-c"]) + + // Different client locality flips the priorities. + cla = s.fetchEDS("zone-a", "", "zsvc.default:grpc", 3) + priByZone = map[string]uint32{} + for _, g := range cla.Endpoints { + priByZone[g.GetLocality().GetZone()] = g.GetPriority() + } + s.Equal(uint32(0), priByZone["zone-a"], "zone-a should be preferred for zone-a client") + s.Equal(uint32(1), priByZone["zone-b"]) + s.Equal(uint32(1), priByZone["zone-c"]) +} + +func (s *XdsIntegrationTestSuite) TestLocalitySubZonePriorities() { + s.createKubeNode("n1", "zone-a", "rack-1") + s.createKubeNode("n2", "zone-a", "rack-2") + s.createKubeNode("n3", "zone-b", "rack-1") + + s.createLocalityService("ssvc", "default", 50101, "sub_zone") + s.createKubeEndpointWithNodes("ssvc", "default", []localityEndpoint{ + {IP: "10.201.0.1", NodeName: "n1"}, + {IP: "10.201.0.2", NodeName: "n2"}, + {IP: "10.201.0.3", NodeName: "n3"}, + }, 50101) + + // Client at (zone-a, rack-1): + // exact match → priority 0 + // same-zone only → priority 1 + // different zone → priority 2 + cla := s.fetchEDS("zone-a", "rack-1", "ssvc.default:grpc", 3) + priByLoc := map[string]uint32{} + for _, g := range cla.Endpoints { + loc := g.GetLocality() + priByLoc[loc.GetZone()+"/"+loc.GetSubZone()] = g.GetPriority() + } + s.Equal(uint32(0), priByLoc["zone-a/rack-1"]) + s.Equal(uint32(1), priByLoc["zone-a/rack-2"]) + s.Equal(uint32(2), priByLoc["zone-b/rack-1"]) +} + +func (s *XdsIntegrationTestSuite) TestLocalityNoAnnotationNoSplit() { + // Without the annotation, all endpoints should land in a single + // empty-locality group regardless of node labels — behavior matches the + // pre-locality default. + s.createKubeNode("plain-a", "zone-a", "") + s.createKubeNode("plain-b", "zone-b", "") + + s.createKubeService("plainsvc", "default", 50102) + s.createKubeEndpointWithNodes("plainsvc", "default", []localityEndpoint{ + {IP: "10.202.0.1", NodeName: "plain-a"}, + {IP: "10.202.0.2", NodeName: "plain-b"}, + }, 50102) + + cla := s.fetchEDS("zone-a", "", "plainsvc.default:grpc", 1) + s.Require().Len(cla.Endpoints, 1) + s.Equal(uint32(0), cla.Endpoints[0].GetPriority()) + s.Empty(cla.Endpoints[0].GetLocality().GetZone()) + s.Len(cla.Endpoints[0].GetLbEndpoints(), 2) +} + func TestXdsIntegration(t *testing.T) { // The fake clientset doesn't implement the WatchList bookmark protocol, // so disable WatchListClient to use the traditional List+Watch flow. @@ -254,7 +454,7 @@ func TestXdsIntegration(t *testing.T) { kube := fake.NewClientset() - testServer, stop, err := di.InitializeTestServer(t.Context(), kube, 1) + testServer, stop, err := di.InitializeTestServer(t.Context(), kube, 1, "") require.NoError(t, err) defer stop() diff --git a/test/xds_test.go b/test/xds_test.go index 2cf5467..337eeca 100644 --- a/test/xds_test.go +++ b/test/xds_test.go @@ -30,7 +30,7 @@ func (s *XdsSuite) SetupTest() { var err error s.kube = fake.NewClientset() s.conn = bufconn.Listen(1) - s.TestServer, s.stop, err = di.InitializeTestServer(s.T().Context(), s.kube, 1) + s.TestServer, s.stop, err = di.InitializeTestServer(s.T().Context(), s.kube, 1, "") s.Require().NoError(err) go s.TestServer.GrpcServer.Serve(s.conn)