diff --git a/pkg/kgateway/extensions2/plugins/waypoint/plugin.go b/pkg/kgateway/extensions2/plugins/waypoint/plugin.go index c04738f9661..c6f93125b84 100644 --- a/pkg/kgateway/extensions2/plugins/waypoint/plugin.go +++ b/pkg/kgateway/extensions2/plugins/waypoint/plugin.go @@ -12,8 +12,10 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" gwv1 "sigs.k8s.io/gateway-api/apis/v1" + apisettings "github.com/kgateway-dev/kgateway/v2/api/settings" "github.com/kgateway-dev/kgateway/v2/pkg/kgateway/extensions2/plugins/waypoint/waypointquery" "github.com/kgateway-dev/kgateway/v2/pkg/kgateway/query" + "github.com/kgateway-dev/kgateway/v2/pkg/kgateway/utils" "github.com/kgateway-dev/kgateway/v2/pkg/kgateway/wellknown" sdk "github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk" "github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/collections" @@ -116,15 +118,23 @@ func (t *PerClientProcessor) processBackend(kctx krt.HandlerContext, ctx context } // All preliminary checks passed, process the ingress use waypoint - processIngressUseWaypoint(in, out) + processIngressUseWaypoint(in, out, &t.commonCols.Settings) } // processIngressUseWaypoint configures the cluster of the connected gateway to have a static // inlined addresses of the destination service. This will cause the traffic from the kgateway // to be redirected to the waypoint by the ztunnel. -func processIngressUseWaypoint(in ir.BackendObjectIR, out *envoyclusterv3.Cluster) { +// Addresses are sorted based on DNS lookup family setting, with the primary address in Address +// and additional addresses in AdditionalAddresses. +func processIngressUseWaypoint(in ir.BackendObjectIR, out *envoyclusterv3.Cluster, settings *apisettings.Settings) { addresses := waypointquery.BackendAddresses(in) + // Sort addresses based on DNS lookup family setting. Since this is a static cluster + // with inlined addresses, we can't use DnsLookupFamily (which only applies to DNS-based + // discovery). Instead, we sort the addresses based on the setting and use the primary + // address in Address and additional addresses in AdditionalAddresses. + sortedAddresses := sortAddressesByDnsLookupFamily(addresses, settings) + // Set the output cluster to be of type STATIC and instead of the default EDS and add // the addresses of the backend embedded into the CLA of this cluster config. out.ClusterDiscoveryType = &envoyclusterv3.Cluster_Type{ @@ -133,37 +143,123 @@ func processIngressUseWaypoint(in ir.BackendObjectIR, out *envoyclusterv3.Cluste out.EdsClusterConfig = nil out.LoadAssignment = &envoyendpointv3.ClusterLoadAssignment{ ClusterName: out.GetName(), - Endpoints: make([]*envoyendpointv3.LocalityLbEndpoints, 0, len(addresses)), + Endpoints: make([]*envoyendpointv3.LocalityLbEndpoints, 0, 1), } - for _, addr := range addresses { - out.GetLoadAssignment().Endpoints = append(out.GetLoadAssignment().GetEndpoints(), claEndpoint(addr, uint32(in.Port))) //nolint:gosec // G115: BackendObjectIR.Port is int32 representing a port number, always in valid range + if endpoint := claEndpoint(sortedAddresses, uint32(in.Port)); endpoint != nil { //nolint:gosec // G115: BackendObjectIR.Port is int32 representing a port number, always in valid range + out.GetLoadAssignment().Endpoints = append(out.GetLoadAssignment().GetEndpoints(), endpoint) } } -func claEndpoint(address string, port uint32) *envoyendpointv3.LocalityLbEndpoints { +// claEndpoint creates a LocalityLbEndpoints with the primary address in Address +// and additional addresses in AdditionalAddresses. +func claEndpoint(addresses []string, port uint32) *envoyendpointv3.LocalityLbEndpoints { + if len(addresses) == 0 { + return nil + } + + // Primary address goes in Address + primaryAddr := addresses[0] + endpoint := &envoyendpointv3.Endpoint{ + Address: &envoycorev3.Address{ + Address: &envoycorev3.Address_SocketAddress{ + SocketAddress: &envoycorev3.SocketAddress{ + Address: primaryAddr, + PortSpecifier: &envoycorev3.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + } + + // Additional addresses go in AdditionalAddresses + if len(addresses) > 1 { + additionalAddresses := make([]*envoyendpointv3.Endpoint_AdditionalAddress, 0, len(addresses)-1) + for _, addr := range addresses[1:] { + additionalAddresses = append(additionalAddresses, &envoyendpointv3.Endpoint_AdditionalAddress{ + Address: &envoycorev3.Address{ + Address: &envoycorev3.Address_SocketAddress{ + SocketAddress: &envoycorev3.SocketAddress{ + Address: addr, + PortSpecifier: &envoycorev3.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + }) + } + endpoint.AdditionalAddresses = additionalAddresses + } + return &envoyendpointv3.LocalityLbEndpoints{ LbEndpoints: []*envoyendpointv3.LbEndpoint{ { HostIdentifier: &envoyendpointv3.LbEndpoint_Endpoint{ - Endpoint: &envoyendpointv3.Endpoint{ - Address: &envoycorev3.Address{ - Address: &envoycorev3.Address_SocketAddress{ - SocketAddress: &envoycorev3.SocketAddress{ - Address: address, - PortSpecifier: &envoycorev3.SocketAddress_PortValue{ - PortValue: port, - }, - }, - }, - }, - }, + Endpoint: endpoint, }, }, }, } } +// sortAddressesByDnsLookupFamily sorts addresses based on the DNS lookup family setting. +// Returns a sorted list of addresses where the first address will be used as primary +// (in Address) and the rest as additional (in AdditionalAddresses). +// Since static clusters can't use DnsLookupFamily (it only applies to DNS-based discovery), +// we sort the addresses based on the setting. +func sortAddressesByDnsLookupFamily(addresses []string, settings *apisettings.Settings) []string { + if settings == nil { + // Default to V4_PREFERRED if settings are not available + return sortAddressesByDnsLookupFamily(addresses, &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Preferred, + }) + } + + // For ALL mode, we don't need to separate by family - just return all addresses + if settings.DnsLookupFamily == apisettings.DnsLookupFamilyAll { + return addresses + } + + // Separate IPv4 and IPv6 addresses for other modes + var ipv4Addrs, ipv6Addrs []string + for _, addr := range addresses { + validIPv4, _, err := utils.IsIpv4Address(addr) + if err != nil { + // Skip invalid addresses + continue + } + if validIPv4 { + ipv4Addrs = append(ipv4Addrs, addr) + } else { + ipv6Addrs = append(ipv6Addrs, addr) + } + } + + // Sort based on DNS lookup family setting + var sortedAddresses []string + switch settings.DnsLookupFamily { + case apisettings.DnsLookupFamilyV4Only: + // Only IPv4 addresses + sortedAddresses = ipv4Addrs + case apisettings.DnsLookupFamilyV6Only: + // Only IPv6 addresses + sortedAddresses = ipv6Addrs + case apisettings.DnsLookupFamilyV4Preferred: + // IPv4 first, then IPv6 as additional addresses + sortedAddresses = append(ipv4Addrs, ipv6Addrs...) + case apisettings.DnsLookupFamilyAuto: + // IPv6 first, then IPv4 as additional addresses + sortedAddresses = append(ipv6Addrs, ipv4Addrs...) + default: + // Default to V4_PREFERRED for unknown values + sortedAddresses = append(ipv4Addrs, ipv6Addrs...) + } + + return sortedAddresses +} + // hasIngressUseWaypointLabel checks if the backend or any relevant namespace/alias has the ingress-use-waypoint label. func hasIngressUseWaypointLabel(kctx krt.HandlerContext, commonCols *collections.CommonCollections, in ir.BackendObjectIR) bool { // Check the backend's own label first diff --git a/pkg/kgateway/extensions2/plugins/waypoint/plugin_test.go b/pkg/kgateway/extensions2/plugins/waypoint/plugin_test.go new file mode 100644 index 00000000000..73fb8e2e5e9 --- /dev/null +++ b/pkg/kgateway/extensions2/plugins/waypoint/plugin_test.go @@ -0,0 +1,275 @@ +package waypoint + +import ( + "testing" + + envoyendpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "github.com/stretchr/testify/assert" + + apisettings "github.com/kgateway-dev/kgateway/v2/api/settings" +) + +func TestSortAddressesByDnsLookupFamily(t *testing.T) { + tests := []struct { + name string + addresses []string + settings *apisettings.Settings + want []string + }{ + { + name: "nil settings defaults to V4_PREFERRED", + addresses: []string{"10.0.0.1", "2001:db8::1"}, + settings: nil, + want: []string{"10.0.0.1", "2001:db8::1"}, + }, + { + name: "ALL mode returns all addresses unchanged", + addresses: []string{"2001:db8::1", "10.0.0.1", "10.0.0.2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyAll, + }, + want: []string{"2001:db8::1", "10.0.0.1", "10.0.0.2"}, + }, + { + name: "V4_ONLY returns only IPv4 addresses", + addresses: []string{"10.0.0.1", "2001:db8::1", "10.0.0.2", "2001:db8::2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Only, + }, + want: []string{"10.0.0.1", "10.0.0.2"}, + }, + { + name: "V4_ONLY with no IPv4 returns empty", + addresses: []string{"2001:db8::1", "2001:db8::2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Only, + }, + want: []string{}, + }, + { + name: "V6_ONLY returns only IPv6 addresses", + addresses: []string{"10.0.0.1", "2001:db8::1", "10.0.0.2", "2001:db8::2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV6Only, + }, + want: []string{"2001:db8::1", "2001:db8::2"}, + }, + { + name: "V6_ONLY with no IPv6 returns empty", + addresses: []string{"10.0.0.1", "10.0.0.2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV6Only, + }, + want: []string{}, + }, + { + name: "V4_PREFERRED returns IPv4 first, then IPv6", + addresses: []string{"2001:db8::1", "10.0.0.1", "2001:db8::2", "10.0.0.2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Preferred, + }, + want: []string{"10.0.0.1", "10.0.0.2", "2001:db8::1", "2001:db8::2"}, + }, + { + name: "V4_PREFERRED with no IPv4 returns IPv6 only", + addresses: []string{"2001:db8::1", "2001:db8::2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Preferred, + }, + want: []string{"2001:db8::1", "2001:db8::2"}, + }, + { + name: "AUTO returns IPv6 first, then IPv4", + addresses: []string{"10.0.0.1", "2001:db8::1", "10.0.0.2", "2001:db8::2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyAuto, + }, + want: []string{"2001:db8::1", "2001:db8::2", "10.0.0.1", "10.0.0.2"}, + }, + { + name: "AUTO with no IPv6 returns IPv4 only", + addresses: []string{"10.0.0.1", "10.0.0.2"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyAuto, + }, + want: []string{"10.0.0.1", "10.0.0.2"}, + }, + { + name: "invalid addresses are skipped", + addresses: []string{"10.0.0.1", "invalid-address", "2001:db8::1", "not-an-ip"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Preferred, + }, + want: []string{"10.0.0.1", "2001:db8::1"}, + }, + { + name: "empty addresses returns empty", + addresses: []string{}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamilyV4Preferred, + }, + want: []string{}, + }, + { + name: "unknown DNS lookup family defaults to V4_PREFERRED", + addresses: []string{"2001:db8::1", "10.0.0.1"}, + settings: &apisettings.Settings{ + DnsLookupFamily: apisettings.DnsLookupFamily("UNKNOWN"), + }, + want: []string{"10.0.0.1", "2001:db8::1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sortAddressesByDnsLookupFamily(tt.addresses, tt.settings) + // Normalize nil and empty slices for comparison + if got == nil { + got = []string{} + } + if tt.want == nil { + tt.want = []string{} + } + assert.Equal(t, tt.want, got, "sortAddressesByDnsLookupFamily() = %v, want %v", got, tt.want) + }) + } +} + +func TestClaEndpoint(t *testing.T) { + tests := []struct { + name string + addresses []string + port uint32 + wantNil bool + validate func(t *testing.T, result *envoyendpointv3.LocalityLbEndpoints) + }{ + { + name: "empty addresses returns nil", + addresses: []string{}, + port: 8080, + wantNil: true, + }, + { + name: "nil addresses returns nil", + addresses: nil, + port: 8080, + wantNil: true, + }, + { + name: "single address - primary only, no AdditionalAddresses", + addresses: []string{"10.0.0.1"}, + port: 8080, + wantNil: false, + validate: func(t *testing.T, result *envoyendpointv3.LocalityLbEndpoints) { + assert.NotNil(t, result) + assert.Len(t, result.LbEndpoints, 1) + + endpoint := result.LbEndpoints[0].GetEndpoint() + assert.NotNil(t, endpoint) + + // Check primary address + socketAddr := endpoint.Address.GetSocketAddress() + assert.NotNil(t, socketAddr) + assert.Equal(t, "10.0.0.1", socketAddr.Address) + assert.Equal(t, uint32(8080), socketAddr.GetPortValue()) + + // Should have no additional addresses + assert.Nil(t, endpoint.AdditionalAddresses) + }, + }, + { + name: "multiple addresses - primary + AdditionalAddresses", + addresses: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}, + port: 9090, + wantNil: false, + validate: func(t *testing.T, result *envoyendpointv3.LocalityLbEndpoints) { + assert.NotNil(t, result) + assert.Len(t, result.LbEndpoints, 1) + + endpoint := result.LbEndpoints[0].GetEndpoint() + assert.NotNil(t, endpoint) + + // Check primary address + socketAddr := endpoint.Address.GetSocketAddress() + assert.NotNil(t, socketAddr) + assert.Equal(t, "10.0.0.1", socketAddr.Address) + assert.Equal(t, uint32(9090), socketAddr.GetPortValue()) + + // Check additional addresses + assert.NotNil(t, endpoint.AdditionalAddresses) + assert.Len(t, endpoint.AdditionalAddresses, 2) + + // First additional address + addr1 := endpoint.AdditionalAddresses[0].Address.GetSocketAddress() + assert.NotNil(t, addr1) + assert.Equal(t, "10.0.0.2", addr1.Address) + assert.Equal(t, uint32(9090), addr1.GetPortValue()) + + // Second additional address + addr2 := endpoint.AdditionalAddresses[1].Address.GetSocketAddress() + assert.NotNil(t, addr2) + assert.Equal(t, "10.0.0.3", addr2.Address) + assert.Equal(t, uint32(9090), addr2.GetPortValue()) + }, + }, + { + name: "mixed IPv4 and IPv6 addresses", + addresses: []string{"10.0.0.1", "2001:db8::1", "10.0.0.2"}, + port: 443, + wantNil: false, + validate: func(t *testing.T, result *envoyendpointv3.LocalityLbEndpoints) { + assert.NotNil(t, result) + endpoint := result.LbEndpoints[0].GetEndpoint() + + // Primary should be first address + assert.Equal(t, "10.0.0.1", endpoint.Address.GetSocketAddress().Address) + assert.Equal(t, uint32(443), endpoint.Address.GetSocketAddress().GetPortValue()) + + // Additional addresses should be in order + assert.Len(t, endpoint.AdditionalAddresses, 2) + assert.Equal(t, "2001:db8::1", endpoint.AdditionalAddresses[0].Address.GetSocketAddress().Address) + assert.Equal(t, uint32(443), endpoint.AdditionalAddresses[0].Address.GetSocketAddress().GetPortValue()) + assert.Equal(t, "10.0.0.2", endpoint.AdditionalAddresses[1].Address.GetSocketAddress().Address) + assert.Equal(t, uint32(443), endpoint.AdditionalAddresses[1].Address.GetSocketAddress().GetPortValue()) + }, + }, + { + name: "port zero is valid", + addresses: []string{"10.0.0.1"}, + port: 0, + wantNil: false, + validate: func(t *testing.T, result *envoyendpointv3.LocalityLbEndpoints) { + assert.NotNil(t, result) + endpoint := result.LbEndpoints[0].GetEndpoint() + assert.Equal(t, uint32(0), endpoint.Address.GetSocketAddress().GetPortValue()) + }, + }, + { + name: "large port number", + addresses: []string{"10.0.0.1", "10.0.0.2"}, + port: 65535, + wantNil: false, + validate: func(t *testing.T, result *envoyendpointv3.LocalityLbEndpoints) { + assert.NotNil(t, result) + endpoint := result.LbEndpoints[0].GetEndpoint() + assert.Equal(t, uint32(65535), endpoint.Address.GetSocketAddress().GetPortValue()) + assert.Equal(t, uint32(65535), endpoint.AdditionalAddresses[0].Address.GetSocketAddress().GetPortValue()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := claEndpoint(tt.addresses, tt.port) + + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +}