diff --git a/internal/di/k8sxds.go b/internal/di/k8sxds.go index 3241000..e7f6eb5 100644 --- a/internal/di/k8sxds.go +++ b/internal/di/k8sxds.go @@ -22,9 +22,9 @@ var K8sXdsSet = wire.NewSet( ProvideLRSServer, ) -func ProvideSnapshotter(ctx context.Context, k8sClient kubernetes.Interface) (*snapshot.Snapshotter, func()) { +func ProvideSnapshotter(ctx context.Context, k8sClient kubernetes.Interface, subZoneLabel snapshot.SubZoneLabel) (*snapshot.Snapshotter, func()) { stopCtx, stop := context.WithCancel(ctx) - snapshotter := snapshot.New(k8sClient) + snapshotter := snapshot.New(k8sClient, subZoneLabel) go func() { err := snapshotter.Start(stopCtx) diff --git a/internal/di/wire.go b/internal/di/wire.go index 257b8ab..3fa1ec9 100644 --- a/internal/di/wire.go +++ b/internal/di/wire.go @@ -6,6 +6,7 @@ import ( "context" "github.com/google/wire" "github.com/wongnai/xds/debug" + "github.com/wongnai/xds/snapshot" "google.golang.org/grpc" "k8s.io/client-go/kubernetes" ) @@ -25,7 +26,7 @@ type DevServer struct { GrpcServer *grpc.Server } -func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds) (Servers, func(), error) { +func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (Servers, func(), error) { wire.Build( KubernetesSet, GrpcSet, @@ -38,7 +39,7 @@ func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSec return Servers{}, nil, nil } -func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds) (TestServer, func(), error) { +func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (TestServer, func(), error) { wire.Build( GrpcSet, K8sXdsSet, diff --git a/internal/di/wire_gen.go b/internal/di/wire_gen.go index 7ef58be..2700167 100644 --- a/internal/di/wire_gen.go +++ b/internal/di/wire_gen.go @@ -9,13 +9,14 @@ package di import ( "context" "github.com/wongnai/xds/debug" + "github.com/wongnai/xds/snapshot" "google.golang.org/grpc" "k8s.io/client-go/kubernetes" ) // Injectors from wire.go: -func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds) (Servers, func(), error) { +func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (Servers, func(), error) { v := ProvideOtelGrpcServerOptions() server, cleanup := ProvideGrpcServer(v) config, err := ProvideClientConfig() @@ -34,7 +35,7 @@ func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSec cleanup() return Servers{}, nil, err } - snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubernetesInterface) + snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubernetesInterface, subZoneLabel) callbackFuncs := ProvideXdsLogger() serverServer, cleanup3 := ProvideXdsServer(ctx, snapshotter, callbackFuncs) sideEffectADSRegistered := ProvideSideEffectADSRegistered(server, serverServer) @@ -76,10 +77,10 @@ func InitializeServer(ctx context.Context, statsIntervalSeconds StatsIntervalSec }, nil } -func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds) (TestServer, func(), error) { +func InitializeTestServer(ctx context.Context, kubeClient kubernetes.Interface, statsIntervalSeconds StatsIntervalSeconds, subZoneLabel snapshot.SubZoneLabel) (TestServer, func(), error) { v := ProvideGrpcTestOption() server, cleanup := ProvideGrpcServer(v) - snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubeClient) + snapshotter, cleanup2 := ProvideSnapshotter(ctx, kubeClient, subZoneLabel) callbackFuncs := ProvideXdsLogger() serverServer, cleanup3 := ProvideXdsServer(ctx, snapshotter, callbackFuncs) sideEffectADSRegistered := ProvideSideEffectADSRegistered(server, serverServer) diff --git a/main.go b/main.go index 823e9b6..ca3ac1e 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "github.com/wongnai/xds/internal/di" "github.com/wongnai/xds/meter" + "github.com/wongnai/xds/snapshot" _ "k8s.io/client-go/plugin/pkg/client/auth/oidc" "k8s.io/klog/v2" ) @@ -19,13 +20,15 @@ func main() { var statsIntervalInSeconds int64 flag.CommandLine.Int64Var(&statsIntervalInSeconds, "statsinterval", 300, "stats update interval in seconds") + subZoneLabel := flag.String("sub-zone-label", snapshot.DefaultSubZoneLabel, + "Kubernetes node label read as sub-zone when a Service requests sub_zone locality") flag.Parse() ctx := context.Background() meter.InstallPromExporter() - servers, stop, err := di.InitializeServer(context.Background(), statsIntervalInSeconds) + servers, stop, err := di.InitializeServer(context.Background(), statsIntervalInSeconds, snapshot.SubZoneLabel(*subZoneLabel)) if err != nil { klog.Fatal(err) } diff --git a/snapshot/endpoints.go b/snapshot/endpoints.go index c26638b..dd31203 100644 --- a/snapshot/endpoints.go +++ b/snapshot/endpoints.go @@ -3,13 +3,14 @@ package snapshot import ( "context" "fmt" + "maps" + "slices" "sort" "github.com/ccoveille/go-safecast/v2" corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" "github.com/envoyproxy/go-control-plane/pkg/cache/types" - "github.com/envoyproxy/go-control-plane/pkg/cache/v3" "github.com/wongnai/xds/meter" "go.opentelemetry.io/otel/metric" "google.golang.org/protobuf/types/known/wrapperspb" @@ -22,15 +23,15 @@ import ( ) type endpointCacheItem struct { - version string - resources []types.Resource + version string + nodeVersion uint64 + mode LocalityMode + resources []types.Resource } func (s *Snapshotter) startEndpoints(ctx context.Context) error { - emit := func() {} - store := k8scache.NewUndeltaStore(func(v []interface{}) { - emit() + s.triggerEndpointsEmit() }, k8scache.DeletionHandlingMetaNamespaceKeyFunc) reflector := k8scache.NewReflector(&k8scache.ListWatch{ @@ -44,7 +45,7 @@ func (s *Snapshotter) startEndpoints(ctx context.Context) error { var lastSnapshotHash uint64 - emit = func() { + emit := func() { version := reflector.LastSyncResourceVersion() s.kubeEventCounter.Add(ctx, 1, metric.WithAttributes(meter.ResourceAttrKey.String("endpoints"))) @@ -53,7 +54,7 @@ func (s *Snapshotter) startEndpoints(ctx context.Context) error { hash, err := resourcesHash(endpointsResources) if err == nil { if hash == lastSnapshotHash { - klog.V(5).Info("new snapshot is equivalent to the previous one") + klog.V(5).Info("new endpoints snapshot is equivalent to the previous one") return } lastSnapshotHash = hash @@ -64,13 +65,11 @@ func (s *Snapshotter) startEndpoints(ctx context.Context) error { resourcesByType := resourcesToMap(endpointsResources) s.setEndpointResourcesByType(resourcesByType) - snapshot, err := cache.NewSnapshot(version, resourcesByType) - if err != nil { - panic(err) - } - - s.endpointsCache.SetSnapshot(ctx, "", snapshot) + s.endpointsCache.setResources(ctx, version, resourcesByType) } + s.emitMu.Lock() + s.emitEndpointsFn = emit + s.emitMu.Unlock() reflector.Run(ctx.Done()) return nil @@ -101,7 +100,14 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso klog.Errorf("fail to get object key: %s", err) return nil } - if val, ok := s.endpointResourceCache[name]; ok && val.version == ep.ResourceVersion { + + mode := s.localityModeFor(ep.Namespace, ep.Name) + nodeVersion := s.nodeLocality.getVersion() + + if val, ok := s.endpointResourceCache[name]; ok && + val.version == ep.ResourceVersion && + val.nodeVersion == nodeVersion && + val.mode == mode { return val.resources } @@ -109,32 +115,26 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso for _, subset := range ep.Subsets { for _, port := range subset.Ports { - var portName string + var clusterName string if port.Name == "" { - portName = fmt.Sprintf("%s.%s:%d", ep.Name, ep.Namespace, port.Port) + clusterName = fmt.Sprintf("%s.%s:%d", ep.Name, ep.Namespace, port.Port) } else { - portName = fmt.Sprintf("%s.%s:%s", ep.Name, ep.Namespace, port.Name) + clusterName = fmt.Sprintf("%s.%s:%s", ep.Name, ep.Namespace, port.Name) } cla := &endpointv3.ClusterLoadAssignment{ - ClusterName: portName, - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LoadBalancingWeight: wrapperspb.UInt32(1), - Locality: &corev3.Locality{}, - LbEndpoints: []*endpointv3.LbEndpoint{}, - }, - }, + ClusterName: clusterName, } out = append(out, cla) sortedAddresses := subset.Addresses sort.SliceStable(sortedAddresses, func(i, j int) bool { - l := sortedAddresses[i].IP - r := sortedAddresses[j].IP - return l < r + return sortedAddresses[i].IP < sortedAddresses[j].IP }) + portU32 := safecast.MustConvert[uint32](port.Port) + groups := map[string]*endpointv3.LocalityLbEndpoints{} + for _, addr := range sortedAddresses { hostname := addr.Hostname if hostname == "" && addr.TargetRef != nil { @@ -143,9 +143,18 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso if hostname == "" && addr.NodeName != nil { hostname = *addr.NodeName } - portU32 := safecast.MustConvert[uint32](port.Port) - cla.Endpoints[0].LbEndpoints = append(cla.Endpoints[0].LbEndpoints, &endpointv3.LbEndpoint{ + loc := s.localityForAddress(addr, mode) + key := loc.GetZone() + localityKeySep + loc.GetSubZone() + g, ok := groups[key] + if !ok { + g = &endpointv3.LocalityLbEndpoints{ + Locality: loc, + LoadBalancingWeight: wrapperspb.UInt32(1), + } + groups[key] = g + } + g.LbEndpoints = append(g.LbEndpoints, &endpointv3.LbEndpoint{ HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ Endpoint: &endpointv3.Endpoint{ Address: &corev3.Address{ @@ -164,13 +173,37 @@ func (s *Snapshotter) kubeEndpointToResources(ep *corev1.Endpoints) []types.Reso }, }) } + + // Emit in sorted locality-key order so hash is stable. + for _, k := range slices.Sorted(maps.Keys(groups)) { + cla.Endpoints = append(cla.Endpoints, groups[k]) + } } } s.endpointResourceCache[name] = endpointCacheItem{ - version: ep.ResourceVersion, - resources: out, + version: ep.ResourceVersion, + nodeVersion: nodeVersion, + mode: mode, + resources: out, } return out } + +// localityForAddress builds the Locality for a single endpoint +// address according to the service's locality mode. +func (s *Snapshotter) localityForAddress(addr corev1.EndpointAddress, mode LocalityMode) *corev3.Locality { + if mode == LocalityNone || addr.NodeName == nil { + return &corev3.Locality{} + } + info := s.nodeLocality.get(*addr.NodeName) + switch mode { + case LocalityZone: + return &corev3.Locality{Zone: info.zone} + case LocalitySubZone: + return &corev3.Locality{Zone: info.zone, SubZone: info.subZone} + default: + return &corev3.Locality{} + } +} diff --git a/snapshot/endpointscache.go b/snapshot/endpointscache.go new file mode 100644 index 0000000..4690d5c --- /dev/null +++ b/snapshot/endpointscache.go @@ -0,0 +1,213 @@ +package snapshot + +import ( + "context" + "slices" + "strings" + "sync" + + "github.com/envoyproxy/go-control-plane/pkg/cache/types" + "github.com/envoyproxy/go-control-plane/pkg/cache/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "google.golang.org/protobuf/proto" + "k8s.io/klog/v2" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" +) + +const localityKeySep = "|" + +// localityNodeHash partitions xDS clients by (zone, sub_zone). +type localityNodeHash struct{} + +func (localityNodeHash) ID(node *corev3.Node) string { + zone, subZone := nodeLocalityFromXds(node) + return zone + localityKeySep + subZone +} + +func nodeLocalityFromXds(node *corev3.Node) (zone, subZone string) { + if node != nil && node.GetLocality() != nil { + zone = node.GetLocality().GetZone() + subZone = node.GetLocality().GetSubZone() + } + return +} + +func splitLocalityKey(k string) (zone, subZone string) { + if idx := strings.Index(k, localityKeySep); idx >= 0 { + return k[:idx], k[idx+1:] + } + return k, "" +} + +// endpointsCache wraps a SnapshotCache to produce per-client-locality +// endpoint responses. It stores the current endpoint resources and, when +// a client issues a watch from a previously-unseen locality, synthesizes +// a snapshot for that locality on demand. +type endpointsCache struct { + inner cache.SnapshotCache + + mu sync.RWMutex + version string + resourcesByType map[string][]types.Resource +} + +func newEndpointsCache() *endpointsCache { + return &endpointsCache{ + inner: cache.NewSnapshotCache(false, localityNodeHash{}, Logger), + } +} + +// setResources replaces the cached endpoint resources and pushes a new +// snapshot for every locality that already has an active watch, plus the +// default (no-locality) key. +func (c *endpointsCache) setResources(ctx context.Context, version string, resourcesByType map[string][]types.Resource) { + c.mu.Lock() + c.version = version + c.resourcesByType = resourcesByType + c.mu.Unlock() + + keys := c.inner.GetStatusKeys() + if !slices.Contains(keys, "") { + keys = append(keys, "") + } + for _, k := range keys { + c.setSnapshotForKey(ctx, k, version, resourcesByType) + } +} + +func (c *endpointsCache) setSnapshotForKey(ctx context.Context, key, version string, resourcesByType map[string][]types.Resource) { + zone, subZone := splitLocalityKey(key) + filtered := resourcesForLocality(zone, subZone, resourcesByType) + snap, err := cache.NewSnapshot(version, filtered) + if err != nil { + klog.Errorf("fail to create endpoint snapshot for locality %q: %s", key, err) + return + } + if err := c.inner.SetSnapshot(ctx, key, snap); err != nil { + klog.Errorf("fail to set endpoint snapshot for locality %q: %s", key, err) + } +} + +// ensureSnapshotForNode makes sure a snapshot exists for the node's locality +// before the inner cache's CreateWatch runs. +func (c *endpointsCache) ensureSnapshotForNode(ctx context.Context, node *corev3.Node) { + key := localityNodeHash{}.ID(node) + if _, err := c.inner.GetSnapshot(key); err == nil { + return + } + c.mu.RLock() + version := c.version + resourcesByType := c.resourcesByType + c.mu.RUnlock() + if resourcesByType == nil { + return + } + c.setSnapshotForKey(ctx, key, version, resourcesByType) +} + +// resourcesForLocality returns a copy of resourcesByType where each +// ClusterLoadAssignment's LocalityLbEndpoints are re-prioritized relative +// to the client at (clientZone, clientSubZone). Non-endpoint types pass +// through untouched. Priorities are compacted so they always start at 0 +// and are contiguous (gRPC's priority LB requires that). +func resourcesForLocality(clientZone, clientSubZone string, resourcesByType map[string][]types.Resource) map[string][]types.Resource { + eps, ok := resourcesByType[resource.EndpointType] + if !ok { + return resourcesByType + } + out := make(map[string][]types.Resource, len(resourcesByType)) + for k, v := range resourcesByType { + if k != resource.EndpointType { + out[k] = v + } + } + newEps := make([]types.Resource, 0, len(eps)) + for _, r := range eps { + cla, ok := r.(*endpointv3.ClusterLoadAssignment) + if !ok { + newEps = append(newEps, r) + continue + } + newEps = append(newEps, assignmentWithPriorities(cla, clientZone, clientSubZone)) + } + out[resource.EndpointType] = newEps + return out +} + +// assignmentWithPriorities clones the input and assigns a priority to each +// LocalityLbEndpoints group based on how closely it matches the client's +// locality: +// +// score 2: group.zone == client.zone && group.sub_zone == client.sub_zone +// score 1: group.zone == client.zone +// score 0: no match (or client has no locality info) +// +// Higher score -> lower (more preferred) priority. Empty buckets are +// skipped so priorities stay contiguous. +func assignmentWithPriorities(in *endpointv3.ClusterLoadAssignment, clientZone, clientSubZone string) *endpointv3.ClusterLoadAssignment { + out := proto.Clone(in).(*endpointv3.ClusterLoadAssignment) + if len(out.Endpoints) == 0 { + return out + } + // score -> groups at that score + bySort := map[int][]*endpointv3.LocalityLbEndpoints{} + scores := make([]int, 0, 3) + for _, g := range out.Endpoints { + s := matchScore(g.GetLocality(), clientZone, clientSubZone) + if _, seen := bySort[s]; !seen { + scores = append(scores, s) + } + bySort[s] = append(bySort[s], g) + } + // Sort the scores in reverse order. + slices.Sort(scores) + slices.Reverse(scores) + + // xDS priorities, on the other hand, start from 0 and go up. + // Clients will first try to fill backends with priority 0 before + // spilling to backends with a higher priority value. + priority := uint32(0) + regrouped := make([]*endpointv3.LocalityLbEndpoints, 0, len(out.Endpoints)) + for _, s := range scores { + for _, g := range bySort[s] { + g.Priority = priority + regrouped = append(regrouped, g) + } + priority++ + } + out.Endpoints = regrouped + return out +} + +func matchScore(loc *corev3.Locality, clientZone, clientSubZone string) int { + if loc == nil || clientZone == "" { + return 0 + } + if loc.GetZone() == "" || loc.GetZone() != clientZone { + return 0 + } + if loc.GetSubZone() != "" && clientSubZone != "" && loc.GetSubZone() == clientSubZone { + return 2 + } + return 1 +} + +// CreateWatch implements cache.ConfigWatcher. +func (c *endpointsCache) CreateWatch(request *cache.Request, sub cache.Subscription, value chan cache.Response) (func(), error) { + c.ensureSnapshotForNode(context.Background(), request.GetNode()) + return c.inner.CreateWatch(request, sub, value) +} + +// CreateDeltaWatch implements cache.ConfigWatcher. +func (c *endpointsCache) CreateDeltaWatch(request *cache.DeltaRequest, sub cache.Subscription, value chan cache.DeltaResponse) (func(), error) { + c.ensureSnapshotForNode(context.Background(), request.GetNode()) + return c.inner.CreateDeltaWatch(request, sub, value) +} + +// Fetch implements cache.ConfigFetcher. +func (c *endpointsCache) Fetch(ctx context.Context, request *cache.Request) (cache.Response, error) { + c.ensureSnapshotForNode(ctx, request.GetNode()) + return c.inner.Fetch(ctx, request) +} diff --git a/snapshot/endpointscache_test.go b/snapshot/endpointscache_test.go new file mode 100644 index 0000000..9891c13 --- /dev/null +++ b/snapshot/endpointscache_test.go @@ -0,0 +1,150 @@ +package snapshot + +import ( + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "github.com/envoyproxy/go-control-plane/pkg/cache/types" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" +) + +func TestClaWithPrioritiesZone(t *testing.T) { + cla := &endpointv3.ClusterLoadAssignment{ + ClusterName: "svc", + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "zone-a"}}, + {Locality: &corev3.Locality{Zone: "zone-b"}}, + {Locality: &corev3.Locality{Zone: "zone-c"}}, + }, + } + + out := assignmentWithPriorities(cla, "zone-b", "") + + if got := len(out.Endpoints); got != 3 { + t.Fatalf("expected 3 groups, got %d", got) + } + // zone-b should be priority 0; everyone else priority 1, and must be + // contiguous. + seenPri := map[uint32]string{} + for _, g := range out.Endpoints { + seenPri[g.Priority] = g.GetLocality().GetZone() + } + if seenPri[0] != "zone-b" { + t.Errorf("priority 0 = %q, want zone-b", seenPri[0]) + } + foundP1 := 0 + for _, g := range out.Endpoints { + if g.Priority == 1 { + foundP1++ + } + } + if foundP1 != 2 { + t.Errorf("priority 1 count = %d, want 2", foundP1) + } + // Must be contiguous. No priority 2. + for _, g := range out.Endpoints { + if g.Priority > 1 { + t.Errorf("unexpected priority %d (priorities must be contiguous)", g.Priority) + } + } +} + +func TestClaWithPrioritiesSubZone(t *testing.T) { + cla := &endpointv3.ClusterLoadAssignment{ + ClusterName: "svc", + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "z1", SubZone: "r1"}}, // exact + {Locality: &corev3.Locality{Zone: "z1", SubZone: "r2"}}, // same zone + {Locality: &corev3.Locality{Zone: "z2", SubZone: "r1"}}, // diff zone + }, + } + + out := assignmentWithPriorities(cla, "z1", "r1") + + pri := map[string]uint32{} + for _, g := range out.Endpoints { + loc := g.GetLocality() + pri[loc.Zone+"/"+loc.SubZone] = g.Priority + } + if pri["z1/r1"] != 0 { + t.Errorf("exact match priority = %d, want 0", pri["z1/r1"]) + } + if pri["z1/r2"] != 1 { + t.Errorf("same-zone priority = %d, want 1", pri["z1/r2"]) + } + if pri["z2/r1"] != 2 { + t.Errorf("diff-zone priority = %d, want 2", pri["z2/r1"]) + } +} + +func TestClaWithPrioritiesCompaction(t *testing.T) { + // Only same-zone and diff-zone exist (no exact sub_zone match). The + // output must start at priority 0 and stay contiguous. + cla := &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "z1", SubZone: "r2"}}, + {Locality: &corev3.Locality{Zone: "z2", SubZone: "r1"}}, + }, + } + + out := assignmentWithPriorities(cla, "z1", "r1") + + maxPri := uint32(0) + sawZero := false + for _, g := range out.Endpoints { + if g.Priority == 0 { + sawZero = true + } + if g.Priority > maxPri { + maxPri = g.Priority + } + } + if !sawZero { + t.Error("missing priority 0") + } + if maxPri != 1 { + t.Errorf("max priority = %d, want 1 (contiguous)", maxPri) + } +} + +func TestClaWithPrioritiesNoClientLocality(t *testing.T) { + // Client has no locality info. All groups get the same priority (0) + // so no preference is expressed. + cla := &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + {Locality: &corev3.Locality{Zone: "z1"}}, + {Locality: &corev3.Locality{Zone: "z2"}}, + }, + } + + out := assignmentWithPriorities(cla, "", "") + + for _, g := range out.Endpoints { + if g.Priority != 0 { + t.Errorf("got priority %d with empty client locality; all should be 0", g.Priority) + } + } +} + +func TestResourcesForLocalityPassesThroughNonEndpoints(t *testing.T) { + in := map[string][]types.Resource{ + resource.ClusterType: {&endpointv3.ClusterLoadAssignment{ClusterName: "x"}}, // wrong type on purpose + } + out := resourcesForLocality("z", "", in) + if _, ok := out[resource.EndpointType]; ok { + t.Error("unexpected endpoint resources added") + } + if len(out[resource.ClusterType]) != 1 { + t.Error("non-endpoint resources should pass through unchanged") + } +} + +func TestLocalityNodeHashIncludesSubZone(t *testing.T) { + h := localityNodeHash{} + a := h.ID(&corev3.Node{Locality: &corev3.Locality{Zone: "z1", SubZone: "r1"}}) + b := h.ID(&corev3.Node{Locality: &corev3.Locality{Zone: "z1", SubZone: "r2"}}) + if a == b { + t.Errorf("expected distinct hashes for different sub_zone; both are %q", a) + } +} diff --git a/snapshot/locality.go b/snapshot/locality.go new file mode 100644 index 0000000..8f85299 --- /dev/null +++ b/snapshot/locality.go @@ -0,0 +1,46 @@ +package snapshot + +import ( + corev1 "k8s.io/api/core/v1" +) + +const ( + // AnnotationLocalityPreference on a Service selects how endpoints are + // split into localities. Values: + // "zone" - split by topology.kubernetes.io/zone + // "sub_zone" - split by zone + the sub-zone label + // "" - no split; a single empty-locality group (default) + AnnotationLocalityPreference = "xds.lmwn.com/locality-preference" + + // LabelZone is the well-known Kubernetes node zone label. + LabelZone = "topology.kubernetes.io/zone" + + // DefaultSubZoneLabel is the default node label we read as sub-zone when + // no override is provided. There is no Kubernetes standard for sub-zone, + // so we default to the common "rack" convention that mirrors the + // topology.kubernetes.io/zone shape. + DefaultSubZoneLabel = "topology.kubernetes.io/rack" +) + +// LocalityMode describes the locality split for a single service. +type LocalityMode int + +const ( + LocalityNone LocalityMode = iota + LocalityZone + LocalitySubZone +) + +func localityModeFromService(svc *corev1.Service) LocalityMode { + if svc == nil { + return LocalityNone + } + switch svc.GetAnnotations()[AnnotationLocalityPreference] { + case "zone": + return LocalityZone + case "sub_zone": + return LocalitySubZone + default: + return LocalityNone + } +} diff --git a/snapshot/nodes.go b/snapshot/nodes.go new file mode 100644 index 0000000..54e6f03 --- /dev/null +++ b/snapshot/nodes.go @@ -0,0 +1,112 @@ +package snapshot + +import ( + "context" + "sync" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/klog/v2" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8scache "k8s.io/client-go/tools/cache" +) + +type nodeLocality struct { + zone string + subZone string +} + +// nodeLocalityStore holds a snapshot of node to {zone, sub-zone} mapping. +// The endpoints builder consults this store when assigning locality to each +// LbEndpoint. +// Version is bumped on every observed change so callers can invalidate +// per-endpoint caches. +type nodeLocalityStore struct { + subZoneLabel string + + mu sync.RWMutex + nodes map[string]nodeLocality + version uint64 +} + +func newNodeLocalityStore(subZoneLabel string) *nodeLocalityStore { + if subZoneLabel == "" { + subZoneLabel = DefaultSubZoneLabel + } + return &nodeLocalityStore{ + subZoneLabel: subZoneLabel, + nodes: map[string]nodeLocality{}, + } +} + +func (s *nodeLocalityStore) get(name string) nodeLocality { + s.mu.RLock() + defer s.mu.RUnlock() + return s.nodes[name] +} + +func (s *nodeLocalityStore) getVersion() uint64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.version +} + +// apply replaces the store with the nodes. Returns true if anything changed. +func (s *nodeLocalityStore) apply(nodes []*corev1.Node) bool { + next := make(map[string]nodeLocality, len(nodes)) + for _, n := range nodes { + next[n.Name] = nodeLocality{ + zone: n.Labels[LabelZone], + subZone: n.Labels[s.subZoneLabel], + } + } + s.mu.Lock() + defer s.mu.Unlock() + if nodeMapsEqual(s.nodes, next) { + return false + } + s.nodes = next + s.version++ + return true +} + +func nodeMapsEqual(a, b map[string]nodeLocality) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } + } + return true +} + +func (s *Snapshotter) startNodes(ctx context.Context) error { + store := k8scache.NewUndeltaStore(func(v []interface{}) { + nodes := make([]*corev1.Node, 0, len(v)) + for _, obj := range v { + if n, ok := obj.(*corev1.Node); ok { + nodes = append(nodes, n) + } + } + if s.nodeLocality.apply(nodes) { + klog.V(2).Infof("node locality updated to version %d (%d nodes)", s.nodeLocality.getVersion(), len(nodes)) + s.triggerEndpointsEmit() + } + }, k8scache.DeletionHandlingMetaNamespaceKeyFunc) + + reflector := k8scache.NewReflector(&k8scache.ListWatch{ + ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) { + return s.client.CoreV1().Nodes().List(ctx, options) + }, + WatchFuncWithContext: func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) { + return s.client.CoreV1().Nodes().Watch(ctx, options) + }, + }, &corev1.Node{}, store, s.ResyncPeriod) + + reflector.Run(ctx.Done()) + return nil +} diff --git a/snapshot/services.go b/snapshot/services.go index 24b1a84..4561e05 100644 --- a/snapshot/services.go +++ b/snapshot/services.go @@ -35,7 +35,13 @@ func (s *Snapshotter) startServices(ctx context.Context) error { store := k8scache.NewUndeltaStore(func(v []interface{}) { emit() + // A service's locality-preference annotation affects how its + // endpoints are split into LocalityLbEndpoints, so any service + // change requires re-emitting endpoints. The endpoint emit path + // dedups via its own hash. + s.triggerEndpointsEmit() }, k8scache.DeletionHandlingMetaNamespaceKeyFunc) + s.serviceStore = store reflector := k8scache.NewReflector(&k8scache.ListWatch{ ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) { diff --git a/snapshot/snapshotter.go b/snapshot/snapshotter.go index 9332c51..6aa5ffb 100644 --- a/snapshot/snapshotter.go +++ b/snapshot/snapshotter.go @@ -12,7 +12,9 @@ import ( "github.com/wongnai/xds/meter" "go.opentelemetry.io/otel/metric" "golang.org/x/sync/errgroup" + corev1 "k8s.io/api/core/v1" "k8s.io/client-go/kubernetes" + k8scache "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" ) @@ -47,20 +49,34 @@ type Snapshotter struct { client kubernetes.Interface servicesCache cache.SnapshotCache - endpointsCache cache.SnapshotCache + endpointsCache *endpointsCache muxCache cache.MuxCache - endpointResourceCache map[string]endpointCacheItem + nodeLocality *nodeLocalityStore + + endpointResourceCache map[string]endpointCacheItem + resourcesByTypeLock sync.RWMutex serviceResourcesByType map[string][]types.Resource endpointResourcesByType map[string][]types.Resource apiGatewayStats map[string]int kubeEventCounter metric.Int64Counter + + // serviceStore is captured at informer start so other code paths + // (e.g. per-endpoint locality lookups) can consult Service objects. + serviceStore k8scache.Store + + emitMu sync.Mutex + emitEndpointsFn func() } -func New(client kubernetes.Interface) *Snapshotter { +// SubZoneLabel selects the Kubernetes node label used as sub-zone for the +// sub_zone-preference locality mode. Empty = use DefaultSubZoneLabel. +type SubZoneLabel string + +func New(client kubernetes.Interface, subZoneLabel SubZoneLabel) *Snapshotter { servicesCache := cache.NewSnapshotCache(false, EmptyNodeID{}, Logger) - endpointsCache := cache.NewSnapshotCache(false, EmptyNodeID{}, Logger) + endpointsCache := newEndpointsCache() muxCache := cache.MuxCache{ Classify: func(r *cache.Request) string { return mapTypeURL(r.TypeUrl) @@ -82,6 +98,8 @@ func New(client kubernetes.Interface) *Snapshotter { endpointsCache: endpointsCache, muxCache: muxCache, + nodeLocality: newNodeLocalityStore(string(subZoneLabel)), + endpointResourceCache: map[string]endpointCacheItem{}, } @@ -105,9 +123,41 @@ func (s *Snapshotter) Start(stopCtx context.Context) error { group.Go(func() error { return s.startEndpoints(groupCtx) }) + group.Go(func() error { + return s.startNodes(groupCtx) + }) return group.Wait() } +// triggerEndpointsEmit runs the endpoints emit pipeline under a mutex so +// concurrent triggers from multiple informers don't race on the shared +// endpointResourceCache and the emit closure's snapshot-hash state. +func (s *Snapshotter) triggerEndpointsEmit() { + s.emitMu.Lock() + defer s.emitMu.Unlock() + if s.emitEndpointsFn != nil { + s.emitEndpointsFn() + } +} + +// localityModeFor reads the Service object behind a namespace/name and +// returns its configured locality mode. Returns LocalityNone if the +// service is not yet known. +func (s *Snapshotter) localityModeFor(namespace, name string) LocalityMode { + if s.serviceStore == nil { + return LocalityNone + } + obj, exists, err := s.serviceStore.GetByKey(namespace + "/" + name) + if err != nil || !exists { + return LocalityNone + } + svc, ok := obj.(*corev1.Service) + if !ok { + return LocalityNone + } + return localityModeFromService(svc) +} + func (s *Snapshotter) snapshotResourceGaugeCallback(_ context.Context, result metric.Int64Observer) error { for k, r := range s.getServiceResourcesByType() { result.Observe(int64(len(r)), metric.WithAttributes(meter.TypeURLAttrKey.String(k))) diff --git a/test/integration_test.go b/test/integration_test.go index 33cee65..f6597ec 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -1,15 +1,22 @@ package test_test import ( + "context" "fmt" "net" "testing" + "time" + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + discoveryv3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/wongnai/xds/internal/di" + "github.com/wongnai/xds/snapshot" "github.com/wongnai/xds/snapshot/apigateway" "github.com/wongnai/xds/test" "google.golang.org/grpc" @@ -247,6 +254,199 @@ func (s *XdsIntegrationTestSuite) TestApiGateway() { require.NoError(s.T(), err) } +// localityEndpoint describes one backend to register behind a Service, along +// with which Node it runs on. +type localityEndpoint struct { + IP string + NodeName string +} + +func (s *XdsIntegrationTestSuite) createKubeNode(name, zone, subZone string) { + node := &corev1.Node{ + TypeMeta: metav1.TypeMeta{APIVersion: "v1", Kind: "Node"}, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: map[string]string{ + snapshot.LabelZone: zone, + snapshot.DefaultSubZoneLabel: subZone, + }, + }, + } + err := s.kube.Tracker().Add(node) + s.Require().NoError(err) +} + +func (s *XdsIntegrationTestSuite) createLocalityService(name, namespace string, port int32, mode string) { + svc := &test.K8SService{ + Name: name, + Namespace: namespace, + Annotations: map[string]string{ + snapshot.AnnotationLocalityPreference: mode, + }, + Ports: []corev1.ServicePort{{ + Name: "grpc", + Port: port, + Protocol: corev1.ProtocolTCP, + }}, + } + err := s.kube.Tracker().Add(svc.AsK8S()) + s.Require().NoError(err) +} + +func (s *XdsIntegrationTestSuite) createKubeEndpointWithNodes(name, namespace string, addrs []localityEndpoint, port int32) { + addresses := make([]corev1.EndpointAddress, 0, len(addrs)) //nolint:staticcheck // legacy Endpoints API + for _, a := range addrs { + nodeName := a.NodeName + addresses = append(addresses, corev1.EndpointAddress{ //nolint:staticcheck // See above + IP: a.IP, + NodeName: &nodeName, + }) + } + endpoint := &corev1.Endpoints{ //nolint:staticcheck // See above + TypeMeta: metav1.TypeMeta{APIVersion: "v1", Kind: "Endpoints"}, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Subsets: []corev1.EndpointSubset{{ //nolint:staticcheck // See above + Addresses: addresses, + Ports: []corev1.EndpointPort{{Name: "grpc", Port: port}}, //nolint:staticcheck // See above + }}, + } + err := s.kube.Tracker().Add(endpoint) + s.Require().NoError(err) +} + +// fetchEDS opens a one-shot ADS stream with the given client locality and +// returns the first matching ClusterLoadAssignment. The stream is +// re-established each call so snapshot-version state is fresh. Uses Eventually +// because the K8s reflector is asynchronous. +func (s *XdsIntegrationTestSuite) fetchEDS(clientZone, clientSubZone, resourceName string, expectedGroups int) *endpointv3.ClusterLoadAssignment { + s.T().Helper() + var cla *endpointv3.ClusterLoadAssignment + s.Require().Eventually(func() bool { + conn, err := grpc.NewClient(s.listener.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return false + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(s.T().Context(), 2*time.Second) + defer cancel() + + stream, err := discoveryv3.NewAggregatedDiscoveryServiceClient(conn).StreamAggregatedResources(ctx) + if err != nil { + return false + } + err = stream.Send(&discoveryv3.DiscoveryRequest{ + Node: &corev3.Node{ + Id: "test-" + clientZone + "-" + clientSubZone, + Locality: &corev3.Locality{Zone: clientZone, SubZone: clientSubZone}, + }, + ResourceNames: []string{resourceName}, + TypeUrl: resource.EndpointType, + }) + if err != nil { + return false + } + resp, err := stream.Recv() + if err != nil || len(resp.Resources) == 0 { + return false + } + candidate := &endpointv3.ClusterLoadAssignment{} + if err := resp.Resources[0].UnmarshalTo(candidate); err != nil { + return false + } + if len(candidate.Endpoints) != expectedGroups { + return false + } + cla = candidate + return true + }, 5*time.Second, 100*time.Millisecond, "timed out waiting for CLA with %d locality groups", expectedGroups) + return cla +} + +func (s *XdsIntegrationTestSuite) TestLocalityZonePriorities() { + s.createKubeNode("node-a", "zone-a", "") + s.createKubeNode("node-b", "zone-b", "") + s.createKubeNode("node-c", "zone-c", "") + + s.createLocalityService("zsvc", "default", 50100, "zone") + s.createKubeEndpointWithNodes("zsvc", "default", []localityEndpoint{ + {IP: "10.200.0.1", NodeName: "node-a"}, + {IP: "10.200.0.2", NodeName: "node-b"}, + {IP: "10.200.0.3", NodeName: "node-c"}, + }, 50100) + + // Client in zone-b: zone-b should be priority 0, others contiguous at 1. + cla := s.fetchEDS("zone-b", "", "zsvc.default:grpc", 3) + priByZone := map[string]uint32{} + for _, g := range cla.Endpoints { + priByZone[g.GetLocality().GetZone()] = g.GetPriority() + } + s.Equal(uint32(0), priByZone["zone-b"], "zone-b should be preferred for zone-b client") + s.Equal(uint32(1), priByZone["zone-a"]) + s.Equal(uint32(1), priByZone["zone-c"]) + + // Different client locality flips the priorities. + cla = s.fetchEDS("zone-a", "", "zsvc.default:grpc", 3) + priByZone = map[string]uint32{} + for _, g := range cla.Endpoints { + priByZone[g.GetLocality().GetZone()] = g.GetPriority() + } + s.Equal(uint32(0), priByZone["zone-a"], "zone-a should be preferred for zone-a client") + s.Equal(uint32(1), priByZone["zone-b"]) + s.Equal(uint32(1), priByZone["zone-c"]) +} + +func (s *XdsIntegrationTestSuite) TestLocalitySubZonePriorities() { + s.createKubeNode("n1", "zone-a", "rack-1") + s.createKubeNode("n2", "zone-a", "rack-2") + s.createKubeNode("n3", "zone-b", "rack-1") + + s.createLocalityService("ssvc", "default", 50101, "sub_zone") + s.createKubeEndpointWithNodes("ssvc", "default", []localityEndpoint{ + {IP: "10.201.0.1", NodeName: "n1"}, + {IP: "10.201.0.2", NodeName: "n2"}, + {IP: "10.201.0.3", NodeName: "n3"}, + }, 50101) + + // Client at (zone-a, rack-1): + // exact match → priority 0 + // same-zone only → priority 1 + // different zone → priority 2 + cla := s.fetchEDS("zone-a", "rack-1", "ssvc.default:grpc", 3) + priByLoc := map[string]uint32{} + for _, g := range cla.Endpoints { + loc := g.GetLocality() + priByLoc[loc.GetZone()+"/"+loc.GetSubZone()] = g.GetPriority() + } + s.Equal(uint32(0), priByLoc["zone-a/rack-1"]) + s.Equal(uint32(1), priByLoc["zone-a/rack-2"]) + s.Equal(uint32(2), priByLoc["zone-b/rack-1"]) +} + +func (s *XdsIntegrationTestSuite) TestLocalityNoAnnotationNoSplit() { + // Without the annotation, all endpoints should land in a single + // empty-locality group regardless of node labels — behavior matches the + // pre-locality default. + s.createKubeNode("plain-a", "zone-a", "") + s.createKubeNode("plain-b", "zone-b", "") + + s.createKubeService("plainsvc", "default", 50102) + s.createKubeEndpointWithNodes("plainsvc", "default", []localityEndpoint{ + {IP: "10.202.0.1", NodeName: "plain-a"}, + {IP: "10.202.0.2", NodeName: "plain-b"}, + }, 50102) + + cla := s.fetchEDS("zone-a", "", "plainsvc.default:grpc", 1) + s.Require().Len(cla.Endpoints, 1) + s.Equal(uint32(0), cla.Endpoints[0].GetPriority()) + s.Empty(cla.Endpoints[0].GetLocality().GetZone()) + s.Len(cla.Endpoints[0].GetLbEndpoints(), 2) +} + func TestXdsIntegration(t *testing.T) { // The fake clientset doesn't implement the WatchList bookmark protocol, // so disable WatchListClient to use the traditional List+Watch flow. @@ -254,7 +454,7 @@ func TestXdsIntegration(t *testing.T) { kube := fake.NewClientset() - testServer, stop, err := di.InitializeTestServer(t.Context(), kube, 1) + testServer, stop, err := di.InitializeTestServer(t.Context(), kube, 1, "") require.NoError(t, err) defer stop() diff --git a/test/xds_test.go b/test/xds_test.go index 2cf5467..337eeca 100644 --- a/test/xds_test.go +++ b/test/xds_test.go @@ -30,7 +30,7 @@ func (s *XdsSuite) SetupTest() { var err error s.kube = fake.NewClientset() s.conn = bufconn.Listen(1) - s.TestServer, s.stop, err = di.InitializeTestServer(s.T().Context(), s.kube, 1) + s.TestServer, s.stop, err = di.InitializeTestServer(s.T().Context(), s.kube, 1, "") s.Require().NoError(err) go s.TestServer.GrpcServer.Serve(s.conn)