From 236b285dd72c2e030f784e551d10c067f2096098 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Thu, 19 Mar 2026 13:09:23 +0000 Subject: [PATCH 01/18] Add DiffTracker core types, state management, and sync operations Introduces the difftracker package with core K8s/NRP state tracking, diff computation, and state mutation logic. Includes comprehensive test coverage (25+ test functions). --- pkg/provider/difftracker/config.go | 49 ++ pkg/provider/difftracker/difftracker.go | 63 ++ pkg/provider/difftracker/difftracker_test.go | 632 +++++++++++++++++ pkg/provider/difftracker/k8s_state_updates.go | 252 +++++++ pkg/provider/difftracker/nrp_state_updates.go | 108 +++ pkg/provider/difftracker/sync_operations.go | 191 ++++++ pkg/provider/difftracker/types.go | 136 ++++ pkg/provider/difftracker/util.go | 302 ++++++++ pkg/provider/difftracker/util_test.go | 648 ++++++++++++++++++ pkg/util/sets/string.go | 33 + pkg/util/sets/string_test.go | 78 +++ 11 files changed, 2492 insertions(+) create mode 100644 pkg/provider/difftracker/config.go create mode 100644 pkg/provider/difftracker/difftracker.go create mode 100644 pkg/provider/difftracker/difftracker_test.go create mode 100644 pkg/provider/difftracker/k8s_state_updates.go create mode 100644 pkg/provider/difftracker/nrp_state_updates.go create mode 100644 pkg/provider/difftracker/sync_operations.go create mode 100644 pkg/provider/difftracker/types.go create mode 100644 pkg/provider/difftracker/util.go create mode 100644 pkg/provider/difftracker/util_test.go diff --git a/pkg/provider/difftracker/config.go b/pkg/provider/difftracker/config.go new file mode 100644 index 0000000000..198df8fe98 --- /dev/null +++ b/pkg/provider/difftracker/config.go @@ -0,0 +1,49 @@ +package difftracker + +import "fmt" + +// Config holds the configuration values needed by DiffTracker +// to perform Azure operations without depending on the entire AzureCloud struct +// This allows DiffTracker to be more modular and testable +type Config struct { + // Azure subscription ID + SubscriptionID string + + // Azure resource group name + ResourceGroup string + + // Azure location/region + Location string + + // Service Gateway resource name + ServiceGatewayResourceName string + + // Full Service Gateway resource ID + ServiceGatewayID string + + // Virtual Network name (required for backend pool configuration) + VNetName string +} + +// Validate checks if the configuration has all required fields +func (c *Config) Validate() error { + if c.SubscriptionID == "" { + return fmt.Errorf("config validation failed: SubscriptionID is required") + } + if c.ResourceGroup == "" { + return fmt.Errorf("config validation failed: ResourceGroup is required") + } + if c.Location == "" { + return fmt.Errorf("config validation failed: Location is required") + } + if c.ServiceGatewayResourceName == "" { + return fmt.Errorf("config validation failed: ServiceGatewayResourceName is required") + } + if c.ServiceGatewayID == "" { + return fmt.Errorf("config validation failed: ServiceGatewayID is required") + } + if c.VNetName == "" { + return fmt.Errorf("config validation failed: VNetName is required") + } + return nil +} diff --git a/pkg/provider/difftracker/difftracker.go b/pkg/provider/difftracker/difftracker.go new file mode 100644 index 0000000000..c8d850bf3a --- /dev/null +++ b/pkg/provider/difftracker/difftracker.go @@ -0,0 +1,63 @@ +package difftracker + +import ( + "fmt" + + "k8s.io/client-go/kubernetes" + "k8s.io/klog/v2" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// InitializeDiffTracker creates and initializes a new DiffTracker with the given state and configuration. +// It validates the configuration and ensures all required dependencies are present. +// Panics if critical dependencies (config, networkClientFactory, kubeClient) are invalid. +func InitializeDiffTracker(K8s K8s_State, NRP NRP_State, config Config, networkClientFactory azclient.ClientFactory, kubeClient kubernetes.Interface) *DiffTracker { + // Validate configuration + if err := config.Validate(); err != nil { + panic(fmt.Sprintf("InitializeDiffTracker: %v", err)) + } + + // Validate required dependencies + if networkClientFactory == nil { + panic("InitializeDiffTracker: networkClientFactory must not be nil") + } + if kubeClient == nil { + panic("InitializeDiffTracker: kubeClient must not be nil") + } + + klog.V(2).Infof("InitializeDiffTracker: initializing with config: subscription=%s, resourceGroup=%s, location=%s", + config.SubscriptionID, config.ResourceGroup, config.Location) + + // If any field is nil, initialize it + if K8s.Services == nil { + K8s.Services = utilsets.NewString() + } + if K8s.Egresses == nil { + K8s.Egresses = utilsets.NewString() + } + if K8s.Nodes == nil { + K8s.Nodes = make(map[string]Node) + } + if NRP.LoadBalancers == nil { + NRP.LoadBalancers = utilsets.NewString() + } + if NRP.NATGateways == nil { + NRP.NATGateways = utilsets.NewString() + } + if NRP.Locations == nil { + NRP.Locations = make(map[string]NRPLocation) + } + + diffTracker := &DiffTracker{ + K8sResources: K8s, + NRPResources: NRP, + + // Configuration and clients + config: config, + networkClientFactory: networkClientFactory, + kubeClient: kubeClient, + } + + return diffTracker +} diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go new file mode 100644 index 0000000000..11dacb8ed7 --- /dev/null +++ b/pkg/provider/difftracker/difftracker_test.go @@ -0,0 +1,632 @@ +package difftracker + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/client-go/kubernetes/fake" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func TestDiffTracker_DeepEqual(t *testing.T) { + tests := []struct { + name string + dt *DiffTracker + expected bool + }{ + { + name: "equal empty states", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "equal states with services", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1", "service2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1", "service2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "services not equal", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1", "service2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.dt.DeepEqual() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateK8sService(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + }, + } + + // Test ADD operation + err := dt.UpdateK8sService(UpdateK8sResource{ + Operation: ADD, + ID: "service1", + }) + assert.NoError(t, err) + assert.True(t, dt.K8sResources.Services.Has("service1")) + + // Test REMOVE operation + err = dt.UpdateK8sService(UpdateK8sResource{ + Operation: REMOVE, + ID: "service1", + }) + assert.NoError(t, err) + assert.False(t, dt.K8sResources.Services.Has("service1")) + + // Test invalid operation + err = dt.UpdateK8sService(UpdateK8sResource{ + Operation: UPDATE, + ID: "service1", + }) + assert.Error(t, err) +} + +func TestGetSyncLoadBalancerServices(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1", "service2", "service3"), + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service2", "service3", "service4"), + }, + } + + result := dt.GetSyncLoadBalancerServices() + + assert.True(t, result.Additions.Has("service1")) + assert.Equal(t, 1, result.Additions.Len()) + + assert.True(t, result.Removals.Has("service4")) + assert.Equal(t, 1, result.Removals.Len()) +} + +func TestUpdateK8sEndpoints(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Nodes: map[string]Node{}, + }, + } + + // Test adding new endpoint + input := UpdateK8sEndpointsInputType{ + InboundIdentity: "service1", + OldAddresses: map[string]string{}, + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + } + + errs := dt.UpdateK8sEndpoints(input) + assert.Empty(t, errs) + assert.Contains(t, dt.K8sResources.Nodes, "node1") + assert.Contains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("service1")) + + // Test removing an endpoint + input = UpdateK8sEndpointsInputType{ + InboundIdentity: "service1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + NewAddresses: map[string]string{}, + } + + errs = dt.UpdateK8sEndpoints(input) + assert.Empty(t, errs) + assert.NotContains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") +} + +func TestUpdateK8sPod(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Nodes: map[string]Node{}, + }, + } + + // Test adding new egress assignment + input := UpdatePodInputType{ + PodOperation: ADD, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + } + + err := dt.UpdateK8sPod(input) + assert.NoError(t, err) + assert.Contains(t, dt.K8sResources.Nodes, "node1") + assert.Contains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") + assert.Equal(t, "public1", dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].PublicOutboundIdentity) + + // Test removing egress assignment + input = UpdatePodInputType{ + PodOperation: REMOVE, + Location: "node1", + Address: "10.0.0.1", + } + + err = dt.UpdateK8sPod(input) + assert.NoError(t, err) + assert.NotContains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") +} + +func TestGetSyncLocationsAddresses(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Nodes: map[string]Node{ + "node1": { + Pods: map[string]Pod{ + "10.0.0.1": { + InboundIdentities: sets.NewString("service1"), + PublicOutboundIdentity: "public1", + }, + }, + }, + }, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("public1"), + Locations: map[string]NRPLocation{}, + }, + } + + result := dt.GetSyncLocationsAddresses() + + assert.Equal(t, PartialUpdate, result.Action) + assert.Len(t, result.Locations, 1) + + location := result.Locations["node1"] + assert.NotNil(t, location) + assert.Equal(t, FullUpdate, location.AddressUpdateAction) + assert.Len(t, location.Addresses, 1) + + var address string + for addr := range location.Addresses { + address = addr + break + } + + assert.Equal(t, "10.0.0.1", address) + assert.True(t, location.Addresses[address].ServiceRef.Has("service1")) + assert.True(t, location.Addresses[address].ServiceRef.Has("public1")) +} + +func TestOperation_String(t *testing.T) { + assert.Equal(t, "ADD", ADD.String()) + assert.Equal(t, "REMOVE", REMOVE.String()) + assert.Equal(t, "UPDATE", UPDATE.String()) +} + +func TestUpdateNRPLoadBalancers(t *testing.T) { + tests := []struct { + name string + initialState *DiffTracker + expectedNRP *sets.IgnoreCaseSet + }{ + { + name: "add services from K8s to NRP", + initialState: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1", "service2", "service3"), + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1"), + }, + }, + expectedNRP: sets.NewString("service1", "service2", "service3"), + }, + { + name: "no changes needed when K8s and NRP are in sync", + initialState: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1", "service2"), + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1", "service2"), + }, + }, + expectedNRP: sets.NewString("service1", "service2"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + syncServices := tt.initialState.GetSyncLoadBalancerServices() + tt.initialState.UpdateNRPLoadBalancers(syncServices) + + assert.True(t, tt.expectedNRP.Equals(tt.initialState.NRPResources.LoadBalancers), + "Expected NRP LoadBalancers %v, but got %v", + tt.expectedNRP.UnsortedList(), + tt.initialState.NRPResources.LoadBalancers.UnsortedList()) + }) + } +} + +func TestUpdateK8sEgress(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Egresses: sets.NewString(), + }, + } + + err := dt.UpdateK8sEgress(UpdateK8sResource{Operation: ADD, ID: "egress1"}) + assert.NoError(t, err) + assert.True(t, dt.K8sResources.Egresses.Has("egress1")) + + err = dt.UpdateK8sEgress(UpdateK8sResource{Operation: REMOVE, ID: "egress1"}) + assert.NoError(t, err) + assert.False(t, dt.K8sResources.Egresses.Has("egress1")) + + err = dt.UpdateK8sEgress(UpdateK8sResource{Operation: UPDATE, ID: "egress1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "error - ResourceType=Egress, Operation=UPDATE and ID=egress1") +} + +func TestGetSyncNRPNATGateways(t *testing.T) { + tests := []struct { + name string + k8sEgresses []string + nrpNATGateways []string + expectedAdditions []string + expectedRemovals []string + }{ + { + name: "empty states", + k8sEgresses: []string{}, + nrpNATGateways: []string{}, + expectedAdditions: []string{}, + expectedRemovals: []string{}, + }, + { + name: "mixed state with additions and removals", + k8sEgresses: []string{"egress1", "egress3", "egress5"}, + nrpNATGateways: []string{"egress1", "egress2", "egress4"}, + expectedAdditions: []string{"egress3", "egress5"}, + expectedRemovals: []string{"egress2", "egress4"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Egresses: sets.NewString(tt.k8sEgresses...), + }, + NRPResources: NRP_State{ + NATGateways: sets.NewString(tt.nrpNATGateways...), + }, + } + + result := dt.GetSyncNRPNATGateways() + + assert.Equal(t, len(tt.expectedAdditions), result.Additions.Len()) + for _, addition := range tt.expectedAdditions { + assert.True(t, result.Additions.Has(addition)) + } + + assert.Equal(t, len(tt.expectedRemovals), result.Removals.Len()) + for _, removal := range tt.expectedRemovals { + assert.True(t, result.Removals.Has(removal)) + } + }) + } +} + +func TestUpdateNRPNATGateways(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8s_State{ + Egresses: sets.NewString("egress1", "egress2", "egress4"), + }, + NRPResources: NRP_State{ + NATGateways: sets.NewString("egress1", "egress3", "egress5"), + }, + } + + syncServices := dt.GetSyncNRPNATGateways() + dt.UpdateNRPNATGateways(syncServices) + + expectedNRP := sets.NewString("egress1", "egress2", "egress4") + assert.True(t, expectedNRP.Equals(dt.NRPResources.NATGateways), + "Expected NRP NATGateways %v, but got %v", + expectedNRP.UnsortedList(), + dt.NRPResources.NATGateways.UnsortedList()) +} + +func TestUpdateLocationsAddresses(t *testing.T) { + tests := []struct { + name string + initialState *DiffTracker + expectedNRP map[string]map[string][]string + }{ + { + name: "sync empty states", + initialState: &DiffTracker{ + K8sResources: K8s_State{Nodes: map[string]Node{}}, + NRPResources: NRP_State{Locations: map[string]NRPLocation{}}, + }, + expectedNRP: map[string]map[string][]string{}, + }, + { + name: "add new location and address", + initialState: &DiffTracker{ + K8sResources: K8s_State{ + Nodes: map[string]Node{ + "node1": { + Pods: map[string]Pod{ + "10.0.0.1": { + InboundIdentities: sets.NewString("service1"), + PublicOutboundIdentity: "public1", + }, + }, + }, + }, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("public1"), + Locations: map[string]NRPLocation{}, + }, + }, + expectedNRP: map[string]map[string][]string{ + "node1": {"10.0.0.1": {"service1", "public1"}}, + }, + }, + { + name: "complex case with multiple operations", + initialState: &DiffTracker{ + K8sResources: K8s_State{ + Nodes: map[string]Node{ + "node1": { + Pods: map[string]Pod{ + "10.0.0.1": { + InboundIdentities: sets.NewString("service1", "service3"), + PublicOutboundIdentity: "public1", + }, + }, + }, + "node3": { + Pods: map[string]Pod{ + "10.0.0.5": {InboundIdentities: sets.NewString("service5")}, + }, + }, + }, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1", "service2", "service3", "service4", "service5"), + NATGateways: sets.NewString("public1"), + Locations: map[string]NRPLocation{ + "node1": { + Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1", "service2")}, + "10.0.0.2": {Services: sets.NewString("service4")}, + }, + }, + "node2": { + Addresses: map[string]NRPAddress{ + "10.0.0.3": {Services: sets.NewString("service3")}, + }, + }, + }, + }, + }, + expectedNRP: map[string]map[string][]string{ + "node1": {"10.0.0.1": {"service1", "service3", "public1"}}, + "node3": {"10.0.0.5": {"service5"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + locationData := tt.initialState.GetSyncLocationsAddresses() + tt.initialState.UpdateLocationsAddresses(locationData) + + assert.Equal(t, len(tt.expectedNRP), len(tt.initialState.NRPResources.Locations)) + + for locName, expectedAddressMap := range tt.expectedNRP { + nrpLoc, exists := tt.initialState.NRPResources.Locations[locName] + assert.True(t, exists, "Expected location %s not found", locName) + assert.Equal(t, len(expectedAddressMap), len(nrpLoc.Addresses)) + + for addr, expectedServices := range expectedAddressMap { + nrpAddr, exists := nrpLoc.Addresses[addr] + assert.True(t, exists, "Expected address %s not found in %s", addr, locName) + assert.Equal(t, len(expectedServices), nrpAddr.Services.Len()) + for _, svc := range expectedServices { + assert.True(t, nrpAddr.Services.Has(svc)) + } + } + } + }) + } +} + +func TestGetSyncOperations(t *testing.T) { + tests := []struct { + name string + initialState *DiffTracker + expectedSyncStatus SyncStatus + }{ + { + name: "states already in sync", + initialState: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("service1")}, + }}, + }, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("egress1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1")}, + }}, + }, + }, + }, + expectedSyncStatus: ALREADY_IN_SYNC, + }, + { + name: "services out of sync", + initialState: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("service1", "service2"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("service1")}, + }}, + }, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("egress1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1")}, + }}, + }, + }, + }, + expectedSyncStatus: SUCCESS, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.initialState.GetSyncOperations() + assert.Equal(t, tt.expectedSyncStatus, result.SyncStatus) + }) + } +} + +// Real Scenario: CloudProvider is down and K8s Cluster is subject to continuous updates. +// This test verifies if the DiffTracker is able to sync K8s Cluster and NRP correctly +// when there is a huge discrepancy between K8s Cluster and NRP. +func TestInitializeDiffTracker(t *testing.T) { + K8sResources := K8s_State{ + Services: sets.NewString("Service0", "Service1", "Service2"), + Egresses: sets.NewString("Egress0", "Egress1", "Egress2"), + Nodes: map[string]Node{ + "Node1": { + Pods: map[string]Pod{ + "Pod34": {InboundIdentities: sets.NewString("Service0"), PublicOutboundIdentity: ""}, + "Pod0": {InboundIdentities: sets.NewString("Service0"), PublicOutboundIdentity: "Egress0"}, + "Pod1": {InboundIdentities: sets.NewString("Service1", "Service2"), PublicOutboundIdentity: "Egress1"}, + "Pod3": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "Egress2"}, + }, + }, + "Node2": { + Pods: map[string]Pod{ + "Pod2": {InboundIdentities: sets.NewString("Service1"), PublicOutboundIdentity: "Egress2"}, + }, + }, + }, + } + + NRPResources := NRP_State{ + LoadBalancers: sets.NewString("Service0", "Service6", "Service5"), + NATGateways: sets.NewString("Egress0", "Egress6", "Egress5"), + Locations: map[string]NRPLocation{ + "Node1": { + Addresses: map[string]NRPAddress{ + "Pod34": {Services: sets.NewString("Service0", "Service5")}, + "Pod00": {Services: sets.NewString("Service6", "Egress5")}, + "Pod0": {Services: sets.NewString("Service0", "Egress0")}, + }, + }, + "Node3": { + Addresses: map[string]NRPAddress{ + "Pod4": {Services: sets.NewString("Service6", "Eggres6")}, + "Pod5": {Services: sets.NewString("Egress5")}, + }, + }, + }, + } + + config := Config{ + SubscriptionID: "test-subscription", + ResourceGroup: "test-rg", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + diffTracker := InitializeDiffTracker(K8sResources, NRPResources, config, mockFactory, mockKubeClient) + syncOperations := diffTracker.GetSyncOperations() + + diffTracker.UpdateNRPLoadBalancers(syncOperations.LoadBalancerUpdates) + diffTracker.UpdateNRPNATGateways(syncOperations.NATGatewayUpdates) + diffTracker.UpdateLocationsAddresses(syncOperations.LocationData) + + assert.Equal(t, SUCCESS, syncOperations.SyncStatus) + + expectedDiffTracker := &DiffTracker{ + K8sResources: K8sResources, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("Service0", "Service1", "Service2"), + NATGateways: sets.NewString("Egress0", "Egress1", "Egress2"), + Locations: map[string]NRPLocation{ + "Node1": { + Addresses: map[string]NRPAddress{ + "Pod34": {Services: sets.NewString("Service0")}, + "Pod0": {Services: sets.NewString("Service0", "Egress0")}, + }, + }, + }, + }, + } + + assert.True(t, diffTracker.Equals(expectedDiffTracker), + "DiffTracker does not match expected state") +} diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go new file mode 100644 index 0000000000..4a421e58c4 --- /dev/null +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -0,0 +1,252 @@ +package difftracker + +import ( + "fmt" + "strings" + + "k8s.io/klog/v2" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +const ( + ResourceTypeService = "Service" + ResourceTypeEgress = "Egress" +) + +func updateK8Resource(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { + if input.ID == "" { + return fmt.Errorf("%s: empty ID not allowed", resourceType) + } + + switch input.Operation { + case ADD: + set.Insert(input.ID) + case REMOVE: + set.Delete(input.ID) + default: + return fmt.Errorf("error - ResourceType=%s, Operation=%s and ID=%s", resourceType, input.Operation, input.ID) + } + return nil +} + +func (dt *DiffTracker) UpdateK8sService(input UpdateK8sResource) error { + dt.mu.Lock() + defer dt.mu.Unlock() + + return updateK8Resource(input, dt.K8sResources.Services, ResourceTypeService) +} + +func (dt *DiffTracker) UpdateK8sEgress(input UpdateK8sResource) error { + dt.mu.Lock() + defer dt.mu.Unlock() + + return updateK8Resource(input, dt.K8sResources.Egresses, ResourceTypeEgress) +} + +// updateK8sEndpointsLocked updates K8s endpoints state. Assumes lock is already held. +func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputType) []error { + var errs []error + for address, location := range input.NewAddresses { + + if location == "" { + errs = append(errs, fmt.Errorf("error UpdateK8sEndpoints, address=%s does not have a node associated", address)) + continue + } + + if _, exists := input.OldAddresses[address]; exists { + continue + } + + nodeState, exists := dt.K8sResources.Nodes[location] + if !exists { + nodeState = Node{ + Pods: make(map[string]Pod), + } + dt.K8sResources.Nodes[location] = nodeState + } + + pod, exists := nodeState.Pods[address] + if !exists { + pod = Pod{ + InboundIdentities: utilsets.NewString(), + } + nodeState.Pods[address] = pod + } + pod.InboundIdentities.Insert(input.InboundIdentity) + } + + for address, location := range input.OldAddresses { + if _, exists := input.NewAddresses[address]; exists { + continue + } + + if location == "" { + errs = append(errs, fmt.Errorf("error UpdateK8sEndpoints, address=%s does not have a node associated", address)) + } + + node, nodeExists := dt.K8sResources.Nodes[location] + if !nodeExists { + continue + } + + pod, podExists := node.Pods[address] + if !podExists { + continue + } + + pod.InboundIdentities.Delete(input.InboundIdentity) + + if !pod.HasIdentities() { + delete(node.Pods, address) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, location) + } + } + } + + return errs +} + +// UpdateK8sEndpoints is a public wrapper that acquires lock before calling updateK8sEndpointsLocked. +func (dt *DiffTracker) UpdateK8sEndpoints(input UpdateK8sEndpointsInputType) []error { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.updateK8sEndpointsLocked(input) +} + +func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) error { + node, exists := dt.K8sResources.Nodes[input.Location] + if !exists { + node = Node{Pods: make(map[string]Pod)} + dt.K8sResources.Nodes[input.Location] = node + } + + pod, exists := node.Pods[input.Address] + if !exists { + pod = Pod{InboundIdentities: utilsets.NewString()} + } + + pod.PublicOutboundIdentity = input.PublicOutboundIdentity + node.Pods[input.Address] = pod + + return nil +} + +// removePod removes a pod from K8s state. Returns true if the pod was actually removed, +// false if it didn't exist (already removed by a previous call). +func (dt *DiffTracker) removePod(input UpdatePodInputType) (removed bool, err error) { + node, exists := dt.K8sResources.Nodes[input.Location] + if !exists { + return false, nil + } + + // Check if pod exists before removing + if _, podExists := node.Pods[input.Address]; !podExists { + return false, nil + } + + delete(node.Pods, input.Address) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, input.Location) + } + + return true, nil +} + +// updateK8sPodLocked updates K8s pod state. Assumes lock is already held. +func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { + switch input.PodOperation { + case ADD, UPDATE: + // Check if pod already exists with the same outbound identity + // This prevents double-counting when pod informer fires AddFunc for pods + // that were already counted during initialization + alreadyExists := false + if node, nodeExists := dt.K8sResources.Nodes[input.Location]; nodeExists { + if pod, podExists := node.Pods[input.Address]; podExists { + if pod.PublicOutboundIdentity == input.PublicOutboundIdentity { + alreadyExists = true + klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s already exists for service %s, skipping counter increment", + input.Location, input.Address, input.PublicOutboundIdentity) + } + } + } + + // Only increment counter if pod doesn't already exist + if !alreadyExists { + counter := 0 + if val, ok := dt.LocalServiceNameToNRPServiceMap.Load(strings.ToLower(input.PublicOutboundIdentity)); ok { + counter = val.(int) + } + dt.LocalServiceNameToNRPServiceMap.Store(strings.ToLower(input.PublicOutboundIdentity), counter+1) + } + return dt.addOrUpdatePod(input) + case REMOVE: + // First, try to remove the pod from K8s state + // This returns false if the pod doesn't exist (duplicate removal) + removed, err := dt.removePod(input) + if err != nil { + return err + } + if !removed { + // Pod didn't exist - this is a duplicate removal, don't decrement counter + klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s was already removed (duplicate delete), skipping counter decrement", + input.Location, input.Address) + return nil + } + + // Pod was actually removed, now decrement the counter + if val, ok := dt.LocalServiceNameToNRPServiceMap.Load(strings.ToLower(input.PublicOutboundIdentity)); ok { + counter := val.(int) + if counter <= 0 { + return fmt.Errorf("error - PublicOutboundIdentity %s has a negative count: %d", input.PublicOutboundIdentity, counter) + } + if counter == 1 { + dt.LocalServiceNameToNRPServiceMap.Delete(strings.ToLower(input.PublicOutboundIdentity)) + } else { + dt.LocalServiceNameToNRPServiceMap.Store(strings.ToLower(input.PublicOutboundIdentity), counter-1) + } + } + return nil + default: + return fmt.Errorf("invalid pod operation: %s for pod at %s:%s", + input.PodOperation, input.Location, input.Address) + } +} + +// UpdateK8sPod is a public wrapper that acquires lock before calling updateK8sPodLocked. +func (dt *DiffTracker) UpdateK8sPod(input UpdatePodInputType) error { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.updateK8sPodLocked(input) +} + +// removeServiceFromK8sStateLocked removes a service from all pod identities in K8s state. +// This is used during service deletion to proactively clear location/address references +// so the LocationsUpdater can sync the removal to NRP. +// Assumes lock is already held. +func (dt *DiffTracker) removeServiceFromK8sStateLocked(serviceUID string, isInbound bool) { + for nodeIP, node := range dt.K8sResources.Nodes { + for podIP, pod := range node.Pods { + if isInbound { + // Remove from inbound identities + if pod.InboundIdentities != nil && pod.InboundIdentities.Has(serviceUID) { + pod.InboundIdentities.Delete(serviceUID) + } + } else { + // Clear outbound identity if it matches + if strings.EqualFold(pod.PublicOutboundIdentity, serviceUID) { + pod.PublicOutboundIdentity = "" + node.Pods[podIP] = pod + } + } + + // Clean up empty pods and nodes + if !pod.HasIdentities() { + delete(node.Pods, podIP) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, nodeIP) + } + } + } + } +} diff --git a/pkg/provider/difftracker/nrp_state_updates.go b/pkg/provider/difftracker/nrp_state_updates.go new file mode 100644 index 0000000000..91f05841b5 --- /dev/null +++ b/pkg/provider/difftracker/nrp_state_updates.go @@ -0,0 +1,108 @@ +package difftracker + +import ( + "k8s.io/klog/v2" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func (dt *DiffTracker) UpdateNRPLoadBalancers(syncServicesReturnType SyncServicesReturnType) { + dt.mu.Lock() + defer dt.mu.Unlock() + + for _, service := range syncServicesReturnType.Additions.UnsortedList() { + dt.NRPResources.LoadBalancers.Insert(service) + klog.V(2).Infof("UpdateNRPLoadBalancers: Added service %s to NRP LoadBalancers\n", service) + } + + for _, service := range syncServicesReturnType.Removals.UnsortedList() { + dt.NRPResources.LoadBalancers.Delete(service) + klog.V(2).Infof("UpdateNRPLoadBalancers: Removed service %s from NRP LoadBalancers\n", service) + } +} + +func (dt *DiffTracker) UpdateNRPNATGateways(syncServicesReturnType SyncServicesReturnType) { + dt.mu.Lock() + defer dt.mu.Unlock() + + for _, service := range syncServicesReturnType.Additions.UnsortedList() { + dt.NRPResources.NATGateways.Insert(service) + klog.V(2).Infof("UpdateNRPNATGateways: Added service %s to NRP NATGateways\n", service) + } + + for _, service := range syncServicesReturnType.Removals.UnsortedList() { + dt.NRPResources.NATGateways.Delete(service) + klog.V(2).Infof("UpdateNRPNATGateways: Removed service %s from NRP NATGateways\n", service) + } +} + +func (dt *DiffTracker) UpdateLocationsAddresses(locationData LocationData) { + dt.mu.Lock() + defer dt.mu.Unlock() + + for locationKey, locationValue := range locationData.Locations { + // Remove empty locations + if len(locationValue.Addresses) == 0 { + delete(dt.NRPResources.Locations, locationKey) + continue + } + + // Get or create location + nrpLocation, exists := dt.NRPResources.Locations[locationKey] + isFullUpdate := !exists || locationValue.AddressUpdateAction == FullUpdate + + // For full updates, start with a fresh location + if isFullUpdate { + nrpLocation = NRPLocation{ + Addresses: make(map[string]NRPAddress), + } + } + + // Process address updates + for addressKey, addressValue := range locationValue.Addresses { + // Remove empty addresses + if addressValue.ServiceRef.Len() == 0 { + delete(nrpLocation.Addresses, addressKey) + continue + } + + // Create new service references set + serviceRefs := createServiceRefsFromAddress(addressValue) + + // For full update or when address doesn't exist, add new address + if isFullUpdate || !addressExists(nrpLocation, addressKey) { + nrpLocation.Addresses[addressKey] = NRPAddress{Services: serviceRefs} + continue + } + + // For partial updates with existing address + existingAddress := nrpLocation.Addresses[addressKey] + if !serviceRefs.Equals(existingAddress.Services) { + nrpLocation.Addresses[addressKey] = NRPAddress{ + Services: serviceRefs, + } + } + } + + // Save location if it has addresses, otherwise delete it + if len(nrpLocation.Addresses) > 0 { + dt.NRPResources.Locations[locationKey] = nrpLocation + } else { + delete(dt.NRPResources.Locations, locationKey) + } + } +} + +// Helper function to check if address exists in a location +func addressExists(location NRPLocation, addressKey string) bool { + _, exists := location.Addresses[addressKey] + return exists +} + +// Helper function to create service references from an address +func createServiceRefsFromAddress(addressValue Address) *utilsets.IgnoreCaseSet { + serviceRefs := utilsets.NewString() + for _, service := range addressValue.ServiceRef.UnsortedList() { + serviceRefs.Insert(service) + } + return serviceRefs +} diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go new file mode 100644 index 0000000000..6b18b4d4d4 --- /dev/null +++ b/pkg/provider/difftracker/sync_operations.go @@ -0,0 +1,191 @@ +package difftracker + +import ( + "k8s.io/klog/v2" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// GetServicesToSync handles the synchronization of services between K8s and NRP +func GetServicesToSync(k8sServices, Services *utilsets.IgnoreCaseSet) SyncServicesReturnType { + klog.Infof("GetServicesToSync: K8s services (%d): %v", k8sServices.Len(), k8sServices.UnsortedList()) + klog.Infof("GetServicesToSync: NRP services (%d): %v", Services.Len(), Services.UnsortedList()) + + syncServices := SyncServicesReturnType{ + Additions: utilsets.NewString(), + Removals: utilsets.NewString(), + } + + for _, service := range k8sServices.UnsortedList() { + if Services.Has(service) { + continue + } + syncServices.Additions.Insert(service) + klog.Infof("GetServicesToSync: Added service %s to additions", service) + } + + for _, service := range Services.UnsortedList() { + if k8sServices.Has(service) { + continue + } + syncServices.Removals.Insert(service) + klog.Infof("GetServicesToSync: Added service %s to removals", service) + } + + klog.Infof("GetServicesToSync: Result - Additions: %d, Removals: %d", syncServices.Additions.Len(), syncServices.Removals.Len()) + return syncServices +} + +func (dt *DiffTracker) GetSyncLoadBalancerServices() SyncServicesReturnType { + dt.mu.Lock() + defer dt.mu.Unlock() + + return GetServicesToSync(dt.K8sResources.Services, dt.NRPResources.LoadBalancers) +} + +func (dt *DiffTracker) GetSyncNRPNATGateways() SyncServicesReturnType { + dt.mu.Lock() + defer dt.mu.Unlock() + + return GetServicesToSync(dt.K8sResources.Egresses, dt.NRPResources.NATGateways) +} + +func (dt *DiffTracker) GetSyncLocationsAddresses() LocationData { + dt.mu.Lock() + defer dt.mu.Unlock() + + result := LocationData{ + Action: PartialUpdate, + Locations: make(map[string]Location), + } + + // Iterate over all nodes in the K8s state + for nodeIp, node := range dt.K8sResources.Nodes { + nrpLocation, locationExists := dt.NRPResources.Locations[nodeIp] + location := initializeLocation(locationExists) + locationUpdated := false + + for address, pod := range node.Pods { + // Filter services: only include services that exist in NRP + serviceRef := dt.createServiceRefFiltered(pod) + + // Check if address exists in NRP and if service list changed + nrpAddressData, addressExists := nrpLocation.Addresses[address] + + // Skip this address if: + // 1. No ready services AND address doesn't exist in NRP (nothing to sync) + // 2. ServiceRef matches what's already in NRP (no change) + if serviceRef.Len() == 0 && !addressExists { + continue + } + + if addressExists && serviceRef.Equals(nrpAddressData.Services) { + continue + } + + // ServiceRef changed (or address is new) - need to sync + addressData := Address{ServiceRef: serviceRef} + location.Addresses[address] = addressData + locationUpdated = true + } + if locationUpdated { + result.Locations[nodeIp] = location + } + } + + // Iterate over all locations in the NRP state + for location, nrpLocation := range dt.NRPResources.Locations { + node, exists := dt.K8sResources.Nodes[location] + if !exists { + result.Locations[location] = Location{ + AddressUpdateAction: PartialUpdate, + Addresses: make(map[string]Address), + } + } else { + locationData := findLocationData(result, location) + if locationData == nil { + locationData = &Location{ + AddressUpdateAction: PartialUpdate, + Addresses: make(map[string]Address), + } + } + for address := range nrpLocation.Addresses { + if _, exists := node.Pods[address]; !exists { + addressData := Address{ServiceRef: utilsets.NewString()} + locationData.Addresses[address] = addressData + result.Locations[location] = *locationData + } + } + } + } + return result +} + +// Helper function to initialize Location based on existence in NRP +func initializeLocation(exists bool) Location { + if !exists { + return Location{ + AddressUpdateAction: FullUpdate, + Addresses: make(map[string]Address), + } + } + return Location{ + AddressUpdateAction: PartialUpdate, + Addresses: make(map[string]Address), + } +} + +// createServiceRefFiltered creates ServiceRef but only includes services that exist in NRP. +// Must be called with dt.mu held. +func (dt *DiffTracker) createServiceRefFiltered(pod Pod) *utilsets.IgnoreCaseSet { + serviceRef := utilsets.NewString() + + // Check inbound services (LoadBalancers) + for _, serviceUID := range pod.InboundIdentities.UnsortedList() { + if dt.isServiceReady(serviceUID, true) { + serviceRef.Insert(serviceUID) + } + } + + // Check outbound service (NAT Gateway) + if pod.PublicOutboundIdentity != "" { + if dt.isServiceReady(pod.PublicOutboundIdentity, false) { + serviceRef.Insert(pod.PublicOutboundIdentity) + } + } + + return serviceRef +} + +// isServiceReady checks if a service is ready for location sync. +// Returns true if the service exists in NRP. +// Must be called with dt.mu held. +func (dt *DiffTracker) isServiceReady(serviceUID string, isInbound bool) bool { + if isInbound { + return dt.NRPResources.LoadBalancers.Has(serviceUID) + } + return dt.NRPResources.NATGateways.Has(serviceUID) +} + +// Helper function to find LocationData in result +func findLocationData(result LocationData, location string) *Location { + for keyCurrentLocation := range result.Locations { + if keyCurrentLocation == location { + loc := result.Locations[keyCurrentLocation] + return &loc + } + } + return nil +} + +func (dt *DiffTracker) GetSyncOperations() *SyncDiffTrackerReturnType { + if dt.DeepEqual() { + return &SyncDiffTrackerReturnType{SyncStatus: ALREADY_IN_SYNC} + } + + return &SyncDiffTrackerReturnType{ + SyncStatus: SUCCESS, + LoadBalancerUpdates: dt.GetSyncLoadBalancerServices(), + NATGatewayUpdates: dt.GetSyncNRPNATGateways(), + LocationData: dt.GetSyncLocationsAddresses(), + } +} diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go new file mode 100644 index 0000000000..66c7f14e75 --- /dev/null +++ b/pkg/provider/difftracker/types.go @@ -0,0 +1,136 @@ +package difftracker + +import ( + "sync" + + "k8s.io/client-go/kubernetes" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// ================================================================================================ +// ENUMS +// ================================================================================================ +type Operation int + +const ( + ADD Operation = iota + REMOVE + UPDATE +) + +type UpdateAction int + +const ( + PartialUpdate UpdateAction = iota + FullUpdate +) + +type SyncStatus int + +const ( + ALREADY_IN_SYNC SyncStatus = iota + SUCCESS +) + +// -------------------------------------------------------------------------------- +// DiffTracker keeps track of the state of the K8s cluster and NRP +// -------------------------------------------------------------------------------- +type NRPAddress struct { + Services *utilsets.IgnoreCaseSet // all inbound and outbound identities +} + +type NRPLocation struct { + Addresses map[string]NRPAddress +} + +type NRP_State struct { + LoadBalancers *utilsets.IgnoreCaseSet + NATGateways *utilsets.IgnoreCaseSet + Locations map[string]NRPLocation +} + +type Pod struct { + InboundIdentities *utilsets.IgnoreCaseSet + PublicOutboundIdentity string +} + +type Node struct { + Pods map[string]Pod +} + +type K8s_State struct { + Services *utilsets.IgnoreCaseSet + Egresses *utilsets.IgnoreCaseSet + Nodes map[string]Node +} + +// DiffTracker is the main struct that contains the state of the K8s and NRP services +type DiffTracker struct { + mu sync.Mutex // Protects concurrent access to DiffTracker + + K8sResources K8s_State + NRPResources NRP_State + + LocalServiceNameToNRPServiceMap sync.Map + + // Configuration and clients + config Config + networkClientFactory azclient.ClientFactory + kubeClient kubernetes.Interface +} + +// -------------------------------------------------------------------------------- +// Types that are used while events are received and processed in order to update K8s state +// -------------------------------------------------------------------------------- + +// UpdateK8sResource represents input for K8s service or egress updates +type UpdateK8sResource struct { + Operation Operation + ID string +} + +// UpdateK8sEndpointsInputType represents input for K8s endpoints updates +type UpdateK8sEndpointsInputType struct { + InboundIdentity string + OldAddresses map[string]string // address -> location + NewAddresses map[string]string // address -> location +} + +// UpdatePodInputType represents input for K8s pod updates (egress assignments) +type UpdatePodInputType struct { + PodOperation Operation + PublicOutboundIdentity string + Location string + Address string +} + +// -------------------------------------------------------------------------------- +// Types that are used while syncing NRP state to K8s state +// -------------------------------------------------------------------------------- +type Address struct { + ServiceRef *utilsets.IgnoreCaseSet +} + +// Location uses a map for Addresses +type Location struct { + AddressUpdateAction UpdateAction + Addresses map[string]Address // key is Address.Address +} + +type LocationData struct { + Action UpdateAction + Locations map[string]Location // key is Location.Location +} + +type SyncServicesReturnType struct { + Additions *utilsets.IgnoreCaseSet + Removals *utilsets.IgnoreCaseSet +} + +type SyncDiffTrackerReturnType struct { + SyncStatus SyncStatus + LoadBalancerUpdates SyncServicesReturnType + NATGatewayUpdates SyncServicesReturnType + LocationData LocationData +} diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go new file mode 100644 index 0000000000..eb853bfaa4 --- /dev/null +++ b/pkg/provider/difftracker/util.go @@ -0,0 +1,302 @@ +package difftracker + +import ( + "encoding/json" + "fmt" + + "k8s.io/klog/v2" +) + +func (operation Operation) String() string { + return [...]string{"ADD", "REMOVE", "UPDATE"}[operation] +} + +func (operation Operation) MarshalJSON() ([]byte, error) { + return json.Marshal(operation.String()) +} + +func (updateAction UpdateAction) String() string { + return [...]string{"PartialUpdate", "FullUpdate"}[updateAction] +} + +func (updateAction UpdateAction) MarshalJSON() ([]byte, error) { + return json.Marshal(updateAction.String()) +} + +func (updateAction *UpdateAction) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + switch s { + case "PartialUpdate": + *updateAction = PartialUpdate + case "FullUpdate": + *updateAction = FullUpdate + default: + return fmt.Errorf("unknown UpdateAction: %q", s) + } + + return nil +} + +func (syncStatus SyncStatus) String() string { + return [...]string{"ALREADY_IN_SYNC", "SUCCESS"}[syncStatus] +} + +func (syncStatus SyncStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(syncStatus.String()) +} + +func (node *Node) HasPods() bool { return len(node.Pods) > 0 } + +func (pod *Pod) HasIdentities() bool { + return pod.InboundIdentities.Len() > 0 || pod.PublicOutboundIdentity != "" +} + +// DeepEqual compares the K8s and NRP states to check if they are in sync +func (dt *DiffTracker) DeepEqual() bool { + klog.Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d", + dt.K8sResources.Services.Len(), dt.NRPResources.LoadBalancers.Len(), + dt.K8sResources.Egresses.Len(), dt.NRPResources.NATGateways.Len()) + + // Compare Services with LoadBalancers + if dt.K8sResources.Services.Len() != dt.NRPResources.LoadBalancers.Len() { + klog.Infof("DeepEqual: Services and LoadBalancers length mismatch") + return false + } + for _, service := range dt.K8sResources.Services.UnsortedList() { + if !dt.NRPResources.LoadBalancers.Has(service) { + klog.Infof("DeepEqual: Service %s not found in LoadBalancers", service) + return false + } + } + for _, service := range dt.NRPResources.LoadBalancers.UnsortedList() { + if !dt.K8sResources.Services.Has(service) { + klog.Infof("DeepEqual: LoadBalancer %s not found in Services", service) + return false + } + } + + // Compare Egresses with NATGateways + if dt.K8sResources.Egresses.Len() != dt.NRPResources.NATGateways.Len() { + klog.Infof("DeepEqual: Egresses and NATGateways length mismatch") + return false + } + for _, egress := range dt.K8sResources.Egresses.UnsortedList() { + if !dt.NRPResources.NATGateways.Has(egress) { + klog.Infof("DeepEqual: Egress %s not found in NATGateways", egress) + return false + } + } + for _, egress := range dt.NRPResources.NATGateways.UnsortedList() { + if !dt.K8sResources.Egresses.Has(egress) { + klog.Infof("DeepEqual: NATGateway %s not found in Egresses", egress) + return false + } + } + + // Compare Nodes with Locations + if len(dt.K8sResources.Nodes) != len(dt.NRPResources.Locations) { + klog.V(2).Infof("DeepEqual: Nodes and Locations length mismatch") + return false + } + for nodeKey, node := range dt.K8sResources.Nodes { + nrpLocation, exists := dt.NRPResources.Locations[nodeKey] + if !exists { + klog.V(2).Infof("DeepEqual: Node %s not found in Locations\n", nodeKey) + return false + } + + // Compare Pods with Addresses + if len(node.Pods) != len(nrpLocation.Addresses) { + klog.V(2).Infof("DeepEqual: Pods and Addresses length mismatch for node %s\n", nodeKey) + return false + } + for podKey, pod := range node.Pods { + nrpAddress, exists := nrpLocation.Addresses[podKey] + if !exists { + klog.V(2).Infof("DeepEqual: Pod %s not found in Addresses for node %s\n", podKey, nodeKey) + return false + } + + // Compare [...InboundIdentities, PublicOutboundIdentity] with Services + combinedIdentities := []string{} + combinedIdentities = append(combinedIdentities, pod.InboundIdentities.UnsortedList()...) + if pod.PublicOutboundIdentity != "" { + combinedIdentities = append(combinedIdentities, pod.PublicOutboundIdentity) + } + + if len(combinedIdentities) != nrpAddress.Services.Len() { + klog.V(2).Infof("DeepEqual: Combined identities length mismatch for pod %s in node %s\n", podKey, nodeKey) + return false + } + + for _, identity := range combinedIdentities { + if !nrpAddress.Services.Has(identity) { + klog.V(2).Infof("DeepEqual: Identity %s not found in Services for pod %s in node %s\n", identity, podKey, nodeKey) + return false + } + } + } + } + + return true +} + +func (syncServicesReturnType *SyncServicesReturnType) Equals(other *SyncServicesReturnType) bool { + return syncServicesReturnType.Additions.Equals(other.Additions) && syncServicesReturnType.Removals.Equals(other.Removals) +} + +// Equals compares two LocationData objects for equality +func (ld *LocationData) Equals(other *LocationData) bool { + if ld.Action != other.Action { + return false + } + + if len(ld.Locations) != len(other.Locations) { + return false + } + + for locName, location := range ld.Locations { + otherLocation, exists := other.Locations[locName] + if !exists { + return false + } + + if location.AddressUpdateAction != otherLocation.AddressUpdateAction { + return false + } + + if len(location.Addresses) != len(otherLocation.Addresses) { + return false + } + + for addrName, address := range location.Addresses { + otherAddress, exists := otherLocation.Addresses[addrName] + if !exists { + return false + } + + if !address.ServiceRef.Equals(otherAddress.ServiceRef) { + return false + } + } + } + + return true +} + +// Equals compares two SyncDiffTrackerReturnType objects for equality +func (sdts *SyncDiffTrackerReturnType) Equals(other *SyncDiffTrackerReturnType) bool { + if sdts.SyncStatus != other.SyncStatus { + return false + } + + if !sdts.LoadBalancerUpdates.Additions.Equals(other.LoadBalancerUpdates.Additions) { + return false + } + + if !sdts.LoadBalancerUpdates.Removals.Equals(other.LoadBalancerUpdates.Removals) { + return false + } + + if !sdts.NATGatewayUpdates.Additions.Equals(other.NATGatewayUpdates.Additions) { + return false + } + + if !sdts.NATGatewayUpdates.Removals.Equals(other.NATGatewayUpdates.Removals) { + return false + } + + if !sdts.LocationData.Equals(&other.LocationData) { + return false + } + + return true +} + +// Equals compares two DiffTracker objects for equality +func (dt *DiffTracker) Equals(other *DiffTracker) bool { + dt.mu.Lock() + defer dt.mu.Unlock() + + other.mu.Lock() + defer other.mu.Unlock() + + if !dt.K8sResources.Services.Equals(other.K8sResources.Services) { + return false + } + + if !dt.K8sResources.Egresses.Equals(other.K8sResources.Egresses) { + return false + } + + if len(dt.K8sResources.Nodes) != len(other.K8sResources.Nodes) { + return false + } + + for nodeKey, node := range dt.K8sResources.Nodes { + otherNode, exists := other.K8sResources.Nodes[nodeKey] + if !exists { + return false + } + + if len(node.Pods) != len(otherNode.Pods) { + return false + } + + for podKey, pod := range node.Pods { + otherPod, exists := otherNode.Pods[podKey] + if !exists { + return false + } + + if !pod.InboundIdentities.Equals(otherPod.InboundIdentities) { + return false + } + + if pod.PublicOutboundIdentity != otherPod.PublicOutboundIdentity { + return false + } + } + } + + // Compare NRP state + if !dt.NRPResources.LoadBalancers.Equals(other.NRPResources.LoadBalancers) { + return false + } + + if !dt.NRPResources.NATGateways.Equals(other.NRPResources.NATGateways) { + return false + } + + if len(dt.NRPResources.Locations) != len(other.NRPResources.Locations) { + return false + } + + for location, nrpLocation := range dt.NRPResources.Locations { + otherNrpLocation, exists := other.NRPResources.Locations[location] + if !exists { + return false + } + + if len(nrpLocation.Addresses) != len(otherNrpLocation.Addresses) { + return false + } + + for address, nrpAddress := range nrpLocation.Addresses { + otherNrpAddress, exists := otherNrpLocation.Addresses[address] + if !exists { + return false + } + + if !nrpAddress.Services.Equals(otherNrpAddress.Services) { + return false + } + } + } + + return true +} diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go new file mode 100644 index 0000000000..9ffce2621f --- /dev/null +++ b/pkg/provider/difftracker/util_test.go @@ -0,0 +1,648 @@ +package difftracker + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// TestOperationStringAndJSON tests Operation String() and MarshalJSON() +func TestOperationStringAndJSON(t *testing.T) { + tests := []struct { + name string + op Operation + expected string + }{ + {"ADD operation", ADD, "ADD"}, + {"REMOVE operation", REMOVE, "REMOVE"}, + {"UPDATE operation", UPDATE, "UPDATE"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() + assert.Equal(t, tt.expected, tt.op.String()) + + // Test MarshalJSON() + data, err := tt.op.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"`+tt.expected+`"`, string(data)) + }) + } +} + +// TestUpdateActionStringAndJSON tests UpdateAction String(), MarshalJSON(), and UnmarshalJSON() +func TestUpdateActionStringAndJSON(t *testing.T) { + tests := []struct { + name string + action UpdateAction + expected string + }{ + {"PartialUpdate", PartialUpdate, "PartialUpdate"}, + {"FullUpdate", FullUpdate, "FullUpdate"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() + assert.Equal(t, tt.expected, tt.action.String()) + + // Test MarshalJSON() + data, err := tt.action.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"`+tt.expected+`"`, string(data)) + + // Test UnmarshalJSON() round-trip + var unmarshaled UpdateAction + err = unmarshaled.UnmarshalJSON(data) + assert.NoError(t, err) + assert.Equal(t, tt.action, unmarshaled) + }) + } +} + +// TestUpdateActionUnmarshalJSON_InvalidValue tests error handling +func TestUpdateActionUnmarshalJSON_InvalidValue(t *testing.T) { + var action UpdateAction + err := action.UnmarshalJSON([]byte(`"InvalidAction"`)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown UpdateAction") +} + +// TestUpdateActionUnmarshalJSON_InvalidJSON tests JSON parsing errors +func TestUpdateActionUnmarshalJSON_InvalidJSON(t *testing.T) { + var action UpdateAction + err := action.UnmarshalJSON([]byte(`{invalid json`)) + assert.Error(t, err) +} + +// TestSyncStatusStringAndJSON tests SyncStatus String() and MarshalJSON() +func TestSyncStatusStringAndJSON(t *testing.T) { + tests := []struct { + name string + status SyncStatus + expected string + }{ + {"ALREADY_IN_SYNC", ALREADY_IN_SYNC, "ALREADY_IN_SYNC"}, + {"SUCCESS", SUCCESS, "SUCCESS"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() + assert.Equal(t, tt.expected, tt.status.String()) + + // Test MarshalJSON() + data, err := tt.status.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"`+tt.expected+`"`, string(data)) + }) + } +} + +// TestNodeHasPods tests Node.HasPods() +func TestNodeHasPods(t *testing.T) { + tests := []struct { + name string + node Node + expected bool + }{ + { + name: "node with pods", + node: Node{Pods: map[string]Pod{"pod1": {}}}, + expected: true, + }, + { + name: "node without pods", + node: Node{Pods: map[string]Pod{}}, + expected: false, + }, + { + name: "node with nil pods map", + node: Node{Pods: nil}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.node.HasPods()) + }) + } +} + +// TestPodHasIdentities tests Pod.HasIdentities() +func TestPodHasIdentities(t *testing.T) { + tests := []struct { + name string + pod Pod + expected bool + }{ + { + name: "pod with inbound identities", + pod: Pod{InboundIdentities: sets.NewString("id1", "id2")}, + expected: true, + }, + { + name: "pod with public outbound identity", + pod: Pod{PublicOutboundIdentity: "outbound-id"}, + expected: true, + }, + { + name: "pod with both identities", + pod: Pod{InboundIdentities: sets.NewString("id1"), PublicOutboundIdentity: "outbound-id"}, + expected: true, + }, + { + name: "pod without identities", + pod: Pod{InboundIdentities: sets.NewString(), PublicOutboundIdentity: ""}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.pod.HasIdentities()) + }) + } +} + +// TestDeepEqual tests DiffTracker.DeepEqual() +func TestDeepEqual(t *testing.T) { + tests := []struct { + name string + dt *DiffTracker + expected bool + }{ + { + name: "in sync - matching services and load balancers", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc1", "svc2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("svc1", "svc2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "not in sync - service count mismatch", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("svc1", "svc2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + { + name: "not in sync - service name mismatch", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString("svc2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + { + name: "in sync - matching egresses and NAT gateways", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString("egress1", "egress2"), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString("egress1", "egress2"), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "not in sync - egress count mismatch", + dt: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + NRPResources: NRP_State{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString("egress1", "egress2"), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.dt.DeepEqual()) + }) + } +} + +// TestLocationDataEquals tests LocationData.Equals() +func TestLocationDataEquals(t *testing.T) { + tests := []struct { + name string + ld1 LocationData + ld2 LocationData + expected bool + }{ + { + name: "equal location data - empty locations", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{}, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{}, + }, + expected: true, + }, + { + name: "equal location data - with addresses", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{ + "addr1": {ServiceRef: sets.NewString("svc1")}, + }, + }, + }, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{ + "addr1": {ServiceRef: sets.NewString("svc1")}, + }, + }, + }, + }, + expected: true, + }, + { + name: "different action", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{}, + }, + ld2: LocationData{ + Action: FullUpdate, + Locations: map[string]Location{}, + }, + expected: false, + }, + { + name: "different location count", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + "loc2": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + }, + expected: false, + }, + { + name: "different address update action", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: FullUpdate, Addresses: map[string]Address{}}, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.ld1.Equals(&tt.ld2)) + }) + } +} + +// TestDiffTrackerEquals tests DiffTracker.Equals() +func TestDiffTrackerEquals(t *testing.T) { + tests := []struct { + name string + dt1 *DiffTracker + dt2 *DiffTracker + expected bool + }{ + { + name: "equal diff trackers", + dt1: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + }, + expected: true, + }, + { + name: "different services", + dt1: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString("svc2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + }, + expected: false, + }, + { + name: "different egresses", + dt1: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString("egress2"), + Nodes: map[string]Node{}, + }, + }, + expected: false, + }, + { + name: "different node count", + dt1: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{"node1": {}}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + }, + expected: false, + }, + { + name: "different pod count in node", + dt1: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{"pod1": {}}}, + }, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8s_State{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{}}, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.dt1.Equals(tt.dt2)) + }) + } +} + +// TestSyncServicesReturnTypeEquals tests SyncServicesReturnType.Equals() +func TestSyncServicesReturnTypeEquals(t *testing.T) { + tests := []struct { + name string + s1 SyncServicesReturnType + s2 SyncServicesReturnType + expected bool + }{ + { + name: "equal - both empty", + s1: SyncServicesReturnType{ + Additions: sets.NewString(), + Removals: sets.NewString(), + }, + s2: SyncServicesReturnType{ + Additions: sets.NewString(), + Removals: sets.NewString(), + }, + expected: true, + }, + { + name: "equal - same additions", + s1: SyncServicesReturnType{ + Additions: sets.NewString("svc1", "svc2"), + Removals: sets.NewString(), + }, + s2: SyncServicesReturnType{ + Additions: sets.NewString("svc1", "svc2"), + Removals: sets.NewString(), + }, + expected: true, + }, + { + name: "not equal - different additions", + s1: SyncServicesReturnType{ + Additions: sets.NewString("svc1"), + Removals: sets.NewString(), + }, + s2: SyncServicesReturnType{ + Additions: sets.NewString("svc2"), + Removals: sets.NewString(), + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.s1.Equals(&tt.s2)) + }) + } +} + +// TestConfigValidate tests Config.Validate() +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config Config + shouldError bool + errorMsg string + }{ + { + name: "valid config", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/subscriptions/sub1/resourceGroups/rg1/providers/Microsoft.Network/serviceGateways/sgw", + }, + shouldError: false, + }, + { + name: "missing subscription ID", + config: Config{ + ResourceGroup: "rg1", + Location: "eastus", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "SubscriptionID is required", + }, + { + name: "missing resource group", + config: Config{ + SubscriptionID: "sub1", + Location: "eastus", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "ResourceGroup is required", + }, + { + name: "missing location", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "Location is required", + }, + { + name: "missing ServiceGatewayResourceName", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + Location: "eastus", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "ServiceGatewayResourceName is required", + }, + { + name: "missing ServiceGatewayID", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + Location: "eastus", + ServiceGatewayResourceName: "sgw", + }, + shouldError: true, + errorMsg: "ServiceGatewayID is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.shouldError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestJSONRoundTrip tests JSON marshaling/unmarshaling for various types +func TestJSONRoundTrip(t *testing.T) { + t.Run("UpdateAction round trip", func(t *testing.T) { + original := PartialUpdate + data, err := json.Marshal(original) + assert.NoError(t, err) + + var unmarshaled UpdateAction + err = json.Unmarshal(data, &unmarshaled) + assert.NoError(t, err) + assert.Equal(t, original, unmarshaled) + }) + + t.Run("Operation round trip", func(t *testing.T) { + original := ADD + data, err := json.Marshal(original) + assert.NoError(t, err) + assert.Equal(t, `"ADD"`, string(data)) + }) + + t.Run("SyncStatus round trip", func(t *testing.T) { + original := SUCCESS + data, err := json.Marshal(original) + assert.NoError(t, err) + assert.Equal(t, `"SUCCESS"`, string(data)) + }) +} diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index 2562fcf6ea..8b1ee4495f 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -17,6 +17,7 @@ limitations under the License. package sets import ( + "encoding/json" "strings" "k8s.io/apimachinery/pkg/util/sets" @@ -98,3 +99,35 @@ func (s *IgnoreCaseSet) Len() int { } return s.set.Len() } + +// MarshalJSON marshals the set to JSON as an array of strings. +func (s *IgnoreCaseSet) MarshalJSON() ([]byte, error) { + if s == nil { + return []byte("null"), nil + } + if s.Len() == 0 { + return []byte("[]"), nil + } + return json.Marshal(s.UnsortedList()) +} + +// Equals returns true if the two sets are equal. +func (s1 *IgnoreCaseSet) Equals(s2 *IgnoreCaseSet) bool { + // Early exit if sizes are different + if len(s1.UnsortedList()) != len(s2.UnsortedList()) { + return false + } + // Check if all items in s1 are in s2 + for _, item := range s1.UnsortedList() { + if !s2.Has(item) { + return false + } + } + // Check if all items in s2 are in s1 + for _, item := range s2.UnsortedList() { + if !s1.Has(item) { + return false + } + } + return true +} diff --git a/pkg/util/sets/string_test.go b/pkg/util/sets/string_test.go index b0286b7d3d..968cdb3dc7 100644 --- a/pkg/util/sets/string_test.go +++ b/pkg/util/sets/string_test.go @@ -367,3 +367,81 @@ func TestLen(t *testing.T) { }) } } +func TestEquals(t *testing.T) { + tests := []struct { + name string + s1 *IgnoreCaseSet + s2 *IgnoreCaseSet + want bool + }{ + { + name: "both nil", + s1: nil, + s2: nil, + want: true, + }, + { + name: "first nil", + s1: nil, + s2: NewString("foo"), + want: false, + }, + { + name: "second nil", + s1: NewString("foo"), + s2: nil, + want: false, + }, + { + name: "empty sets", + s1: NewString(), + s2: NewString(), + want: true, + }, + { + name: "same elements", + s1: NewString("foo", "bar"), + s2: NewString("foo", "bar"), + want: true, + }, + { + name: "same elements with different case", + s1: NewString("foo", "bar"), + s2: NewString("FOO", "BAR"), + want: true, + }, + { + name: "different sizes", + s1: NewString("foo", "bar"), + s2: NewString("foo"), + want: false, + }, + { + name: "same size but different elements", + s1: NewString("foo", "bar"), + s2: NewString("foo", "baz"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.s1 == nil && tt.s2 == nil { + // Special case for nil sets + if !tt.want { + t.Errorf("Equals() = true, want %v", tt.want) + } + return + } + if tt.s1 == nil || tt.s2 == nil { + // One set is nil, they can't be equal + if tt.want { + t.Errorf("Equals() = false, want %v", tt.want) + } + return + } + if got := tt.s1.Equals(tt.s2); got != tt.want { + t.Errorf("Equals() = %v, want %v", got, tt.want) + } + }) + } +} From 37e919a481f53da0c6e89ccb69588fdd206aaa20 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Wed, 20 May 2026 07:27:38 +0000 Subject: [PATCH 02/18] addressed comments --- pkg/provider/difftracker/config.go | 16 +++ pkg/provider/difftracker/difftracker.go | 31 +++-- pkg/provider/difftracker/difftracker_test.go | 107 +++++++++++------- pkg/provider/difftracker/k8s_state_updates.go | 36 +++++- pkg/provider/difftracker/nrp_state_updates.go | 17 +++ pkg/provider/difftracker/sync_operations.go | 17 +++ pkg/provider/difftracker/types.go | 25 +++- pkg/provider/difftracker/util.go | 16 +++ pkg/provider/difftracker/util_test.go | 56 +++++---- 9 files changed, 240 insertions(+), 81 deletions(-) diff --git a/pkg/provider/difftracker/config.go b/pkg/provider/difftracker/config.go index 198df8fe98..3ce5d7c8c1 100644 --- a/pkg/provider/difftracker/config.go +++ b/pkg/provider/difftracker/config.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import "fmt" diff --git a/pkg/provider/difftracker/difftracker.go b/pkg/provider/difftracker/difftracker.go index c8d850bf3a..19fb3c1794 100644 --- a/pkg/provider/difftracker/difftracker.go +++ b/pkg/provider/difftracker/difftracker.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( @@ -5,25 +21,24 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) // InitializeDiffTracker creates and initializes a new DiffTracker with the given state and configuration. // It validates the configuration and ensures all required dependencies are present. -// Panics if critical dependencies (config, networkClientFactory, kubeClient) are invalid. -func InitializeDiffTracker(K8s K8s_State, NRP NRP_State, config Config, networkClientFactory azclient.ClientFactory, kubeClient kubernetes.Interface) *DiffTracker { - // Validate configuration +// Returns an error if the configuration is invalid or if any required dependency is nil. +func InitializeDiffTracker(K8s K8sState, NRP NRPState, config Config, networkClientFactory azclient.ClientFactory, kubeClient kubernetes.Interface) (*DiffTracker, error) { if err := config.Validate(); err != nil { - panic(fmt.Sprintf("InitializeDiffTracker: %v", err)) + return nil, fmt.Errorf("InitializeDiffTracker: %w", err) } - // Validate required dependencies if networkClientFactory == nil { - panic("InitializeDiffTracker: networkClientFactory must not be nil") + return nil, fmt.Errorf("InitializeDiffTracker: networkClientFactory must not be nil") } if kubeClient == nil { - panic("InitializeDiffTracker: kubeClient must not be nil") + return nil, fmt.Errorf("InitializeDiffTracker: kubeClient must not be nil") } klog.V(2).Infof("InitializeDiffTracker: initializing with config: subscription=%s, resourceGroup=%s, location=%s", @@ -59,5 +74,5 @@ func InitializeDiffTracker(K8s K8s_State, NRP NRP_State, config Config, networkC kubeClient: kubeClient, } - return diffTracker + return diffTracker, nil } diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index 11dacb8ed7..d4e0366e16 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( @@ -5,7 +21,9 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" + "k8s.io/client-go/kubernetes/fake" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -19,12 +37,12 @@ func TestDiffTracker_DeepEqual(t *testing.T) { { name: "equal empty states", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString(), NATGateways: sets.NewString(), Locations: map[string]NRPLocation{}, @@ -35,12 +53,12 @@ func TestDiffTracker_DeepEqual(t *testing.T) { { name: "equal states with services", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1", "service2"), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1", "service2"), NATGateways: sets.NewString(), Locations: map[string]NRPLocation{}, @@ -51,12 +69,12 @@ func TestDiffTracker_DeepEqual(t *testing.T) { { name: "services not equal", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1", "service2"), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1"), NATGateways: sets.NewString(), Locations: map[string]NRPLocation{}, @@ -74,15 +92,15 @@ func TestDiffTracker_DeepEqual(t *testing.T) { } } -func TestUpdateK8sService(t *testing.T) { +func TestEnqueueK8sServiceOperation(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), }, } // Test ADD operation - err := dt.UpdateK8sService(UpdateK8sResource{ + err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{ Operation: ADD, ID: "service1", }) @@ -90,7 +108,7 @@ func TestUpdateK8sService(t *testing.T) { assert.True(t, dt.K8sResources.Services.Has("service1")) // Test REMOVE operation - err = dt.UpdateK8sService(UpdateK8sResource{ + err = dt.EnqueueK8sServiceOperation(UpdateK8sResource{ Operation: REMOVE, ID: "service1", }) @@ -98,7 +116,7 @@ func TestUpdateK8sService(t *testing.T) { assert.False(t, dt.K8sResources.Services.Has("service1")) // Test invalid operation - err = dt.UpdateK8sService(UpdateK8sResource{ + err = dt.EnqueueK8sServiceOperation(UpdateK8sResource{ Operation: UPDATE, ID: "service1", }) @@ -107,10 +125,10 @@ func TestUpdateK8sService(t *testing.T) { func TestGetSyncLoadBalancerServices(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1", "service2", "service3"), }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service2", "service3", "service4"), }, } @@ -126,7 +144,7 @@ func TestGetSyncLoadBalancerServices(t *testing.T) { func TestUpdateK8sEndpoints(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Nodes: map[string]Node{}, }, } @@ -158,7 +176,7 @@ func TestUpdateK8sEndpoints(t *testing.T) { func TestUpdateK8sPod(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Nodes: map[string]Node{}, }, } @@ -191,7 +209,7 @@ func TestUpdateK8sPod(t *testing.T) { func TestGetSyncLocationsAddresses(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Nodes: map[string]Node{ "node1": { Pods: map[string]Pod{ @@ -203,7 +221,7 @@ func TestGetSyncLocationsAddresses(t *testing.T) { }, }, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1"), NATGateways: sets.NewString("public1"), Locations: map[string]NRPLocation{}, @@ -246,10 +264,10 @@ func TestUpdateNRPLoadBalancers(t *testing.T) { { name: "add services from K8s to NRP", initialState: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1", "service2", "service3"), }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1"), }, }, @@ -258,10 +276,10 @@ func TestUpdateNRPLoadBalancers(t *testing.T) { { name: "no changes needed when K8s and NRP are in sync", initialState: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1", "service2"), }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1", "service2"), }, }, @@ -282,22 +300,22 @@ func TestUpdateNRPLoadBalancers(t *testing.T) { } } -func TestUpdateK8sEgress(t *testing.T) { +func TestEnqueueK8sEgressOperation(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Egresses: sets.NewString(), }, } - err := dt.UpdateK8sEgress(UpdateK8sResource{Operation: ADD, ID: "egress1"}) + err := dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: ADD, ID: "egress1"}) assert.NoError(t, err) assert.True(t, dt.K8sResources.Egresses.Has("egress1")) - err = dt.UpdateK8sEgress(UpdateK8sResource{Operation: REMOVE, ID: "egress1"}) + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: REMOVE, ID: "egress1"}) assert.NoError(t, err) assert.False(t, dt.K8sResources.Egresses.Has("egress1")) - err = dt.UpdateK8sEgress(UpdateK8sResource{Operation: UPDATE, ID: "egress1"}) + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: UPDATE, ID: "egress1"}) assert.Error(t, err) assert.Contains(t, err.Error(), "error - ResourceType=Egress, Operation=UPDATE and ID=egress1") } @@ -329,10 +347,10 @@ func TestGetSyncNRPNATGateways(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Egresses: sets.NewString(tt.k8sEgresses...), }, - NRPResources: NRP_State{ + NRPResources: NRPState{ NATGateways: sets.NewString(tt.nrpNATGateways...), }, } @@ -354,10 +372,10 @@ func TestGetSyncNRPNATGateways(t *testing.T) { func TestUpdateNRPNATGateways(t *testing.T) { dt := &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Egresses: sets.NewString("egress1", "egress2", "egress4"), }, - NRPResources: NRP_State{ + NRPResources: NRPState{ NATGateways: sets.NewString("egress1", "egress3", "egress5"), }, } @@ -381,15 +399,15 @@ func TestUpdateLocationsAddresses(t *testing.T) { { name: "sync empty states", initialState: &DiffTracker{ - K8sResources: K8s_State{Nodes: map[string]Node{}}, - NRPResources: NRP_State{Locations: map[string]NRPLocation{}}, + K8sResources: K8sState{Nodes: map[string]Node{}}, + NRPResources: NRPState{Locations: map[string]NRPLocation{}}, }, expectedNRP: map[string]map[string][]string{}, }, { name: "add new location and address", initialState: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Nodes: map[string]Node{ "node1": { Pods: map[string]Pod{ @@ -401,7 +419,7 @@ func TestUpdateLocationsAddresses(t *testing.T) { }, }, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1"), NATGateways: sets.NewString("public1"), Locations: map[string]NRPLocation{}, @@ -414,7 +432,7 @@ func TestUpdateLocationsAddresses(t *testing.T) { { name: "complex case with multiple operations", initialState: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Nodes: map[string]Node{ "node1": { Pods: map[string]Pod{ @@ -431,7 +449,7 @@ func TestUpdateLocationsAddresses(t *testing.T) { }, }, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1", "service2", "service3", "service4", "service5"), NATGateways: sets.NewString("public1"), Locations: map[string]NRPLocation{ @@ -490,7 +508,7 @@ func TestGetSyncOperations(t *testing.T) { { name: "states already in sync", initialState: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1"), Egresses: sets.NewString("egress1"), Nodes: map[string]Node{ @@ -499,7 +517,7 @@ func TestGetSyncOperations(t *testing.T) { }}, }, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1"), NATGateways: sets.NewString("egress1"), Locations: map[string]NRPLocation{ @@ -514,7 +532,7 @@ func TestGetSyncOperations(t *testing.T) { { name: "services out of sync", initialState: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("service1", "service2"), Egresses: sets.NewString("egress1"), Nodes: map[string]Node{ @@ -523,7 +541,7 @@ func TestGetSyncOperations(t *testing.T) { }}, }, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("service1"), NATGateways: sets.NewString("egress1"), Locations: map[string]NRPLocation{ @@ -549,7 +567,7 @@ func TestGetSyncOperations(t *testing.T) { // This test verifies if the DiffTracker is able to sync K8s Cluster and NRP correctly // when there is a huge discrepancy between K8s Cluster and NRP. func TestInitializeDiffTracker(t *testing.T) { - K8sResources := K8s_State{ + K8sResources := K8sState{ Services: sets.NewString("Service0", "Service1", "Service2"), Egresses: sets.NewString("Egress0", "Egress1", "Egress2"), Nodes: map[string]Node{ @@ -569,7 +587,7 @@ func TestInitializeDiffTracker(t *testing.T) { }, } - NRPResources := NRP_State{ + NRPResources := NRPState{ LoadBalancers: sets.NewString("Service0", "Service6", "Service5"), NATGateways: sets.NewString("Egress0", "Egress6", "Egress5"), Locations: map[string]NRPLocation{ @@ -602,7 +620,8 @@ func TestInitializeDiffTracker(t *testing.T) { defer ctrl.Finish() mockFactory := mock_azclient.NewMockClientFactory(ctrl) mockKubeClient := fake.NewSimpleClientset() - diffTracker := InitializeDiffTracker(K8sResources, NRPResources, config, mockFactory, mockKubeClient) + diffTracker, err := InitializeDiffTracker(K8sResources, NRPResources, config, mockFactory, mockKubeClient) + assert.NoError(t, err) syncOperations := diffTracker.GetSyncOperations() diffTracker.UpdateNRPLoadBalancers(syncOperations.LoadBalancerUpdates) @@ -613,7 +632,7 @@ func TestInitializeDiffTracker(t *testing.T) { expectedDiffTracker := &DiffTracker{ K8sResources: K8sResources, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("Service0", "Service1", "Service2"), NATGateways: sets.NewString("Egress0", "Egress1", "Egress2"), Locations: map[string]NRPLocation{ diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index 4a421e58c4..34f6227f7c 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( @@ -5,6 +21,7 @@ import ( "strings" "k8s.io/klog/v2" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -13,7 +30,10 @@ const ( ResourceTypeEgress = "Egress" ) -func updateK8Resource(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { +// enqueueK8sResourceOperation applies the requested operation (ADD/REMOVE) to the +// in-memory K8s resource set. It does not perform any Azure update calls; it only +// mutates the local desired-state model that will later be reconciled with NRP. +func enqueueK8sResourceOperation(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { if input.ID == "" { return fmt.Errorf("%s: empty ID not allowed", resourceType) } @@ -29,18 +49,24 @@ func updateK8Resource(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, reso return nil } -func (dt *DiffTracker) UpdateK8sService(input UpdateK8sResource) error { +// EnqueueK8sServiceOperation records a service ADD/REMOVE in the local K8s state set. +// The change is reconciled with NRP later by the sync operations; this method itself +// performs no Azure calls. +func (dt *DiffTracker) EnqueueK8sServiceOperation(input UpdateK8sResource) error { dt.mu.Lock() defer dt.mu.Unlock() - return updateK8Resource(input, dt.K8sResources.Services, ResourceTypeService) + return enqueueK8sResourceOperation(input, dt.K8sResources.Services, ResourceTypeService) } -func (dt *DiffTracker) UpdateK8sEgress(input UpdateK8sResource) error { +// EnqueueK8sEgressOperation records an egress ADD/REMOVE in the local K8s state set. +// The change is reconciled with NRP later by the sync operations; this method itself +// performs no Azure calls. +func (dt *DiffTracker) EnqueueK8sEgressOperation(input UpdateK8sResource) error { dt.mu.Lock() defer dt.mu.Unlock() - return updateK8Resource(input, dt.K8sResources.Egresses, ResourceTypeEgress) + return enqueueK8sResourceOperation(input, dt.K8sResources.Egresses, ResourceTypeEgress) } // updateK8sEndpointsLocked updates K8s endpoints state. Assumes lock is already held. diff --git a/pkg/provider/difftracker/nrp_state_updates.go b/pkg/provider/difftracker/nrp_state_updates.go index 91f05841b5..5b59e69d7b 100644 --- a/pkg/provider/difftracker/nrp_state_updates.go +++ b/pkg/provider/difftracker/nrp_state_updates.go @@ -1,7 +1,24 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( "k8s.io/klog/v2" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index 6b18b4d4d4..8d2050dd44 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -1,7 +1,24 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( "k8s.io/klog/v2" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index 66c7f14e75..9cab51fee6 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -1,9 +1,26 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( "sync" "k8s.io/client-go/kubernetes" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -44,7 +61,7 @@ type NRPLocation struct { Addresses map[string]NRPAddress } -type NRP_State struct { +type NRPState struct { LoadBalancers *utilsets.IgnoreCaseSet NATGateways *utilsets.IgnoreCaseSet Locations map[string]NRPLocation @@ -59,7 +76,7 @@ type Node struct { Pods map[string]Pod } -type K8s_State struct { +type K8sState struct { Services *utilsets.IgnoreCaseSet Egresses *utilsets.IgnoreCaseSet Nodes map[string]Node @@ -69,8 +86,8 @@ type K8s_State struct { type DiffTracker struct { mu sync.Mutex // Protects concurrent access to DiffTracker - K8sResources K8s_State - NRPResources NRP_State + K8sResources K8sState + NRPResources NRPState LocalServiceNameToNRPServiceMap sync.Map diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index eb853bfaa4..f91a82df0c 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go index 9ffce2621f..7b7b81136b 100644 --- a/pkg/provider/difftracker/util_test.go +++ b/pkg/provider/difftracker/util_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( @@ -179,12 +195,12 @@ func TestDeepEqual(t *testing.T) { { name: "in sync - matching services and load balancers", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc1", "svc2"), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("svc1", "svc2"), NATGateways: sets.NewString(), Locations: map[string]NRPLocation{}, @@ -195,12 +211,12 @@ func TestDeepEqual(t *testing.T) { { name: "not in sync - service count mismatch", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc1"), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("svc1", "svc2"), NATGateways: sets.NewString(), Locations: map[string]NRPLocation{}, @@ -211,12 +227,12 @@ func TestDeepEqual(t *testing.T) { { name: "not in sync - service name mismatch", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc1"), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString("svc2"), NATGateways: sets.NewString(), Locations: map[string]NRPLocation{}, @@ -227,12 +243,12 @@ func TestDeepEqual(t *testing.T) { { name: "in sync - matching egresses and NAT gateways", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString("egress1", "egress2"), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString(), NATGateways: sets.NewString("egress1", "egress2"), Locations: map[string]NRPLocation{}, @@ -243,12 +259,12 @@ func TestDeepEqual(t *testing.T) { { name: "not in sync - egress count mismatch", dt: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString("egress1"), Nodes: map[string]Node{}, }, - NRPResources: NRP_State{ + NRPResources: NRPState{ LoadBalancers: sets.NewString(), NATGateways: sets.NewString("egress1", "egress2"), Locations: map[string]NRPLocation{}, @@ -376,14 +392,14 @@ func TestDiffTrackerEquals(t *testing.T) { { name: "equal diff trackers", dt1: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc1"), Egresses: sets.NewString("egress1"), Nodes: map[string]Node{}, }, }, dt2: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc1"), Egresses: sets.NewString("egress1"), Nodes: map[string]Node{}, @@ -394,14 +410,14 @@ func TestDiffTrackerEquals(t *testing.T) { { name: "different services", dt1: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc1"), Egresses: sets.NewString(), Nodes: map[string]Node{}, }, }, dt2: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString("svc2"), Egresses: sets.NewString(), Nodes: map[string]Node{}, @@ -412,14 +428,14 @@ func TestDiffTrackerEquals(t *testing.T) { { name: "different egresses", dt1: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString("egress1"), Nodes: map[string]Node{}, }, }, dt2: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString("egress2"), Nodes: map[string]Node{}, @@ -430,14 +446,14 @@ func TestDiffTrackerEquals(t *testing.T) { { name: "different node count", dt1: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString(), Nodes: map[string]Node{"node1": {}}, }, }, dt2: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString(), Nodes: map[string]Node{}, @@ -448,7 +464,7 @@ func TestDiffTrackerEquals(t *testing.T) { { name: "different pod count in node", dt1: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString(), Nodes: map[string]Node{ @@ -457,7 +473,7 @@ func TestDiffTrackerEquals(t *testing.T) { }, }, dt2: &DiffTracker{ - K8sResources: K8s_State{ + K8sResources: K8sState{ Services: sets.NewString(), Egresses: sets.NewString(), Nodes: map[string]Node{ From cfaaaa75143f47399af3f641a8804894079eceef Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Wed, 20 May 2026 12:59:44 +0000 Subject: [PATCH 03/18] address review comments: scope helper to DiffTracker; clarify NRP type docs --- pkg/provider/difftracker/k8s_state_updates.go | 6 +++--- pkg/provider/difftracker/types.go | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index 34f6227f7c..2c8a20b7e7 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -33,7 +33,7 @@ const ( // enqueueK8sResourceOperation applies the requested operation (ADD/REMOVE) to the // in-memory K8s resource set. It does not perform any Azure update calls; it only // mutates the local desired-state model that will later be reconciled with NRP. -func enqueueK8sResourceOperation(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { +func (dt *DiffTracker) enqueueK8sResourceOperation(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { if input.ID == "" { return fmt.Errorf("%s: empty ID not allowed", resourceType) } @@ -56,7 +56,7 @@ func (dt *DiffTracker) EnqueueK8sServiceOperation(input UpdateK8sResource) error dt.mu.Lock() defer dt.mu.Unlock() - return enqueueK8sResourceOperation(input, dt.K8sResources.Services, ResourceTypeService) + return dt.enqueueK8sResourceOperation(input, dt.K8sResources.Services, ResourceTypeService) } // EnqueueK8sEgressOperation records an egress ADD/REMOVE in the local K8s state set. @@ -66,7 +66,7 @@ func (dt *DiffTracker) EnqueueK8sEgressOperation(input UpdateK8sResource) error dt.mu.Lock() defer dt.mu.Unlock() - return enqueueK8sResourceOperation(input, dt.K8sResources.Egresses, ResourceTypeEgress) + return dt.enqueueK8sResourceOperation(input, dt.K8sResources.Egresses, ResourceTypeEgress) } // updateK8sEndpointsLocked updates K8s endpoints state. Assumes lock is already held. diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index 9cab51fee6..43a8865cb7 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -53,18 +53,28 @@ const ( // -------------------------------------------------------------------------------- // DiffTracker keeps track of the state of the K8s cluster and NRP // -------------------------------------------------------------------------------- +// NRPAddress holds the NRP-side state for a single pod address (pod IP). type NRPAddress struct { - Services *utilsets.IgnoreCaseSet // all inbound and outbound identities + // Services holds the SGW service identities (LBs for inbound, NATGWs for + // outbound) currently associated with this address on the NRP side. + // These are SGW service identities, not Kubernetes Service names. + Services *utilsets.IgnoreCaseSet } +// NRPLocation holds the NRP-side state for a single node/VM and groups the +// pod addresses running on it. type NRPLocation struct { + // Addresses is keyed by pod IP. Each pod IP is added to the ServiceGateway + // as an address under this location once the pod is created. Addresses map[string]NRPAddress } type NRPState struct { LoadBalancers *utilsets.IgnoreCaseSet NATGateways *utilsets.IgnoreCaseSet - Locations map[string]NRPLocation + // Locations is keyed by node/VM IP (e.g. "10.0.0.1"). "Location" here is + // an SGW concept identifying a node, not an Azure region (e.g. "eastus2"). + Locations map[string]NRPLocation } type Pod struct { From 03bbd36ed92150b6dd5498cb24792a57631db5c9 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 9 Jun 2026 07:40:58 +0000 Subject: [PATCH 04/18] refactor: standardize SyncStatus constants and improve test assertions --- pkg/provider/difftracker/difftracker_test.go | 6 +++--- pkg/provider/difftracker/k8s_state_updates.go | 2 ++ pkg/provider/difftracker/sync_operations.go | 10 +++++----- pkg/provider/difftracker/types.go | 4 ++-- pkg/provider/difftracker/util.go | 2 +- pkg/provider/difftracker/util_test.go | 9 +++++---- pkg/util/sets/string.go | 16 ++++++++-------- 7 files changed, 26 insertions(+), 23 deletions(-) diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index d4e0366e16..e78b5c3200 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -527,7 +527,7 @@ func TestGetSyncOperations(t *testing.T) { }, }, }, - expectedSyncStatus: ALREADY_IN_SYNC, + expectedSyncStatus: AlreadyInSync, }, { name: "services out of sync", @@ -551,7 +551,7 @@ func TestGetSyncOperations(t *testing.T) { }, }, }, - expectedSyncStatus: SUCCESS, + expectedSyncStatus: Success, }, } @@ -628,7 +628,7 @@ func TestInitializeDiffTracker(t *testing.T) { diffTracker.UpdateNRPNATGateways(syncOperations.NATGatewayUpdates) diffTracker.UpdateLocationsAddresses(syncOperations.LocationData) - assert.Equal(t, SUCCESS, syncOperations.SyncStatus) + assert.Equal(t, Success, syncOperations.SyncStatus) expectedDiffTracker := &DiffTracker{ K8sResources: K8sResources, diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index 2c8a20b7e7..dc213ab4b7 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -250,6 +250,8 @@ func (dt *DiffTracker) UpdateK8sPod(input UpdatePodInputType) error { // This is used during service deletion to proactively clear location/address references // so the LocationsUpdater can sync the removal to NRP. // Assumes lock is already held. +// +//nolint:unused // wired up by the engine in a follow-up PR func (dt *DiffTracker) removeServiceFromK8sStateLocked(serviceUID string, isInbound bool) { for nodeIP, node := range dt.K8sResources.Nodes { for podIP, pod := range node.Pods { diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index 8d2050dd44..bb4baa4033 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -76,8 +76,8 @@ func (dt *DiffTracker) GetSyncLocationsAddresses() LocationData { } // Iterate over all nodes in the K8s state - for nodeIp, node := range dt.K8sResources.Nodes { - nrpLocation, locationExists := dt.NRPResources.Locations[nodeIp] + for nodeIP, node := range dt.K8sResources.Nodes { + nrpLocation, locationExists := dt.NRPResources.Locations[nodeIP] location := initializeLocation(locationExists) locationUpdated := false @@ -105,7 +105,7 @@ func (dt *DiffTracker) GetSyncLocationsAddresses() LocationData { locationUpdated = true } if locationUpdated { - result.Locations[nodeIp] = location + result.Locations[nodeIP] = location } } @@ -196,11 +196,11 @@ func findLocationData(result LocationData, location string) *Location { func (dt *DiffTracker) GetSyncOperations() *SyncDiffTrackerReturnType { if dt.DeepEqual() { - return &SyncDiffTrackerReturnType{SyncStatus: ALREADY_IN_SYNC} + return &SyncDiffTrackerReturnType{SyncStatus: AlreadyInSync} } return &SyncDiffTrackerReturnType{ - SyncStatus: SUCCESS, + SyncStatus: Success, LoadBalancerUpdates: dt.GetSyncLoadBalancerServices(), NATGatewayUpdates: dt.GetSyncNRPNATGateways(), LocationData: dt.GetSyncLocationsAddresses(), diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index 43a8865cb7..d39a4119df 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -46,8 +46,8 @@ const ( type SyncStatus int const ( - ALREADY_IN_SYNC SyncStatus = iota - SUCCESS + AlreadyInSync SyncStatus = iota + Success ) // -------------------------------------------------------------------------------- diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index f91a82df0c..d2d8dabaf3 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -58,7 +58,7 @@ func (updateAction *UpdateAction) UnmarshalJSON(data []byte) error { } func (syncStatus SyncStatus) String() string { - return [...]string{"ALREADY_IN_SYNC", "SUCCESS"}[syncStatus] + return [...]string{"AlreadyInSync", "Success"}[syncStatus] } func (syncStatus SyncStatus) MarshalJSON() ([]byte, error) { diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go index 7b7b81136b..130d3cf316 100644 --- a/pkg/provider/difftracker/util_test.go +++ b/pkg/provider/difftracker/util_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -101,8 +102,8 @@ func TestSyncStatusStringAndJSON(t *testing.T) { status SyncStatus expected string }{ - {"ALREADY_IN_SYNC", ALREADY_IN_SYNC, "ALREADY_IN_SYNC"}, - {"SUCCESS", SUCCESS, "SUCCESS"}, + {"AlreadyInSync", AlreadyInSync, "AlreadyInSync"}, + {"Success", Success, "Success"}, } for _, tt := range tests { @@ -656,9 +657,9 @@ func TestJSONRoundTrip(t *testing.T) { }) t.Run("SyncStatus round trip", func(t *testing.T) { - original := SUCCESS + original := Success data, err := json.Marshal(original) assert.NoError(t, err) - assert.Equal(t, `"SUCCESS"`, string(data)) + assert.Equal(t, `"Success"`, string(data)) }) } diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index 8b1ee4495f..7c22980a83 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -112,20 +112,20 @@ func (s *IgnoreCaseSet) MarshalJSON() ([]byte, error) { } // Equals returns true if the two sets are equal. -func (s1 *IgnoreCaseSet) Equals(s2 *IgnoreCaseSet) bool { +func (s *IgnoreCaseSet) Equals(other *IgnoreCaseSet) bool { // Early exit if sizes are different - if len(s1.UnsortedList()) != len(s2.UnsortedList()) { + if len(s.UnsortedList()) != len(other.UnsortedList()) { return false } - // Check if all items in s1 are in s2 - for _, item := range s1.UnsortedList() { - if !s2.Has(item) { + // Check if all items in s are in other + for _, item := range s.UnsortedList() { + if !other.Has(item) { return false } } - // Check if all items in s2 are in s1 - for _, item := range s2.UnsortedList() { - if !s1.Has(item) { + // Check if all items in other are in s + for _, item := range other.UnsortedList() { + if !s.Has(item) { return false } } From 522c44aca026e68d38d53acdd28b4e2357bcaf81 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Mon, 15 Jun 2026 13:17:27 +0000 Subject: [PATCH 05/18] addressed comments: - Added unit test `TestUpdateK8sPodRemoveUsesStoredIdentity` to verify correct decrementing of outbound identity reference counts when removing pods. - Updated `removePod` method to return the stored PublicOutboundIdentity of the removed pod for accurate reference counting. - Refactored `updateK8sPodLocked` to decrement the reference count using the authoritative identity from the pod rather than the input. - Improved logging verbosity in `GetServicesToSync` and `DeepEqual` methods for better traceability during debugging. - Introduced lock-free methods for synchronization operations to enhance performance and reduce contention. - Enhanced enum string representations to handle out-of-range values gracefully. - Simplified `Equals` method in `IgnoreCaseSet` to improve efficiency by removing redundant checks. --- pkg/provider/difftracker/config.go | 3 +- pkg/provider/difftracker/coverage_test.go | 539 ++++++++++++++++++ pkg/provider/difftracker/difftracker_test.go | 33 ++ pkg/provider/difftracker/k8s_state_updates.go | 62 +- pkg/provider/difftracker/sync_operations.go | 52 +- pkg/provider/difftracker/types.go | 21 +- pkg/provider/difftracker/util.go | 54 +- pkg/provider/difftracker/util_test.go | 9 + pkg/util/sets/string.go | 12 +- 9 files changed, 718 insertions(+), 67 deletions(-) create mode 100644 pkg/provider/difftracker/coverage_test.go diff --git a/pkg/provider/difftracker/config.go b/pkg/provider/difftracker/config.go index 3ce5d7c8c1..b277a4cdbc 100644 --- a/pkg/provider/difftracker/config.go +++ b/pkg/provider/difftracker/config.go @@ -28,7 +28,8 @@ type Config struct { // Azure resource group name ResourceGroup string - // Azure location/region + // Azure location/region (e.g. "eastus2"). Distinct from the difftracker + // Location type, which identifies a node by IP. Location string // Service Gateway resource name diff --git a/pkg/provider/difftracker/coverage_test.go b/pkg/provider/difftracker/coverage_test.go new file mode 100644 index 0000000000..bfc5c07206 --- /dev/null +++ b/pkg/provider/difftracker/coverage_test.go @@ -0,0 +1,539 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "k8s.io/client-go/kubernetes/fake" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func validTestConfig() Config { + return Config{ + SubscriptionID: "test-subscription", + ResourceGroup: "test-rg", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } +} + +// TestInitializeDiffTrackerErrorPaths covers the validation/error branches of +// InitializeDiffTracker and the nil-field initialization. +func TestInitializeDiffTrackerErrorPaths(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + + // Invalid config (empty) -> validation error. + _, err := InitializeDiffTracker(K8sState{}, NRPState{}, Config{}, mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "InitializeDiffTracker") + + // Nil networkClientFactory. + _, err = InitializeDiffTracker(K8sState{}, NRPState{}, validTestConfig(), nil, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "networkClientFactory must not be nil") + + // Nil kubeClient. + _, err = InitializeDiffTracker(K8sState{}, NRPState{}, validTestConfig(), mockFactory, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "kubeClient must not be nil") + + // Valid call with empty states -> nil fields get initialized. + dt, err := InitializeDiffTracker(K8sState{}, NRPState{}, validTestConfig(), mockFactory, mockKubeClient) + assert.NoError(t, err) + assert.NotNil(t, dt) + assert.NotNil(t, dt.K8sResources.Services) + assert.NotNil(t, dt.K8sResources.Egresses) + assert.NotNil(t, dt.K8sResources.Nodes) + assert.NotNil(t, dt.NRPResources.LoadBalancers) + assert.NotNil(t, dt.NRPResources.NATGateways) + assert.NotNil(t, dt.NRPResources.Locations) +} + +// TestEnqueueK8sResourceOperationErrors covers the empty-ID and invalid-operation +// branches of enqueueK8sResourceOperation (via the public wrappers). +func TestEnqueueK8sResourceOperationErrors(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Services: sets.NewString(), Egresses: sets.NewString()}, + } + + // Empty ID. + err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: ADD, ID: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty ID") + + // Invalid operation (UPDATE is not valid for the resource set). + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: UPDATE, ID: "egress1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Operation=UPDATE") + + // Successful ADD then REMOVE. + assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: ADD, ID: "svc1"})) + assert.True(t, dt.K8sResources.Services.Has("svc1")) + assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: REMOVE, ID: "svc1"})) + assert.False(t, dt.K8sResources.Services.Has("svc1")) +} + +// TestUpdateK8sPodInvalidOperation covers the default branch of updateK8sPodLocked. +func TestUpdateK8sPodInvalidOperation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + err := dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Operation(99), + Location: "node1", + Address: "10.0.0.1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid pod operation") +} + +// TestUpdateK8sPodRemoveNonExistent covers the duplicate-removal branch. +func TestUpdateK8sPodRemoveNonExistent(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + // Removing a pod that was never added is a no-op (no error, no counter change). + err := dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: REMOVE, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + }) + assert.NoError(t, err) + _, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.False(t, ok) +} + +// TestUpdateK8sEndpointsMissingLocation covers the error branch where an address +// has no associated node location. +func TestUpdateK8sEndpointsMissingLocation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": ""}, + }) + assert.NotEmpty(t, errs) + assert.Contains(t, errs[0].Error(), "does not have a node associated") +} + +// TestRemoveServiceFromK8sStateLocked covers removeServiceFromK8sStateLocked for +// both inbound and outbound identities, including empty-pod/node cleanup. +func TestRemoveServiceFromK8sStateLocked(t *testing.T) { + t.Run("inbound removal cleans up empty pods and nodes", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + // Pod backs only svc1 inbound -> becomes empty and is removed. + "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, + // Pod backs svc1 and svc2 -> keeps svc2. + "10.0.0.2": {InboundIdentities: sets.NewString("svc1", "svc2")}, + }}, + }, + }, + } + dt.removeServiceFromK8sStateLocked("svc1", true) + + // 10.0.0.1 had only svc1 -> removed. + _, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.False(t, ok) + // 10.0.0.2 still has svc2. + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.2"] + assert.True(t, ok) + assert.True(t, pod.InboundIdentities.Has("svc2")) + assert.False(t, pod.InboundIdentities.Has("svc1")) + }) + + t.Run("outbound removal clears identity and removes empty node", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + }}, + }, + }, + } + dt.removeServiceFromK8sStateLocked("egress1", false) + + // Pod had only the outbound identity -> pod and node removed. + _, nodeOK := dt.K8sResources.Nodes["node1"] + assert.False(t, nodeOK) + }) + + t.Run("public wrapper acquires lock and removes inbound identity", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, + }}, + }, + }, + } + dt.RemoveServiceFromK8sState("svc1", true) + + _, nodeOK := dt.K8sResources.Nodes["node1"] + assert.False(t, nodeOK) + }) +} + +// TestLocationDataEqualsMoreCases covers additional LocationData.Equals branches. +func TestLocationDataEqualsMoreCases(t *testing.T) { + base := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, + }, + }, + } + equal := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, + }, + }, + } + assert.True(t, base.Equals(&equal)) + + // Different top-level Action. + diffAction := equal + diffAction.Action = FullUpdate + assert.False(t, base.Equals(&diffAction)) + + // Different number of locations. + diffLen := LocationData{Action: PartialUpdate, Locations: map[string]Location{}} + assert.False(t, base.Equals(&diffLen)) + + // Missing location name. + diffName := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{"node2": base.Locations["node1"]}, + } + assert.False(t, base.Equals(&diffName)) + + // Different AddressUpdateAction. + diffAUA := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: FullUpdate, Addresses: base.Locations["node1"].Addresses}, + }, + } + assert.False(t, base.Equals(&diffAUA)) + + // Different addresses length. + diffAddrLen := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + } + assert.False(t, base.Equals(&diffAddrLen)) + + // Missing address name. + diffAddrName := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.2": {ServiceRef: sets.NewString("svc1")}}}, + }, + } + assert.False(t, base.Equals(&diffAddrName)) + + // Different ServiceRef. + diffRef := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc2")}}}, + }, + } + assert.False(t, base.Equals(&diffRef)) +} + +// TestSyncDiffTrackerReturnTypeEquals covers SyncDiffTrackerReturnType.Equals branches. +func TestSyncDiffTrackerReturnTypeEquals(t *testing.T) { + mk := func() SyncDiffTrackerReturnType { + return SyncDiffTrackerReturnType{ + SyncStatus: Success, + LoadBalancerUpdates: SyncServicesReturnType{Additions: sets.NewString("a"), Removals: sets.NewString("b")}, + NATGatewayUpdates: SyncServicesReturnType{Additions: sets.NewString("c"), Removals: sets.NewString("d")}, + LocationData: LocationData{Action: PartialUpdate, Locations: map[string]Location{}}, + } + } + + base := mk() + equal := mk() + assert.True(t, base.Equals(&equal)) + + // Different SyncStatus. + diffStatus := mk() + diffStatus.SyncStatus = AlreadyInSync + assert.False(t, base.Equals(&diffStatus)) + + // Different LB additions. + diffLBAdd := mk() + diffLBAdd.LoadBalancerUpdates.Additions = sets.NewString("x") + assert.False(t, base.Equals(&diffLBAdd)) + + // Different LB removals. + diffLBRem := mk() + diffLBRem.LoadBalancerUpdates.Removals = sets.NewString("x") + assert.False(t, base.Equals(&diffLBRem)) + + // Different NATGW additions. + diffNGAdd := mk() + diffNGAdd.NATGatewayUpdates.Additions = sets.NewString("x") + assert.False(t, base.Equals(&diffNGAdd)) + + // Different NATGW removals. + diffNGRem := mk() + diffNGRem.NATGatewayUpdates.Removals = sets.NewString("x") + assert.False(t, base.Equals(&diffNGRem)) + + // Different LocationData. + diffLoc := mk() + diffLoc.LocationData.Action = FullUpdate + assert.False(t, base.Equals(&diffLoc)) +} + +// TestDiffTrackerEqualsMoreCases covers additional DiffTracker.Equals branches. +func TestDiffTrackerEqualsMoreCases(t *testing.T) { + mk := func() *DiffTracker { + return &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egr1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1"), + NATGateways: sets.NewString("egr1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }}, + }, + }, + } + } + + assert.True(t, mk().Equals(mk())) + + // Different Services. + a := mk() + b := mk() + b.K8sResources.Services = sets.NewString("other") + assert.False(t, a.Equals(b)) + + // Different Egresses. + b = mk() + b.K8sResources.Egresses = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different node count. + b = mk() + b.K8sResources.Nodes["node2"] = Node{Pods: map[string]Pod{}} + assert.False(t, mk().Equals(b)) + + // Missing node name. + b = mk() + delete(b.K8sResources.Nodes, "node1") + b.K8sResources.Nodes["nodeX"] = Node{Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }} + assert.False(t, mk().Equals(b)) + + // Different pod count. + b = mk() + n := b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.2"] = Pod{InboundIdentities: sets.NewString()} + assert.False(t, mk().Equals(b)) + + // Missing pod address. + b = mk() + n = b.K8sResources.Nodes["node1"] + delete(n.Pods, "10.0.0.1") + n.Pods["10.0.0.9"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"} + assert.False(t, mk().Equals(b)) + + // Different InboundIdentities. + b = mk() + n = b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("other"), PublicOutboundIdentity: "egr1"} + assert.False(t, mk().Equals(b)) + + // Different PublicOutboundIdentity. + b = mk() + n = b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "other"} + assert.False(t, mk().Equals(b)) + + // Different NRP LoadBalancers. + b = mk() + b.NRPResources.LoadBalancers = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different NRP NATGateways. + b = mk() + b.NRPResources.NATGateways = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different NRP location count. + b = mk() + b.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} + assert.False(t, mk().Equals(b)) +} + +// TestUpdateK8sPodAddIdempotent covers the alreadyExists branch of updateK8sPodLocked: +// a repeated ADD for the same pod+identity must not double-count the counter. +func TestUpdateK8sPodAddIdempotent(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + in := UpdatePodInputType{ + PodOperation: ADD, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + } + assert.NoError(t, dt.UpdateK8sPod(in)) + // Second identical ADD (e.g. informer resync) must be a no-op for the counter. + assert.NoError(t, dt.UpdateK8sPod(in)) + + val, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) +} + +// TestUpdateK8sEndpointsAddThenRemove covers the OldAddresses removal path of +// updateK8sEndpointsLocked, including empty-pod/node cleanup. +func TestUpdateK8sEndpointsAddThenRemove(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + // Add endpoint for svc1 at pod 10.0.0.1 on node1. + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("svc1")) + + // Remove the same endpoint (now in OldAddresses, absent from NewAddresses). + errs = dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + // Pod had only svc1 -> pod and node cleaned up. + _, ok := dt.K8sResources.Nodes["node1"] + assert.False(t, ok) +} + +// TestDeepEqualMoreCases covers the node/pod/identity mismatch branches of DeepEqual. +func TestDeepEqualMoreCases(t *testing.T) { + // Helper to build a DiffTracker that is in-sync by construction. + inSync := func() *DiffTracker { + return &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egr1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1"), + NATGateways: sets.NewString("egr1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }}, + }, + }, + } + } + + assert.True(t, inSync().DeepEqual()) + + // LoadBalancer present in NRP but not in K8s Services (reverse-direction check). + d := inSync() + d.NRPResources.LoadBalancers = sets.NewString("svc1", "extra") + d.K8sResources.Services = sets.NewString("svc1", "different") + assert.False(t, d.DeepEqual()) + + // Egress name mismatch (reverse direction). + d = inSync() + d.NRPResources.NATGateways = sets.NewString("other") + d.K8sResources.Egresses = sets.NewString("egr1") + // lengths equal (1==1) but names differ -> mismatch + d.NRPResources.NATGateways = sets.NewString("egr1") + d.K8sResources.Egresses = sets.NewString("egr2") + d.NRPResources.NATGateways = sets.NewString("egr2x") + assert.False(t, d.DeepEqual()) + + // Nodes vs Locations length mismatch. + d = inSync() + d.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} + assert.False(t, d.DeepEqual()) + + // Node missing in Locations (same count, different key). + d = inSync() + delete(d.NRPResources.Locations, "node1") + d.NRPResources.Locations["nodeX"] = NRPLocation{Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }} + assert.False(t, d.DeepEqual()) + + // Pods vs Addresses length mismatch. + d = inSync() + loc := d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.2"] = NRPAddress{Services: sets.NewString("svc1")} + assert.False(t, d.DeepEqual()) + + // Pod missing in Addresses (same count, different key). + d = inSync() + loc = d.NRPResources.Locations["node1"] + delete(loc.Addresses, "10.0.0.1") + loc.Addresses["10.0.0.9"] = NRPAddress{Services: sets.NewString("svc1", "egr1")} + assert.False(t, d.DeepEqual()) + + // Combined identities length mismatch. + d = inSync() + loc = d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1")} + assert.False(t, d.DeepEqual()) + + // Identity not found in Services (same count, different identity). + d = inSync() + loc = d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1", "egrX")} + assert.False(t, d.DeepEqual()) +} diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index e78b5c3200..0439e277b3 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -207,6 +207,39 @@ func TestUpdateK8sPod(t *testing.T) { assert.NotContains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") } +// TestUpdateK8sPodRemoveUsesStoredIdentity verifies that a REMOVE whose input omits +// (or mismatches) PublicOutboundIdentity still decrements the counter of the identity +// actually stored on the pod, rather than corrupting a different ("" or wrong) counter. +func TestUpdateK8sPodRemoveUsesStoredIdentity(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Nodes: map[string]Node{}}, + } + + // Add a pod with outbound identity "public1". + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: ADD, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + // Remove with the identity omitted from the input (the problematic case). + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: REMOVE, + Location: "node1", + Address: "10.0.0.1", + })) + + // The stored identity's counter must be cleared, and no bogus "" entry created. + _, ok = dt.outboundIdentityPodRefCount.Load("public1") + assert.False(t, ok, "counter for stored identity public1 should be removed") + _, ok = dt.outboundIdentityPodRefCount.Load("") + assert.False(t, ok, "no counter should be created for an empty identity") +} + func TestGetSyncLocationsAddresses(t *testing.T) { dt := &DiffTracker{ K8sResources: K8sState{ diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index dc213ab4b7..02b51c8019 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -70,10 +70,12 @@ func (dt *DiffTracker) EnqueueK8sEgressOperation(input UpdateK8sResource) error } // updateK8sEndpointsLocked updates K8s endpoints state. Assumes lock is already held. +// Terminology: an "endpoint" here is a Kubernetes EndpointSlice endpoint, i.e. a pod IP +// (the map key "address"), and its "location" is the IP of the node/VM hosting that pod +// (the map value). Both NewAddresses and OldAddresses are address(podIP) -> location(nodeIP). func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputType) []error { var errs []error for address, location := range input.NewAddresses { - if location == "" { errs = append(errs, fmt.Errorf("error UpdateK8sEndpoints, address=%s does not have a node associated", address)) continue @@ -133,7 +135,10 @@ func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputTyp return errs } -// UpdateK8sEndpoints is a public wrapper that acquires lock before calling updateK8sEndpointsLocked. +// UpdateK8sEndpoints is the public, lock-acquiring entry point for applying an +// EndpointSlice change to K8s state. It is called by the EndpointSlice informer +// handlers (Add/Update/Delete). Use this rather than updateK8sEndpointsLocked +// unless the caller already holds dt.mu. func (dt *DiffTracker) UpdateK8sEndpoints(input UpdateK8sEndpointsInputType) []error { dt.mu.Lock() defer dt.mu.Unlock() @@ -158,25 +163,29 @@ func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) error { return nil } -// removePod removes a pod from K8s state. Returns true if the pod was actually removed, -// false if it didn't exist (already removed by a previous call). -func (dt *DiffTracker) removePod(input UpdatePodInputType) (removed bool, err error) { +// removePod removes a pod from K8s state. Returns true and the removed pod's stored +// PublicOutboundIdentity if the pod was actually removed, or false if it didn't exist +// (already removed by a previous call). The returned identity is authoritative for +// reference-counter bookkeeping, independent of whatever the caller passed in input. +func (dt *DiffTracker) removePod(input UpdatePodInputType) (removed bool, identity string, err error) { node, exists := dt.K8sResources.Nodes[input.Location] if !exists { - return false, nil + return false, "", nil } // Check if pod exists before removing - if _, podExists := node.Pods[input.Address]; !podExists { - return false, nil + pod, podExists := node.Pods[input.Address] + if !podExists { + return false, "", nil } + identity = pod.PublicOutboundIdentity delete(node.Pods, input.Address) if !node.HasPods() { delete(dt.K8sResources.Nodes, input.Location) } - return true, nil + return true, identity, nil } // updateK8sPodLocked updates K8s pod state. Assumes lock is already held. @@ -200,16 +209,16 @@ func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { // Only increment counter if pod doesn't already exist if !alreadyExists { counter := 0 - if val, ok := dt.LocalServiceNameToNRPServiceMap.Load(strings.ToLower(input.PublicOutboundIdentity)); ok { + if val, ok := dt.outboundIdentityPodRefCount.Load(strings.ToLower(input.PublicOutboundIdentity)); ok { counter = val.(int) } - dt.LocalServiceNameToNRPServiceMap.Store(strings.ToLower(input.PublicOutboundIdentity), counter+1) + dt.outboundIdentityPodRefCount.Store(strings.ToLower(input.PublicOutboundIdentity), counter+1) } return dt.addOrUpdatePod(input) case REMOVE: // First, try to remove the pod from K8s state // This returns false if the pod doesn't exist (duplicate removal) - removed, err := dt.removePod(input) + removed, identity, err := dt.removePod(input) if err != nil { return err } @@ -220,16 +229,19 @@ func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { return nil } - // Pod was actually removed, now decrement the counter - if val, ok := dt.LocalServiceNameToNRPServiceMap.Load(strings.ToLower(input.PublicOutboundIdentity)); ok { + // Decrement using the pod's stored identity (authoritative), not input. + if identity == "" { + return nil + } + if val, ok := dt.outboundIdentityPodRefCount.Load(strings.ToLower(identity)); ok { counter := val.(int) if counter <= 0 { - return fmt.Errorf("error - PublicOutboundIdentity %s has a negative count: %d", input.PublicOutboundIdentity, counter) + return fmt.Errorf("error - PublicOutboundIdentity %s has a non-positive count: %d", identity, counter) } if counter == 1 { - dt.LocalServiceNameToNRPServiceMap.Delete(strings.ToLower(input.PublicOutboundIdentity)) + dt.outboundIdentityPodRefCount.Delete(strings.ToLower(identity)) } else { - dt.LocalServiceNameToNRPServiceMap.Store(strings.ToLower(input.PublicOutboundIdentity), counter-1) + dt.outboundIdentityPodRefCount.Store(strings.ToLower(identity), counter-1) } } return nil @@ -239,7 +251,10 @@ func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { } } -// UpdateK8sPod is a public wrapper that acquires lock before calling updateK8sPodLocked. +// UpdateK8sPod is the public, lock-acquiring entry point for applying a pod +// egress-assignment change to K8s state. It is called by the pod informer +// handlers (Add/Update/Delete). Use this rather than updateK8sPodLocked unless +// the caller already holds dt.mu. func (dt *DiffTracker) UpdateK8sPod(input UpdatePodInputType) error { dt.mu.Lock() defer dt.mu.Unlock() @@ -250,8 +265,6 @@ func (dt *DiffTracker) UpdateK8sPod(input UpdatePodInputType) error { // This is used during service deletion to proactively clear location/address references // so the LocationsUpdater can sync the removal to NRP. // Assumes lock is already held. -// -//nolint:unused // wired up by the engine in a follow-up PR func (dt *DiffTracker) removeServiceFromK8sStateLocked(serviceUID string, isInbound bool) { for nodeIP, node := range dt.K8sResources.Nodes { for podIP, pod := range node.Pods { @@ -278,3 +291,12 @@ func (dt *DiffTracker) removeServiceFromK8sStateLocked(serviceUID string, isInbo } } } + +// RemoveServiceFromK8sState is the public, lock-acquiring entry point for +// removeServiceFromK8sStateLocked. Use it to clear a deleted service's references +// from pod identities when the caller does not already hold dt.mu. +func (dt *DiffTracker) RemoveServiceFromK8sState(serviceUID string, isInbound bool) { + dt.mu.Lock() + defer dt.mu.Unlock() + dt.removeServiceFromK8sStateLocked(serviceUID, isInbound) +} diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index bb4baa4033..19edd01b93 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -23,9 +23,9 @@ import ( ) // GetServicesToSync handles the synchronization of services between K8s and NRP -func GetServicesToSync(k8sServices, Services *utilsets.IgnoreCaseSet) SyncServicesReturnType { - klog.Infof("GetServicesToSync: K8s services (%d): %v", k8sServices.Len(), k8sServices.UnsortedList()) - klog.Infof("GetServicesToSync: NRP services (%d): %v", Services.Len(), Services.UnsortedList()) +func GetServicesToSync(k8sServices, nrpServices *utilsets.IgnoreCaseSet) SyncServicesReturnType { + klog.V(2).Infof("GetServicesToSync: K8s services (%d): %v", k8sServices.Len(), k8sServices.UnsortedList()) + klog.V(2).Infof("GetServicesToSync: NRP services (%d): %v", nrpServices.Len(), nrpServices.UnsortedList()) syncServices := SyncServicesReturnType{ Additions: utilsets.NewString(), @@ -33,22 +33,22 @@ func GetServicesToSync(k8sServices, Services *utilsets.IgnoreCaseSet) SyncServic } for _, service := range k8sServices.UnsortedList() { - if Services.Has(service) { + if nrpServices.Has(service) { continue } syncServices.Additions.Insert(service) - klog.Infof("GetServicesToSync: Added service %s to additions", service) + klog.V(4).Infof("GetServicesToSync: Added service %s to additions", service) } - for _, service := range Services.UnsortedList() { + for _, service := range nrpServices.UnsortedList() { if k8sServices.Has(service) { continue } syncServices.Removals.Insert(service) - klog.Infof("GetServicesToSync: Added service %s to removals", service) + klog.V(4).Infof("GetServicesToSync: Added service %s to removals", service) } - klog.Infof("GetServicesToSync: Result - Additions: %d, Removals: %d", syncServices.Additions.Len(), syncServices.Removals.Len()) + klog.V(2).Infof("GetServicesToSync: Result - Additions: %d, Removals: %d", syncServices.Additions.Len(), syncServices.Removals.Len()) return syncServices } @@ -56,6 +56,11 @@ func (dt *DiffTracker) GetSyncLoadBalancerServices() SyncServicesReturnType { dt.mu.Lock() defer dt.mu.Unlock() + return dt.getSyncLoadBalancerServicesLocked() +} + +// getSyncLoadBalancerServicesLocked is the lock-free body. Callers must hold dt.mu. +func (dt *DiffTracker) getSyncLoadBalancerServicesLocked() SyncServicesReturnType { return GetServicesToSync(dt.K8sResources.Services, dt.NRPResources.LoadBalancers) } @@ -63,6 +68,11 @@ func (dt *DiffTracker) GetSyncNRPNATGateways() SyncServicesReturnType { dt.mu.Lock() defer dt.mu.Unlock() + return dt.getSyncNRPNATGatewaysLocked() +} + +// getSyncNRPNATGatewaysLocked is the lock-free body. Callers must hold dt.mu. +func (dt *DiffTracker) getSyncNRPNATGatewaysLocked() SyncServicesReturnType { return GetServicesToSync(dt.K8sResources.Egresses, dt.NRPResources.NATGateways) } @@ -70,6 +80,11 @@ func (dt *DiffTracker) GetSyncLocationsAddresses() LocationData { dt.mu.Lock() defer dt.mu.Unlock() + return dt.getSyncLocationsAddressesLocked() +} + +// getSyncLocationsAddressesLocked is the lock-free body. Callers must hold dt.mu. +func (dt *DiffTracker) getSyncLocationsAddressesLocked() LocationData { result := LocationData{ Action: PartialUpdate, Locations: make(map[string]Location), @@ -185,24 +200,27 @@ func (dt *DiffTracker) isServiceReady(serviceUID string, isInbound bool) bool { // Helper function to find LocationData in result func findLocationData(result LocationData, location string) *Location { - for keyCurrentLocation := range result.Locations { - if keyCurrentLocation == location { - loc := result.Locations[keyCurrentLocation] - return &loc - } + if loc, ok := result.Locations[location]; ok { + return &loc } return nil } func (dt *DiffTracker) GetSyncOperations() *SyncDiffTrackerReturnType { - if dt.DeepEqual() { + // Take the lock once so DeepEqual and all three sync computations observe a + // single consistent snapshot of the state (avoids a data race with mutating + // methods and inconsistency between the individual GetSync* results). + dt.mu.Lock() + defer dt.mu.Unlock() + + if dt.deepEqualLocked() { return &SyncDiffTrackerReturnType{SyncStatus: AlreadyInSync} } return &SyncDiffTrackerReturnType{ SyncStatus: Success, - LoadBalancerUpdates: dt.GetSyncLoadBalancerServices(), - NATGatewayUpdates: dt.GetSyncNRPNATGateways(), - LocationData: dt.GetSyncLocationsAddresses(), + LoadBalancerUpdates: dt.getSyncLoadBalancerServicesLocked(), + NATGatewayUpdates: dt.getSyncNRPNATGatewaysLocked(), + LocationData: dt.getSyncLocationsAddressesLocked(), } } diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index d39a4119df..a39161e0fc 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -70,15 +70,25 @@ type NRPLocation struct { } type NRPState struct { + // LoadBalancers holds the UIDs of inbound services that have a LoadBalancer + // registered on the NRP side. These are SGW service identities, not Azure + // LoadBalancer resource names. LoadBalancers *utilsets.IgnoreCaseSet - NATGateways *utilsets.IgnoreCaseSet + // NATGateways holds the UIDs of outbound/egress services that have a NAT + // Gateway registered on the NRP side (SGW service identities, not Azure + // resource names). + NATGateways *utilsets.IgnoreCaseSet // Locations is keyed by node/VM IP (e.g. "10.0.0.1"). "Location" here is // an SGW concept identifying a node, not an Azure region (e.g. "eastus2"). Locations map[string]NRPLocation } type Pod struct { - InboundIdentities *utilsets.IgnoreCaseSet + // InboundIdentities holds the UIDs of the inbound ServiceGateway services + // (LoadBalancers) this pod backs. A pod may back several, hence a set. + InboundIdentities *utilsets.IgnoreCaseSet + // PublicOutboundIdentity is the UID of the single outbound/egress ServiceGateway + // service (NAT Gateway) this pod uses for egress; empty if the pod has no egress. PublicOutboundIdentity string } @@ -99,7 +109,12 @@ type DiffTracker struct { K8sResources K8sState NRPResources NRPState - LocalServiceNameToNRPServiceMap sync.Map + // outboundIdentityPodRefCount counts how many pods reference each outbound + // (egress) identity, keyed by lowercased PublicOutboundIdentity. It lets the + // engine delete a NAT Gateway when its last egress pod is removed. Inbound + // (LoadBalancer) services are not tracked here; their lifecycle follows the + // Kubernetes Service object. + outboundIdentityPodRefCount sync.Map // Configuration and clients config Config diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index d2d8dabaf3..da05777fd1 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -24,7 +24,11 @@ import ( ) func (operation Operation) String() string { - return [...]string{"ADD", "REMOVE", "UPDATE"}[operation] + names := [...]string{"ADD", "REMOVE", "UPDATE"} + if int(operation) < 0 || int(operation) >= len(names) { + return fmt.Sprintf("Operation(%d)", int(operation)) + } + return names[operation] } func (operation Operation) MarshalJSON() ([]byte, error) { @@ -32,7 +36,11 @@ func (operation Operation) MarshalJSON() ([]byte, error) { } func (updateAction UpdateAction) String() string { - return [...]string{"PartialUpdate", "FullUpdate"}[updateAction] + names := [...]string{"PartialUpdate", "FullUpdate"} + if int(updateAction) < 0 || int(updateAction) >= len(names) { + return fmt.Sprintf("UpdateAction(%d)", int(updateAction)) + } + return names[updateAction] } func (updateAction UpdateAction) MarshalJSON() ([]byte, error) { @@ -58,7 +66,11 @@ func (updateAction *UpdateAction) UnmarshalJSON(data []byte) error { } func (syncStatus SyncStatus) String() string { - return [...]string{"AlreadyInSync", "Success"}[syncStatus] + names := [...]string{"AlreadyInSync", "Success"} + if int(syncStatus) < 0 || int(syncStatus) >= len(names) { + return fmt.Sprintf("SyncStatus(%d)", int(syncStatus)) + } + return names[syncStatus] } func (syncStatus SyncStatus) MarshalJSON() ([]byte, error) { @@ -71,69 +83,77 @@ func (pod *Pod) HasIdentities() bool { return pod.InboundIdentities.Len() > 0 || pod.PublicOutboundIdentity != "" } -// DeepEqual compares the K8s and NRP states to check if they are in sync +// DeepEqual compares the K8s and NRP states to check if they are in sync. +// It acquires dt.mu so callers get a consistent snapshot. func (dt *DiffTracker) DeepEqual() bool { - klog.Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d", + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.deepEqualLocked() +} + +// deepEqualLocked is the lock-free body of DeepEqual. Callers must hold dt.mu. +func (dt *DiffTracker) deepEqualLocked() bool { + klog.V(4).Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d", dt.K8sResources.Services.Len(), dt.NRPResources.LoadBalancers.Len(), dt.K8sResources.Egresses.Len(), dt.NRPResources.NATGateways.Len()) // Compare Services with LoadBalancers if dt.K8sResources.Services.Len() != dt.NRPResources.LoadBalancers.Len() { - klog.Infof("DeepEqual: Services and LoadBalancers length mismatch") + klog.V(4).Infof("DeepEqual: Services and LoadBalancers length mismatch") return false } for _, service := range dt.K8sResources.Services.UnsortedList() { if !dt.NRPResources.LoadBalancers.Has(service) { - klog.Infof("DeepEqual: Service %s not found in LoadBalancers", service) + klog.V(4).Infof("DeepEqual: Service %s not found in LoadBalancers", service) return false } } for _, service := range dt.NRPResources.LoadBalancers.UnsortedList() { if !dt.K8sResources.Services.Has(service) { - klog.Infof("DeepEqual: LoadBalancer %s not found in Services", service) + klog.V(4).Infof("DeepEqual: LoadBalancer %s not found in Services", service) return false } } // Compare Egresses with NATGateways if dt.K8sResources.Egresses.Len() != dt.NRPResources.NATGateways.Len() { - klog.Infof("DeepEqual: Egresses and NATGateways length mismatch") + klog.V(4).Infof("DeepEqual: Egresses and NATGateways length mismatch") return false } for _, egress := range dt.K8sResources.Egresses.UnsortedList() { if !dt.NRPResources.NATGateways.Has(egress) { - klog.Infof("DeepEqual: Egress %s not found in NATGateways", egress) + klog.V(4).Infof("DeepEqual: Egress %s not found in NATGateways", egress) return false } } for _, egress := range dt.NRPResources.NATGateways.UnsortedList() { if !dt.K8sResources.Egresses.Has(egress) { - klog.Infof("DeepEqual: NATGateway %s not found in Egresses", egress) + klog.V(4).Infof("DeepEqual: NATGateway %s not found in Egresses", egress) return false } } // Compare Nodes with Locations if len(dt.K8sResources.Nodes) != len(dt.NRPResources.Locations) { - klog.V(2).Infof("DeepEqual: Nodes and Locations length mismatch") + klog.V(4).Infof("DeepEqual: Nodes and Locations length mismatch") return false } for nodeKey, node := range dt.K8sResources.Nodes { nrpLocation, exists := dt.NRPResources.Locations[nodeKey] if !exists { - klog.V(2).Infof("DeepEqual: Node %s not found in Locations\n", nodeKey) + klog.V(4).Infof("DeepEqual: Node %s not found in Locations\n", nodeKey) return false } // Compare Pods with Addresses if len(node.Pods) != len(nrpLocation.Addresses) { - klog.V(2).Infof("DeepEqual: Pods and Addresses length mismatch for node %s\n", nodeKey) + klog.V(4).Infof("DeepEqual: Pods and Addresses length mismatch for node %s\n", nodeKey) return false } for podKey, pod := range node.Pods { nrpAddress, exists := nrpLocation.Addresses[podKey] if !exists { - klog.V(2).Infof("DeepEqual: Pod %s not found in Addresses for node %s\n", podKey, nodeKey) + klog.V(4).Infof("DeepEqual: Pod %s not found in Addresses for node %s\n", podKey, nodeKey) return false } @@ -145,13 +165,13 @@ func (dt *DiffTracker) DeepEqual() bool { } if len(combinedIdentities) != nrpAddress.Services.Len() { - klog.V(2).Infof("DeepEqual: Combined identities length mismatch for pod %s in node %s\n", podKey, nodeKey) + klog.V(4).Infof("DeepEqual: Combined identities length mismatch for pod %s in node %s\n", podKey, nodeKey) return false } for _, identity := range combinedIdentities { if !nrpAddress.Services.Has(identity) { - klog.V(2).Infof("DeepEqual: Identity %s not found in Services for pod %s in node %s\n", identity, podKey, nodeKey) + klog.V(4).Infof("DeepEqual: Identity %s not found in Services for pod %s in node %s\n", identity, podKey, nodeKey) return false } } diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go index 130d3cf316..42f8168837 100644 --- a/pkg/provider/difftracker/util_test.go +++ b/pkg/provider/difftracker/util_test.go @@ -80,6 +80,15 @@ func TestUpdateActionStringAndJSON(t *testing.T) { } } +// TestEnumStringOutOfRange verifies String() does not panic for out-of-range enum +// values and returns a descriptive fallback instead. +func TestEnumStringOutOfRange(t *testing.T) { + assert.Equal(t, "Operation(99)", Operation(99).String()) + assert.Equal(t, "Operation(-1)", Operation(-1).String()) + assert.Equal(t, "UpdateAction(99)", UpdateAction(99).String()) + assert.Equal(t, "SyncStatus(99)", SyncStatus(99).String()) +} + // TestUpdateActionUnmarshalJSON_InvalidValue tests error handling func TestUpdateActionUnmarshalJSON_InvalidValue(t *testing.T) { var action UpdateAction diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index 7c22980a83..a9f79a407f 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -113,21 +113,15 @@ func (s *IgnoreCaseSet) MarshalJSON() ([]byte, error) { // Equals returns true if the two sets are equal. func (s *IgnoreCaseSet) Equals(other *IgnoreCaseSet) bool { - // Early exit if sizes are different - if len(s.UnsortedList()) != len(other.UnsortedList()) { + // Two sets of equal size are equal iff every item in one is contained in the + // other, so a single containment check in one direction is sufficient. + if s.Len() != other.Len() { return false } - // Check if all items in s are in other for _, item := range s.UnsortedList() { if !other.Has(item) { return false } } - // Check if all items in other are in s - for _, item := range other.UnsortedList() { - if !s.Has(item) { - return false - } - } return true } From abfa9ce7a4fd839f72d8d497fbb41ff3a2ec206d Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:10:48 +0000 Subject: [PATCH 06/18] refactor: replace string array with switch statements for Operation, UpdateAction, and SyncStatus --- pkg/provider/difftracker/util.go | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index da05777fd1..40d66db603 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -24,11 +24,16 @@ import ( ) func (operation Operation) String() string { - names := [...]string{"ADD", "REMOVE", "UPDATE"} - if int(operation) < 0 || int(operation) >= len(names) { + switch operation { + case ADD: + return "ADD" + case REMOVE: + return "REMOVE" + case UPDATE: + return "UPDATE" + default: return fmt.Sprintf("Operation(%d)", int(operation)) } - return names[operation] } func (operation Operation) MarshalJSON() ([]byte, error) { @@ -36,11 +41,14 @@ func (operation Operation) MarshalJSON() ([]byte, error) { } func (updateAction UpdateAction) String() string { - names := [...]string{"PartialUpdate", "FullUpdate"} - if int(updateAction) < 0 || int(updateAction) >= len(names) { + switch updateAction { + case PartialUpdate: + return "PartialUpdate" + case FullUpdate: + return "FullUpdate" + default: return fmt.Sprintf("UpdateAction(%d)", int(updateAction)) } - return names[updateAction] } func (updateAction UpdateAction) MarshalJSON() ([]byte, error) { @@ -66,11 +74,14 @@ func (updateAction *UpdateAction) UnmarshalJSON(data []byte) error { } func (syncStatus SyncStatus) String() string { - names := [...]string{"AlreadyInSync", "Success"} - if int(syncStatus) < 0 || int(syncStatus) >= len(names) { + switch syncStatus { + case AlreadyInSync: + return "AlreadyInSync" + case Success: + return "Success" + default: return fmt.Sprintf("SyncStatus(%d)", int(syncStatus)) } - return names[syncStatus] } func (syncStatus SyncStatus) MarshalJSON() ([]byte, error) { From 2e5c96e141cbf9abf6eb47c91206e9c3c609e4bb Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Thu, 18 Jun 2026 19:50:59 +0000 Subject: [PATCH 07/18] Refactor difftracker: Improve logging, enhance equality checks, and add set operations - Updated logging statements in UpdateNRPLoadBalancers and UpdateNRPNATGateways to remove unnecessary newline characters. - Simplified the GetServicesToSync function by utilizing set operations for additions and removals. - Enhanced equality checks in DiffTracker, SyncDiffTrackerReturnType, and LocationData to cover more edge cases. - Introduced Difference method in IgnoreCaseSet for set difference operations. - Added unit tests for new functionality and edge cases in LocationData and DiffTracker equality checks. - Removed unused helper functions and improved code readability. --- pkg/provider/difftracker/coverage_test.go | 539 ------------------ pkg/provider/difftracker/difftracker.go | 46 +- pkg/provider/difftracker/difftracker_test.go | 398 ++++++++++++- pkg/provider/difftracker/k8s_state_updates.go | 155 +++-- pkg/provider/difftracker/nrp_state_updates.go | 25 +- pkg/provider/difftracker/sync_operations.go | 37 +- pkg/provider/difftracker/types.go | 27 +- pkg/provider/difftracker/util.go | 142 +++-- pkg/provider/difftracker/util_test.go | 303 +++++++++- pkg/util/sets/string.go | 13 + pkg/util/sets/string_test.go | 83 ++- 11 files changed, 974 insertions(+), 794 deletions(-) delete mode 100644 pkg/provider/difftracker/coverage_test.go diff --git a/pkg/provider/difftracker/coverage_test.go b/pkg/provider/difftracker/coverage_test.go deleted file mode 100644 index bfc5c07206..0000000000 --- a/pkg/provider/difftracker/coverage_test.go +++ /dev/null @@ -1,539 +0,0 @@ -/* -Copyright 2026 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package difftracker - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - - "k8s.io/client-go/kubernetes/fake" - - "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" - "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" -) - -func validTestConfig() Config { - return Config{ - SubscriptionID: "test-subscription", - ResourceGroup: "test-rg", - Location: "eastus", - VNetName: "test-vnet", - ServiceGatewayResourceName: "test-sgw", - ServiceGatewayID: "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", - } -} - -// TestInitializeDiffTrackerErrorPaths covers the validation/error branches of -// InitializeDiffTracker and the nil-field initialization. -func TestInitializeDiffTrackerErrorPaths(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockFactory := mock_azclient.NewMockClientFactory(ctrl) - mockKubeClient := fake.NewSimpleClientset() - - // Invalid config (empty) -> validation error. - _, err := InitializeDiffTracker(K8sState{}, NRPState{}, Config{}, mockFactory, mockKubeClient) - assert.Error(t, err) - assert.Contains(t, err.Error(), "InitializeDiffTracker") - - // Nil networkClientFactory. - _, err = InitializeDiffTracker(K8sState{}, NRPState{}, validTestConfig(), nil, mockKubeClient) - assert.Error(t, err) - assert.Contains(t, err.Error(), "networkClientFactory must not be nil") - - // Nil kubeClient. - _, err = InitializeDiffTracker(K8sState{}, NRPState{}, validTestConfig(), mockFactory, nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "kubeClient must not be nil") - - // Valid call with empty states -> nil fields get initialized. - dt, err := InitializeDiffTracker(K8sState{}, NRPState{}, validTestConfig(), mockFactory, mockKubeClient) - assert.NoError(t, err) - assert.NotNil(t, dt) - assert.NotNil(t, dt.K8sResources.Services) - assert.NotNil(t, dt.K8sResources.Egresses) - assert.NotNil(t, dt.K8sResources.Nodes) - assert.NotNil(t, dt.NRPResources.LoadBalancers) - assert.NotNil(t, dt.NRPResources.NATGateways) - assert.NotNil(t, dt.NRPResources.Locations) -} - -// TestEnqueueK8sResourceOperationErrors covers the empty-ID and invalid-operation -// branches of enqueueK8sResourceOperation (via the public wrappers). -func TestEnqueueK8sResourceOperationErrors(t *testing.T) { - dt := &DiffTracker{ - K8sResources: K8sState{Services: sets.NewString(), Egresses: sets.NewString()}, - } - - // Empty ID. - err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: ADD, ID: ""}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "empty ID") - - // Invalid operation (UPDATE is not valid for the resource set). - err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: UPDATE, ID: "egress1"}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "Operation=UPDATE") - - // Successful ADD then REMOVE. - assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: ADD, ID: "svc1"})) - assert.True(t, dt.K8sResources.Services.Has("svc1")) - assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: REMOVE, ID: "svc1"})) - assert.False(t, dt.K8sResources.Services.Has("svc1")) -} - -// TestUpdateK8sPodInvalidOperation covers the default branch of updateK8sPodLocked. -func TestUpdateK8sPodInvalidOperation(t *testing.T) { - dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} - err := dt.UpdateK8sPod(UpdatePodInputType{ - PodOperation: Operation(99), - Location: "node1", - Address: "10.0.0.1", - }) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid pod operation") -} - -// TestUpdateK8sPodRemoveNonExistent covers the duplicate-removal branch. -func TestUpdateK8sPodRemoveNonExistent(t *testing.T) { - dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} - // Removing a pod that was never added is a no-op (no error, no counter change). - err := dt.UpdateK8sPod(UpdatePodInputType{ - PodOperation: REMOVE, - PublicOutboundIdentity: "public1", - Location: "node1", - Address: "10.0.0.1", - }) - assert.NoError(t, err) - _, ok := dt.outboundIdentityPodRefCount.Load("public1") - assert.False(t, ok) -} - -// TestUpdateK8sEndpointsMissingLocation covers the error branch where an address -// has no associated node location. -func TestUpdateK8sEndpointsMissingLocation(t *testing.T) { - dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} - errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ - InboundIdentity: "svc1", - NewAddresses: map[string]string{"10.0.0.1": ""}, - }) - assert.NotEmpty(t, errs) - assert.Contains(t, errs[0].Error(), "does not have a node associated") -} - -// TestRemoveServiceFromK8sStateLocked covers removeServiceFromK8sStateLocked for -// both inbound and outbound identities, including empty-pod/node cleanup. -func TestRemoveServiceFromK8sStateLocked(t *testing.T) { - t.Run("inbound removal cleans up empty pods and nodes", func(t *testing.T) { - dt := &DiffTracker{ - K8sResources: K8sState{ - Nodes: map[string]Node{ - "node1": {Pods: map[string]Pod{ - // Pod backs only svc1 inbound -> becomes empty and is removed. - "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, - // Pod backs svc1 and svc2 -> keeps svc2. - "10.0.0.2": {InboundIdentities: sets.NewString("svc1", "svc2")}, - }}, - }, - }, - } - dt.removeServiceFromK8sStateLocked("svc1", true) - - // 10.0.0.1 had only svc1 -> removed. - _, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] - assert.False(t, ok) - // 10.0.0.2 still has svc2. - pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.2"] - assert.True(t, ok) - assert.True(t, pod.InboundIdentities.Has("svc2")) - assert.False(t, pod.InboundIdentities.Has("svc1")) - }) - - t.Run("outbound removal clears identity and removes empty node", func(t *testing.T) { - dt := &DiffTracker{ - K8sResources: K8sState{ - Nodes: map[string]Node{ - "node1": {Pods: map[string]Pod{ - "10.0.0.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, - }}, - }, - }, - } - dt.removeServiceFromK8sStateLocked("egress1", false) - - // Pod had only the outbound identity -> pod and node removed. - _, nodeOK := dt.K8sResources.Nodes["node1"] - assert.False(t, nodeOK) - }) - - t.Run("public wrapper acquires lock and removes inbound identity", func(t *testing.T) { - dt := &DiffTracker{ - K8sResources: K8sState{ - Nodes: map[string]Node{ - "node1": {Pods: map[string]Pod{ - "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, - }}, - }, - }, - } - dt.RemoveServiceFromK8sState("svc1", true) - - _, nodeOK := dt.K8sResources.Nodes["node1"] - assert.False(t, nodeOK) - }) -} - -// TestLocationDataEqualsMoreCases covers additional LocationData.Equals branches. -func TestLocationDataEqualsMoreCases(t *testing.T) { - base := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{ - "node1": { - AddressUpdateAction: PartialUpdate, - Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, - }, - }, - } - equal := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{ - "node1": { - AddressUpdateAction: PartialUpdate, - Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, - }, - }, - } - assert.True(t, base.Equals(&equal)) - - // Different top-level Action. - diffAction := equal - diffAction.Action = FullUpdate - assert.False(t, base.Equals(&diffAction)) - - // Different number of locations. - diffLen := LocationData{Action: PartialUpdate, Locations: map[string]Location{}} - assert.False(t, base.Equals(&diffLen)) - - // Missing location name. - diffName := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{"node2": base.Locations["node1"]}, - } - assert.False(t, base.Equals(&diffName)) - - // Different AddressUpdateAction. - diffAUA := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{ - "node1": {AddressUpdateAction: FullUpdate, Addresses: base.Locations["node1"].Addresses}, - }, - } - assert.False(t, base.Equals(&diffAUA)) - - // Different addresses length. - diffAddrLen := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{ - "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, - }, - } - assert.False(t, base.Equals(&diffAddrLen)) - - // Missing address name. - diffAddrName := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{ - "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.2": {ServiceRef: sets.NewString("svc1")}}}, - }, - } - assert.False(t, base.Equals(&diffAddrName)) - - // Different ServiceRef. - diffRef := LocationData{ - Action: PartialUpdate, - Locations: map[string]Location{ - "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc2")}}}, - }, - } - assert.False(t, base.Equals(&diffRef)) -} - -// TestSyncDiffTrackerReturnTypeEquals covers SyncDiffTrackerReturnType.Equals branches. -func TestSyncDiffTrackerReturnTypeEquals(t *testing.T) { - mk := func() SyncDiffTrackerReturnType { - return SyncDiffTrackerReturnType{ - SyncStatus: Success, - LoadBalancerUpdates: SyncServicesReturnType{Additions: sets.NewString("a"), Removals: sets.NewString("b")}, - NATGatewayUpdates: SyncServicesReturnType{Additions: sets.NewString("c"), Removals: sets.NewString("d")}, - LocationData: LocationData{Action: PartialUpdate, Locations: map[string]Location{}}, - } - } - - base := mk() - equal := mk() - assert.True(t, base.Equals(&equal)) - - // Different SyncStatus. - diffStatus := mk() - diffStatus.SyncStatus = AlreadyInSync - assert.False(t, base.Equals(&diffStatus)) - - // Different LB additions. - diffLBAdd := mk() - diffLBAdd.LoadBalancerUpdates.Additions = sets.NewString("x") - assert.False(t, base.Equals(&diffLBAdd)) - - // Different LB removals. - diffLBRem := mk() - diffLBRem.LoadBalancerUpdates.Removals = sets.NewString("x") - assert.False(t, base.Equals(&diffLBRem)) - - // Different NATGW additions. - diffNGAdd := mk() - diffNGAdd.NATGatewayUpdates.Additions = sets.NewString("x") - assert.False(t, base.Equals(&diffNGAdd)) - - // Different NATGW removals. - diffNGRem := mk() - diffNGRem.NATGatewayUpdates.Removals = sets.NewString("x") - assert.False(t, base.Equals(&diffNGRem)) - - // Different LocationData. - diffLoc := mk() - diffLoc.LocationData.Action = FullUpdate - assert.False(t, base.Equals(&diffLoc)) -} - -// TestDiffTrackerEqualsMoreCases covers additional DiffTracker.Equals branches. -func TestDiffTrackerEqualsMoreCases(t *testing.T) { - mk := func() *DiffTracker { - return &DiffTracker{ - K8sResources: K8sState{ - Services: sets.NewString("svc1"), - Egresses: sets.NewString("egr1"), - Nodes: map[string]Node{ - "node1": {Pods: map[string]Pod{ - "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, - }}, - }, - }, - NRPResources: NRPState{ - LoadBalancers: sets.NewString("svc1"), - NATGateways: sets.NewString("egr1"), - Locations: map[string]NRPLocation{ - "node1": {Addresses: map[string]NRPAddress{ - "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, - }}, - }, - }, - } - } - - assert.True(t, mk().Equals(mk())) - - // Different Services. - a := mk() - b := mk() - b.K8sResources.Services = sets.NewString("other") - assert.False(t, a.Equals(b)) - - // Different Egresses. - b = mk() - b.K8sResources.Egresses = sets.NewString("other") - assert.False(t, mk().Equals(b)) - - // Different node count. - b = mk() - b.K8sResources.Nodes["node2"] = Node{Pods: map[string]Pod{}} - assert.False(t, mk().Equals(b)) - - // Missing node name. - b = mk() - delete(b.K8sResources.Nodes, "node1") - b.K8sResources.Nodes["nodeX"] = Node{Pods: map[string]Pod{ - "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, - }} - assert.False(t, mk().Equals(b)) - - // Different pod count. - b = mk() - n := b.K8sResources.Nodes["node1"] - n.Pods["10.0.0.2"] = Pod{InboundIdentities: sets.NewString()} - assert.False(t, mk().Equals(b)) - - // Missing pod address. - b = mk() - n = b.K8sResources.Nodes["node1"] - delete(n.Pods, "10.0.0.1") - n.Pods["10.0.0.9"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"} - assert.False(t, mk().Equals(b)) - - // Different InboundIdentities. - b = mk() - n = b.K8sResources.Nodes["node1"] - n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("other"), PublicOutboundIdentity: "egr1"} - assert.False(t, mk().Equals(b)) - - // Different PublicOutboundIdentity. - b = mk() - n = b.K8sResources.Nodes["node1"] - n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "other"} - assert.False(t, mk().Equals(b)) - - // Different NRP LoadBalancers. - b = mk() - b.NRPResources.LoadBalancers = sets.NewString("other") - assert.False(t, mk().Equals(b)) - - // Different NRP NATGateways. - b = mk() - b.NRPResources.NATGateways = sets.NewString("other") - assert.False(t, mk().Equals(b)) - - // Different NRP location count. - b = mk() - b.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} - assert.False(t, mk().Equals(b)) -} - -// TestUpdateK8sPodAddIdempotent covers the alreadyExists branch of updateK8sPodLocked: -// a repeated ADD for the same pod+identity must not double-count the counter. -func TestUpdateK8sPodAddIdempotent(t *testing.T) { - dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} - in := UpdatePodInputType{ - PodOperation: ADD, - PublicOutboundIdentity: "public1", - Location: "node1", - Address: "10.0.0.1", - } - assert.NoError(t, dt.UpdateK8sPod(in)) - // Second identical ADD (e.g. informer resync) must be a no-op for the counter. - assert.NoError(t, dt.UpdateK8sPod(in)) - - val, ok := dt.outboundIdentityPodRefCount.Load("public1") - assert.True(t, ok) - assert.Equal(t, 1, val.(int)) -} - -// TestUpdateK8sEndpointsAddThenRemove covers the OldAddresses removal path of -// updateK8sEndpointsLocked, including empty-pod/node cleanup. -func TestUpdateK8sEndpointsAddThenRemove(t *testing.T) { - dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} - - // Add endpoint for svc1 at pod 10.0.0.1 on node1. - errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ - InboundIdentity: "svc1", - NewAddresses: map[string]string{"10.0.0.1": "node1"}, - }) - assert.Empty(t, errs) - assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("svc1")) - - // Remove the same endpoint (now in OldAddresses, absent from NewAddresses). - errs = dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ - InboundIdentity: "svc1", - OldAddresses: map[string]string{"10.0.0.1": "node1"}, - }) - assert.Empty(t, errs) - // Pod had only svc1 -> pod and node cleaned up. - _, ok := dt.K8sResources.Nodes["node1"] - assert.False(t, ok) -} - -// TestDeepEqualMoreCases covers the node/pod/identity mismatch branches of DeepEqual. -func TestDeepEqualMoreCases(t *testing.T) { - // Helper to build a DiffTracker that is in-sync by construction. - inSync := func() *DiffTracker { - return &DiffTracker{ - K8sResources: K8sState{ - Services: sets.NewString("svc1"), - Egresses: sets.NewString("egr1"), - Nodes: map[string]Node{ - "node1": {Pods: map[string]Pod{ - "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, - }}, - }, - }, - NRPResources: NRPState{ - LoadBalancers: sets.NewString("svc1"), - NATGateways: sets.NewString("egr1"), - Locations: map[string]NRPLocation{ - "node1": {Addresses: map[string]NRPAddress{ - "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, - }}, - }, - }, - } - } - - assert.True(t, inSync().DeepEqual()) - - // LoadBalancer present in NRP but not in K8s Services (reverse-direction check). - d := inSync() - d.NRPResources.LoadBalancers = sets.NewString("svc1", "extra") - d.K8sResources.Services = sets.NewString("svc1", "different") - assert.False(t, d.DeepEqual()) - - // Egress name mismatch (reverse direction). - d = inSync() - d.NRPResources.NATGateways = sets.NewString("other") - d.K8sResources.Egresses = sets.NewString("egr1") - // lengths equal (1==1) but names differ -> mismatch - d.NRPResources.NATGateways = sets.NewString("egr1") - d.K8sResources.Egresses = sets.NewString("egr2") - d.NRPResources.NATGateways = sets.NewString("egr2x") - assert.False(t, d.DeepEqual()) - - // Nodes vs Locations length mismatch. - d = inSync() - d.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} - assert.False(t, d.DeepEqual()) - - // Node missing in Locations (same count, different key). - d = inSync() - delete(d.NRPResources.Locations, "node1") - d.NRPResources.Locations["nodeX"] = NRPLocation{Addresses: map[string]NRPAddress{ - "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, - }} - assert.False(t, d.DeepEqual()) - - // Pods vs Addresses length mismatch. - d = inSync() - loc := d.NRPResources.Locations["node1"] - loc.Addresses["10.0.0.2"] = NRPAddress{Services: sets.NewString("svc1")} - assert.False(t, d.DeepEqual()) - - // Pod missing in Addresses (same count, different key). - d = inSync() - loc = d.NRPResources.Locations["node1"] - delete(loc.Addresses, "10.0.0.1") - loc.Addresses["10.0.0.9"] = NRPAddress{Services: sets.NewString("svc1", "egr1")} - assert.False(t, d.DeepEqual()) - - // Combined identities length mismatch. - d = inSync() - loc = d.NRPResources.Locations["node1"] - loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1")} - assert.False(t, d.DeepEqual()) - - // Identity not found in Services (same count, different identity). - d = inSync() - loc = d.NRPResources.Locations["node1"] - loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1", "egrX")} - assert.False(t, d.DeepEqual()) -} diff --git a/pkg/provider/difftracker/difftracker.go b/pkg/provider/difftracker/difftracker.go index 19fb3c1794..01727b2545 100644 --- a/pkg/provider/difftracker/difftracker.go +++ b/pkg/provider/difftracker/difftracker.go @@ -23,50 +23,50 @@ import ( "k8s.io/klog/v2" "sigs.k8s.io/cloud-provider-azure/pkg/azclient" - utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) -// InitializeDiffTracker creates and initializes a new DiffTracker with the given state and configuration. +// New creates and initializes a new DiffTracker with the given state and configuration. // It validates the configuration and ensures all required dependencies are present. // Returns an error if the configuration is invalid or if any required dependency is nil. -func InitializeDiffTracker(K8s K8sState, NRP NRPState, config Config, networkClientFactory azclient.ClientFactory, kubeClient kubernetes.Interface) (*DiffTracker, error) { +func New(k8s K8sState, nrp NRPState, config Config, networkClientFactory azclient.ClientFactory, kubeClient kubernetes.Interface) (*DiffTracker, error) { if err := config.Validate(); err != nil { - return nil, fmt.Errorf("InitializeDiffTracker: %w", err) + return nil, fmt.Errorf("difftracker.New: %w", err) } if networkClientFactory == nil { - return nil, fmt.Errorf("InitializeDiffTracker: networkClientFactory must not be nil") + return nil, fmt.Errorf("difftracker.New: networkClientFactory must not be nil") } if kubeClient == nil { - return nil, fmt.Errorf("InitializeDiffTracker: kubeClient must not be nil") + return nil, fmt.Errorf("difftracker.New: kubeClient must not be nil") } - klog.V(2).Infof("InitializeDiffTracker: initializing with config: subscription=%s, resourceGroup=%s, location=%s", - config.SubscriptionID, config.ResourceGroup, config.Location) + klog.V(2).Infof("difftracker.New: initializing with config: subscription=%s, resourceGroup=%s, location=%s, serviceGatewayResourceName=%s, serviceGatewayID=%s, vNetName=%s", + config.SubscriptionID, config.ResourceGroup, config.Location, config.ServiceGatewayResourceName, config.ServiceGatewayID, config.VNetName) - // If any field is nil, initialize it - if K8s.Services == nil { - K8s.Services = utilsets.NewString() + // The caller is expected to pass fully initialized state structs. A nil + // field is unexpected and indicates a programming error, so error out. + if k8s.Services == nil { + return nil, fmt.Errorf("difftracker.New: k8s.Services must not be nil") } - if K8s.Egresses == nil { - K8s.Egresses = utilsets.NewString() + if k8s.Egresses == nil { + return nil, fmt.Errorf("difftracker.New: k8s.Egresses must not be nil") } - if K8s.Nodes == nil { - K8s.Nodes = make(map[string]Node) + if k8s.Nodes == nil { + return nil, fmt.Errorf("difftracker.New: k8s.Nodes must not be nil") } - if NRP.LoadBalancers == nil { - NRP.LoadBalancers = utilsets.NewString() + if nrp.LoadBalancers == nil { + return nil, fmt.Errorf("difftracker.New: nrp.LoadBalancers must not be nil") } - if NRP.NATGateways == nil { - NRP.NATGateways = utilsets.NewString() + if nrp.NATGateways == nil { + return nil, fmt.Errorf("difftracker.New: nrp.NATGateways must not be nil") } - if NRP.Locations == nil { - NRP.Locations = make(map[string]NRPLocation) + if nrp.Locations == nil { + return nil, fmt.Errorf("difftracker.New: nrp.Locations must not be nil") } diffTracker := &DiffTracker{ - K8sResources: K8s, - NRPResources: NRP, + K8sResources: k8s, + NRPResources: nrp, // Configuration and clients config: config, diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index 0439e277b3..607eb708e4 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -99,17 +99,17 @@ func TestEnqueueK8sServiceOperation(t *testing.T) { }, } - // Test ADD operation + // Test Add operation err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{ - Operation: ADD, + Operation: Add, ID: "service1", }) assert.NoError(t, err) assert.True(t, dt.K8sResources.Services.Has("service1")) - // Test REMOVE operation + // Test Remove operation err = dt.EnqueueK8sServiceOperation(UpdateK8sResource{ - Operation: REMOVE, + Operation: Remove, ID: "service1", }) assert.NoError(t, err) @@ -117,7 +117,7 @@ func TestEnqueueK8sServiceOperation(t *testing.T) { // Test invalid operation err = dt.EnqueueK8sServiceOperation(UpdateK8sResource{ - Operation: UPDATE, + Operation: Update, ID: "service1", }) assert.Error(t, err) @@ -183,7 +183,7 @@ func TestUpdateK8sPod(t *testing.T) { // Test adding new egress assignment input := UpdatePodInputType{ - PodOperation: ADD, + PodOperation: Add, PublicOutboundIdentity: "public1", Location: "node1", Address: "10.0.0.1", @@ -197,7 +197,7 @@ func TestUpdateK8sPod(t *testing.T) { // Test removing egress assignment input = UpdatePodInputType{ - PodOperation: REMOVE, + PodOperation: Remove, Location: "node1", Address: "10.0.0.1", } @@ -207,17 +207,74 @@ func TestUpdateK8sPod(t *testing.T) { assert.NotContains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") } -// TestUpdateK8sPodRemoveUsesStoredIdentity verifies that a REMOVE whose input omits -// (or mismatches) PublicOutboundIdentity still decrements the counter of the identity -// actually stored on the pod, rather than corrupting a different ("" or wrong) counter. -func TestUpdateK8sPodRemoveUsesStoredIdentity(t *testing.T) { +// TestUpdateK8sPodIdentityChangeDecrementsOld verifies that when a pod's egress +// identity changes (X -> Y), the old identity's ref-counter is decremented so it +// doesn't leak, while the new identity is counted. +func TestUpdateK8sPodIdentityChangeDecrementsOld(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + // Pod starts using egress "X". + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "X", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("x") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + // Pod's egress changes to "Y": X must be released, Y counted. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Update, + PublicOutboundIdentity: "Y", + Location: "node1", + Address: "10.0.0.1", + })) + _, ok = dt.outboundIdentityPodRefCount.Load("x") + assert.False(t, ok, "old identity X must be decremented on identity change, not leaked") + val, ok = dt.outboundIdentityPodRefCount.Load("y") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) +} + +// TestUpdateK8sPodCaseInsensitiveReAdd verifies that re-adding the same pod with a +// different-cased identity is treated as the same identity (no double counting), +// since the counter is keyed case-insensitively. +func TestUpdateK8sPodCaseInsensitiveReAdd(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "Svc1", + Location: "node1", + Address: "10.0.0.1", + })) + // Informer resync delivers the same identity with different casing. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "svc1", + Location: "node1", + Address: "10.0.0.1", + })) + + val, ok := dt.outboundIdentityPodRefCount.Load("svc1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int), "case-only re-ADD must not double-count the identity") +} + +// TestUpdateK8sPodRemoveDecrementsCounter verifies that removing a pod decrements +// the outbound-identity ref-counter for the identity passed in the remove input +// (matching the engine's DeletePod, which always passes the service UID), clearing +// the entry when the last pod for that identity is removed. +func TestUpdateK8sPodRemoveDecrementsCounter(t *testing.T) { dt := &DiffTracker{ K8sResources: K8sState{Nodes: map[string]Node{}}, } // Add a pod with outbound identity "public1". assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ - PodOperation: ADD, + PodOperation: Add, PublicOutboundIdentity: "public1", Location: "node1", Address: "10.0.0.1", @@ -226,16 +283,16 @@ func TestUpdateK8sPodRemoveUsesStoredIdentity(t *testing.T) { assert.True(t, ok) assert.Equal(t, 1, val.(int)) - // Remove with the identity omitted from the input (the problematic case). + // Remove passing the same identity; the counter must be cleared. assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ - PodOperation: REMOVE, - Location: "node1", - Address: "10.0.0.1", + PodOperation: Remove, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", })) - // The stored identity's counter must be cleared, and no bogus "" entry created. _, ok = dt.outboundIdentityPodRefCount.Load("public1") - assert.False(t, ok, "counter for stored identity public1 should be removed") + assert.False(t, ok, "counter for public1 should be removed after the last pod is removed") _, ok = dt.outboundIdentityPodRefCount.Load("") assert.False(t, ok, "no counter should be created for an empty identity") } @@ -283,9 +340,9 @@ func TestGetSyncLocationsAddresses(t *testing.T) { } func TestOperation_String(t *testing.T) { - assert.Equal(t, "ADD", ADD.String()) - assert.Equal(t, "REMOVE", REMOVE.String()) - assert.Equal(t, "UPDATE", UPDATE.String()) + assert.Equal(t, "Add", Add.String()) + assert.Equal(t, "Remove", Remove.String()) + assert.Equal(t, "Update", Update.String()) } func TestUpdateNRPLoadBalancers(t *testing.T) { @@ -340,17 +397,17 @@ func TestEnqueueK8sEgressOperation(t *testing.T) { }, } - err := dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: ADD, ID: "egress1"}) + err := dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Add, ID: "egress1"}) assert.NoError(t, err) assert.True(t, dt.K8sResources.Egresses.Has("egress1")) - err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: REMOVE, ID: "egress1"}) + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Remove, ID: "egress1"}) assert.NoError(t, err) assert.False(t, dt.K8sResources.Egresses.Has("egress1")) - err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: UPDATE, ID: "egress1"}) + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Update, ID: "egress1"}) assert.Error(t, err) - assert.Contains(t, err.Error(), "error - ResourceType=Egress, Operation=UPDATE and ID=egress1") + assert.Contains(t, err.Error(), "error - ResourceType=Egress, Operation=Update and ID=egress1") } func TestGetSyncNRPNATGateways(t *testing.T) { @@ -599,7 +656,7 @@ func TestGetSyncOperations(t *testing.T) { // Real Scenario: CloudProvider is down and K8s Cluster is subject to continuous updates. // This test verifies if the DiffTracker is able to sync K8s Cluster and NRP correctly // when there is a huge discrepancy between K8s Cluster and NRP. -func TestInitializeDiffTracker(t *testing.T) { +func TestNew(t *testing.T) { K8sResources := K8sState{ Services: sets.NewString("Service0", "Service1", "Service2"), Egresses: sets.NewString("Egress0", "Egress1", "Egress2"), @@ -653,7 +710,7 @@ func TestInitializeDiffTracker(t *testing.T) { defer ctrl.Finish() mockFactory := mock_azclient.NewMockClientFactory(ctrl) mockKubeClient := fake.NewSimpleClientset() - diffTracker, err := InitializeDiffTracker(K8sResources, NRPResources, config, mockFactory, mockKubeClient) + diffTracker, err := New(K8sResources, NRPResources, config, mockFactory, mockKubeClient) assert.NoError(t, err) syncOperations := diffTracker.GetSyncOperations() @@ -682,3 +739,290 @@ func TestInitializeDiffTracker(t *testing.T) { assert.True(t, diffTracker.Equals(expectedDiffTracker), "DiffTracker does not match expected state") } + +func validTestConfig() Config { + return Config{ + SubscriptionID: "test-subscription", + ResourceGroup: "test-rg", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } +} + +func emptyK8sState() K8sState { + return K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: make(map[string]Node), + } +} + +func emptyNRPState() NRPState { + return NRPState{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString(), + Locations: make(map[string]NRPLocation), + } +} + +// TestNewErrorPaths covers the validation/error branches of +// New and the successful initialization path. +func TestNewErrorPaths(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + + // Invalid config (empty) -> validation error. + _, err := New(K8sState{}, NRPState{}, Config{}, mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "New") + + // Nil networkClientFactory. + _, err = New(K8sState{}, NRPState{}, validTestConfig(), nil, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "networkClientFactory must not be nil") + + // Nil kubeClient. + _, err = New(K8sState{}, NRPState{}, validTestConfig(), mockFactory, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "kubeClient must not be nil") + + // Uninitialized state field -> error out instead of silently initializing. + k8sMissingServices := emptyK8sState() + k8sMissingServices.Services = nil + _, err = New(k8sMissingServices, emptyNRPState(), validTestConfig(), mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "k8s.Services must not be nil") + + nrpMissingLocations := emptyNRPState() + nrpMissingLocations.Locations = nil + _, err = New(emptyK8sState(), nrpMissingLocations, validTestConfig(), mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nrp.Locations must not be nil") + + // Valid call with fully initialized (empty) states. + dt, err := New(emptyK8sState(), emptyNRPState(), validTestConfig(), mockFactory, mockKubeClient) + assert.NoError(t, err) + assert.NotNil(t, dt) + assert.NotNil(t, dt.K8sResources.Services) + assert.NotNil(t, dt.K8sResources.Egresses) + assert.NotNil(t, dt.K8sResources.Nodes) + assert.NotNil(t, dt.NRPResources.LoadBalancers) + assert.NotNil(t, dt.NRPResources.NATGateways) + assert.NotNil(t, dt.NRPResources.Locations) +} + +// TestEnqueueK8sResourceOperationErrors covers the empty-ID and invalid-operation +// branches of enqueueK8sResourceOperation (via the public wrappers). +func TestEnqueueK8sResourceOperationErrors(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Services: sets.NewString(), Egresses: sets.NewString()}, + } + + // Empty ID. + err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: Add, ID: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty ID") + + // Invalid operation (UPDATE is not valid for the resource set). + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Update, ID: "egress1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Operation=Update") + + // Successful Add then Remove. + assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: Add, ID: "svc1"})) + assert.True(t, dt.K8sResources.Services.Has("svc1")) + assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: Remove, ID: "svc1"})) + assert.False(t, dt.K8sResources.Services.Has("svc1")) +} + +// TestUpdateK8sPodInvalidOperation covers the default branch of updateK8sPodLocked. +func TestUpdateK8sPodInvalidOperation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + err := dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Operation(99), + Location: "node1", + Address: "10.0.0.1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid pod operation") +} + +// TestUpdateK8sPodRemoveNonExistent covers the duplicate-removal branch. +func TestUpdateK8sPodRemoveNonExistent(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + // Removing a pod that was never added is a no-op (no error, no counter change). + err := dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + }) + assert.NoError(t, err) + _, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.False(t, ok) +} + +// TestUpdateK8sEndpointsMissingLocation covers the error branch where an address +// has no associated node location. +func TestUpdateK8sEndpointsMissingLocation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": ""}, + }) + assert.NotEmpty(t, errs) + assert.Contains(t, errs[0].Error(), "does not have a node associated") +} + +// TestRemoveServiceFromK8sStateLocked covers removeServiceFromK8sStateLocked for +// both inbound and outbound identities, including empty-pod/node cleanup. +func TestRemoveServiceFromK8sStateLocked(t *testing.T) { + t.Run("inbound removal cleans up empty pods and nodes", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + // Pod backs only svc1 inbound -> becomes empty and is removed. + "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, + // Pod backs svc1 and svc2 -> keeps svc2. + "10.0.0.2": {InboundIdentities: sets.NewString("svc1", "svc2")}, + }}, + }, + }, + } + dt.removeServiceFromK8sStateLocked("svc1", true) + + // 10.0.0.1 had only svc1 -> removed. + _, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.False(t, ok) + // 10.0.0.2 still has svc2. + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.2"] + assert.True(t, ok) + assert.True(t, pod.InboundIdentities.Has("svc2")) + assert.False(t, pod.InboundIdentities.Has("svc1")) + }) + + t.Run("outbound removal clears identity and removes empty node", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + }}, + }, + }, + } + dt.removeServiceFromK8sStateLocked("egress1", false) + + // Pod had only the outbound identity -> pod and node removed. + _, nodeOK := dt.K8sResources.Nodes["node1"] + assert.False(t, nodeOK) + }) + + t.Run("public wrapper acquires lock and removes inbound identity", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, + }}, + }, + }, + } + dt.RemoveServiceFromK8sState("svc1", true) + + _, nodeOK := dt.K8sResources.Nodes["node1"] + assert.False(t, nodeOK) + }) + + t.Run("outbound service removal decrements ref-counter", func(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + // Add an egress pod so the counter is seeded. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "egress1", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("egress1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + // Removing the egress service must release the counter, not leak it. + dt.RemoveServiceFromK8sState("egress1", false) + _, ok = dt.outboundIdentityPodRefCount.Load("egress1") + assert.False(t, ok, "service removal must decrement the outbound ref-counter") + }) +} + +// TestUpdateK8sPodAddNoOutboundIdentity verifies that adding a pod with no +// outbound access (empty PublicOutboundIdentity) tracks the pod in state but +// does not create a bogus ref-counter entry under the empty-string key. This +// keeps the Add path symmetric with the Remove path, which skips the decrement +// for an empty identity. +func TestUpdateK8sPodAddNoOutboundIdentity(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + in := UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "", + Location: "node1", + Address: "10.0.0.1", + } + assert.NoError(t, dt.UpdateK8sPod(in)) + + // The pod must be tracked in state. + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.True(t, ok) + assert.Equal(t, "", pod.PublicOutboundIdentity) + + // No counter entry should be created for the empty identity. + _, ok = dt.outboundIdentityPodRefCount.Load("") + assert.False(t, ok, "no counter should be created for an empty outbound identity") +} + +// TestUpdateK8sPodAddIdempotent covers the alreadyExists branch of updateK8sPodLocked: +// a repeated Add for the same pod+identity must not double-count the counter. +func TestUpdateK8sPodAddIdempotent(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + in := UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + } + assert.NoError(t, dt.UpdateK8sPod(in)) + // Second identical Add (e.g. informer resync) must be a no-op for the counter. + assert.NoError(t, dt.UpdateK8sPod(in)) + + val, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) +} + +// TestUpdateK8sEndpointsAddThenRemove covers the OldAddresses removal path of +// updateK8sEndpointsLocked, including empty-pod/node cleanup. +func TestUpdateK8sEndpointsAddThenRemove(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + // Add endpoint for svc1 at pod 10.0.0.1 on node1. + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("svc1")) + + // Remove the same endpoint (now in OldAddresses, absent from NewAddresses). + errs = dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + // Pod had only svc1 -> pod and node cleaned up. + _, ok := dt.K8sResources.Nodes["node1"] + assert.False(t, ok) +} diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index 02b51c8019..bba35dc6e0 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -30,7 +30,7 @@ const ( ResourceTypeEgress = "Egress" ) -// enqueueK8sResourceOperation applies the requested operation (ADD/REMOVE) to the +// enqueueK8sResourceOperation applies the requested operation (Add/Remove) to the // in-memory K8s resource set. It does not perform any Azure update calls; it only // mutates the local desired-state model that will later be reconciled with NRP. func (dt *DiffTracker) enqueueK8sResourceOperation(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { @@ -39,17 +39,19 @@ func (dt *DiffTracker) enqueueK8sResourceOperation(input UpdateK8sResource, set } switch input.Operation { - case ADD: + case Add: set.Insert(input.ID) - case REMOVE: + klog.V(2).Infof("enqueueK8sResourceOperation: Added %s %s to K8s state", resourceType, input.ID) + case Remove: set.Delete(input.ID) + klog.V(2).Infof("enqueueK8sResourceOperation: Removed %s %s from K8s state", resourceType, input.ID) default: return fmt.Errorf("error - ResourceType=%s, Operation=%s and ID=%s", resourceType, input.Operation, input.ID) } return nil } -// EnqueueK8sServiceOperation records a service ADD/REMOVE in the local K8s state set. +// EnqueueK8sServiceOperation records a service Add/Remove in the local K8s state set. // The change is reconciled with NRP later by the sync operations; this method itself // performs no Azure calls. func (dt *DiffTracker) EnqueueK8sServiceOperation(input UpdateK8sResource) error { @@ -59,7 +61,7 @@ func (dt *DiffTracker) EnqueueK8sServiceOperation(input UpdateK8sResource) error return dt.enqueueK8sResourceOperation(input, dt.K8sResources.Services, ResourceTypeService) } -// EnqueueK8sEgressOperation records an egress ADD/REMOVE in the local K8s state set. +// EnqueueK8sEgressOperation records an egress Add/Remove in the local K8s state set. // The change is reconciled with NRP later by the sync operations; this method itself // performs no Azure calls. func (dt *DiffTracker) EnqueueK8sEgressOperation(input UpdateK8sResource) error { @@ -87,20 +89,17 @@ func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputTyp nodeState, exists := dt.K8sResources.Nodes[location] if !exists { - nodeState = Node{ - Pods: make(map[string]Pod), - } + nodeState = newNode() dt.K8sResources.Nodes[location] = nodeState } pod, exists := nodeState.Pods[address] if !exists { - pod = Pod{ - InboundIdentities: utilsets.NewString(), - } + pod = newPod() nodeState.Pods[address] = pod } pod.InboundIdentities.Insert(input.InboundIdentity) + klog.V(2).Infof("updateK8sEndpointsLocked: Added inbound identity %s to pod %s on node %s", input.InboundIdentity, address, location) } for address, location := range input.OldAddresses { @@ -110,6 +109,7 @@ func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputTyp if location == "" { errs = append(errs, fmt.Errorf("error UpdateK8sEndpoints, address=%s does not have a node associated", address)) + continue } node, nodeExists := dt.K8sResources.Nodes[location] @@ -123,6 +123,7 @@ func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputTyp } pod.InboundIdentities.Delete(input.InboundIdentity) + klog.V(2).Infof("updateK8sEndpointsLocked: Removed inbound identity %s from pod %s on node %s", input.InboundIdentity, address, location) if !pod.HasIdentities() { delete(node.Pods, address) @@ -148,103 +149,128 @@ func (dt *DiffTracker) UpdateK8sEndpoints(input UpdateK8sEndpointsInputType) []e func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) error { node, exists := dt.K8sResources.Nodes[input.Location] if !exists { - node = Node{Pods: make(map[string]Pod)} + node = newNode() dt.K8sResources.Nodes[input.Location] = node } pod, exists := node.Pods[input.Address] if !exists { - pod = Pod{InboundIdentities: utilsets.NewString()} + pod = newPod() } pod.PublicOutboundIdentity = input.PublicOutboundIdentity node.Pods[input.Address] = pod + klog.V(2).Infof("addOrUpdatePod: Set outbound identity %q for pod %s on node %s", input.PublicOutboundIdentity, input.Address, input.Location) return nil } -// removePod removes a pod from K8s state. Returns true and the removed pod's stored -// PublicOutboundIdentity if the pod was actually removed, or false if it didn't exist -// (already removed by a previous call). The returned identity is authoritative for -// reference-counter bookkeeping, independent of whatever the caller passed in input. -func (dt *DiffTracker) removePod(input UpdatePodInputType) (removed bool, identity string, err error) { +// removePod removes a pod from K8s state. Returns true if the pod was actually +// removed, or false if it didn't exist (already removed by a previous call). +func (dt *DiffTracker) removePod(input UpdatePodInputType) bool { node, exists := dt.K8sResources.Nodes[input.Location] if !exists { - return false, "", nil + return false } - // Check if pod exists before removing - pod, podExists := node.Pods[input.Address] - if !podExists { - return false, "", nil + if _, podExists := node.Pods[input.Address]; !podExists { + return false } - identity = pod.PublicOutboundIdentity delete(node.Pods, input.Address) if !node.HasPods() { delete(dt.K8sResources.Nodes, input.Location) } + klog.V(2).Infof("removePod: Removed pod %s from node %s", input.Address, input.Location) - return true, identity, nil + return true +} + +// incrementOutboundRefCount increments the ref-counter for a pod's outbound +// (egress) identity. Empty identities are not counted. +func (dt *DiffTracker) incrementOutboundRefCount(identity string) { + if identity == "" { + return + } + key := strings.ToLower(identity) + counter := 0 + if val, ok := dt.outboundIdentityPodRefCount.Load(key); ok { + counter = val.(int) + } + dt.outboundIdentityPodRefCount.Store(key, counter+1) +} + +// decrementOutboundRefCount decrements the ref-counter for an outbound (egress) +// identity, deleting the entry when it reaches zero. Empty or unknown identities +// are a no-op. Returns an error if the counter is already non-positive. +func (dt *DiffTracker) decrementOutboundRefCount(identity string) error { + if identity == "" { + return nil + } + key := strings.ToLower(identity) + val, ok := dt.outboundIdentityPodRefCount.Load(key) + if !ok { + return nil + } + counter := val.(int) + if counter <= 0 { + return fmt.Errorf("error - PublicOutboundIdentity %s has a non-positive count: %d", identity, counter) + } + if counter == 1 { + dt.outboundIdentityPodRefCount.Delete(key) + } else { + dt.outboundIdentityPodRefCount.Store(key, counter-1) + } + return nil } // updateK8sPodLocked updates K8s pod state. Assumes lock is already held. func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { switch input.PodOperation { - case ADD, UPDATE: - // Check if pod already exists with the same outbound identity - // This prevents double-counting when pod informer fires AddFunc for pods - // that were already counted during initialization + case Add, Update: + // Determine the pod's current (old) outbound identity, if any, and whether + // this exact pod+identity is already counted (idempotent re-ADD, e.g. an + // informer resync). The comparison is case-insensitive because the counter + // is keyed on the lowercased identity. + oldIdentity := "" alreadyExists := false if node, nodeExists := dt.K8sResources.Nodes[input.Location]; nodeExists { if pod, podExists := node.Pods[input.Address]; podExists { - if pod.PublicOutboundIdentity == input.PublicOutboundIdentity { + oldIdentity = pod.PublicOutboundIdentity + if strings.EqualFold(oldIdentity, input.PublicOutboundIdentity) { alreadyExists = true - klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s already exists for service %s, skipping counter increment", + klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s already exists for service %s, skipping counter update", input.Location, input.Address, input.PublicOutboundIdentity) } } } - // Only increment counter if pod doesn't already exist if !alreadyExists { - counter := 0 - if val, ok := dt.outboundIdentityPodRefCount.Load(strings.ToLower(input.PublicOutboundIdentity)); ok { - counter = val.(int) + // The pod's egress identity is changing (old -> new). Release the old + // identity's count before counting the new one, otherwise the old + // identity's counter would leak (a later remove decrements the new + // identity, never the old). + if !strings.EqualFold(oldIdentity, input.PublicOutboundIdentity) { + if err := dt.decrementOutboundRefCount(oldIdentity); err != nil { + return err + } } - dt.outboundIdentityPodRefCount.Store(strings.ToLower(input.PublicOutboundIdentity), counter+1) + dt.incrementOutboundRefCount(input.PublicOutboundIdentity) } return dt.addOrUpdatePod(input) - case REMOVE: - // First, try to remove the pod from K8s state - // This returns false if the pod doesn't exist (duplicate removal) - removed, identity, err := dt.removePod(input) - if err != nil { - return err - } - if !removed { - // Pod didn't exist - this is a duplicate removal, don't decrement counter + case Remove: + // First, try to remove the pod from K8s state. + // removePod returns false if the pod doesn't exist (duplicate removal). + if !dt.removePod(input) { klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s was already removed (duplicate delete), skipping counter decrement", input.Location, input.Address) return nil } - // Decrement using the pod's stored identity (authoritative), not input. - if identity == "" { - return nil - } - if val, ok := dt.outboundIdentityPodRefCount.Load(strings.ToLower(identity)); ok { - counter := val.(int) - if counter <= 0 { - return fmt.Errorf("error - PublicOutboundIdentity %s has a non-positive count: %d", identity, counter) - } - if counter == 1 { - dt.outboundIdentityPodRefCount.Delete(strings.ToLower(identity)) - } else { - dt.outboundIdentityPodRefCount.Store(strings.ToLower(identity), counter-1) - } - } - return nil + // Decrement the ref-counter for the outbound identity passed in the + // remove input. A missing key (e.g. an empty identity, which is never + // counted) is a no-op. + return dt.decrementOutboundRefCount(input.PublicOutboundIdentity) default: return fmt.Errorf("invalid pod operation: %s for pod at %s:%s", input.PodOperation, input.Location, input.Address) @@ -274,10 +300,15 @@ func (dt *DiffTracker) removeServiceFromK8sStateLocked(serviceUID string, isInbo pod.InboundIdentities.Delete(serviceUID) } } else { - // Clear outbound identity if it matches + // Clear outbound identity if it matches, releasing its ref-count so + // the counter doesn't leak when a service is deleted before its pods + // are removed. if strings.EqualFold(pod.PublicOutboundIdentity, serviceUID) { pod.PublicOutboundIdentity = "" node.Pods[podIP] = pod + if err := dt.decrementOutboundRefCount(serviceUID); err != nil { + klog.Warningf("removeServiceFromK8sStateLocked: %v", err) + } } } diff --git a/pkg/provider/difftracker/nrp_state_updates.go b/pkg/provider/difftracker/nrp_state_updates.go index 5b59e69d7b..94d427493c 100644 --- a/pkg/provider/difftracker/nrp_state_updates.go +++ b/pkg/provider/difftracker/nrp_state_updates.go @@ -18,8 +18,6 @@ package difftracker import ( "k8s.io/klog/v2" - - utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) func (dt *DiffTracker) UpdateNRPLoadBalancers(syncServicesReturnType SyncServicesReturnType) { @@ -28,12 +26,12 @@ func (dt *DiffTracker) UpdateNRPLoadBalancers(syncServicesReturnType SyncService for _, service := range syncServicesReturnType.Additions.UnsortedList() { dt.NRPResources.LoadBalancers.Insert(service) - klog.V(2).Infof("UpdateNRPLoadBalancers: Added service %s to NRP LoadBalancers\n", service) + klog.V(2).Infof("UpdateNRPLoadBalancers: Added service %s to NRP LoadBalancers", service) } for _, service := range syncServicesReturnType.Removals.UnsortedList() { dt.NRPResources.LoadBalancers.Delete(service) - klog.V(2).Infof("UpdateNRPLoadBalancers: Removed service %s from NRP LoadBalancers\n", service) + klog.V(2).Infof("UpdateNRPLoadBalancers: Removed service %s from NRP LoadBalancers", service) } } @@ -43,12 +41,12 @@ func (dt *DiffTracker) UpdateNRPNATGateways(syncServicesReturnType SyncServicesR for _, service := range syncServicesReturnType.Additions.UnsortedList() { dt.NRPResources.NATGateways.Insert(service) - klog.V(2).Infof("UpdateNRPNATGateways: Added service %s to NRP NATGateways\n", service) + klog.V(2).Infof("UpdateNRPNATGateways: Added service %s to NRP NATGateways", service) } for _, service := range syncServicesReturnType.Removals.UnsortedList() { dt.NRPResources.NATGateways.Delete(service) - klog.V(2).Infof("UpdateNRPNATGateways: Removed service %s from NRP NATGateways\n", service) + klog.V(2).Infof("UpdateNRPNATGateways: Removed service %s from NRP NATGateways", service) } } @@ -108,18 +106,3 @@ func (dt *DiffTracker) UpdateLocationsAddresses(locationData LocationData) { } } } - -// Helper function to check if address exists in a location -func addressExists(location NRPLocation, addressKey string) bool { - _, exists := location.Addresses[addressKey] - return exists -} - -// Helper function to create service references from an address -func createServiceRefsFromAddress(addressValue Address) *utilsets.IgnoreCaseSet { - serviceRefs := utilsets.NewString() - for _, service := range addressValue.ServiceRef.UnsortedList() { - serviceRefs.Insert(service) - } - return serviceRefs -} diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index 19edd01b93..ae6e13128a 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -28,25 +28,13 @@ func GetServicesToSync(k8sServices, nrpServices *utilsets.IgnoreCaseSet) SyncSer klog.V(2).Infof("GetServicesToSync: NRP services (%d): %v", nrpServices.Len(), nrpServices.UnsortedList()) syncServices := SyncServicesReturnType{ - Additions: utilsets.NewString(), - Removals: utilsets.NewString(), - } - - for _, service := range k8sServices.UnsortedList() { - if nrpServices.Has(service) { - continue - } - syncServices.Additions.Insert(service) - klog.V(4).Infof("GetServicesToSync: Added service %s to additions", service) - } - - for _, service := range nrpServices.UnsortedList() { - if k8sServices.Has(service) { - continue - } - syncServices.Removals.Insert(service) - klog.V(4).Infof("GetServicesToSync: Added service %s to removals", service) + // Additions are in K8s but not yet in NRP; removals are in NRP but no + // longer in K8s. + Additions: k8sServices.Difference(nrpServices), + Removals: nrpServices.Difference(k8sServices), } + klog.V(4).Infof("GetServicesToSync: additions=%v, removals=%v", + syncServices.Additions.UnsortedList(), syncServices.Removals.UnsortedList()) klog.V(2).Infof("GetServicesToSync: Result - Additions: %d, Removals: %d", syncServices.Additions.Len(), syncServices.Removals.Len()) return syncServices @@ -101,16 +89,16 @@ func (dt *DiffTracker) getSyncLocationsAddressesLocked() LocationData { serviceRef := dt.createServiceRefFiltered(pod) // Check if address exists in NRP and if service list changed - nrpAddressData, addressExists := nrpLocation.Addresses[address] + nrpAddressData, nrpAddrExists := nrpLocation.Addresses[address] // Skip this address if: // 1. No ready services AND address doesn't exist in NRP (nothing to sync) // 2. ServiceRef matches what's already in NRP (no change) - if serviceRef.Len() == 0 && !addressExists { + if serviceRef.Len() == 0 && !nrpAddrExists { continue } - if addressExists && serviceRef.Equals(nrpAddressData.Services) { + if nrpAddrExists && serviceRef.Equals(nrpAddressData.Services) { continue } @@ -198,9 +186,10 @@ func (dt *DiffTracker) isServiceReady(serviceUID string, isInbound bool) bool { return dt.NRPResources.NATGateways.Has(serviceUID) } -// Helper function to find LocationData in result -func findLocationData(result LocationData, location string) *Location { - if loc, ok := result.Locations[location]; ok { +// findLocationData returns a pointer to the Location stored under the given key +// in data, or nil if no such location exists. +func findLocationData(data LocationData, location string) *Location { + if loc, ok := data.Locations[location]; ok { return &loc } return nil diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index a39161e0fc..d99cb0ead8 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -31,22 +31,25 @@ import ( type Operation int const ( - ADD Operation = iota - REMOVE - UPDATE + UnknownOperation Operation = iota + Add + Remove + Update ) type UpdateAction int const ( - PartialUpdate UpdateAction = iota + UnknownUpdateAction UpdateAction = iota + PartialUpdate FullUpdate ) type SyncStatus int const ( - AlreadyInSync SyncStatus = iota + UnknownSyncStatus SyncStatus = iota + AlreadyInSync Success ) @@ -92,10 +95,20 @@ type Pod struct { PublicOutboundIdentity string } +// newPod returns a Pod with its InboundIdentities set initialized. +func newPod() Pod { + return Pod{InboundIdentities: utilsets.NewString()} +} + type Node struct { Pods map[string]Pod } +// newNode returns a Node with its Pods map initialized. +func newNode() Node { + return Node{Pods: make(map[string]Pod)} +} + type K8sState struct { Services *utilsets.IgnoreCaseSet Egresses *utilsets.IgnoreCaseSet @@ -157,12 +170,12 @@ type Address struct { // Location uses a map for Addresses type Location struct { AddressUpdateAction UpdateAction - Addresses map[string]Address // key is Address.Address + Addresses map[string]Address // key is the pod IP } type LocationData struct { Action UpdateAction - Locations map[string]Location // key is Location.Location + Locations map[string]Location // key is the node IP } type SyncServicesReturnType struct { diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index 40d66db603..a4426317e7 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -19,18 +19,23 @@ package difftracker import ( "encoding/json" "fmt" + "reflect" "k8s.io/klog/v2" + + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) func (operation Operation) String() string { switch operation { - case ADD: - return "ADD" - case REMOVE: - return "REMOVE" - case UPDATE: - return "UPDATE" + case UnknownOperation: + return "UnknownOperation" + case Add: + return "Add" + case Remove: + return "Remove" + case Update: + return "Update" default: return fmt.Sprintf("Operation(%d)", int(operation)) } @@ -42,6 +47,8 @@ func (operation Operation) MarshalJSON() ([]byte, error) { func (updateAction UpdateAction) String() string { switch updateAction { + case UnknownUpdateAction: + return "UnknownUpdateAction" case PartialUpdate: return "PartialUpdate" case FullUpdate: @@ -75,6 +82,8 @@ func (updateAction *UpdateAction) UnmarshalJSON(data []byte) error { func (syncStatus SyncStatus) String() string { switch syncStatus { + case UnknownSyncStatus: + return "UnknownSyncStatus" case AlreadyInSync: return "AlreadyInSync" case Success: @@ -104,47 +113,22 @@ func (dt *DiffTracker) DeepEqual() bool { // deepEqualLocked is the lock-free body of DeepEqual. Callers must hold dt.mu. func (dt *DiffTracker) deepEqualLocked() bool { - klog.V(4).Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d", + klog.V(4).Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d, K8s Nodes=%d, NRP Locations=%d", dt.K8sResources.Services.Len(), dt.NRPResources.LoadBalancers.Len(), - dt.K8sResources.Egresses.Len(), dt.NRPResources.NATGateways.Len()) + dt.K8sResources.Egresses.Len(), dt.NRPResources.NATGateways.Len(), + len(dt.K8sResources.Nodes), len(dt.NRPResources.Locations)) - // Compare Services with LoadBalancers - if dt.K8sResources.Services.Len() != dt.NRPResources.LoadBalancers.Len() { - klog.V(4).Infof("DeepEqual: Services and LoadBalancers length mismatch") + // Compare Services with LoadBalancers and Egresses with NATGateways. + if !dt.K8sResources.Services.Equals(dt.NRPResources.LoadBalancers) { + klog.V(4).Infof("DeepEqual: Services and LoadBalancers mismatch") return false } - for _, service := range dt.K8sResources.Services.UnsortedList() { - if !dt.NRPResources.LoadBalancers.Has(service) { - klog.V(4).Infof("DeepEqual: Service %s not found in LoadBalancers", service) - return false - } - } - for _, service := range dt.NRPResources.LoadBalancers.UnsortedList() { - if !dt.K8sResources.Services.Has(service) { - klog.V(4).Infof("DeepEqual: LoadBalancer %s not found in Services", service) - return false - } - } - - // Compare Egresses with NATGateways - if dt.K8sResources.Egresses.Len() != dt.NRPResources.NATGateways.Len() { - klog.V(4).Infof("DeepEqual: Egresses and NATGateways length mismatch") + if !dt.K8sResources.Egresses.Equals(dt.NRPResources.NATGateways) { + klog.V(4).Infof("DeepEqual: Egresses and NATGateways mismatch") return false } - for _, egress := range dt.K8sResources.Egresses.UnsortedList() { - if !dt.NRPResources.NATGateways.Has(egress) { - klog.V(4).Infof("DeepEqual: Egress %s not found in NATGateways", egress) - return false - } - } - for _, egress := range dt.NRPResources.NATGateways.UnsortedList() { - if !dt.K8sResources.Egresses.Has(egress) { - klog.V(4).Infof("DeepEqual: NATGateway %s not found in Egresses", egress) - return false - } - } - // Compare Nodes with Locations + // Compare Nodes with Locations. if len(dt.K8sResources.Nodes) != len(dt.NRPResources.Locations) { klog.V(4).Infof("DeepEqual: Nodes and Locations length mismatch") return false @@ -152,48 +136,39 @@ func (dt *DiffTracker) deepEqualLocked() bool { for nodeKey, node := range dt.K8sResources.Nodes { nrpLocation, exists := dt.NRPResources.Locations[nodeKey] if !exists { - klog.V(4).Infof("DeepEqual: Node %s not found in Locations\n", nodeKey) + klog.V(4).Infof("DeepEqual: Node %s not found in Locations", nodeKey) return false } - // Compare Pods with Addresses + // Compare Pods with Addresses. if len(node.Pods) != len(nrpLocation.Addresses) { - klog.V(4).Infof("DeepEqual: Pods and Addresses length mismatch for node %s\n", nodeKey) + klog.V(4).Infof("DeepEqual: Pods and Addresses length mismatch for node %s", nodeKey) return false } for podKey, pod := range node.Pods { nrpAddress, exists := nrpLocation.Addresses[podKey] if !exists { - klog.V(4).Infof("DeepEqual: Pod %s not found in Addresses for node %s\n", podKey, nodeKey) + klog.V(4).Infof("DeepEqual: Pod %s not found in Addresses for node %s", podKey, nodeKey) return false } - // Compare [...InboundIdentities, PublicOutboundIdentity] with Services - combinedIdentities := []string{} - combinedIdentities = append(combinedIdentities, pod.InboundIdentities.UnsortedList()...) + // Compare [...InboundIdentities, PublicOutboundIdentity] with Services. + combinedIdentities := utilsets.NewString(pod.InboundIdentities.UnsortedList()...) if pod.PublicOutboundIdentity != "" { - combinedIdentities = append(combinedIdentities, pod.PublicOutboundIdentity) + combinedIdentities.Insert(pod.PublicOutboundIdentity) } - - if len(combinedIdentities) != nrpAddress.Services.Len() { - klog.V(4).Infof("DeepEqual: Combined identities length mismatch for pod %s in node %s\n", podKey, nodeKey) + if !combinedIdentities.Equals(nrpAddress.Services) { + klog.V(4).Infof("DeepEqual: Identities and Services mismatch for pod %s in node %s", podKey, nodeKey) return false } - - for _, identity := range combinedIdentities { - if !nrpAddress.Services.Has(identity) { - klog.V(4).Infof("DeepEqual: Identity %s not found in Services for pod %s in node %s\n", identity, podKey, nodeKey) - return false - } - } } } return true } -func (syncServicesReturnType *SyncServicesReturnType) Equals(other *SyncServicesReturnType) bool { - return syncServicesReturnType.Additions.Equals(other.Additions) && syncServicesReturnType.Removals.Equals(other.Removals) +func (s *SyncServicesReturnType) Equals(other *SyncServicesReturnType) bool { + return s.Additions.Equals(other.Additions) && s.Removals.Equals(other.Removals) } // Equals compares two LocationData objects for equality @@ -236,28 +211,28 @@ func (ld *LocationData) Equals(other *LocationData) bool { } // Equals compares two SyncDiffTrackerReturnType objects for equality -func (sdts *SyncDiffTrackerReturnType) Equals(other *SyncDiffTrackerReturnType) bool { - if sdts.SyncStatus != other.SyncStatus { +func (s *SyncDiffTrackerReturnType) Equals(other *SyncDiffTrackerReturnType) bool { + if s.SyncStatus != other.SyncStatus { return false } - if !sdts.LoadBalancerUpdates.Additions.Equals(other.LoadBalancerUpdates.Additions) { + if !s.LoadBalancerUpdates.Additions.Equals(other.LoadBalancerUpdates.Additions) { return false } - if !sdts.LoadBalancerUpdates.Removals.Equals(other.LoadBalancerUpdates.Removals) { + if !s.LoadBalancerUpdates.Removals.Equals(other.LoadBalancerUpdates.Removals) { return false } - if !sdts.NATGatewayUpdates.Additions.Equals(other.NATGatewayUpdates.Additions) { + if !s.NATGatewayUpdates.Additions.Equals(other.NATGatewayUpdates.Additions) { return false } - if !sdts.NATGatewayUpdates.Removals.Equals(other.NATGatewayUpdates.Removals) { + if !s.NATGatewayUpdates.Removals.Equals(other.NATGatewayUpdates.Removals) { return false } - if !sdts.LocationData.Equals(&other.LocationData) { + if !s.LocationData.Equals(&other.LocationData) { return false } @@ -266,11 +241,19 @@ func (sdts *SyncDiffTrackerReturnType) Equals(other *SyncDiffTrackerReturnType) // Equals compares two DiffTracker objects for equality func (dt *DiffTracker) Equals(other *DiffTracker) bool { - dt.mu.Lock() - defer dt.mu.Unlock() - - other.mu.Lock() - defer other.mu.Unlock() + // Lock both trackers in a consistent order (by pointer address) so that + // concurrent dt.Equals(other) and other.Equals(dt) calls can't deadlock. + // If both refer to the same object, lock only once to avoid self-deadlock. + first, second := dt, other + if reflect.ValueOf(first).Pointer() > reflect.ValueOf(second).Pointer() { + first, second = second, first + } + first.mu.Lock() + defer first.mu.Unlock() + if second != first { + second.mu.Lock() + defer second.mu.Unlock() + } if !dt.K8sResources.Services.Equals(other.K8sResources.Services) { return false @@ -347,3 +330,18 @@ func (dt *DiffTracker) Equals(other *DiffTracker) bool { return true } + +// addressExists reports whether an address with the given key exists in a location. +func addressExists(location NRPLocation, addressKey string) bool { + _, exists := location.Addresses[addressKey] + return exists +} + +// createServiceRefsFromAddress returns a copy of the address's service references. +func createServiceRefsFromAddress(addressValue Address) *utilsets.IgnoreCaseSet { + serviceRefs := utilsets.NewString() + for _, service := range addressValue.ServiceRef.UnsortedList() { + serviceRefs.Insert(service) + } + return serviceRefs +} diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go index 42f8168837..802796b398 100644 --- a/pkg/provider/difftracker/util_test.go +++ b/pkg/provider/difftracker/util_test.go @@ -32,9 +32,9 @@ func TestOperationStringAndJSON(t *testing.T) { op Operation expected string }{ - {"ADD operation", ADD, "ADD"}, - {"REMOVE operation", REMOVE, "REMOVE"}, - {"UPDATE operation", UPDATE, "UPDATE"}, + {"Add operation", Add, "Add"}, + {"Remove operation", Remove, "Remove"}, + {"Update operation", Update, "Update"}, } for _, tt := range tests { @@ -659,10 +659,10 @@ func TestJSONRoundTrip(t *testing.T) { }) t.Run("Operation round trip", func(t *testing.T) { - original := ADD + original := Add data, err := json.Marshal(original) assert.NoError(t, err) - assert.Equal(t, `"ADD"`, string(data)) + assert.Equal(t, `"Add"`, string(data)) }) t.Run("SyncStatus round trip", func(t *testing.T) { @@ -672,3 +672,296 @@ func TestJSONRoundTrip(t *testing.T) { assert.Equal(t, `"Success"`, string(data)) }) } + +// TestLocationDataEqualsMoreCases covers additional LocationData.Equals branches. +func TestLocationDataEqualsMoreCases(t *testing.T) { + base := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, + }, + }, + } + equal := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, + }, + }, + } + assert.True(t, base.Equals(&equal)) + + // Different top-level Action. + diffAction := equal + diffAction.Action = FullUpdate + assert.False(t, base.Equals(&diffAction)) + + // Different number of locations. + diffLen := LocationData{Action: PartialUpdate, Locations: map[string]Location{}} + assert.False(t, base.Equals(&diffLen)) + + // Missing location name. + diffName := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{"node2": base.Locations["node1"]}, + } + assert.False(t, base.Equals(&diffName)) + + // Different AddressUpdateAction. + diffAUA := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: FullUpdate, Addresses: base.Locations["node1"].Addresses}, + }, + } + assert.False(t, base.Equals(&diffAUA)) + + // Different addresses length. + diffAddrLen := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + } + assert.False(t, base.Equals(&diffAddrLen)) + + // Missing address name. + diffAddrName := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.2": {ServiceRef: sets.NewString("svc1")}}}, + }, + } + assert.False(t, base.Equals(&diffAddrName)) + + // Different ServiceRef. + diffRef := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc2")}}}, + }, + } + assert.False(t, base.Equals(&diffRef)) +} + +// TestSyncDiffTrackerReturnTypeEquals covers SyncDiffTrackerReturnType.Equals branches. +func TestSyncDiffTrackerReturnTypeEquals(t *testing.T) { + mk := func() SyncDiffTrackerReturnType { + return SyncDiffTrackerReturnType{ + SyncStatus: Success, + LoadBalancerUpdates: SyncServicesReturnType{Additions: sets.NewString("a"), Removals: sets.NewString("b")}, + NATGatewayUpdates: SyncServicesReturnType{Additions: sets.NewString("c"), Removals: sets.NewString("d")}, + LocationData: LocationData{Action: PartialUpdate, Locations: map[string]Location{}}, + } + } + + base := mk() + equal := mk() + assert.True(t, base.Equals(&equal)) + + // Different SyncStatus. + diffStatus := mk() + diffStatus.SyncStatus = AlreadyInSync + assert.False(t, base.Equals(&diffStatus)) + + // Different LB additions. + diffLBAdd := mk() + diffLBAdd.LoadBalancerUpdates.Additions = sets.NewString("x") + assert.False(t, base.Equals(&diffLBAdd)) + + // Different LB removals. + diffLBRem := mk() + diffLBRem.LoadBalancerUpdates.Removals = sets.NewString("x") + assert.False(t, base.Equals(&diffLBRem)) + + // Different NATGW additions. + diffNGAdd := mk() + diffNGAdd.NATGatewayUpdates.Additions = sets.NewString("x") + assert.False(t, base.Equals(&diffNGAdd)) + + // Different NATGW removals. + diffNGRem := mk() + diffNGRem.NATGatewayUpdates.Removals = sets.NewString("x") + assert.False(t, base.Equals(&diffNGRem)) + + // Different LocationData. + diffLoc := mk() + diffLoc.LocationData.Action = FullUpdate + assert.False(t, base.Equals(&diffLoc)) +} + +// TestDiffTrackerEqualsMoreCases covers additional DiffTracker.Equals branches. +func TestDiffTrackerEqualsMoreCases(t *testing.T) { + mk := func() *DiffTracker { + return &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egr1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1"), + NATGateways: sets.NewString("egr1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }}, + }, + }, + } + } + + assert.True(t, mk().Equals(mk())) + + // Different Services. + a := mk() + b := mk() + b.K8sResources.Services = sets.NewString("other") + assert.False(t, a.Equals(b)) + + // Different Egresses. + b = mk() + b.K8sResources.Egresses = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different node count. + b = mk() + b.K8sResources.Nodes["node2"] = Node{Pods: map[string]Pod{}} + assert.False(t, mk().Equals(b)) + + // Missing node name. + b = mk() + delete(b.K8sResources.Nodes, "node1") + b.K8sResources.Nodes["nodeX"] = Node{Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }} + assert.False(t, mk().Equals(b)) + + // Different pod count. + b = mk() + n := b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.2"] = Pod{InboundIdentities: sets.NewString()} + assert.False(t, mk().Equals(b)) + + // Missing pod address. + b = mk() + n = b.K8sResources.Nodes["node1"] + delete(n.Pods, "10.0.0.1") + n.Pods["10.0.0.9"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"} + assert.False(t, mk().Equals(b)) + + // Different InboundIdentities. + b = mk() + n = b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("other"), PublicOutboundIdentity: "egr1"} + assert.False(t, mk().Equals(b)) + + // Different PublicOutboundIdentity. + b = mk() + n = b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "other"} + assert.False(t, mk().Equals(b)) + + // Different NRP LoadBalancers. + b = mk() + b.NRPResources.LoadBalancers = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different NRP NATGateways. + b = mk() + b.NRPResources.NATGateways = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different NRP location count. + b = mk() + b.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} + assert.False(t, mk().Equals(b)) +} + +// TestDeepEqualMoreCases covers the node/pod/identity mismatch branches of DeepEqual. +func TestDeepEqualMoreCases(t *testing.T) { + // Helper to build a DiffTracker that is in-sync by construction. + inSync := func() *DiffTracker { + return &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egr1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1"), + NATGateways: sets.NewString("egr1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }}, + }, + }, + } + } + + assert.True(t, inSync().DeepEqual()) + + // LoadBalancer present in NRP but not in K8s Services (reverse-direction check). + d := inSync() + d.NRPResources.LoadBalancers = sets.NewString("svc1", "extra") + d.K8sResources.Services = sets.NewString("svc1", "different") + assert.False(t, d.DeepEqual()) + + // Egress name mismatch (reverse direction): lengths equal (1==1) but names + // differ -> mismatch. + d = inSync() + d.K8sResources.Egresses = sets.NewString("egr2") + d.NRPResources.NATGateways = sets.NewString("egr2x") + assert.False(t, d.DeepEqual()) + + // Nodes vs Locations length mismatch. + d = inSync() + d.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} + assert.False(t, d.DeepEqual()) + + // Node missing in Locations (same count, different key). + d = inSync() + delete(d.NRPResources.Locations, "node1") + d.NRPResources.Locations["nodeX"] = NRPLocation{Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }} + assert.False(t, d.DeepEqual()) + + // Pods vs Addresses length mismatch. + d = inSync() + loc := d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.2"] = NRPAddress{Services: sets.NewString("svc1")} + assert.False(t, d.DeepEqual()) + + // Pod missing in Addresses (same count, different key). + d = inSync() + loc = d.NRPResources.Locations["node1"] + delete(loc.Addresses, "10.0.0.1") + loc.Addresses["10.0.0.9"] = NRPAddress{Services: sets.NewString("svc1", "egr1")} + assert.False(t, d.DeepEqual()) + + // Combined identities length mismatch. + d = inSync() + loc = d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1")} + assert.False(t, d.DeepEqual()) + + // Identity not found in Services (same count, different identity). + d = inSync() + loc = d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1", "egrX")} + assert.False(t, d.DeepEqual()) +} diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index a9f79a407f..5a1c9b1da1 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -125,3 +125,16 @@ func (s *IgnoreCaseSet) Equals(other *IgnoreCaseSet) bool { } return true } + +// Difference returns a new IgnoreCaseSet containing the items in s that are not +// present in other (i.e. the set difference s \ other). It is safe to call on +// nil or uninitialized sets. +func (s *IgnoreCaseSet) Difference(other *IgnoreCaseSet) *IgnoreCaseSet { + result := NewString() + for _, item := range s.UnsortedList() { + if !other.Has(item) { + result.Insert(item) + } + } + return result +} diff --git a/pkg/util/sets/string_test.go b/pkg/util/sets/string_test.go index 968cdb3dc7..0b380bd522 100644 --- a/pkg/util/sets/string_test.go +++ b/pkg/util/sets/string_test.go @@ -425,23 +425,78 @@ func TestEquals(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.s1 == nil && tt.s2 == nil { - // Special case for nil sets - if !tt.want { - t.Errorf("Equals() = true, want %v", tt.want) - } - return - } - if tt.s1 == nil || tt.s2 == nil { - // One set is nil, they can't be equal - if tt.want { - t.Errorf("Equals() = false, want %v", tt.want) - } - return - } + // Equals is nil-safe on both the receiver and the argument, so call it + // directly for every case (including the nil ones) to exercise the + // production code path rather than asserting an expectation against itself. if got := tt.s1.Equals(tt.s2); got != tt.want { t.Errorf("Equals() = %v, want %v", got, tt.want) } }) } } + +func TestDifference(t *testing.T) { + tests := []struct { + name string + s *IgnoreCaseSet + other *IgnoreCaseSet + want *IgnoreCaseSet + }{ + { + name: "disjoint sets", + s: NewString("foo", "bar"), + other: NewString("baz"), + want: NewString("foo", "bar"), + }, + { + name: "partial overlap", + s: NewString("foo", "bar", "baz"), + other: NewString("bar"), + want: NewString("foo", "baz"), + }, + { + name: "all removed", + s: NewString("foo", "bar"), + other: NewString("foo", "bar"), + want: NewString(), + }, + { + name: "case-insensitive", + s: NewString("Foo", "BAR"), + other: NewString("foo"), + want: NewString("bar"), + }, + { + name: "other empty", + s: NewString("foo", "bar"), + other: NewString(), + want: NewString("foo", "bar"), + }, + { + name: "s empty", + s: NewString(), + other: NewString("foo"), + want: NewString(), + }, + { + name: "other nil", + s: NewString("foo"), + other: nil, + want: NewString("foo"), + }, + { + name: "s nil", + s: nil, + other: NewString("foo"), + want: NewString(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.s.Difference(tt.other) + if !got.Equals(tt.want) { + t.Errorf("Difference() = %v, want %v", got.UnsortedList(), tt.want.UnsortedList()) + } + }) + } +} From bd70126881c134a70f43424099fee290ded194d1 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Fri, 19 Jun 2026 15:21:09 +0000 Subject: [PATCH 08/18] difftracker: drop unused public DeepEqual wrapper DeepEqual() had no production callers; GetSyncOperations uses the lock-free deepEqualLocked() body under a single held lock. Tests now call deepEqualLocked() directly. --- pkg/provider/difftracker/difftracker_test.go | 2 +- pkg/provider/difftracker/util.go | 11 ++-------- pkg/provider/difftracker/util_test.go | 22 ++++++++++---------- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index 607eb708e4..a47bca0d3b 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -86,7 +86,7 @@ func TestDiffTracker_DeepEqual(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := tt.dt.DeepEqual() + result := tt.dt.deepEqualLocked() assert.Equal(t, tt.expected, result) }) } diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index a4426317e7..0719391047 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -103,15 +103,8 @@ func (pod *Pod) HasIdentities() bool { return pod.InboundIdentities.Len() > 0 || pod.PublicOutboundIdentity != "" } -// DeepEqual compares the K8s and NRP states to check if they are in sync. -// It acquires dt.mu so callers get a consistent snapshot. -func (dt *DiffTracker) DeepEqual() bool { - dt.mu.Lock() - defer dt.mu.Unlock() - return dt.deepEqualLocked() -} - -// deepEqualLocked is the lock-free body of DeepEqual. Callers must hold dt.mu. +// deepEqualLocked compares the K8s and NRP states to check if they are in sync. +// Callers must hold dt.mu. func (dt *DiffTracker) deepEqualLocked() bool { klog.V(4).Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d, K8s Nodes=%d, NRP Locations=%d", dt.K8sResources.Services.Len(), dt.NRPResources.LoadBalancers.Len(), diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go index 802796b398..dbd7d18e99 100644 --- a/pkg/provider/difftracker/util_test.go +++ b/pkg/provider/difftracker/util_test.go @@ -195,7 +195,7 @@ func TestPodHasIdentities(t *testing.T) { } } -// TestDeepEqual tests DiffTracker.DeepEqual() +// TestDeepEqual tests DiffTracker.deepEqualLocked() func TestDeepEqual(t *testing.T) { tests := []struct { name string @@ -286,7 +286,7 @@ func TestDeepEqual(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, tt.dt.DeepEqual()) + assert.Equal(t, tt.expected, tt.dt.deepEqualLocked()) }) } } @@ -912,25 +912,25 @@ func TestDeepEqualMoreCases(t *testing.T) { } } - assert.True(t, inSync().DeepEqual()) + assert.True(t, inSync().deepEqualLocked()) // LoadBalancer present in NRP but not in K8s Services (reverse-direction check). d := inSync() d.NRPResources.LoadBalancers = sets.NewString("svc1", "extra") d.K8sResources.Services = sets.NewString("svc1", "different") - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Egress name mismatch (reverse direction): lengths equal (1==1) but names // differ -> mismatch. d = inSync() d.K8sResources.Egresses = sets.NewString("egr2") d.NRPResources.NATGateways = sets.NewString("egr2x") - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Nodes vs Locations length mismatch. d = inSync() d.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Node missing in Locations (same count, different key). d = inSync() @@ -938,30 +938,30 @@ func TestDeepEqualMoreCases(t *testing.T) { d.NRPResources.Locations["nodeX"] = NRPLocation{Addresses: map[string]NRPAddress{ "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, }} - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Pods vs Addresses length mismatch. d = inSync() loc := d.NRPResources.Locations["node1"] loc.Addresses["10.0.0.2"] = NRPAddress{Services: sets.NewString("svc1")} - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Pod missing in Addresses (same count, different key). d = inSync() loc = d.NRPResources.Locations["node1"] delete(loc.Addresses, "10.0.0.1") loc.Addresses["10.0.0.9"] = NRPAddress{Services: sets.NewString("svc1", "egr1")} - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Combined identities length mismatch. d = inSync() loc = d.NRPResources.Locations["node1"] loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1")} - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) // Identity not found in Services (same count, different identity). d = inSync() loc = d.NRPResources.Locations["node1"] loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1", "egrX")} - assert.False(t, d.DeepEqual()) + assert.False(t, d.deepEqualLocked()) } From a852c29eb6833f6a5baa5a54e5a2aa177baea1eb Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Mon, 22 Jun 2026 09:19:46 +0000 Subject: [PATCH 09/18] addressed comments + added general improvements --- pkg/provider/difftracker/difftracker.go | 8 ++ pkg/provider/difftracker/difftracker_test.go | 75 +++++++++++++++++++ pkg/provider/difftracker/k8s_state_updates.go | 14 ++-- pkg/provider/difftracker/nrp_state_updates.go | 4 +- pkg/util/sets/string.go | 11 ++- pkg/util/sets/string_test.go | 11 +++ 6 files changed, 110 insertions(+), 13 deletions(-) diff --git a/pkg/provider/difftracker/difftracker.go b/pkg/provider/difftracker/difftracker.go index 01727b2545..3066d690bb 100644 --- a/pkg/provider/difftracker/difftracker.go +++ b/pkg/provider/difftracker/difftracker.go @@ -74,5 +74,13 @@ func New(k8s K8sState, nrp NRPState, config Config, networkClientFactory azclien kubeClient: kubeClient, } + // Seed the outbound ref-counter from egress pods already in the initial state + // so a later REMOVE can drive the counter to zero. + for _, node := range k8s.Nodes { + for _, pod := range node.Pods { + diffTracker.incrementOutboundRefCount(pod.PublicOutboundIdentity) + } + } + return diffTracker, nil } diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index a47bca0d3b..e79bad7b78 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -815,6 +815,40 @@ func TestNewErrorPaths(t *testing.T) { assert.NotNil(t, dt.NRPResources.Locations) } +// TestNewSeedsOutboundRefCount verifies New seeds the outbound ref-counter from +// the egress pods already present in the initial state, so the counter is +// non-zero for identities that have backing pods at construction time. +func TestNewSeedsOutboundRefCount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + + k8s := emptyK8sState() + k8s.Nodes["node1"] = Node{Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + "10.0.0.2": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + "10.0.0.3": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: ""}, + }} + k8s.Nodes["node2"] = Node{Pods: map[string]Pod{ + "10.0.1.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "Egress2"}, + }} + + dt, err := New(k8s, emptyNRPState(), validTestConfig(), mockFactory, mockKubeClient) + assert.NoError(t, err) + + val, ok := dt.outboundIdentityPodRefCount.Load("egress1") + assert.True(t, ok) + assert.Equal(t, 2, val.(int)) + + val, ok = dt.outboundIdentityPodRefCount.Load("egress2") + assert.True(t, ok, "identity key is lowercased") + assert.Equal(t, 1, val.(int)) + + _, ok = dt.outboundIdentityPodRefCount.Load("") + assert.False(t, ok, "pods without an egress identity are not counted") +} + // TestEnqueueK8sResourceOperationErrors covers the empty-ID and invalid-operation // branches of enqueueK8sResourceOperation (via the public wrappers). func TestEnqueueK8sResourceOperationErrors(t *testing.T) { @@ -1026,3 +1060,44 @@ func TestUpdateK8sEndpointsAddThenRemove(t *testing.T) { _, ok := dt.K8sResources.Nodes["node1"] assert.False(t, ok) } + +// TestUpdateK8sEndpointsRelocation covers the case where the same pod IP appears +// in both OldAddresses and NewAddresses but on a different node (relocation): the +// pod must be removed from the old node and added to the new one. +func TestUpdateK8sEndpointsRelocation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("svc1")) + + // Same pod IP moves from node1 to node2. + errs = dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + NewAddresses: map[string]string{"10.0.0.1": "node2"}, + }) + assert.Empty(t, errs) + + // Old node is gone, new node holds the pod with svc1. + _, ok := dt.K8sResources.Nodes["node1"] + assert.False(t, ok, "pod must be removed from the old node") + pod, ok := dt.K8sResources.Nodes["node2"].Pods["10.0.0.1"] + assert.True(t, ok, "pod must be added to the new node") + assert.True(t, pod.InboundIdentities.Has("svc1")) +} + +func TestUpdateK8sPodRejectsEmptyLocationOrAddress(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + err := dt.UpdateK8sPod(UpdatePodInputType{PodOperation: Add, Location: "", Address: "10.0.0.1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must not be empty") + + err = dt.UpdateK8sPod(UpdatePodInputType{PodOperation: Add, Location: "node1", Address: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must not be empty") +} diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index bba35dc6e0..04adee1630 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -83,7 +83,7 @@ func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputTyp continue } - if _, exists := input.OldAddresses[address]; exists { + if oldLocation, exists := input.OldAddresses[address]; exists && oldLocation == location { continue } @@ -103,7 +103,7 @@ func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputTyp } for address, location := range input.OldAddresses { - if _, exists := input.NewAddresses[address]; exists { + if newLocation, exists := input.NewAddresses[address]; exists && newLocation == location { continue } @@ -146,7 +146,7 @@ func (dt *DiffTracker) UpdateK8sEndpoints(input UpdateK8sEndpointsInputType) []e return dt.updateK8sEndpointsLocked(input) } -func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) error { +func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) { node, exists := dt.K8sResources.Nodes[input.Location] if !exists { node = newNode() @@ -161,8 +161,6 @@ func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) error { pod.PublicOutboundIdentity = input.PublicOutboundIdentity node.Pods[input.Address] = pod klog.V(2).Infof("addOrUpdatePod: Set outbound identity %q for pod %s on node %s", input.PublicOutboundIdentity, input.Address, input.Location) - - return nil } // removePod removes a pod from K8s state. Returns true if the pod was actually @@ -226,6 +224,9 @@ func (dt *DiffTracker) decrementOutboundRefCount(identity string) error { // updateK8sPodLocked updates K8s pod state. Assumes lock is already held. func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { + if input.Location == "" || input.Address == "" { + return fmt.Errorf("updateK8sPodLocked: Location and Address must not be empty (location=%q, address=%q)", input.Location, input.Address) + } switch input.PodOperation { case Add, Update: // Determine the pod's current (old) outbound identity, if any, and whether @@ -257,7 +258,8 @@ func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { } dt.incrementOutboundRefCount(input.PublicOutboundIdentity) } - return dt.addOrUpdatePod(input) + dt.addOrUpdatePod(input) + return nil case Remove: // First, try to remove the pod from K8s state. // removePod returns false if the pod doesn't exist (duplicate removal). diff --git a/pkg/provider/difftracker/nrp_state_updates.go b/pkg/provider/difftracker/nrp_state_updates.go index 94d427493c..df7c5419de 100644 --- a/pkg/provider/difftracker/nrp_state_updates.go +++ b/pkg/provider/difftracker/nrp_state_updates.go @@ -76,7 +76,9 @@ func (dt *DiffTracker) UpdateLocationsAddresses(locationData LocationData) { for addressKey, addressValue := range locationValue.Addresses { // Remove empty addresses if addressValue.ServiceRef.Len() == 0 { - delete(nrpLocation.Addresses, addressKey) + if !isFullUpdate { + delete(nrpLocation.Addresses, addressKey) + } continue } diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index 5a1c9b1da1..83f881ca85 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -38,14 +38,13 @@ func NewString(items ...string) *IgnoreCaseSet { return &IgnoreCaseSet{set: set} } -// Insert adds the given items to the set. It only works if the set is initialized. +// Insert adds the given items to the set, initializing the underlying set if needed. func (s *IgnoreCaseSet) Insert(items ...string) { - var lowerItems []string - for _, item := range items { - lowerItems = append(lowerItems, strings.ToLower(item)) + if s.set == nil { + s.set = sets.New[string]() } - for _, item := range lowerItems { - s.set.Insert(item) + for _, item := range items { + s.set.Insert(strings.ToLower(item)) } } diff --git a/pkg/util/sets/string_test.go b/pkg/util/sets/string_test.go index 0b380bd522..79d640d0a2 100644 --- a/pkg/util/sets/string_test.go +++ b/pkg/util/sets/string_test.go @@ -500,3 +500,14 @@ func TestDifference(t *testing.T) { }) } } + +func TestInsertOnUninitializedSet(t *testing.T) { + s := &IgnoreCaseSet{} + s.Insert("Foo", "BAR") + if !s.Has("foo") || !s.Has("bar") { + t.Errorf("Insert on uninitialized set should lazily initialize and add items, got %v", s.UnsortedList()) + } + if s.Len() != 2 { + t.Errorf("expected len 2, got %d", s.Len()) + } +} From b4f46579fcd332359785a56e7ac8c782dc17ecd9 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 23 Jun 2026 09:22:40 +0000 Subject: [PATCH 10/18] difftracker: revert conditional empty-address delete to unconditional The full-update branch builds a fresh address map, so guarding the delete with !isFullUpdate only skipped a no-op while reading as an inconsistency. Restore the unconditional delete: empty-ServiceRef addresses are always removed, and an all-empty location is dropped by the trailing cleanup. --- pkg/provider/difftracker/nrp_state_updates.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/provider/difftracker/nrp_state_updates.go b/pkg/provider/difftracker/nrp_state_updates.go index df7c5419de..94d427493c 100644 --- a/pkg/provider/difftracker/nrp_state_updates.go +++ b/pkg/provider/difftracker/nrp_state_updates.go @@ -76,9 +76,7 @@ func (dt *DiffTracker) UpdateLocationsAddresses(locationData LocationData) { for addressKey, addressValue := range locationValue.Addresses { // Remove empty addresses if addressValue.ServiceRef.Len() == 0 { - if !isFullUpdate { - delete(nrpLocation.Addresses, addressKey) - } + delete(nrpLocation.Addresses, addressKey) continue } From 462c9fcbb43adc017e970a1977f364ebd65fad9d Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 23 Jun 2026 12:17:30 +0000 Subject: [PATCH 11/18] difftracker: fix egress-remove inbound loss and gone-node address leak Make the pod Remove path identity-aware: clear only the matching outbound identity and delete the pod only once it has no inbound or outbound identities left. Previously removePod deleted the whole pod record, silently dropping the inbound (LoadBalancer) backing of a pod that was both an egress pod and an LB backend when its egress was removed/changed. Fix the whole-node removal path in getSyncLocationsAddressesLocked: when a node is gone from K8s but still present in NRP, enumerate each NRP address with an empty ServiceRef instead of emitting an empty Addresses map. An empty Addresses map under PartialUpdate is a no-op on the Service Gateway, so the stale addresses were never removed while the local model dropped the location, leaving NRP and the SGW permanently out of sync. Also make the PublicOutboundIdentity comparison in DiffTracker.Equals case-insensitive for consistency with the rest of the package. Adds regression tests for both fixes and updates TestUpdateK8sPod to pass the egress identity on Remove (matching how the engine drives it). --- pkg/provider/difftracker/difftracker_test.go | 98 ++++++++++++++++++- pkg/provider/difftracker/k8s_state_updates.go | 48 +++++---- pkg/provider/difftracker/sync_operations.go | 9 +- pkg/provider/difftracker/util.go | 3 +- 4 files changed, 132 insertions(+), 26 deletions(-) diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index e79bad7b78..6ac3f8ed23 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -197,9 +197,10 @@ func TestUpdateK8sPod(t *testing.T) { // Test removing egress assignment input = UpdatePodInputType{ - PodOperation: Remove, - Location: "node1", - Address: "10.0.0.1", + PodOperation: Remove, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", } err = dt.UpdateK8sPod(input) @@ -1101,3 +1102,94 @@ func TestUpdateK8sPodRejectsEmptyLocationOrAddress(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "must not be empty") } + +// TestUpdateK8sPodRemovePreservesInboundIdentities verifies that removing a pod's +// egress assignment clears only the outbound identity and keeps the pod while it +// still backs an inbound (LoadBalancer) service, whereas an egress-only pod is +// removed entirely and its ref-counter released. +func TestUpdateK8sPodRemovePreservesInboundIdentities(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "lb1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "egressA", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("egressa") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "egressA", + Location: "node1", + Address: "10.0.0.1", + })) + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.True(t, ok, "pod backing an inbound service must be kept") + assert.True(t, pod.InboundIdentities.Has("lb1")) + assert.Equal(t, "", pod.PublicOutboundIdentity) + _, ok = dt.outboundIdentityPodRefCount.Load("egressa") + assert.False(t, ok, "egress ref-counter must be released") + + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "egressB", + Location: "node2", + Address: "10.0.0.2", + })) + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "egressB", + Location: "node2", + Address: "10.0.0.2", + })) + _, nodeOK := dt.K8sResources.Nodes["node2"] + assert.False(t, nodeOK, "egress-only pod and its empty node must be removed") + _, ok = dt.outboundIdentityPodRefCount.Load("egressb") + assert.False(t, ok) +} + +// TestGetSyncLocationsAddressesRemovesGoneNodeAddresses verifies that when a node +// is gone from K8s but still present in NRP, every address is emitted with an empty +// ServiceRef so the PartialUpdate removes them on the Service Gateway, and applying +// the result drops the location locally. +func TestGetSyncLocationsAddressesRemovesGoneNodeAddresses(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Nodes: map[string]Node{}}, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1", "service2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{ + "node1": { + Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1")}, + "10.0.0.2": {Services: sets.NewString("service2")}, + }, + }, + }, + }, + } + + result := dt.GetSyncLocationsAddresses() + + loc, ok := result.Locations["node1"] + assert.True(t, ok) + assert.Equal(t, PartialUpdate, loc.AddressUpdateAction) + assert.Len(t, loc.Addresses, 2) + for _, addr := range []string{"10.0.0.1", "10.0.0.2"} { + a, ok := loc.Addresses[addr] + assert.True(t, ok, "address %s must be enumerated for removal", addr) + assert.Equal(t, 0, a.ServiceRef.Len(), "address %s must have empty ServiceRef", addr) + } + + dt.UpdateLocationsAddresses(result) + _, ok = dt.NRPResources.Locations["node1"] + assert.False(t, ok, "gone node's location must be removed locally") +} diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go index 04adee1630..46fa116bd3 100644 --- a/pkg/provider/difftracker/k8s_state_updates.go +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -163,25 +163,36 @@ func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) { klog.V(2).Infof("addOrUpdatePod: Set outbound identity %q for pod %s on node %s", input.PublicOutboundIdentity, input.Address, input.Location) } -// removePod removes a pod from K8s state. Returns true if the pod was actually -// removed, or false if it didn't exist (already removed by a previous call). -func (dt *DiffTracker) removePod(input UpdatePodInputType) bool { - node, exists := dt.K8sResources.Nodes[input.Location] - if !exists { - return false +// removePod clears the pod's outbound identity when it matches the input and +// deletes the pod only once it has no identities left, so an egress removal never +// drops a pod's inbound (LoadBalancer) backing. It returns whether the pod existed +// and any error from decrementing the ref-counter. +func (dt *DiffTracker) removePod(input UpdatePodInputType) (existed bool, err error) { + node, nodeExists := dt.K8sResources.Nodes[input.Location] + if !nodeExists { + return false, nil + } + + pod, podExists := node.Pods[input.Address] + if !podExists { + return false, nil } - if _, podExists := node.Pods[input.Address]; !podExists { - return false + if pod.PublicOutboundIdentity != "" && strings.EqualFold(pod.PublicOutboundIdentity, input.PublicOutboundIdentity) { + pod.PublicOutboundIdentity = "" + node.Pods[input.Address] = pod + err = dt.decrementOutboundRefCount(input.PublicOutboundIdentity) } - delete(node.Pods, input.Address) - if !node.HasPods() { - delete(dt.K8sResources.Nodes, input.Location) + if !pod.HasIdentities() { + delete(node.Pods, input.Address) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, input.Location) + } + klog.V(2).Infof("removePod: Removed pod %s from node %s", input.Address, input.Location) } - klog.V(2).Infof("removePod: Removed pod %s from node %s", input.Address, input.Location) - return true + return true, err } // incrementOutboundRefCount increments the ref-counter for a pod's outbound @@ -261,18 +272,13 @@ func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { dt.addOrUpdatePod(input) return nil case Remove: - // First, try to remove the pod from K8s state. - // removePod returns false if the pod doesn't exist (duplicate removal). - if !dt.removePod(input) { + existed, err := dt.removePod(input) + if !existed { klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s was already removed (duplicate delete), skipping counter decrement", input.Location, input.Address) return nil } - - // Decrement the ref-counter for the outbound identity passed in the - // remove input. A missing key (e.g. an empty identity, which is never - // counted) is a no-op. - return dt.decrementOutboundRefCount(input.PublicOutboundIdentity) + return err default: return fmt.Errorf("invalid pod operation: %s for pod at %s:%s", input.PodOperation, input.Location, input.Address) diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index ae6e13128a..fd7d934ff6 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -116,10 +116,17 @@ func (dt *DiffTracker) getSyncLocationsAddressesLocked() LocationData { for location, nrpLocation := range dt.NRPResources.Locations { node, exists := dt.K8sResources.Nodes[location] if !exists { - result.Locations[location] = Location{ + // Node gone from K8s but still in NRP: enumerate each address with an + // empty ServiceRef so the PartialUpdate removes them on the SGW. An empty + // Addresses map under PartialUpdate is a no-op that would leak them. + loc := Location{ AddressUpdateAction: PartialUpdate, Addresses: make(map[string]Address), } + for address := range nrpLocation.Addresses { + loc.Addresses[address] = Address{ServiceRef: utilsets.NewString()} + } + result.Locations[location] = loc } else { locationData := findLocationData(result, location) if locationData == nil { diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go index 0719391047..4406285e51 100644 --- a/pkg/provider/difftracker/util.go +++ b/pkg/provider/difftracker/util.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "k8s.io/klog/v2" @@ -280,7 +281,7 @@ func (dt *DiffTracker) Equals(other *DiffTracker) bool { return false } - if pod.PublicOutboundIdentity != otherPod.PublicOutboundIdentity { + if !strings.EqualFold(pod.PublicOutboundIdentity, otherPod.PublicOutboundIdentity) { return false } } From 2f81e3a763b2365adb4e15e84d5d7f921e40a032 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 23 Jun 2026 13:03:56 +0000 Subject: [PATCH 12/18] difftracker: revert gone-node address enumeration Emitting an empty Addresses map under PartialUpdate already deletes the whole location on the Service Gateway (a null/empty addresses array deletes the location). Enumerating each address with an empty ServiceRef instead deletes the addresses but can leave an empty location container behind, which is strictly worse than the location-level deletion. Revert getSyncLocationsAddressesLocked's gone-node branch to emit an empty Addresses map and update the test accordingly. The egress-remove inbound-loss fix and the EqualFold change from the previous commit are kept. --- pkg/provider/difftracker/difftracker_test.go | 17 ++++++----------- pkg/provider/difftracker/sync_operations.go | 9 +-------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go index 6ac3f8ed23..274c739b7d 100644 --- a/pkg/provider/difftracker/difftracker_test.go +++ b/pkg/provider/difftracker/difftracker_test.go @@ -1156,11 +1156,11 @@ func TestUpdateK8sPodRemovePreservesInboundIdentities(t *testing.T) { assert.False(t, ok) } -// TestGetSyncLocationsAddressesRemovesGoneNodeAddresses verifies that when a node -// is gone from K8s but still present in NRP, every address is emitted with an empty -// ServiceRef so the PartialUpdate removes them on the Service Gateway, and applying -// the result drops the location locally. -func TestGetSyncLocationsAddressesRemovesGoneNodeAddresses(t *testing.T) { +// TestGetSyncLocationsAddressesRemovesGoneNode verifies that when a node is gone +// from K8s but still present in NRP, the location is emitted with an empty +// Addresses map (a PartialUpdate that deletes the whole location on the Service +// Gateway), and applying the result drops the location locally. +func TestGetSyncLocationsAddressesRemovesGoneNode(t *testing.T) { dt := &DiffTracker{ K8sResources: K8sState{Nodes: map[string]Node{}}, NRPResources: NRPState{ @@ -1182,12 +1182,7 @@ func TestGetSyncLocationsAddressesRemovesGoneNodeAddresses(t *testing.T) { loc, ok := result.Locations["node1"] assert.True(t, ok) assert.Equal(t, PartialUpdate, loc.AddressUpdateAction) - assert.Len(t, loc.Addresses, 2) - for _, addr := range []string{"10.0.0.1", "10.0.0.2"} { - a, ok := loc.Addresses[addr] - assert.True(t, ok, "address %s must be enumerated for removal", addr) - assert.Equal(t, 0, a.ServiceRef.Len(), "address %s must have empty ServiceRef", addr) - } + assert.Empty(t, loc.Addresses) dt.UpdateLocationsAddresses(result) _, ok = dt.NRPResources.Locations["node1"] diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index fd7d934ff6..ae6e13128a 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -116,17 +116,10 @@ func (dt *DiffTracker) getSyncLocationsAddressesLocked() LocationData { for location, nrpLocation := range dt.NRPResources.Locations { node, exists := dt.K8sResources.Nodes[location] if !exists { - // Node gone from K8s but still in NRP: enumerate each address with an - // empty ServiceRef so the PartialUpdate removes them on the SGW. An empty - // Addresses map under PartialUpdate is a no-op that would leak them. - loc := Location{ + result.Locations[location] = Location{ AddressUpdateAction: PartialUpdate, Addresses: make(map[string]Address), } - for address := range nrpLocation.Addresses { - loc.Addresses[address] = Address{ServiceRef: utilsets.NewString()} - } - result.Locations[location] = loc } else { locationData := findLocationData(result, location) if locationData == nil { From 0d05062bdcb66f77cb2dd4f1f1a78c48a41a4122 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 23 Jun 2026 13:09:48 +0000 Subject: [PATCH 13/18] difftracker: rename isServiceReady to isServiceReadyToSync The bare "ready" collides with the conventional Service readiness meaning (a Service is "ready" once its external IP is provisioned and serving traffic). isServiceReadyToSync names the actual intent - whether the service is ready to be synced to the Service Gateway - without overloading "ready". --- pkg/provider/difftracker/sync_operations.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go index ae6e13128a..0923400cb2 100644 --- a/pkg/provider/difftracker/sync_operations.go +++ b/pkg/provider/difftracker/sync_operations.go @@ -161,14 +161,14 @@ func (dt *DiffTracker) createServiceRefFiltered(pod Pod) *utilsets.IgnoreCaseSet // Check inbound services (LoadBalancers) for _, serviceUID := range pod.InboundIdentities.UnsortedList() { - if dt.isServiceReady(serviceUID, true) { + if dt.isServiceReadyToSync(serviceUID, true) { serviceRef.Insert(serviceUID) } } // Check outbound service (NAT Gateway) if pod.PublicOutboundIdentity != "" { - if dt.isServiceReady(pod.PublicOutboundIdentity, false) { + if dt.isServiceReadyToSync(pod.PublicOutboundIdentity, false) { serviceRef.Insert(pod.PublicOutboundIdentity) } } @@ -176,10 +176,9 @@ func (dt *DiffTracker) createServiceRefFiltered(pod Pod) *utilsets.IgnoreCaseSet return serviceRef } -// isServiceReady checks if a service is ready for location sync. -// Returns true if the service exists in NRP. -// Must be called with dt.mu held. -func (dt *DiffTracker) isServiceReady(serviceUID string, isInbound bool) bool { +// isServiceReadyToSync reports whether a service is ready to be synced to the +// Service Gateway, i.e. its NRP resource exists. Must be called with dt.mu held. +func (dt *DiffTracker) isServiceReadyToSync(serviceUID string, isInbound bool) bool { if isInbound { return dt.NRPResources.LoadBalancers.Has(serviceUID) } From 5ee8d536e6bde4552936d79e154434f7c0affa0a Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:39:59 +0000 Subject: [PATCH 14/18] CP2: ServiceGateway Azure operations layer Adds the Azure operations + resource-builder layer for Container LB: - azure_operations.go: PIP/LB/NAT Gateway CRUD + ServiceGateway service/location updates - resource_helpers.go: inbound/outbound resource builders, config extraction - dto_mappers.go: K8s updates -> ServiceGateway ServicesDataDTO mapper - types.go: service config + ServiceGateway DTO types (deferred from CP1) - consts.go: NatGatewayIDTemplate Adapted to public armnetwork/v9 SDK. Tests ported from the engine branch. --- pkg/consts/consts.go | 2 + pkg/provider/difftracker/azure_operations.go | 463 ++++++++++++++++++ pkg/provider/difftracker/dto_mappers.go | 80 +++ pkg/provider/difftracker/dto_mappers_test.go | 108 ++++ pkg/provider/difftracker/resource_helpers.go | 301 ++++++++++++ .../difftracker/resource_helpers_test.go | 443 +++++++++++++++++ pkg/provider/difftracker/types.go | 169 +++++++ 7 files changed, 1566 insertions(+) create mode 100644 pkg/provider/difftracker/azure_operations.go create mode 100644 pkg/provider/difftracker/dto_mappers.go create mode 100644 pkg/provider/difftracker/dto_mappers_test.go create mode 100644 pkg/provider/difftracker/resource_helpers.go create mode 100644 pkg/provider/difftracker/resource_helpers_test.go diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index e2305842be..7211b6ebc7 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -367,6 +367,8 @@ const ( FrontendIPConfigIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/frontendIPConfigurations/%s" // BackendPoolIDTemplate is the template of the backend pool BackendPoolIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s" + // NatGatewayIDTemplate is the template of the nat gateway + NatGatewayIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/natGateways/%s" // LoadBalancerProbeIDTemplate is the template of the load balancer probe LoadBalancerProbeIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/probes/%s" diff --git a/pkg/provider/difftracker/azure_operations.go b/pkg/provider/difftracker/azure_operations.go new file mode 100644 index 0000000000..6454b6f52a --- /dev/null +++ b/pkg/provider/difftracker/azure_operations.go @@ -0,0 +1,463 @@ +// Package difftracker provides state tracking and synchronization between Kubernetes +// resources and Azure Network Resource Provider (NRP) resources. +// +// Architecture: +// - DiffTracker: Maintains the desired state (K8s) and actual state (NRP) +// - ServiceUpdater: Handles service creation/deletion (LoadBalancers, NAT Gateways) +// - LocationsUpdater: Handles address/location synchronization +// +// Azure Operations: +// All Azure SDK operations in azure_operations.go attempt their action once and return +// errors immediately. Retry logic is the responsibility of the callers (ServiceUpdater, +// LocationsUpdater) which implement appropriate retry strategies based on their use cases. +// +// Error Handling: +// Transient errors (throttling, timeouts) should be retried by callers. +// Permanent errors (authentication, resource not found) should not be retried. +package difftracker + +import ( + "context" + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + servicehelper "k8s.io/cloud-provider/service/helpers" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" +) + +// createOrUpdatePIP creates or updates a public IP address +func (dt *DiffTracker) createOrUpdatePIP(ctx context.Context, pipResourceGroup string, pip *armnetwork.PublicIPAddress) error { + _, err := dt.createOrUpdatePIPWithResponse(ctx, pipResourceGroup, pip) + return err +} + +// createOrUpdatePIPWithResponse creates or updates a public IP address and returns the response +// containing the allocated IP address. Use this when you need the IP address after PIP creation. +func (dt *DiffTracker) createOrUpdatePIPWithResponse(ctx context.Context, pipResourceGroup string, pip *armnetwork.PublicIPAddress) (*armnetwork.PublicIPAddress, error) { + pipName := ptr.Deref(pip.Name, "") + if pipName == "" { + return nil, fmt.Errorf("createOrUpdatePIPWithResponse: pip name is empty") + } + klog.Infof("createOrUpdatePIPWithResponse(%s): start", pipName) + + response, err := dt.networkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(ctx, pipResourceGroup, pipName, *pip) + klog.V(10).Infof("PublicIPAddressClient.CreateOrUpdate(%s, %s): end", pipResourceGroup, pipName) + if err != nil { + klog.Warningf("PublicIPAddressClient.CreateOrUpdate(%s, %s) failed: %s", pipResourceGroup, pipName, err.Error()) + return nil, err + } + + return response, nil +} + +// deletePublicIP deletes a public IP address +func (dt *DiffTracker) deletePublicIP(ctx context.Context, pipResourceGroup string, pipName string) error { + if pipName == "" { + return fmt.Errorf("deletePublicIP: pipName is empty") + } + klog.Infof("deletePublicIP(%s): start", pipName) + + err := dt.networkClientFactory.GetPublicIPAddressClient().Delete(ctx, pipResourceGroup, pipName) + klog.Infof("deletePublicIP(%s): end, error: %v", pipName, err) + + if err != nil { + klog.Warningf("deletePublicIP(%s) failed: %s", pipName, err.Error()) + return err + } + + return nil +} + +// createOrUpdateLB creates or updates a load balancer +func (dt *DiffTracker) createOrUpdateLB(ctx context.Context, lb armnetwork.LoadBalancer) error { + lbName := ptr.Deref(lb.Name, "") + if lbName == "" { + return fmt.Errorf("createOrUpdateLB: load balancer name is empty") + } + klog.Infof("createOrUpdateLB(%s): start", lbName) + + _, err := dt.networkClientFactory.GetLoadBalancerClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, lbName, lb) + klog.V(10).Infof("LoadBalancerClient.CreateOrUpdate(%s): end", lbName) + if err != nil { + klog.Warningf("LoadBalancerClient.CreateOrUpdate(%s) failed: %v", lbName, err) + return err + } + + return nil +} + +// deleteLB deletes a load balancer by service UID +func (dt *DiffTracker) deleteLB(ctx context.Context, uid string) error { + // Normalize + uid = strings.ToLower(uid) + + // Try to retrieve the live Service + svc, err := dt.getServiceByUID(ctx, uid) + if err != nil { + // Service not found - try direct cleanup + klog.V(3).Infof("deleteLB: service uid %s not found, attempting direct LoadBalancer deletion", uid) + + // Delete the load balancer directly by name (uid is the LB name) + if err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid); err != nil { + klog.Warningf("deleteLB: failed to delete LoadBalancer %s directly: %v", uid, err) + return err + } + + klog.V(3).Infof("deleteLB: successfully deleted LoadBalancer %s directly", uid) + return nil + } + + // Service exists - use standard deletion + if err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid); err != nil { + return fmt.Errorf("deleteLB: failed to delete LoadBalancer for service %s/%s (uid=%s): %w", + svc.Namespace, svc.Name, uid, err) + } + + return nil +} + +// createOrUpdateNatGateway creates or updates a NAT gateway +func (dt *DiffTracker) createOrUpdateNatGateway(ctx context.Context, natGatewayResourceGroup string, natGateway armnetwork.NatGateway) error { + natGatewayName := ptr.Deref(natGateway.Name, "") + if natGatewayName == "" { + return fmt.Errorf("createOrUpdateNatGateway: NAT gateway name is empty") + } + klog.Infof("createOrUpdateNatGateway(%s): start", natGatewayName) + + _, err := dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, natGatewayResourceGroup, natGatewayName, natGateway) + if err != nil { + klog.Warningf("NatGatewayClient.CreateOrUpdate(%s) failed: %v", natGatewayName, err) + return err + } + + klog.V(10).Infof("NatGatewayClient.CreateOrUpdate(%s): success", natGatewayName) + klog.Infof("createOrUpdateNatGateway(%s): end, error: nil", natGatewayName) + return nil +} + +// deleteNatGateway deletes a NAT gateway +func (dt *DiffTracker) deleteNatGateway(ctx context.Context, natGatewayResourceGroup string, natGatewayName string) error { + if natGatewayName == "" { + return fmt.Errorf("deleteNatGateway: NAT gateway name is empty") + } + klog.Infof("deleteNatGateway(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) + + err := dt.networkClientFactory.GetNatGatewayClient().Delete(ctx, natGatewayResourceGroup, natGatewayName) + if err != nil { + klog.Errorf("NatGatewayClient.Delete(%s) in resource group %s failed: %v", natGatewayName, natGatewayResourceGroup, err) + return err + } + + klog.V(10).Infof("NatGatewayClient.Delete(%s) in resource group %s: success", natGatewayName, natGatewayResourceGroup) + klog.Infof("deleteNatGateway(%s) in resource group %s: end, error: nil", natGatewayName, natGatewayResourceGroup) + return nil +} + +// disassociateNatGatewayFromServiceGateway removes the NAT gateway association from the Service Gateway +// This should be called before deleting the NAT gateway to properly clean up the references +func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Context, serviceGatewayName string, natGatewayName string) error { + klog.Infof("disassociateNatGatewayFromServiceGateway: Disassociating NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) + + // Step 1: Get the service and remove the NAT gateway reference + services, err := dt.networkClientFactory.GetServiceGatewayClient().GetServices(ctx, dt.config.ResourceGroup, serviceGatewayName) + if err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to get Service Gateway %s: %v", serviceGatewayName, err) + return fmt.Errorf("failed to get Service Gateway services: %w", err) + } + + var serviceToBeUpdated *armnetwork.ServiceGatewayService + for _, service := range services { + if service.Name != nil && *service.Name == natGatewayName { + serviceToBeUpdated = service + break + } + } + + if serviceToBeUpdated == nil { + klog.Infof("disassociateNatGatewayFromServiceGateway: NAT Gateway %s is not associated with Service Gateway %s", natGatewayName, serviceGatewayName) + return nil + } + + // Remove the NAT gateway reference from the service + if serviceToBeUpdated.Properties != nil { + serviceToBeUpdated.Properties.PublicNatGatewayID = nil + } + + updateServicesRequest := armnetwork.ServiceGatewayUpdateServicesRequest{ + Action: ptr.To(armnetwork.ServiceUpdateActionPartialUpdate), + ServiceRequests: []*armnetwork.ServiceGatewayServiceRequest{ + { + IsDelete: ptr.To(false), + Service: serviceToBeUpdated, + }, + }, + } + + err = dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, updateServicesRequest) + if err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update Service Gateway %s to disassociate NAT Gateway %s: %v", serviceGatewayName, natGatewayName, err) + return fmt.Errorf("failed to update Service Gateway: %w", err) + } + klog.Infof("disassociateNatGatewayFromServiceGateway: Successfully removed NAT Gateway %s reference from Service Gateway %s", natGatewayName, serviceGatewayName) + + // Step 2: Get the NAT gateway and remove the service gateway reference + natGateway, err := dt.networkClientFactory.GetNatGatewayClient().Get(ctx, dt.config.ResourceGroup, natGatewayName, nil) + if err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to get NAT Gateway %s: %v", natGatewayName, err) + return fmt.Errorf("failed to get NAT Gateway: %w", err) + } + + if natGateway.Properties != nil { + natGateway.Properties.ServiceGateway = nil + } + + _, err = dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, natGatewayName, *natGateway) + if err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update NAT Gateway %s to remove Service Gateway reference: %v", natGatewayName, err) + return fmt.Errorf("failed to update NAT Gateway: %w", err) + } + + klog.Infof("disassociateNatGatewayFromServiceGateway: Successfully disassociated NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) + return nil +} + +// updateNRPSGWServices updates services in the Service Gateway +func (dt *DiffTracker) updateNRPSGWServices(ctx context.Context, serviceGatewayName string, updateServicesRequestDTO ServicesDataDTO) error { + // Early return if no services to update + if len(updateServicesRequestDTO.Services) == 0 { + klog.Infof("updateNRPSGWServices(%s): no services to update", serviceGatewayName) + return nil + } + + klog.Infof("updateNRPSGWServices(%s): start", serviceGatewayName) + + // Convert DTO to ARM SDK request + req := armnetwork.ServiceGatewayUpdateServicesRequest{ + Action: convertServicesUpdateActionToARM(updateServicesRequestDTO.Action), + ServiceRequests: convertServiceDTOsToServiceRequests(updateServicesRequestDTO.Services, dt.config), + } + + err := dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, req) + if err != nil { + klog.Warningf("ServiceGatewayClient.UpdateServices(%s) failed: %v", serviceGatewayName, err) + return err + } + + klog.V(10).Infof("ServiceGatewayClient.UpdateServices(%s): success", serviceGatewayName) + klog.Infof("updateNRPSGWServices(%s): end, error: nil", serviceGatewayName) + return nil +} + +// updateNRPSGWAddressLocations updates address locations in the Service Gateway +func (dt *DiffTracker) updateNRPSGWAddressLocations(ctx context.Context, serviceGatewayName string, locationsDTO LocationsDataDTO) error { + klog.Infof("updateNRPSGWAddressLocations(%s): start", serviceGatewayName) + + // Convert DTO to ARM SDK request + req := armnetwork.ServiceGatewayUpdateAddressLocationsRequest{ + Action: convertLocationsUpdateActionToARM(locationsDTO.Action), + AddressLocations: convertLocationDTOsToAddressLocations(locationsDTO.Locations), + } + + err := dt.networkClientFactory.GetServiceGatewayClient().UpdateAddressLocations(ctx, dt.config.ResourceGroup, serviceGatewayName, req) + if err != nil { + klog.Warningf("ServiceGatewayClient.UpdateAddressLocations(%s) failed: %v", serviceGatewayName, err) + return err + } + + klog.V(10).Infof("ServiceGatewayClient.UpdateAddressLocations(%s): success", serviceGatewayName) + klog.Infof("updateNRPSGWAddressLocations(%s): end, error: nil", serviceGatewayName) + return nil +} + +// getServiceByUID returns the Service whose UID matches the given uid +func (dt *DiffTracker) getServiceByUID(ctx context.Context, uid string) (*v1.Service, error) { + // list via client (could be expensive; acceptable for initialization) + svcList, err := dt.kubeClient.CoreV1().Services(v1.NamespaceAll).List(ctx, metav1.ListOptions{}) + if err != nil { + return nil, fmt.Errorf("getServiceByUID: list failed: %w", err) + } + for _, svc := range svcList.Items { + if string(svc.UID) == uid { + return svc.DeepCopy(), nil + } + } + return nil, fmt.Errorf("service with uid %s not found", uid) +} + +// updateServiceLoadBalancerStatus updates the K8s Service status with the LoadBalancer IP address. +// This is called after the PIP is successfully created in ServiceGateway mode to populate +// the Service.Status.LoadBalancer.Ingress field, which would otherwise be empty since +// EnsureLoadBalancer returns immediately in async mode. +func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serviceUID string, ip string) error { + if ip == "" { + return fmt.Errorf("updateServiceLoadBalancerStatus: ip is empty") + } + + svc, err := dt.getServiceByUID(ctx, serviceUID) + if err != nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: failed to get service: %w", err) + } + + // Check if the status already has this IP to avoid unnecessary updates + for _, ingress := range svc.Status.LoadBalancer.Ingress { + if ingress.IP == ip { + klog.V(3).Infof("updateServiceLoadBalancerStatus: service %s/%s already has IP %s", svc.Namespace, svc.Name, ip) + return nil + } + } + + // Make a copy so we don't mutate the shared informer cache + updated := svc.DeepCopy() + updated.Status.LoadBalancer = v1.LoadBalancerStatus{ + Ingress: []v1.LoadBalancerIngress{ + {IP: ip}, + }, + } + + _, err = servicehelper.PatchService(dt.kubeClient.CoreV1(), svc, updated) + if err != nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: failed to patch service: %w", err) + } + + klog.V(2).Infof("updateServiceLoadBalancerStatus: updated service %s/%s with LoadBalancer IP %s", svc.Namespace, svc.Name, ip) + return nil +} + +// Helper functions to convert DTOs to ARM SDK types + +func convertServicesUpdateActionToARM(action UpdateAction) *armnetwork.ServiceUpdateAction { + switch action { + case PartialUpdate: + return ptr.To(armnetwork.ServiceUpdateActionPartialUpdate) + case FullUpdate: + return ptr.To(armnetwork.ServiceUpdateActionFullUpdate) + default: + return ptr.To(armnetwork.ServiceUpdateActionPartialUpdate) + } +} + +func convertLocationsUpdateActionToARM(action UpdateAction) *armnetwork.UpdateAction { + switch action { + case PartialUpdate: + return ptr.To(armnetwork.UpdateActionPartialUpdate) + case FullUpdate: + return ptr.To(armnetwork.UpdateActionFullUpdate) + default: + return ptr.To(armnetwork.UpdateActionPartialUpdate) + } +} + +// extractResourceChildName extracts a child resource name from an Azure resource ID +func extractResourceChildName(id, segment string) string { + if id == "" { + return "" + } + parts := strings.Split(id, "/") + // Look for the explicit segment first + for i := 0; i < len(parts)-1; i++ { + if strings.EqualFold(parts[i], segment) && parts[i+1] != "" { + return parts[i+1] + } + } + // Fallback: last non-empty + for i := len(parts) - 1; i >= 0; i-- { + if parts[i] != "" { + return parts[i] + } + } + return "" +} + +func convertServiceDTOsToServiceRequests(services []ServiceDTO, config Config) []*armnetwork.ServiceGatewayServiceRequest { + serviceRequests := make([]*armnetwork.ServiceGatewayServiceRequest, 0, len(services)) + + // Construct VNet resource ID once for all backend pools + vnetID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/virtualNetworks/%s", + config.SubscriptionID, config.ResourceGroup, config.VNetName) + + for _, svc := range services { + // Build backend pools with full details + loadBalancerBackendPools := make([]*armnetwork.BackendAddressPool, len(svc.LoadBalancerBackendPools)) + for i := range svc.LoadBalancerBackendPools { + backendPoolResourceID := svc.LoadBalancerBackendPools[i].Id + backendPoolName := extractResourceChildName(backendPoolResourceID, "backendAddressPools") + loadBalancerBackendPools[i] = &armnetwork.BackendAddressPool{ + ID: &backendPoolResourceID, + Name: &backendPoolName, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ + ID: &vnetID, + }, + }, + } + } + + // Set service type and NAT gateway based on service type + var serviceType armnetwork.ServiceType + var publicNatGatewayID *string + switch svc.ServiceType { + case Inbound: + serviceType = armnetwork.ServiceTypeInbound + publicNatGatewayID = nil + case Outbound: + serviceType = armnetwork.ServiceTypeOutbound + publicNatGatewayID = &svc.PublicNatGateway.Id + } + + // Build and append the service request + serviceRequests = append(serviceRequests, &armnetwork.ServiceGatewayServiceRequest{ + IsDelete: &svc.IsDelete, + Service: &armnetwork.ServiceGatewayService{ + Name: &svc.Service, + Properties: &armnetwork.ServiceGatewayServicePropertiesFormat{ + LoadBalancerBackendPools: loadBalancerBackendPools, + PublicNatGatewayID: publicNatGatewayID, + ServiceType: &serviceType, + }, + }, + }) + } + return serviceRequests +} + +func convertLocationDTOsToAddressLocations(locations []LocationDTO) []*armnetwork.ServiceGatewayAddressLocation { + armLocations := make([]*armnetwork.ServiceGatewayAddressLocation, 0, len(locations)) + for _, loc := range locations { + armLoc := &armnetwork.ServiceGatewayAddressLocation{ + AddressLocation: ptr.To(loc.Location), + } + + // Set address update action + switch loc.AddressUpdateAction { + case PartialUpdate: + armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionPartialUpdate) + case FullUpdate: + armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionFullUpdate) + } + + // Convert addresses - always initialize the slice to avoid null in JSON + armLoc.Addresses = make([]*armnetwork.ServiceGatewayAddress, 0, len(loc.Addresses)) + for _, addr := range loc.Addresses { + armAddr := &armnetwork.ServiceGatewayAddress{ + Address: ptr.To(addr.Address), + } + + // Convert service names - always initialize the slice to avoid null in JSON + armAddr.Services = make([]*string, 0, addr.ServiceNames.Len()) + if addr.ServiceNames != nil && addr.ServiceNames.Len() > 0 { + for _, svcName := range addr.ServiceNames.UnsortedList() { + armAddr.Services = append(armAddr.Services, ptr.To(svcName)) + } + } + + armLoc.Addresses = append(armLoc.Addresses, armAddr) + } + + armLocations = append(armLocations, armLoc) + } + return armLocations +} diff --git a/pkg/provider/difftracker/dto_mappers.go b/pkg/provider/difftracker/dto_mappers.go new file mode 100644 index 0000000000..0efc1779bf --- /dev/null +++ b/pkg/provider/difftracker/dto_mappers.go @@ -0,0 +1,80 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "fmt" + + "sigs.k8s.io/cloud-provider-azure/pkg/consts" +) + +func MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(loadBalancerUpdates SyncServicesReturnType, natGatewayUpdates SyncServicesReturnType, subscriptionID string, resourceGroup string) ServicesDataDTO { + var ServicesDataDTO ServicesDataDTO + ServicesDataDTO.Action = PartialUpdate + ServicesDataDTO.Services = []ServiceDTO{} + for _, service := range loadBalancerUpdates.Additions.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + ServiceType: Inbound, + LoadBalancerBackendPools: []LoadBalancerBackendPoolDTO{ + { + Id: fmt.Sprintf( + consts.BackendPoolIDTemplate, + subscriptionID, + resourceGroup, + service, + service, + // fmt.Sprintf("%s-backendpool", service), + ), + }, + }, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + for _, service := range loadBalancerUpdates.Removals.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + IsDelete: true, + ServiceType: Inbound, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + for _, service := range natGatewayUpdates.Additions.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + ServiceType: Outbound, + PublicNatGateway: NatGatewayDTO{ + Id: fmt.Sprintf( + consts.NatGatewayIDTemplate, + subscriptionID, + resourceGroup, + service, + ), + }, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + for _, service := range natGatewayUpdates.Removals.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + IsDelete: true, + ServiceType: Outbound, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + return ServicesDataDTO +} diff --git a/pkg/provider/difftracker/dto_mappers_test.go b/pkg/provider/difftracker/dto_mappers_test.go new file mode 100644 index 0000000000..e6dce7807d --- /dev/null +++ b/pkg/provider/difftracker/dto_mappers_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func makeInboundConfig(frontendPorts ...int32) *InboundConfig { + cfg := &InboundConfig{} + for _, p := range frontendPorts { + cfg.FrontendPorts = append(cfg.FrontendPorts, PortMapping{Port: p, Protocol: "TCP"}) + cfg.BackendPorts = append(cfg.BackendPorts, PortMapping{Port: p, Protocol: "TCP"}) + } + return cfg +} + +func TestInboundConfig_Equals(t *testing.T) { + a := makeInboundConfig(80, 443) + b := makeInboundConfig(80, 443) + assert.True(t, a.Equals(b), "identical configs should be equal") + + c := makeInboundConfig(80, 8080) + assert.False(t, a.Equals(c), "different ports should be unequal") + + d := makeInboundConfig(443, 80) + assert.False(t, a.Equals(d), "ordered comparison: reversed ports must be unequal") + + assert.True(t, (*InboundConfig)(nil).Equals(nil), "nil-nil equal") + assert.False(t, a.Equals(nil), "non-nil vs nil unequal") +} + +func TestMapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(t *testing.T) { + tests := []struct { + name string + lbUpdates SyncServicesReturnType + natUpdates SyncServicesReturnType + expectedLen int + }{ + { + name: "only inbound additions", + lbUpdates: SyncServicesReturnType{ + Additions: sets.NewString("svc1", "svc2"), + }, + natUpdates: SyncServicesReturnType{}, + expectedLen: 2, + }, + { + name: "only outbound additions", + lbUpdates: SyncServicesReturnType{}, + natUpdates: SyncServicesReturnType{ + Additions: sets.NewString("egress1"), + }, + expectedLen: 1, + }, + { + name: "mixed additions", + lbUpdates: SyncServicesReturnType{ + Additions: sets.NewString("svc1"), + }, + natUpdates: SyncServicesReturnType{ + Additions: sets.NewString("egress1"), + }, + expectedLen: 2, + }, + { + name: "removals only", + lbUpdates: SyncServicesReturnType{ + Removals: sets.NewString("svc1"), + }, + natUpdates: SyncServicesReturnType{ + Removals: sets.NewString("egress1"), + }, + expectedLen: 2, + }, + { + name: "empty updates", + lbUpdates: SyncServicesReturnType{}, + natUpdates: SyncServicesReturnType{}, + expectedLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(tt.lbUpdates, tt.natUpdates, "sub1", "rg1") + assert.Equal(t, tt.expectedLen, len(result.Services)) + }) + } +} diff --git a/pkg/provider/difftracker/resource_helpers.go b/pkg/provider/difftracker/resource_helpers.go new file mode 100644 index 0000000000..ce4715995b --- /dev/null +++ b/pkg/provider/difftracker/resource_helpers.go @@ -0,0 +1,301 @@ +package difftracker + +import ( + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + v1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" + "sigs.k8s.io/cloud-provider-azure/pkg/consts" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// buildInboundServiceResources constructs the PIP, LoadBalancer, and ServicesDTO for an inbound service +// Returns the resources ready to be created via createOrUpdatePIP/createOrUpdateLB/updateNRPSGWServices +func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtConfig Config) ( + pip armnetwork.PublicIPAddress, + lb armnetwork.LoadBalancer, + servicesDTO ServicesDataDTO, +) { + pipName := fmt.Sprintf("%s-pip", serviceUID) + + // Build Public IP + pip = armnetwork.PublicIPAddress{ + Name: to.Ptr(pipName), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandardV2), + }, + Location: to.Ptr(dtConfig.Location), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + }, + } + + // Build LoadBalancer with backend pool and rules + backendPoolName := serviceUID + frontendIPConfigID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/frontendIPConfigurations/frontend", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID) + backendPoolID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID, backendPoolName) + + // Build backend pool + backendPools := []*armnetwork.BackendAddressPool{ + { + Name: to.Ptr(backendPoolName), + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + // Backend pool will be populated by ServiceGateway with pod IPs + }, + }, + } + + // Build LB rules and probes from config + var lbRules []*armnetwork.LoadBalancingRule + var probes []*armnetwork.Probe + + if config != nil && len(config.FrontendPorts) > 0 { + // For SLB with PodIP backend pool, we disable floating IP and don't create health probes + // Traffic goes directly to pod IPs on the backend port + for i, frontendPort := range config.FrontendPorts { + backendPort := frontendPort.Port + if i < len(config.BackendPorts) { + backendPort = config.BackendPorts[i].Port + } + + protocol := armnetwork.TransportProtocolTCP + if frontendPort.Protocol == "UDP" { + protocol = armnetwork.TransportProtocolUDP + } + + ruleName := fmt.Sprintf("rule-%s-%d", strings.ToLower(frontendPort.Protocol), frontendPort.Port) + + lbRules = append(lbRules, &armnetwork.LoadBalancingRule{ + Name: to.Ptr(ruleName), + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + FrontendPort: to.Ptr(frontendPort.Port), + BackendPort: to.Ptr(backendPort), + EnableFloatingIP: to.Ptr(false), // Disabled for PodIP backend + IdleTimeoutInMinutes: to.Ptr(int32(4)), + EnableTCPReset: to.Ptr(true), + FrontendIPConfiguration: &armnetwork.SubResource{ + ID: to.Ptr(frontendIPConfigID), + }, + BackendAddressPool: &armnetwork.SubResource{ + ID: to.Ptr(backendPoolID), + }, + // No probe for PodIP backend pools + }, + }) + + klog.V(4).Infof("buildInboundServiceResources: created LB rule %s: frontend=%d backend=%d protocol=%s for service %s", + ruleName, frontendPort.Port, backendPort, frontendPort.Protocol, serviceUID) + } + } else { + klog.V(2).Infof("buildInboundServiceResources: no port configuration provided for service %s, creating LB without rules", serviceUID) + } + + lb = armnetwork.LoadBalancer{ + Name: to.Ptr(serviceUID), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s", dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID)), + Location: to.Ptr(dtConfig.Location), + SKU: &armnetwork.LoadBalancerSKU{ + Name: to.Ptr(armnetwork.LoadBalancerSKUName(consts.LoadBalancerSKUService)), + }, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Scope: to.Ptr(armnetwork.LoadBalancerScopePublic), + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ + { + Name: to.Ptr("frontend"), + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + }, + }, + }, + }, + BackendAddressPools: backendPools, + LoadBalancingRules: lbRules, + Probes: probes, + }, + } + + // Build ServicesDTO for ServiceGateway registration + servicesDTO = MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: newIgnoreCaseSetFromSlice([]string{serviceUID}), + Removals: nil, + }, + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) + + return pip, lb, servicesDTO +} + +// buildOutboundServiceResources constructs the PIP, NAT Gateway, and ServicesDTO for an outbound service +// Returns the resources ready to be created via createOrUpdatePIP/createOrUpdateNatGateway/updateNRPSGWServices +func buildOutboundServiceResources(serviceUID string, config *OutboundConfig, dtConfig Config) ( + pip armnetwork.PublicIPAddress, + natGateway armnetwork.NatGateway, + servicesDTO ServicesDataDTO, +) { + pipName := fmt.Sprintf("%s-pip", serviceUID) + + // Build Public IP + pip = armnetwork.PublicIPAddress{ + Name: to.Ptr(pipName), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandardV2), + }, + Location: to.Ptr(dtConfig.Location), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + }, + } + + // Build NAT Gateway + natGateway = armnetwork.NatGateway{ + Name: to.Ptr(serviceUID), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/natGateways/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID)), + SKU: &armnetwork.NatGatewaySKU{ + Name: to.Ptr(armnetwork.NatGatewaySKUNameStandardV2), + }, + Location: to.Ptr(dtConfig.Location), + Properties: &armnetwork.NatGatewayPropertiesFormat{ + ServiceGateway: &armnetwork.SubResource{ + ID: to.Ptr(dtConfig.ServiceGatewayID), + }, + PublicIPAddresses: []*armnetwork.SubResource{ + { + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + }, + }, + }, + } + + // Build ServicesDTO for ServiceGateway registration + servicesDTO = MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + SyncServicesReturnType{ + Additions: newIgnoreCaseSetFromSlice([]string{serviceUID}), + Removals: nil, + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) + + return pip, natGateway, servicesDTO +} + +// newIgnoreCaseSetFromSlice creates an IgnoreCaseSet from a slice of strings +func newIgnoreCaseSetFromSlice(items []string) *utilsets.IgnoreCaseSet { + set := utilsets.NewString() + for _, item := range items { + set.Insert(item) + } + return set +} + +// ExtractInboundConfigFromService creates InboundConfig from a Kubernetes Service +// This is shared between initialization and the provider layer +func ExtractInboundConfigFromService(service *v1.Service) *InboundConfig { + if service == nil || len(service.Spec.Ports) == 0 { + return nil + } + + config := &InboundConfig{ + FrontendPorts: make([]PortMapping, 0, len(service.Spec.Ports)), + BackendPorts: make([]PortMapping, 0, len(service.Spec.Ports)), + } + + // Extract port mappings from service + for _, port := range service.Spec.Ports { + protocol := string(port.Protocol) + if protocol == "" { + protocol = "TCP" + } + + // Frontend port (service port) + config.FrontendPorts = append(config.FrontendPorts, PortMapping{ + Port: port.Port, + Protocol: protocol, + }) + + // Backend port (target port) + // For SLB with PodIP backend, we use TargetPort + // If TargetPort is not specified, default to Port + backendPort := port.Port + if port.TargetPort.Type == 0 && port.TargetPort.IntVal > 0 { // intstr.Int + backendPort = port.TargetPort.IntVal + } else if port.TargetPort.Type == 1 { // intstr.String + // Named ports not supported in SLB mode - use Port as fallback + serviceName := fmt.Sprintf("%s/%s", service.Namespace, service.Name) + klog.V(2).Infof("Named targetPort %s not supported in SLB mode for service %s, using Port %d", + port.TargetPort.StrVal, serviceName, port.Port) + backendPort = port.Port + } + + config.BackendPorts = append(config.BackendPorts, PortMapping{ + Port: backendPort, + Protocol: protocol, + }) + } + + return config +} + +// buildInboundResourceNames returns the resource names for an inbound service +func buildInboundResourceNames(serviceUID string) (lbName string, pipName string, backendPoolName string) { + return serviceUID, fmt.Sprintf("%s-pip", serviceUID), serviceUID +} + +// buildOutboundResourceNames returns the resource names for an outbound service +func buildOutboundResourceNames(serviceUID string) (natGatewayName string, pipName string) { + return serviceUID, fmt.Sprintf("%s-pip", serviceUID) +} + +// buildServiceGatewayRemovalDTO creates a ServicesDTO for removing a service from ServiceGateway +func buildServiceGatewayRemovalDTO(serviceUID string, isInbound bool, dtConfig Config) ServicesDataDTO { + if isInbound { + return MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: nil, + Removals: newIgnoreCaseSetFromSlice([]string{serviceUID}), + }, + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) + } + return MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + SyncServicesReturnType{ + Additions: nil, + Removals: newIgnoreCaseSetFromSlice([]string{serviceUID}), + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) +} diff --git a/pkg/provider/difftracker/resource_helpers_test.go b/pkg/provider/difftracker/resource_helpers_test.go new file mode 100644 index 0000000000..d8ad99d90d --- /dev/null +++ b/pkg/provider/difftracker/resource_helpers_test.go @@ -0,0 +1,443 @@ +package difftracker + +import ( + "testing" + + "sigs.k8s.io/cloud-provider-azure/pkg/consts" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" +) + +func TestExtractInboundConfigFromService_NilService(t *testing.T) { + config := ExtractInboundConfigFromService(nil) + assert.Nil(t, config) +} + +func TestExtractInboundConfigFromService_EmptyPorts(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{}, + }, + } + config := ExtractInboundConfigFromService(service) + assert.Nil(t, config) +} + +func TestExtractInboundConfigFromService_SingleTCPPort(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt(8080), + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + assert.Len(t, config.FrontendPorts, 1) + assert.Len(t, config.BackendPorts, 1) + + // Check frontend port + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, "TCP", config.FrontendPorts[0].Protocol) + + // Check backend port (should be TargetPort) + assert.Equal(t, int32(8080), config.BackendPorts[0].Port) + assert.Equal(t, "TCP", config.BackendPorts[0].Protocol) +} + +func TestExtractInboundConfigFromService_MultiplePortsWithUDP(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt(8080), + }, + { + Name: "dns", + Protocol: v1.ProtocolUDP, + Port: 53, + TargetPort: intstr.FromInt(5353), + }, + { + Name: "https", + Protocol: v1.ProtocolTCP, + Port: 443, + TargetPort: intstr.FromInt(8443), + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + assert.Len(t, config.FrontendPorts, 3) + assert.Len(t, config.BackendPorts, 3) + + // Verify HTTP + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, "TCP", config.FrontendPorts[0].Protocol) + assert.Equal(t, int32(8080), config.BackendPorts[0].Port) + + // Verify DNS (UDP) + assert.Equal(t, int32(53), config.FrontendPorts[1].Port) + assert.Equal(t, "UDP", config.FrontendPorts[1].Protocol) + assert.Equal(t, int32(5353), config.BackendPorts[1].Port) + + // Verify HTTPS + assert.Equal(t, int32(443), config.FrontendPorts[2].Port) + assert.Equal(t, "TCP", config.FrontendPorts[2].Protocol) + assert.Equal(t, int32(8443), config.BackendPorts[2].Port) +} + +func TestExtractInboundConfigFromService_NoTargetPort(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + // TargetPort not specified + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + + // When TargetPort is not specified, backend port should equal frontend port + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, int32(80), config.BackendPorts[0].Port) +} + +func TestExtractInboundConfigFromService_NamedTargetPort(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromString("http-port"), // Named port + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + + // Named ports should fall back to Port + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, int32(80), config.BackendPorts[0].Port) +} + +func TestExtractInboundConfigFromService_EmptyProtocol(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Port: 80, + // Protocol not specified + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + + // Default protocol should be TCP + assert.Equal(t, "TCP", config.FrontendPorts[0].Protocol) + assert.Equal(t, "TCP", config.BackendPorts[0].Protocol) +} + +func TestBuildInboundServiceResources_WithConfig(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{ + {Port: 80, Protocol: "TCP"}, + {Port: 443, Protocol: "TCP"}, + }, + BackendPorts: []PortMapping{ + {Port: 8080, Protocol: "TCP"}, + {Port: 8443, Protocol: "TCP"}, + }, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + pip, lb, servicesDTO := buildInboundServiceResources("service-uid-123", config, dtConfig) + + // Verify PIP + assert.NotNil(t, pip.Name) + assert.Equal(t, "service-uid-123-pip", *pip.Name) + assert.Equal(t, armnetwork.PublicIPAddressSKUNameStandardV2, *pip.SKU.Name) + assert.Equal(t, "eastus", *pip.Location) + + // Verify LoadBalancer + assert.NotNil(t, lb.Name) + assert.Equal(t, "service-uid-123", *lb.Name) + assert.Equal(t, armnetwork.LoadBalancerSKUName(consts.LoadBalancerSKUService), *lb.SKU.Name) + assert.Equal(t, "eastus", *lb.Location) + + // Verify backend pool + assert.Len(t, lb.Properties.BackendAddressPools, 1) + assert.Equal(t, "service-uid-123", *lb.Properties.BackendAddressPools[0].Name) + + // Verify LB rules + assert.Len(t, lb.Properties.LoadBalancingRules, 2) + + // Rule 1: port 80 -> 8080 + rule1 := lb.Properties.LoadBalancingRules[0] + assert.Equal(t, "rule-tcp-80", *rule1.Name) + assert.Equal(t, armnetwork.TransportProtocolTCP, *rule1.Properties.Protocol) + assert.Equal(t, int32(80), *rule1.Properties.FrontendPort) + assert.Equal(t, int32(8080), *rule1.Properties.BackendPort) + assert.False(t, *rule1.Properties.EnableFloatingIP) + + // Rule 2: port 443 -> 8443 + rule2 := lb.Properties.LoadBalancingRules[1] + assert.Equal(t, "rule-tcp-443", *rule2.Name) + assert.Equal(t, int32(443), *rule2.Properties.FrontendPort) + assert.Equal(t, int32(8443), *rule2.Properties.BackendPort) + + // Verify ServicesDTO + assert.Len(t, servicesDTO.Services, 1) + assert.Contains(t, servicesDTO.Services[0].Service, "service-uid-123") + assert.Equal(t, Inbound, servicesDTO.Services[0].ServiceType) +} + +func TestBuildInboundServiceResources_NilConfig(t *testing.T) { + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + pip, lb, servicesDTO := buildInboundServiceResources("service-uid-123", nil, dtConfig) + + // Should still create LB, just without rules + assert.NotNil(t, lb.Name) + assert.Equal(t, "service-uid-123", *lb.Name) + + // Should have backend pool but no rules + assert.Len(t, lb.Properties.BackendAddressPools, 1) + assert.Empty(t, lb.Properties.LoadBalancingRules) + + // PIP should still be created + assert.NotNil(t, pip.Name) + + // ServicesDTO should still be valid + assert.Len(t, servicesDTO.Services, 1) + assert.Equal(t, Inbound, servicesDTO.Services[0].ServiceType) +} + +func TestBuildInboundServiceResources_UDPProtocol(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{ + {Port: 53, Protocol: "UDP"}, + }, + BackendPorts: []PortMapping{ + {Port: 5353, Protocol: "UDP"}, + }, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "westus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + _, lb, _ := buildInboundServiceResources("service-uid-udp", config, dtConfig) + + // Verify UDP rule + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + rule := lb.Properties.LoadBalancingRules[0] + assert.Equal(t, "rule-udp-53", *rule.Name) + assert.Equal(t, armnetwork.TransportProtocolUDP, *rule.Properties.Protocol) + assert.Equal(t, int32(53), *rule.Properties.FrontendPort) + assert.Equal(t, int32(5353), *rule.Properties.BackendPort) +} + +func TestBuildOutboundServiceResources_Basic(t *testing.T) { + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "centralus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + pip, natGw, servicesDTO := buildOutboundServiceResources("egress-uid-456", nil, dtConfig) + + // Verify PIP + assert.NotNil(t, pip.Name) + assert.Equal(t, "egress-uid-456-pip", *pip.Name) + assert.Equal(t, armnetwork.PublicIPAddressSKUNameStandardV2, *pip.SKU.Name) + assert.Equal(t, "centralus", *pip.Location) + + // Verify NAT Gateway + assert.NotNil(t, natGw.Name) + assert.Equal(t, "egress-uid-456", *natGw.Name) + assert.Equal(t, armnetwork.NatGatewaySKUNameStandardV2, *natGw.SKU.Name) + assert.Equal(t, "centralus", *natGw.Location) + + // Verify NAT Gateway has ServiceGateway reference + assert.NotNil(t, natGw.Properties.ServiceGateway) + assert.Equal(t, dtConfig.ServiceGatewayID, *natGw.Properties.ServiceGateway.ID) + + // Verify NAT Gateway has PIP reference + assert.Len(t, natGw.Properties.PublicIPAddresses, 1) + assert.Contains(t, *natGw.Properties.PublicIPAddresses[0].ID, "egress-uid-456-pip") + + // Verify ServicesDTO + assert.Len(t, servicesDTO.Services, 1) + assert.Contains(t, servicesDTO.Services[0].Service, "egress-uid-456") + assert.Equal(t, Outbound, servicesDTO.Services[0].ServiceType) +} + +func TestNewIgnoreCaseSetFromSlice_Empty(t *testing.T) { + set := newIgnoreCaseSetFromSlice([]string{}) + assert.NotNil(t, set) + assert.Equal(t, 0, set.Len()) +} + +func TestNewIgnoreCaseSetFromSlice_WithItems(t *testing.T) { + items := []string{"service1", "service2", "SERVICE3"} + set := newIgnoreCaseSetFromSlice(items) + + assert.Equal(t, 3, set.Len()) + assert.True(t, set.Has("service1")) + assert.True(t, set.Has("service2")) + assert.True(t, set.Has("service3")) // Case insensitive + assert.True(t, set.Has("SERVICE3")) +} + +func TestBuildInboundServiceResources_BackendPoolNaming(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + _, lb, _ := buildInboundServiceResources("my-service-uid", config, dtConfig) + + // Backend pool name must match serviceUID for SLB mode + assert.Len(t, lb.Properties.BackendAddressPools, 1) + backendPool := lb.Properties.BackendAddressPools[0] + assert.Equal(t, "my-service-uid", *backendPool.Name) + + // LB rule should reference the correct backend pool + rule := lb.Properties.LoadBalancingRules[0] + assert.Contains(t, *rule.Properties.BackendAddressPool.ID, "my-service-uid") +} + +func TestBuildInboundServiceResources_NoProbesForPodIPBackend(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + _, lb, _ := buildInboundServiceResources("service-uid", config, dtConfig) + + // For PodIP backend pools, no health probes should be created + assert.Empty(t, lb.Properties.Probes) + + // LB rules should have no probe reference + rule := lb.Properties.LoadBalancingRules[0] + assert.Nil(t, rule.Properties.Probe) +} + +func TestBuildInboundServiceResources_ResourceIDs(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + + dtConfig := Config{ + SubscriptionID: "sub-123", + ResourceGroup: "rg-456", + Location: "eastus", + ServiceGatewayResourceName: "sgw-789", + ServiceGatewayID: "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/serviceGateways/sgw-789", + } + + pip, lb, _ := buildInboundServiceResources("svc-abc", config, dtConfig) + + // Verify PIP ID format + expectedPIPID := "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/publicIPAddresses/svc-abc-pip" + assert.Equal(t, expectedPIPID, *pip.ID) + + // Verify LB references PIP correctly + frontendConfig := lb.Properties.FrontendIPConfigurations[0] + assert.Equal(t, expectedPIPID, *frontendConfig.Properties.PublicIPAddress.ID) + + // Verify backend pool ID reference in rule + rule := lb.Properties.LoadBalancingRules[0] + expectedBackendPoolID := "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/loadBalancers/svc-abc/backendAddressPools/svc-abc" + assert.Equal(t, expectedBackendPoolID, *rule.Properties.BackendAddressPool.ID) +} diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index d99cb0ead8..4b014f1f2a 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -189,3 +189,172 @@ type SyncDiffTrackerReturnType struct { NATGatewayUpdates SyncServicesReturnType LocationData LocationData } + +// ================================================================================================ +// CP2: Service configuration types + ServiceGateway DTOs (deferred from CP1) +// ================================================================================================ + +// PortMapping represents a port mapping configuration +type PortMapping struct { + Port int32 + Protocol string // TCP or UDP +} + +// HealthProbeConfig represents health probe configuration +type HealthProbeConfig struct { + Protocol string // TCP, HTTP, or HTTPS + Port int32 + IntervalInSeconds int32 + NumberOfProbes int32 + RequestPath *string // For HTTP/HTTPS probes +} + +// InboundConfig contains Load Balancer configuration for inbound services +type InboundConfig struct { + FrontendPorts []PortMapping // nullable for future use + BackendPorts []PortMapping // nullable for future use + Protocol *string // TCP/UDP, nullable + IdleTimeoutMinutes *int32 // nullable + SessionPersistence *string // nullable + HealthProbe *HealthProbeConfig // nullable +} + +// OutboundConfig contains NAT Gateway configuration for outbound services +type OutboundConfig struct { + // Placeholder for future NAT Gateway options +} + +// Equals returns true if two InboundConfigs describe the same desired LB shape. +// Used by UpdateService to short-circuit no-op reconciles. +// Comparison is order-sensitive for FrontendPorts/BackendPorts because the +// position of a port determines its pairing with a backend port in +// buildInboundServiceResources. +func (c *InboundConfig) Equals(other *InboundConfig) bool { + if c == nil && other == nil { + return true + } + if c == nil || other == nil { + return false + } + if !portMappingsEqual(c.FrontendPorts, other.FrontendPorts) { + return false + } + if !portMappingsEqual(c.BackendPorts, other.BackendPorts) { + return false + } + if !strPtrEqual(c.Protocol, other.Protocol) { + return false + } + if !int32PtrEqual(c.IdleTimeoutMinutes, other.IdleTimeoutMinutes) { + return false + } + if !strPtrEqual(c.SessionPersistence, other.SessionPersistence) { + return false + } + if !healthProbeEqual(c.HealthProbe, other.HealthProbe) { + return false + } + return true +} + +func portMappingsEqual(a, b []PortMapping) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func strPtrEqual(a, b *string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func int32PtrEqual(a, b *int32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func healthProbeEqual(a, b *HealthProbeConfig) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + if a.Protocol != b.Protocol || a.Port != b.Port || + a.IntervalInSeconds != b.IntervalInSeconds || a.NumberOfProbes != b.NumberOfProbes { + return false + } + return strPtrEqual(a.RequestPath, b.RequestPath) +} + + +// -------------------------------------------------------------------------------- +// Data Transfer Objects (DTOs) for LocationData (following the ServiceGateway API documentation) +// -------------------------------------------------------------------------------- + +// AddressDTO represents the DTO for Address +type AddressDTO struct { + Address string `json:"Address"` + ServiceNames *utilsets.IgnoreCaseSet `json:"ServiceNames"` +} + +// LocationDTO represents the DTO for Location +type LocationDTO struct { + Location string `json:"Location"` + AddressUpdateAction UpdateAction `json:"AddressUpdateAction"` + Addresses []AddressDTO `json:"Addresses"` +} + +// LocationsDataDTO represents the DTO for LocationData +type LocationsDataDTO struct { + Action UpdateAction `json:"Action"` + Locations []LocationDTO `json:"Locations"` +} + +// ================================================================================================ +// Data Transfer Objects (DTOs) for ServiceData (following the ServiceGateway API documentation) +// ================================================================================================ + +type ServiceType string + +const ( + Inbound ServiceType = "Inbound" + Outbound ServiceType = "Outbound" +) + +type LoadBalancerBackendPoolDTO struct { + Id string `json:"Id"` +} + +type NatGatewayDTO struct { + Id string `json:"Id"` +} + +type ServiceDTO struct { + Service string `json:"Service"` + ServiceType ServiceType `json:"ServiceType"` + LoadBalancerBackendPools []LoadBalancerBackendPoolDTO `json:"LoadBalancerBackendPools"` + PublicNatGateway NatGatewayDTO `json:"PublicNatGateway"` + IsDelete bool +} + +type ServicesDataDTO struct { + Action UpdateAction `json:"Action"` + Services []ServiceDTO `json:"Services"` +} From e6032854462f1798453ecaacd0d5ca709a2b9baf Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Fri, 19 Jun 2026 11:47:26 +0000 Subject: [PATCH 15/18] CP2: harden ServiceGateway Azure operations layer Address review findings in the Azure operations layer: - Idempotent deletes (treat 404 as success); delete LB by UID without a cluster-wide Service list. - Service request mapping: no empty PublicNatGatewayID on removals, error on unknown ServiceType, allow empty FullUpdate, warn on unknown update action. - LB rules: only set EnableTCPReset on TCP, case-insensitive protocol with rejection of unsupported ones, frontend/backend port range validation, apply configured idle timeout. - Reject named targetPort instead of silently misrouting. - Preserve dual-stack/hostname ingress when programming the LB IP and retry on conflict. - Make NAT gateway disassociation idempotent and independent per side. - Nil-guard createOrUpdatePIPWithResponse; return "" for missing ID segments. --- pkg/provider/difftracker/azure_operations.go | 205 ++++++++++-------- .../difftracker/azure_operations_test.go | 128 +++++++++++ pkg/provider/difftracker/resource_helpers.go | 83 ++++--- .../difftracker/resource_helpers_test.go | 115 +++++++++- pkg/provider/difftracker/types.go | 1 - 5 files changed, 396 insertions(+), 136 deletions(-) create mode 100644 pkg/provider/difftracker/azure_operations_test.go diff --git a/pkg/provider/difftracker/azure_operations.go b/pkg/provider/difftracker/azure_operations.go index 6454b6f52a..db689d05db 100644 --- a/pkg/provider/difftracker/azure_operations.go +++ b/pkg/provider/difftracker/azure_operations.go @@ -19,14 +19,18 @@ package difftracker import ( "context" "fmt" + "net" "strings" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/util/retry" servicehelper "k8s.io/cloud-provider/service/helpers" "k8s.io/klog/v2" "k8s.io/utils/ptr" + + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" ) // createOrUpdatePIP creates or updates a public IP address @@ -38,6 +42,9 @@ func (dt *DiffTracker) createOrUpdatePIP(ctx context.Context, pipResourceGroup s // createOrUpdatePIPWithResponse creates or updates a public IP address and returns the response // containing the allocated IP address. Use this when you need the IP address after PIP creation. func (dt *DiffTracker) createOrUpdatePIPWithResponse(ctx context.Context, pipResourceGroup string, pip *armnetwork.PublicIPAddress) (*armnetwork.PublicIPAddress, error) { + if pip == nil { + return nil, fmt.Errorf("createOrUpdatePIPWithResponse: pip is nil") + } pipName := ptr.Deref(pip.Name, "") if pipName == "" { return nil, fmt.Errorf("createOrUpdatePIPWithResponse: pip name is empty") @@ -64,7 +71,7 @@ func (dt *DiffTracker) deletePublicIP(ctx context.Context, pipResourceGroup stri err := dt.networkClientFactory.GetPublicIPAddressClient().Delete(ctx, pipResourceGroup, pipName) klog.Infof("deletePublicIP(%s): end, error: %v", pipName, err) - if err != nil { + if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { klog.Warningf("deletePublicIP(%s) failed: %s", pipName, err.Error()) return err } @@ -92,29 +99,12 @@ func (dt *DiffTracker) createOrUpdateLB(ctx context.Context, lb armnetwork.LoadB // deleteLB deletes a load balancer by service UID func (dt *DiffTracker) deleteLB(ctx context.Context, uid string) error { - // Normalize uid = strings.ToLower(uid) + klog.V(3).Infof("deleteLB: deleting LoadBalancer %s", uid) - // Try to retrieve the live Service - svc, err := dt.getServiceByUID(ctx, uid) - if err != nil { - // Service not found - try direct cleanup - klog.V(3).Infof("deleteLB: service uid %s not found, attempting direct LoadBalancer deletion", uid) - - // Delete the load balancer directly by name (uid is the LB name) - if err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid); err != nil { - klog.Warningf("deleteLB: failed to delete LoadBalancer %s directly: %v", uid, err) - return err - } - - klog.V(3).Infof("deleteLB: successfully deleted LoadBalancer %s directly", uid) - return nil - } - - // Service exists - use standard deletion - if err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid); err != nil { - return fmt.Errorf("deleteLB: failed to delete LoadBalancer for service %s/%s (uid=%s): %w", - svc.Namespace, svc.Name, uid, err) + err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid) + if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { + return fmt.Errorf("deleteLB: failed to delete LoadBalancer (uid=%s): %w", uid, err) } return nil @@ -147,7 +137,7 @@ func (dt *DiffTracker) deleteNatGateway(ctx context.Context, natGatewayResourceG klog.Infof("deleteNatGateway(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) err := dt.networkClientFactory.GetNatGatewayClient().Delete(ctx, natGatewayResourceGroup, natGatewayName) - if err != nil { + if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { klog.Errorf("NatGatewayClient.Delete(%s) in resource group %s failed: %v", natGatewayName, natGatewayResourceGroup, err) return err } @@ -162,7 +152,7 @@ func (dt *DiffTracker) deleteNatGateway(ctx context.Context, natGatewayResourceG func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Context, serviceGatewayName string, natGatewayName string) error { klog.Infof("disassociateNatGatewayFromServiceGateway: Disassociating NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) - // Step 1: Get the service and remove the NAT gateway reference + // Step 1: clear the ServiceGateway-side reference if it still has one. services, err := dt.networkClientFactory.GetServiceGatewayClient().GetServices(ctx, dt.config.ResourceGroup, serviceGatewayName) if err != nil { klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to get Service Gateway %s: %v", serviceGatewayName, err) @@ -177,48 +167,41 @@ func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Cont } } - if serviceToBeUpdated == nil { - klog.Infof("disassociateNatGatewayFromServiceGateway: NAT Gateway %s is not associated with Service Gateway %s", natGatewayName, serviceGatewayName) - return nil - } - - // Remove the NAT gateway reference from the service - if serviceToBeUpdated.Properties != nil { + if serviceToBeUpdated != nil && serviceToBeUpdated.Properties != nil && serviceToBeUpdated.Properties.PublicNatGatewayID != nil { serviceToBeUpdated.Properties.PublicNatGatewayID = nil - } - - updateServicesRequest := armnetwork.ServiceGatewayUpdateServicesRequest{ - Action: ptr.To(armnetwork.ServiceUpdateActionPartialUpdate), - ServiceRequests: []*armnetwork.ServiceGatewayServiceRequest{ - { - IsDelete: ptr.To(false), - Service: serviceToBeUpdated, + updateServicesRequest := armnetwork.ServiceGatewayUpdateServicesRequest{ + Action: ptr.To(armnetwork.ServiceUpdateActionPartialUpdate), + ServiceRequests: []*armnetwork.ServiceGatewayServiceRequest{ + { + IsDelete: ptr.To(false), + Service: serviceToBeUpdated, + }, }, - }, - } - - err = dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, updateServicesRequest) - if err != nil { - klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update Service Gateway %s to disassociate NAT Gateway %s: %v", serviceGatewayName, natGatewayName, err) - return fmt.Errorf("failed to update Service Gateway: %w", err) + } + if err := dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, updateServicesRequest); err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update Service Gateway %s to disassociate NAT Gateway %s: %v", serviceGatewayName, natGatewayName, err) + return fmt.Errorf("failed to update Service Gateway: %w", err) + } } - klog.Infof("disassociateNatGatewayFromServiceGateway: Successfully removed NAT Gateway %s reference from Service Gateway %s", natGatewayName, serviceGatewayName) - // Step 2: Get the NAT gateway and remove the service gateway reference + // Step 2: clear the NAT-gateway-side reference if it still points back at the + // Service Gateway. This runs independently of Step 1 so a retry after a + // partial failure still reconciles the NAT gateway. natGateway, err := dt.networkClientFactory.GetNatGatewayClient().Get(ctx, dt.config.ResourceGroup, natGatewayName, nil) if err != nil { + if exists, cerr := errutils.CheckResourceExistsFromAzcoreError(err); cerr == nil && !exists { + return nil + } klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to get NAT Gateway %s: %v", natGatewayName, err) return fmt.Errorf("failed to get NAT Gateway: %w", err) } - if natGateway.Properties != nil { + if natGateway.Properties != nil && natGateway.Properties.ServiceGateway != nil { natGateway.Properties.ServiceGateway = nil - } - - _, err = dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, natGatewayName, *natGateway) - if err != nil { - klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update NAT Gateway %s to remove Service Gateway reference: %v", natGatewayName, err) - return fmt.Errorf("failed to update NAT Gateway: %w", err) + if _, err := dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, natGatewayName, *natGateway); err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update NAT Gateway %s to remove Service Gateway reference: %v", natGatewayName, err) + return fmt.Errorf("failed to update NAT Gateway: %w", err) + } } klog.Infof("disassociateNatGatewayFromServiceGateway: Successfully disassociated NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) @@ -227,8 +210,7 @@ func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Cont // updateNRPSGWServices updates services in the Service Gateway func (dt *DiffTracker) updateNRPSGWServices(ctx context.Context, serviceGatewayName string, updateServicesRequestDTO ServicesDataDTO) error { - // Early return if no services to update - if len(updateServicesRequestDTO.Services) == 0 { + if len(updateServicesRequestDTO.Services) == 0 && updateServicesRequestDTO.Action != FullUpdate { klog.Infof("updateNRPSGWServices(%s): no services to update", serviceGatewayName) return nil } @@ -236,12 +218,16 @@ func (dt *DiffTracker) updateNRPSGWServices(ctx context.Context, serviceGatewayN klog.Infof("updateNRPSGWServices(%s): start", serviceGatewayName) // Convert DTO to ARM SDK request + serviceRequests, err := convertServiceDTOsToServiceRequests(updateServicesRequestDTO.Services, dt.config) + if err != nil { + return fmt.Errorf("updateNRPSGWServices(%s): %w", serviceGatewayName, err) + } req := armnetwork.ServiceGatewayUpdateServicesRequest{ Action: convertServicesUpdateActionToARM(updateServicesRequestDTO.Action), - ServiceRequests: convertServiceDTOsToServiceRequests(updateServicesRequestDTO.Services, dt.config), + ServiceRequests: serviceRequests, } - err := dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, req) + err = dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, req) if err != nil { klog.Warningf("ServiceGatewayClient.UpdateServices(%s) failed: %v", serviceGatewayName, err) return err @@ -296,35 +282,65 @@ func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serv if ip == "" { return fmt.Errorf("updateServiceLoadBalancerStatus: ip is empty") } - - svc, err := dt.getServiceByUID(ctx, serviceUID) - if err != nil { - return fmt.Errorf("updateServiceLoadBalancerStatus: failed to get service: %w", err) + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: invalid ip %q", ip) } + newIsIPv4 := parsedIP.To4() != nil + + return retry.RetryOnConflict(retry.DefaultRetry, func() error { + svc, err := dt.getServiceByUID(ctx, serviceUID) + if err != nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: failed to get service: %w", err) + } + + desired := make([]v1.LoadBalancerIngress, 0, len(svc.Status.LoadBalancer.Ingress)+1) + newPresent := false + for _, ingress := range svc.Status.LoadBalancer.Ingress { + if ingress.IP == "" { + desired = append(desired, ingress) + continue + } + ingressIP := net.ParseIP(ingress.IP) + if ingressIP != nil && (ingressIP.To4() != nil) == newIsIPv4 { + if ingress.IP == ip { + newPresent = true + desired = append(desired, ingress) + } + continue + } + desired = append(desired, ingress) + } + if !newPresent { + desired = append(desired, v1.LoadBalancerIngress{IP: ip}) + } - // Check if the status already has this IP to avoid unnecessary updates - for _, ingress := range svc.Status.LoadBalancer.Ingress { - if ingress.IP == ip { + if loadBalancerIngressEqual(svc.Status.LoadBalancer.Ingress, desired) { klog.V(3).Infof("updateServiceLoadBalancerStatus: service %s/%s already has IP %s", svc.Namespace, svc.Name, ip) return nil } - } - // Make a copy so we don't mutate the shared informer cache - updated := svc.DeepCopy() - updated.Status.LoadBalancer = v1.LoadBalancerStatus{ - Ingress: []v1.LoadBalancerIngress{ - {IP: ip}, - }, - } + updated := svc.DeepCopy() + updated.Status.LoadBalancer.Ingress = desired + if _, err := servicehelper.PatchService(dt.kubeClient.CoreV1(), svc, updated); err != nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: failed to patch service: %w", err) + } - _, err = servicehelper.PatchService(dt.kubeClient.CoreV1(), svc, updated) - if err != nil { - return fmt.Errorf("updateServiceLoadBalancerStatus: failed to patch service: %w", err) - } + klog.V(2).Infof("updateServiceLoadBalancerStatus: updated service %s/%s with LoadBalancer IP %s", svc.Namespace, svc.Name, ip) + return nil + }) +} - klog.V(2).Infof("updateServiceLoadBalancerStatus: updated service %s/%s with LoadBalancer IP %s", svc.Namespace, svc.Name, ip) - return nil +func loadBalancerIngressEqual(a, b []v1.LoadBalancerIngress) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].IP != b[i].IP || a[i].Hostname != b[i].Hostname { + return false + } + } + return true } // Helper functions to convert DTOs to ARM SDK types @@ -336,6 +352,7 @@ func convertServicesUpdateActionToARM(action UpdateAction) *armnetwork.ServiceUp case FullUpdate: return ptr.To(armnetwork.ServiceUpdateActionFullUpdate) default: + klog.Warningf("convertServicesUpdateActionToARM: unknown UpdateAction %q, defaulting to PartialUpdate", action) return ptr.To(armnetwork.ServiceUpdateActionPartialUpdate) } } @@ -347,32 +364,27 @@ func convertLocationsUpdateActionToARM(action UpdateAction) *armnetwork.UpdateAc case FullUpdate: return ptr.To(armnetwork.UpdateActionFullUpdate) default: + klog.Warningf("convertLocationsUpdateActionToARM: unknown UpdateAction %q, defaulting to PartialUpdate", action) return ptr.To(armnetwork.UpdateActionPartialUpdate) } } -// extractResourceChildName extracts a child resource name from an Azure resource ID +// extractResourceChildName extracts a child resource name from an Azure resource ID. +// Returns "" if the requested segment is not present. func extractResourceChildName(id, segment string) string { if id == "" { return "" } parts := strings.Split(id, "/") - // Look for the explicit segment first for i := 0; i < len(parts)-1; i++ { if strings.EqualFold(parts[i], segment) && parts[i+1] != "" { return parts[i+1] } } - // Fallback: last non-empty - for i := len(parts) - 1; i >= 0; i-- { - if parts[i] != "" { - return parts[i] - } - } return "" } -func convertServiceDTOsToServiceRequests(services []ServiceDTO, config Config) []*armnetwork.ServiceGatewayServiceRequest { +func convertServiceDTOsToServiceRequests(services []ServiceDTO, config Config) ([]*armnetwork.ServiceGatewayServiceRequest, error) { serviceRequests := make([]*armnetwork.ServiceGatewayServiceRequest, 0, len(services)) // Construct VNet resource ID once for all backend pools @@ -405,7 +417,12 @@ func convertServiceDTOsToServiceRequests(services []ServiceDTO, config Config) [ publicNatGatewayID = nil case Outbound: serviceType = armnetwork.ServiceTypeOutbound - publicNatGatewayID = &svc.PublicNatGateway.Id + if !svc.IsDelete && svc.PublicNatGateway.Id != "" { + natID := svc.PublicNatGateway.Id + publicNatGatewayID = &natID + } + default: + return nil, fmt.Errorf("convertServiceDTOsToServiceRequests: unknown ServiceType %q for service %q", svc.ServiceType, svc.Service) } // Build and append the service request @@ -421,7 +438,7 @@ func convertServiceDTOsToServiceRequests(services []ServiceDTO, config Config) [ }, }) } - return serviceRequests + return serviceRequests, nil } func convertLocationDTOsToAddressLocations(locations []LocationDTO) []*armnetwork.ServiceGatewayAddressLocation { @@ -448,10 +465,8 @@ func convertLocationDTOsToAddressLocations(locations []LocationDTO) []*armnetwor // Convert service names - always initialize the slice to avoid null in JSON armAddr.Services = make([]*string, 0, addr.ServiceNames.Len()) - if addr.ServiceNames != nil && addr.ServiceNames.Len() > 0 { - for _, svcName := range addr.ServiceNames.UnsortedList() { - armAddr.Services = append(armAddr.Services, ptr.To(svcName)) - } + for _, svcName := range addr.ServiceNames.UnsortedList() { + armAddr.Services = append(armAddr.Services, ptr.To(svcName)) } armLoc.Addresses = append(armLoc.Addresses, armAddr) diff --git a/pkg/provider/difftracker/azure_operations_test.go b/pkg/provider/difftracker/azure_operations_test.go new file mode 100644 index 0000000000..94355f6163 --- /dev/null +++ b/pkg/provider/difftracker/azure_operations_test.go @@ -0,0 +1,128 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestConvertServiceDTOsToServiceRequests_OutboundRemovalHasNoNatGatewayID(t *testing.T) { + reqs, err := convertServiceDTOsToServiceRequests([]ServiceDTO{ + {Service: "egr1", ServiceType: Outbound, IsDelete: true}, + }, Config{SubscriptionID: "sub", ResourceGroup: "rg", VNetName: "vnet"}) + assert.NoError(t, err) + assert.Len(t, reqs, 1) + assert.Nil(t, reqs[0].Service.Properties.PublicNatGatewayID) +} + +func TestConvertServiceDTOsToServiceRequests_OutboundAddHasNatGatewayID(t *testing.T) { + natID := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/natGateways/egr1" + reqs, err := convertServiceDTOsToServiceRequests([]ServiceDTO{ + {Service: "egr1", ServiceType: Outbound, PublicNatGateway: NatGatewayDTO{Id: natID}}, + }, Config{SubscriptionID: "sub", ResourceGroup: "rg", VNetName: "vnet"}) + assert.NoError(t, err) + assert.Len(t, reqs, 1) + if assert.NotNil(t, reqs[0].Service.Properties.PublicNatGatewayID) { + assert.Equal(t, natID, *reqs[0].Service.Properties.PublicNatGatewayID) + } +} + +func TestConvertServiceDTOsToServiceRequests_UnknownServiceTypeErrors(t *testing.T) { + _, err := convertServiceDTOsToServiceRequests([]ServiceDTO{ + {Service: "x", ServiceType: ServiceType("")}, + }, Config{SubscriptionID: "sub", ResourceGroup: "rg", VNetName: "vnet"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown ServiceType") +} + +func TestExtractResourceChildName(t *testing.T) { + id := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb1/backendAddressPools/pool1" + assert.Equal(t, "pool1", extractResourceChildName(id, "backendAddressPools")) + assert.Equal(t, "", extractResourceChildName("/subscriptions/sub/loadBalancers/lb1", "backendAddressPools")) + assert.Equal(t, "", extractResourceChildName("", "backendAddressPools")) +} + +func TestCreateOrUpdatePIPWithResponse_NilPip(t *testing.T) { + dt := &DiffTracker{} + _, err := dt.createOrUpdatePIPWithResponse(context.Background(), "rg", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "pip is nil") +} + +func TestUpdateServiceLoadBalancerStatus_PreservesDualStackAndHostname(t *testing.T) { + svc := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "svc", Namespace: "ns", UID: "uid-1"}, + Status: v1.ServiceStatus{ + LoadBalancer: v1.LoadBalancerStatus{ + Ingress: []v1.LoadBalancerIngress{ + {IP: "2001:db8::1"}, + {Hostname: "example.com"}, + }, + }, + }, + } + kubeClient := fake.NewSimpleClientset(svc) + dt := &DiffTracker{kubeClient: kubeClient} + + err := dt.updateServiceLoadBalancerStatus(context.Background(), "uid-1", "10.0.0.1") + assert.NoError(t, err) + + got, err := kubeClient.CoreV1().Services("ns").Get(context.Background(), "svc", metav1.GetOptions{}) + assert.NoError(t, err) + var ips, hosts []string + for _, ing := range got.Status.LoadBalancer.Ingress { + if ing.IP != "" { + ips = append(ips, ing.IP) + } + if ing.Hostname != "" { + hosts = append(hosts, ing.Hostname) + } + } + assert.Contains(t, ips, "10.0.0.1") + assert.Contains(t, ips, "2001:db8::1") + assert.Contains(t, hosts, "example.com") +} + +func TestUpdateServiceLoadBalancerStatus_ReplacesStaleSameFamily(t *testing.T) { + svc := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "svc", Namespace: "ns", UID: "uid-2"}, + Status: v1.ServiceStatus{ + LoadBalancer: v1.LoadBalancerStatus{ + Ingress: []v1.LoadBalancerIngress{{IP: "10.0.0.9"}}, + }, + }, + } + kubeClient := fake.NewSimpleClientset(svc) + dt := &DiffTracker{kubeClient: kubeClient} + + err := dt.updateServiceLoadBalancerStatus(context.Background(), "uid-2", "10.0.0.1") + assert.NoError(t, err) + + got, err := kubeClient.CoreV1().Services("ns").Get(context.Background(), "svc", metav1.GetOptions{}) + assert.NoError(t, err) + var ips []string + for _, ing := range got.Status.LoadBalancer.Ingress { + ips = append(ips, ing.IP) + } + assert.Equal(t, []string{"10.0.0.1"}, ips) +} diff --git a/pkg/provider/difftracker/resource_helpers.go b/pkg/provider/difftracker/resource_helpers.go index ce4715995b..8aa0cea0c7 100644 --- a/pkg/provider/difftracker/resource_helpers.go +++ b/pkg/provider/difftracker/resource_helpers.go @@ -7,7 +7,9 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/klog/v2" + "sigs.k8s.io/cloud-provider-azure/pkg/consts" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -18,6 +20,7 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo pip armnetwork.PublicIPAddress, lb armnetwork.LoadBalancer, servicesDTO ServicesDataDTO, + err error, ) { pipName := fmt.Sprintf("%s-pip", serviceUID) @@ -57,38 +60,59 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo var probes []*armnetwork.Probe if config != nil && len(config.FrontendPorts) > 0 { - // For SLB with PodIP backend pool, we disable floating IP and don't create health probes - // Traffic goes directly to pod IPs on the backend port + idleTimeout := int32(4) + if config.IdleTimeoutMinutes != nil { + idleTimeout = *config.IdleTimeoutMinutes + if idleTimeout < 4 || idleTimeout > 30 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: idle timeout %d out of range (4-30) for service %s", idleTimeout, serviceUID) + } + } for i, frontendPort := range config.FrontendPorts { backendPort := frontendPort.Port if i < len(config.BackendPorts) { backendPort = config.BackendPorts[i].Port } - protocol := armnetwork.TransportProtocolTCP - if frontendPort.Protocol == "UDP" { + if frontendPort.Port < 1 || frontendPort.Port > 65534 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: frontend port %d out of range (1-65534) for service %s", frontendPort.Port, serviceUID) + } + if backendPort < 1 || backendPort > 65534 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: backend port %d out of range (1-65534) for service %s", backendPort, serviceUID) + } + + var protocol armnetwork.TransportProtocol + switch { + case strings.EqualFold(frontendPort.Protocol, "TCP"): + protocol = armnetwork.TransportProtocolTCP + case strings.EqualFold(frontendPort.Protocol, "UDP"): protocol = armnetwork.TransportProtocolUDP + default: + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: unsupported protocol %q for service %s", frontendPort.Protocol, serviceUID) } ruleName := fmt.Sprintf("rule-%s-%d", strings.ToLower(frontendPort.Protocol), frontendPort.Port) - lbRules = append(lbRules, &armnetwork.LoadBalancingRule{ - Name: to.Ptr(ruleName), - Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ - Protocol: to.Ptr(protocol), - FrontendPort: to.Ptr(frontendPort.Port), - BackendPort: to.Ptr(backendPort), - EnableFloatingIP: to.Ptr(false), // Disabled for PodIP backend - IdleTimeoutInMinutes: to.Ptr(int32(4)), - EnableTCPReset: to.Ptr(true), - FrontendIPConfiguration: &armnetwork.SubResource{ - ID: to.Ptr(frontendIPConfigID), - }, - BackendAddressPool: &armnetwork.SubResource{ - ID: to.Ptr(backendPoolID), - }, - // No probe for PodIP backend pools + ruleProps := &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + FrontendPort: to.Ptr(frontendPort.Port), + BackendPort: to.Ptr(backendPort), + EnableFloatingIP: to.Ptr(false), // Disabled for PodIP backend + IdleTimeoutInMinutes: to.Ptr(idleTimeout), + FrontendIPConfiguration: &armnetwork.SubResource{ + ID: to.Ptr(frontendIPConfigID), + }, + BackendAddressPool: &armnetwork.SubResource{ + ID: to.Ptr(backendPoolID), }, + // No probe for PodIP backend pools + } + if protocol == armnetwork.TransportProtocolTCP { + ruleProps.EnableTCPReset = to.Ptr(true) + } + + lbRules = append(lbRules, &armnetwork.LoadBalancingRule{ + Name: to.Ptr(ruleName), + Properties: ruleProps, }) klog.V(4).Infof("buildInboundServiceResources: created LB rule %s: frontend=%d backend=%d protocol=%s for service %s", @@ -138,7 +162,7 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo dtConfig.ResourceGroup, ) - return pip, lb, servicesDTO + return pip, lb, servicesDTO, nil } // buildOutboundServiceResources constructs the PIP, NAT Gateway, and ServicesDTO for an outbound service @@ -241,14 +265,15 @@ func ExtractInboundConfigFromService(service *v1.Service) *InboundConfig { // For SLB with PodIP backend, we use TargetPort // If TargetPort is not specified, default to Port backendPort := port.Port - if port.TargetPort.Type == 0 && port.TargetPort.IntVal > 0 { // intstr.Int - backendPort = port.TargetPort.IntVal - } else if port.TargetPort.Type == 1 { // intstr.String - // Named ports not supported in SLB mode - use Port as fallback - serviceName := fmt.Sprintf("%s/%s", service.Namespace, service.Name) - klog.V(2).Infof("Named targetPort %s not supported in SLB mode for service %s, using Port %d", - port.TargetPort.StrVal, serviceName, port.Port) - backendPort = port.Port + switch port.TargetPort.Type { + case intstr.Int: + if port.TargetPort.IntVal > 0 { + backendPort = port.TargetPort.IntVal + } + case intstr.String: + klog.Warningf("ExtractInboundConfigFromService: named targetPort %q is not supported for service %s/%s; rejecting", + port.TargetPort.StrVal, service.Namespace, service.Name) + return nil } config.BackendPorts = append(config.BackendPorts, PortMapping{ diff --git a/pkg/provider/difftracker/resource_helpers_test.go b/pkg/provider/difftracker/resource_helpers_test.go index d8ad99d90d..eda01ded30 100644 --- a/pkg/provider/difftracker/resource_helpers_test.go +++ b/pkg/provider/difftracker/resource_helpers_test.go @@ -159,11 +159,7 @@ func TestExtractInboundConfigFromService_NamedTargetPort(t *testing.T) { } config := ExtractInboundConfigFromService(service) - assert.NotNil(t, config) - - // Named ports should fall back to Port - assert.Equal(t, int32(80), config.FrontendPorts[0].Port) - assert.Equal(t, int32(80), config.BackendPorts[0].Port) + assert.Nil(t, config) } func TestExtractInboundConfigFromService_EmptyProtocol(t *testing.T) { @@ -211,7 +207,8 @@ func TestBuildInboundServiceResources_WithConfig(t *testing.T) { ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", } - pip, lb, servicesDTO := buildInboundServiceResources("service-uid-123", config, dtConfig) + pip, lb, servicesDTO, err := buildInboundServiceResources("service-uid-123", config, dtConfig) + assert.NoError(t, err) // Verify PIP assert.NotNil(t, pip.Name) @@ -261,7 +258,8 @@ func TestBuildInboundServiceResources_NilConfig(t *testing.T) { ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", } - pip, lb, servicesDTO := buildInboundServiceResources("service-uid-123", nil, dtConfig) + pip, lb, servicesDTO, err := buildInboundServiceResources("service-uid-123", nil, dtConfig) + assert.NoError(t, err) // Should still create LB, just without rules assert.NotNil(t, lb.Name) @@ -297,7 +295,8 @@ func TestBuildInboundServiceResources_UDPProtocol(t *testing.T) { ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", } - _, lb, _ := buildInboundServiceResources("service-uid-udp", config, dtConfig) + _, lb, _, err := buildInboundServiceResources("service-uid-udp", config, dtConfig) + assert.NoError(t, err) // Verify UDP rule assert.Len(t, lb.Properties.LoadBalancingRules, 1) @@ -305,6 +304,7 @@ func TestBuildInboundServiceResources_UDPProtocol(t *testing.T) { assert.Equal(t, "rule-udp-53", *rule.Name) assert.Equal(t, armnetwork.TransportProtocolUDP, *rule.Properties.Protocol) assert.Equal(t, int32(53), *rule.Properties.FrontendPort) + assert.Nil(t, rule.Properties.EnableTCPReset) assert.Equal(t, int32(5353), *rule.Properties.BackendPort) } @@ -376,7 +376,8 @@ func TestBuildInboundServiceResources_BackendPoolNaming(t *testing.T) { ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", } - _, lb, _ := buildInboundServiceResources("my-service-uid", config, dtConfig) + _, lb, _, err := buildInboundServiceResources("my-service-uid", config, dtConfig) + assert.NoError(t, err) // Backend pool name must match serviceUID for SLB mode assert.Len(t, lb.Properties.BackendAddressPools, 1) @@ -402,7 +403,8 @@ func TestBuildInboundServiceResources_NoProbesForPodIPBackend(t *testing.T) { ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", } - _, lb, _ := buildInboundServiceResources("service-uid", config, dtConfig) + _, lb, _, err := buildInboundServiceResources("service-uid", config, dtConfig) + assert.NoError(t, err) // For PodIP backend pools, no health probes should be created assert.Empty(t, lb.Properties.Probes) @@ -426,7 +428,8 @@ func TestBuildInboundServiceResources_ResourceIDs(t *testing.T) { ServiceGatewayID: "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/serviceGateways/sgw-789", } - pip, lb, _ := buildInboundServiceResources("svc-abc", config, dtConfig) + pip, lb, _, err := buildInboundServiceResources("svc-abc", config, dtConfig) + assert.NoError(t, err) // Verify PIP ID format expectedPIPID := "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/publicIPAddresses/svc-abc-pip" @@ -441,3 +444,93 @@ func TestBuildInboundServiceResources_ResourceIDs(t *testing.T) { expectedBackendPoolID := "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/loadBalancers/svc-abc/backendAddressPools/svc-abc" assert.Equal(t, expectedBackendPoolID, *rule.Properties.BackendAddressPool.ID) } + +func TestBuildInboundServiceResources_LowercaseUDP(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 53, Protocol: "udp"}}, + BackendPorts: []PortMapping{{Port: 5353, Protocol: "udp"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + assert.Equal(t, armnetwork.TransportProtocolUDP, *lb.Properties.LoadBalancingRules[0].Properties.Protocol) +} + +func TestBuildInboundServiceResources_UnsupportedProtocolErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 53, Protocol: "SCTP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported protocol") +} + +func TestBuildInboundServiceResources_PortOutOfRangeErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 65535, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "out of range") +} + +func TestBuildInboundServiceResources_TCPHasResetEnabled(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + if assert.NotNil(t, lb.Properties.LoadBalancingRules[0].Properties.EnableTCPReset) { + assert.True(t, *lb.Properties.LoadBalancingRules[0].Properties.EnableTCPReset) + } +} + +func TestBuildInboundServiceResources_BackendPortOutOfRangeErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 65535, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "backend port") +} + +func TestBuildInboundServiceResources_AppliesIdleTimeout(t *testing.T) { + idle := int32(30) + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + IdleTimeoutMinutes: &idle, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + assert.Equal(t, int32(30), *lb.Properties.LoadBalancingRules[0].Properties.IdleTimeoutInMinutes) +} + +func TestBuildInboundServiceResources_IdleTimeoutOutOfRangeErrors(t *testing.T) { + idle := int32(99) + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + IdleTimeoutMinutes: &idle, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "idle timeout") +} diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go index 4b014f1f2a..3ac5611801 100644 --- a/pkg/provider/difftracker/types.go +++ b/pkg/provider/difftracker/types.go @@ -303,7 +303,6 @@ func healthProbeEqual(a, b *HealthProbeConfig) bool { return strPtrEqual(a.RequestPath, b.RequestPath) } - // -------------------------------------------------------------------------------- // Data Transfer Objects (DTOs) for LocationData (following the ServiceGateway API documentation) // -------------------------------------------------------------------------------- From 54bdc2d4856cc925a6f193f26e3ec5a76c9c0e8b Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Wed, 24 Jun 2026 17:06:14 +0000 Subject: [PATCH 16/18] Enhance Azure operations: add LoadBalancerSKUNameService constant, improve logging verbosity, and update tests for backend port validation --- pkg/consts/consts.go | 3 + pkg/provider/difftracker/azure_operations.go | 93 +++++++++++-------- pkg/provider/difftracker/dto_mappers.go | 1 - pkg/provider/difftracker/resource_helpers.go | 50 +++++----- .../difftracker/resource_helpers_test.go | 44 ++++----- 5 files changed, 107 insertions(+), 84 deletions(-) diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 7211b6ebc7..523639fc89 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -243,6 +243,9 @@ const ( LoadBalancerSKUStandard = "standard" // LoadBalancerSKUService is the load balancer service SKU LoadBalancerSKUService = "service" + // LoadBalancerSKUNameService is the case-sensitive ARM SKU name for the + // ServiceGateway inbound load balancer; it must be sent as "Service". + LoadBalancerSKUNameService = "Service" // ServiceAnnotationLoadBalancerInternal is the annotation used on the service ServiceAnnotationLoadBalancerInternal = "service.beta.kubernetes.io/azure-load-balancer-internal" diff --git a/pkg/provider/difftracker/azure_operations.go b/pkg/provider/difftracker/azure_operations.go index db689d05db..d12144c085 100644 --- a/pkg/provider/difftracker/azure_operations.go +++ b/pkg/provider/difftracker/azure_operations.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + // Package difftracker provides state tracking and synchronization between Kubernetes // resources and Azure Network Resource Provider (NRP) resources. // @@ -24,6 +40,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" v1 "k8s.io/api/core/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/util/retry" servicehelper "k8s.io/cloud-provider/service/helpers" @@ -49,10 +66,10 @@ func (dt *DiffTracker) createOrUpdatePIPWithResponse(ctx context.Context, pipRes if pipName == "" { return nil, fmt.Errorf("createOrUpdatePIPWithResponse: pip name is empty") } - klog.Infof("createOrUpdatePIPWithResponse(%s): start", pipName) + klog.V(4).Infof("createOrUpdatePIPWithResponse(%s): start", pipName) response, err := dt.networkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(ctx, pipResourceGroup, pipName, *pip) - klog.V(10).Infof("PublicIPAddressClient.CreateOrUpdate(%s, %s): end", pipResourceGroup, pipName) + klog.V(4).Infof("PublicIPAddressClient.CreateOrUpdate(%s, %s): end", pipResourceGroup, pipName) if err != nil { klog.Warningf("PublicIPAddressClient.CreateOrUpdate(%s, %s) failed: %s", pipResourceGroup, pipName, err.Error()) return nil, err @@ -66,10 +83,10 @@ func (dt *DiffTracker) deletePublicIP(ctx context.Context, pipResourceGroup stri if pipName == "" { return fmt.Errorf("deletePublicIP: pipName is empty") } - klog.Infof("deletePublicIP(%s): start", pipName) + klog.V(4).Infof("deletePublicIP(%s): start", pipName) err := dt.networkClientFactory.GetPublicIPAddressClient().Delete(ctx, pipResourceGroup, pipName) - klog.Infof("deletePublicIP(%s): end, error: %v", pipName, err) + klog.V(4).Infof("deletePublicIP(%s): end, error: %v", pipName, err) if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { klog.Warningf("deletePublicIP(%s) failed: %s", pipName, err.Error()) @@ -85,10 +102,10 @@ func (dt *DiffTracker) createOrUpdateLB(ctx context.Context, lb armnetwork.LoadB if lbName == "" { return fmt.Errorf("createOrUpdateLB: load balancer name is empty") } - klog.Infof("createOrUpdateLB(%s): start", lbName) + klog.V(4).Infof("createOrUpdateLB(%s): start", lbName) _, err := dt.networkClientFactory.GetLoadBalancerClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, lbName, lb) - klog.V(10).Infof("LoadBalancerClient.CreateOrUpdate(%s): end", lbName) + klog.V(4).Infof("LoadBalancerClient.CreateOrUpdate(%s): end", lbName) if err != nil { klog.Warningf("LoadBalancerClient.CreateOrUpdate(%s) failed: %v", lbName, err) return err @@ -100,7 +117,7 @@ func (dt *DiffTracker) createOrUpdateLB(ctx context.Context, lb armnetwork.LoadB // deleteLB deletes a load balancer by service UID func (dt *DiffTracker) deleteLB(ctx context.Context, uid string) error { uid = strings.ToLower(uid) - klog.V(3).Infof("deleteLB: deleting LoadBalancer %s", uid) + klog.V(4).Infof("deleteLB: deleting LoadBalancer %s", uid) err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid) if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { @@ -116,7 +133,7 @@ func (dt *DiffTracker) createOrUpdateNatGateway(ctx context.Context, natGatewayR if natGatewayName == "" { return fmt.Errorf("createOrUpdateNatGateway: NAT gateway name is empty") } - klog.Infof("createOrUpdateNatGateway(%s): start", natGatewayName) + klog.V(4).Infof("createOrUpdateNatGateway(%s): start", natGatewayName) _, err := dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, natGatewayResourceGroup, natGatewayName, natGateway) if err != nil { @@ -124,8 +141,8 @@ func (dt *DiffTracker) createOrUpdateNatGateway(ctx context.Context, natGatewayR return err } - klog.V(10).Infof("NatGatewayClient.CreateOrUpdate(%s): success", natGatewayName) - klog.Infof("createOrUpdateNatGateway(%s): end, error: nil", natGatewayName) + klog.V(4).Infof("NatGatewayClient.CreateOrUpdate(%s): success", natGatewayName) + klog.V(4).Infof("createOrUpdateNatGateway(%s): end, error: nil", natGatewayName) return nil } @@ -134,7 +151,7 @@ func (dt *DiffTracker) deleteNatGateway(ctx context.Context, natGatewayResourceG if natGatewayName == "" { return fmt.Errorf("deleteNatGateway: NAT gateway name is empty") } - klog.Infof("deleteNatGateway(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) + klog.V(4).Infof("deleteNatGateway(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) err := dt.networkClientFactory.GetNatGatewayClient().Delete(ctx, natGatewayResourceGroup, natGatewayName) if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { @@ -142,15 +159,15 @@ func (dt *DiffTracker) deleteNatGateway(ctx context.Context, natGatewayResourceG return err } - klog.V(10).Infof("NatGatewayClient.Delete(%s) in resource group %s: success", natGatewayName, natGatewayResourceGroup) - klog.Infof("deleteNatGateway(%s) in resource group %s: end, error: nil", natGatewayName, natGatewayResourceGroup) + klog.V(4).Infof("NatGatewayClient.Delete(%s) in resource group %s: success", natGatewayName, natGatewayResourceGroup) + klog.V(4).Infof("deleteNatGateway(%s) in resource group %s: end, error: nil", natGatewayName, natGatewayResourceGroup) return nil } // disassociateNatGatewayFromServiceGateway removes the NAT gateway association from the Service Gateway // This should be called before deleting the NAT gateway to properly clean up the references func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Context, serviceGatewayName string, natGatewayName string) error { - klog.Infof("disassociateNatGatewayFromServiceGateway: Disassociating NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) + klog.V(2).Infof("disassociateNatGatewayFromServiceGateway: Disassociating NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) // Step 1: clear the ServiceGateway-side reference if it still has one. services, err := dt.networkClientFactory.GetServiceGatewayClient().GetServices(ctx, dt.config.ResourceGroup, serviceGatewayName) @@ -204,18 +221,18 @@ func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Cont } } - klog.Infof("disassociateNatGatewayFromServiceGateway: Successfully disassociated NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) + klog.V(2).Infof("disassociateNatGatewayFromServiceGateway: Successfully disassociated NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) return nil } // updateNRPSGWServices updates services in the Service Gateway func (dt *DiffTracker) updateNRPSGWServices(ctx context.Context, serviceGatewayName string, updateServicesRequestDTO ServicesDataDTO) error { if len(updateServicesRequestDTO.Services) == 0 && updateServicesRequestDTO.Action != FullUpdate { - klog.Infof("updateNRPSGWServices(%s): no services to update", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWServices(%s): no services to update", serviceGatewayName) return nil } - klog.Infof("updateNRPSGWServices(%s): start", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWServices(%s): start", serviceGatewayName) // Convert DTO to ARM SDK request serviceRequests, err := convertServiceDTOsToServiceRequests(updateServicesRequestDTO.Services, dt.config) @@ -233,14 +250,14 @@ func (dt *DiffTracker) updateNRPSGWServices(ctx context.Context, serviceGatewayN return err } - klog.V(10).Infof("ServiceGatewayClient.UpdateServices(%s): success", serviceGatewayName) - klog.Infof("updateNRPSGWServices(%s): end, error: nil", serviceGatewayName) + klog.V(4).Infof("ServiceGatewayClient.UpdateServices(%s): success", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWServices(%s): end, error: nil", serviceGatewayName) return nil } // updateNRPSGWAddressLocations updates address locations in the Service Gateway func (dt *DiffTracker) updateNRPSGWAddressLocations(ctx context.Context, serviceGatewayName string, locationsDTO LocationsDataDTO) error { - klog.Infof("updateNRPSGWAddressLocations(%s): start", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWAddressLocations(%s): start", serviceGatewayName) // Convert DTO to ARM SDK request req := armnetwork.ServiceGatewayUpdateAddressLocationsRequest{ @@ -254,14 +271,19 @@ func (dt *DiffTracker) updateNRPSGWAddressLocations(ctx context.Context, service return err } - klog.V(10).Infof("ServiceGatewayClient.UpdateAddressLocations(%s): success", serviceGatewayName) - klog.Infof("updateNRPSGWAddressLocations(%s): end, error: nil", serviceGatewayName) + klog.V(4).Infof("ServiceGatewayClient.UpdateAddressLocations(%s): success", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWAddressLocations(%s): end, error: nil", serviceGatewayName) return nil } // getServiceByUID returns the Service whose UID matches the given uid func (dt *DiffTracker) getServiceByUID(ctx context.Context, uid string) (*v1.Service, error) { - // list via client (could be expensive; acceptable for initialization) + // This lists all Services and scans for a UID match, which is expensive. It is + // acceptable for now since it runs once per service operation (plus conflict retries), + // not in a hot path. + // TODO: maintain a UID -> namespace/name map from the Service informer events so this + // becomes an O(1) lookup plus a direct Services(ns).Get(name) (or lister-cache read) + // instead of a NamespaceAll list. svcList, err := dt.kubeClient.CoreV1().Services(v1.NamespaceAll).List(ctx, metav1.ListOptions{}) if err != nil { return nil, fmt.Errorf("getServiceByUID: list failed: %w", err) @@ -275,9 +297,10 @@ func (dt *DiffTracker) getServiceByUID(ctx context.Context, uid string) (*v1.Ser } // updateServiceLoadBalancerStatus updates the K8s Service status with the LoadBalancer IP address. -// This is called after the PIP is successfully created in ServiceGateway mode to populate -// the Service.Status.LoadBalancer.Ingress field, which would otherwise be empty since -// EnsureLoadBalancer returns immediately in async mode. +// EnsureLoadBalancer returns an empty LoadBalancer status immediately while the difftracker +// engine provisions the PIP, LB and ServiceGateway registration asynchronously in the +// background. This function backfills Service.Status.LoadBalancer.Ingress once the PIP is +// created, since it would otherwise stay empty. func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serviceUID string, ip string) error { if ip == "" { return fmt.Errorf("updateServiceLoadBalancerStatus: ip is empty") @@ -296,6 +319,10 @@ func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serv desired := make([]v1.LoadBalancerIngress, 0, len(svc.Status.LoadBalancer.Ingress)+1) newPresent := false + // Rebuild the ingress list, keeping it dual-stack safe: preserve non-IP (hostname-only) + // entries and entries of the other IP family untouched, while replacing any stale + // same-family IP with the new one. If the new IP is already present we keep it and + // avoid appending a duplicate below. for _, ingress := range svc.Status.LoadBalancer.Ingress { if ingress.IP == "" { desired = append(desired, ingress) @@ -315,7 +342,7 @@ func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serv desired = append(desired, v1.LoadBalancerIngress{IP: ip}) } - if loadBalancerIngressEqual(svc.Status.LoadBalancer.Ingress, desired) { + if apiequality.Semantic.DeepEqual(svc.Status.LoadBalancer.Ingress, desired) { klog.V(3).Infof("updateServiceLoadBalancerStatus: service %s/%s already has IP %s", svc.Namespace, svc.Name, ip) return nil } @@ -331,18 +358,6 @@ func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serv }) } -func loadBalancerIngressEqual(a, b []v1.LoadBalancerIngress) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i].IP != b[i].IP || a[i].Hostname != b[i].Hostname { - return false - } - } - return true -} - // Helper functions to convert DTOs to ARM SDK types func convertServicesUpdateActionToARM(action UpdateAction) *armnetwork.ServiceUpdateAction { diff --git a/pkg/provider/difftracker/dto_mappers.go b/pkg/provider/difftracker/dto_mappers.go index 0efc1779bf..dc85795a11 100644 --- a/pkg/provider/difftracker/dto_mappers.go +++ b/pkg/provider/difftracker/dto_mappers.go @@ -38,7 +38,6 @@ func MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(loadBalancerUpdates Sy resourceGroup, service, service, - // fmt.Sprintf("%s-backendpool", service), ), }, }, diff --git a/pkg/provider/difftracker/resource_helpers.go b/pkg/provider/difftracker/resource_helpers.go index 8aa0cea0c7..b9b40bfd8c 100644 --- a/pkg/provider/difftracker/resource_helpers.go +++ b/pkg/provider/difftracker/resource_helpers.go @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package difftracker import ( @@ -55,9 +71,9 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo }, } - // Build LB rules and probes from config + // Build LB rules from config. No health probes are created: PodIP backend pools + // don't support them, so the LB's Probes field is left nil. var lbRules []*armnetwork.LoadBalancingRule - var probes []*armnetwork.Probe if config != nil && len(config.FrontendPorts) > 0 { idleTimeout := int32(4) @@ -76,8 +92,8 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo if frontendPort.Port < 1 || frontendPort.Port > 65534 { return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: frontend port %d out of range (1-65534) for service %s", frontendPort.Port, serviceUID) } - if backendPort < 1 || backendPort > 65534 { - return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: backend port %d out of range (1-65534) for service %s", backendPort, serviceUID) + if backendPort < 1 || backendPort > 65535 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: backend port %d out of range (1-65535) for service %s", backendPort, serviceUID) } var protocol armnetwork.TransportProtocol @@ -127,7 +143,7 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s", dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID)), Location: to.Ptr(dtConfig.Location), SKU: &armnetwork.LoadBalancerSKU{ - Name: to.Ptr(armnetwork.LoadBalancerSKUName(consts.LoadBalancerSKUService)), + Name: to.Ptr(armnetwork.LoadBalancerSKUName(consts.LoadBalancerSKUNameService)), }, Properties: &armnetwork.LoadBalancerPropertiesFormat{ Scope: to.Ptr(armnetwork.LoadBalancerScopePublic), @@ -144,14 +160,13 @@ func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtCo }, BackendAddressPools: backendPools, LoadBalancingRules: lbRules, - Probes: probes, }, } // Build ServicesDTO for ServiceGateway registration servicesDTO = MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( SyncServicesReturnType{ - Additions: newIgnoreCaseSetFromSlice([]string{serviceUID}), + Additions: utilsets.NewString(serviceUID), Removals: nil, }, SyncServicesReturnType{ @@ -217,7 +232,7 @@ func buildOutboundServiceResources(serviceUID string, config *OutboundConfig, dt Removals: nil, }, SyncServicesReturnType{ - Additions: newIgnoreCaseSetFromSlice([]string{serviceUID}), + Additions: utilsets.NewString(serviceUID), Removals: nil, }, dtConfig.SubscriptionID, @@ -227,15 +242,6 @@ func buildOutboundServiceResources(serviceUID string, config *OutboundConfig, dt return pip, natGateway, servicesDTO } -// newIgnoreCaseSetFromSlice creates an IgnoreCaseSet from a slice of strings -func newIgnoreCaseSetFromSlice(items []string) *utilsets.IgnoreCaseSet { - set := utilsets.NewString() - for _, item := range items { - set.Insert(item) - } - return set -} - // ExtractInboundConfigFromService creates InboundConfig from a Kubernetes Service // This is shared between initialization and the provider layer func ExtractInboundConfigFromService(service *v1.Service) *InboundConfig { @@ -271,9 +277,9 @@ func ExtractInboundConfigFromService(service *v1.Service) *InboundConfig { backendPort = port.TargetPort.IntVal } case intstr.String: - klog.Warningf("ExtractInboundConfigFromService: named targetPort %q is not supported for service %s/%s; rejecting", - port.TargetPort.StrVal, service.Namespace, service.Name) - return nil + klog.Warningf("ExtractInboundConfigFromService: named targetPort %q is not supported for service %s/%s; falling back to port %d", + port.TargetPort.StrVal, service.Namespace, service.Name, port.Port) + backendPort = port.Port } config.BackendPorts = append(config.BackendPorts, PortMapping{ @@ -301,7 +307,7 @@ func buildServiceGatewayRemovalDTO(serviceUID string, isInbound bool, dtConfig C return MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( SyncServicesReturnType{ Additions: nil, - Removals: newIgnoreCaseSetFromSlice([]string{serviceUID}), + Removals: utilsets.NewString(serviceUID), }, SyncServicesReturnType{ Additions: nil, @@ -318,7 +324,7 @@ func buildServiceGatewayRemovalDTO(serviceUID string, isInbound bool, dtConfig C }, SyncServicesReturnType{ Additions: nil, - Removals: newIgnoreCaseSetFromSlice([]string{serviceUID}), + Removals: utilsets.NewString(serviceUID), }, dtConfig.SubscriptionID, dtConfig.ResourceGroup, diff --git a/pkg/provider/difftracker/resource_helpers_test.go b/pkg/provider/difftracker/resource_helpers_test.go index eda01ded30..2c5ef31166 100644 --- a/pkg/provider/difftracker/resource_helpers_test.go +++ b/pkg/provider/difftracker/resource_helpers_test.go @@ -3,8 +3,6 @@ package difftracker import ( "testing" - "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" "github.com/stretchr/testify/assert" v1 "k8s.io/api/core/v1" @@ -159,7 +157,12 @@ func TestExtractInboundConfigFromService_NamedTargetPort(t *testing.T) { } config := ExtractInboundConfigFromService(service) - assert.Nil(t, config) + assert.NotNil(t, config) + assert.Len(t, config.FrontendPorts, 1) + assert.Len(t, config.BackendPorts, 1) + + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, config.FrontendPorts[0].Port, config.BackendPorts[0].Port) } func TestExtractInboundConfigFromService_EmptyProtocol(t *testing.T) { @@ -219,7 +222,7 @@ func TestBuildInboundServiceResources_WithConfig(t *testing.T) { // Verify LoadBalancer assert.NotNil(t, lb.Name) assert.Equal(t, "service-uid-123", *lb.Name) - assert.Equal(t, armnetwork.LoadBalancerSKUName(consts.LoadBalancerSKUService), *lb.SKU.Name) + assert.Equal(t, "Service", string(*lb.SKU.Name)) assert.Equal(t, "eastus", *lb.Location) // Verify backend pool @@ -345,23 +348,6 @@ func TestBuildOutboundServiceResources_Basic(t *testing.T) { assert.Equal(t, Outbound, servicesDTO.Services[0].ServiceType) } -func TestNewIgnoreCaseSetFromSlice_Empty(t *testing.T) { - set := newIgnoreCaseSetFromSlice([]string{}) - assert.NotNil(t, set) - assert.Equal(t, 0, set.Len()) -} - -func TestNewIgnoreCaseSetFromSlice_WithItems(t *testing.T) { - items := []string{"service1", "service2", "SERVICE3"} - set := newIgnoreCaseSetFromSlice(items) - - assert.Equal(t, 3, set.Len()) - assert.True(t, set.Has("service1")) - assert.True(t, set.Has("service2")) - assert.True(t, set.Has("service3")) // Case insensitive - assert.True(t, set.Has("SERVICE3")) -} - func TestBuildInboundServiceResources_BackendPoolNaming(t *testing.T) { config := &InboundConfig{ FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, @@ -495,13 +481,27 @@ func TestBuildInboundServiceResources_TCPHasResetEnabled(t *testing.T) { } } -func TestBuildInboundServiceResources_BackendPortOutOfRangeErrors(t *testing.T) { +func TestBuildInboundServiceResources_BackendPortMaxIsValid(t *testing.T) { config := &InboundConfig{ FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, BackendPorts: []PortMapping{{Port: 65535, Protocol: "TCP"}}, } dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + if assert.Len(t, lb.Properties.LoadBalancingRules, 1) { + assert.Equal(t, int32(65535), *lb.Properties.LoadBalancingRules[0].Properties.BackendPort) + } +} + +func TestBuildInboundServiceResources_BackendPortOutOfRangeErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 65536, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) assert.Error(t, err) assert.Contains(t, err.Error(), "backend port") From 5fd7ab210c63371570f444662e3e930de1f1a3d0 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Wed, 24 Jun 2026 18:43:11 +0000 Subject: [PATCH 17/18] Add unit tests for inbound and outbound resource name generation and service gateway removal DTO --- .../azure_operations_clients_test.go | 378 ++++++++++++++++++ .../difftracker/resource_helpers_test.go | 37 ++ 2 files changed, 415 insertions(+) create mode 100644 pkg/provider/difftracker/azure_operations_clients_test.go diff --git a/pkg/provider/difftracker/azure_operations_clients_test.go b/pkg/provider/difftracker/azure_operations_clients_test.go new file mode 100644 index 0000000000..bbdf729221 --- /dev/null +++ b/pkg/provider/difftracker/azure_operations_clients_test.go @@ -0,0 +1,378 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/utils/ptr" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/natgatewayclient/mock_natgatewayclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/servicegatewayclient/mock_servicegatewayclient" +) + +func testConfig() Config { + return Config{ + SubscriptionID: "sub", + ResourceGroup: "rg", + Location: "eastus", + VNetName: "vnet", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/serviceGateways/sgw", + } +} + +func notFoundError() error { + return &azcore.ResponseError{StatusCode: http.StatusNotFound} +} + +func TestCreateOrUpdatePIP_Mock(t *testing.T) { + pip := &armnetwork.PublicIPAddress{Name: ptr.To("svc-pip")} + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc-pip", gomock.Any()).Return(pip, nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.createOrUpdatePIP(context.Background(), "rg", pip)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc-pip", gomock.Any()).Return(nil, errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.createOrUpdatePIP(context.Background(), "rg", pip)) + }) +} + +func TestDeletePublicIP_Mock(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().Delete(gomock.Any(), "rg", "svc-pip").Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deletePublicIP(context.Background(), "rg", "svc-pip")) + }) + + t.Run("not-found is success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().Delete(gomock.Any(), "rg", "svc-pip").Return(notFoundError()) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deletePublicIP(context.Background(), "rg", "svc-pip")) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().Delete(gomock.Any(), "rg", "svc-pip").Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.deletePublicIP(context.Background(), "rg", "svc-pip")) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.deletePublicIP(context.Background(), "rg", "")) + }) +} + +func TestCreateOrUpdateLB_Mock(t *testing.T) { + lb := armnetwork.LoadBalancer{Name: ptr.To("svc")} + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.createOrUpdateLB(context.Background(), lb)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.createOrUpdateLB(context.Background(), lb)) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.createOrUpdateLB(context.Background(), armnetwork.LoadBalancer{})) + }) +} + +func TestDeleteLB_Mock(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().Delete(gomock.Any(), "rg", "uid").Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deleteLB(context.Background(), "uid")) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().Delete(gomock.Any(), "rg", "uid").Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.deleteLB(context.Background(), "uid")) + }) +} + +func TestCreateOrUpdateNatGateway_Mock(t *testing.T) { + natGW := armnetwork.NatGateway{Name: ptr.To("svc")} + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.createOrUpdateNatGateway(context.Background(), "rg", natGW)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.createOrUpdateNatGateway(context.Background(), "rg", natGW)) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.createOrUpdateNatGateway(context.Background(), "rg", armnetwork.NatGateway{})) + }) +} + +func TestDeleteNatGateway_Mock(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().Delete(gomock.Any(), "rg", "svc").Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deleteNatGateway(context.Background(), "rg", "svc")) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().Delete(gomock.Any(), "rg", "svc").Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.deleteNatGateway(context.Background(), "rg", "svc")) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.deleteNatGateway(context.Background(), "rg", "")) + }) +} + +func TestUpdateNRPSGWServices_Mock(t *testing.T) { + servicesDTO := ServicesDataDTO{ + Action: PartialUpdate, + Services: []ServiceDTO{ + {Service: "svc", ServiceType: Inbound, IsDelete: true}, + }, + } + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateServices(gomock.Any(), "rg", "sgw", gomock.Any()).Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.updateNRPSGWServices(context.Background(), "sgw", servicesDTO)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateServices(gomock.Any(), "rg", "sgw", gomock.Any()).Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.updateNRPSGWServices(context.Background(), "sgw", servicesDTO)) + }) + + t.Run("no-op when empty and not full update", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + // No client call expected. + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.updateNRPSGWServices(context.Background(), "sgw", ServicesDataDTO{Action: PartialUpdate})) + }) +} + +func TestUpdateNRPSGWAddressLocations_Mock(t *testing.T) { + locationsDTO := LocationsDataDTO{ + Action: PartialUpdate, + Locations: []LocationDTO{ + {Location: "node1", AddressUpdateAction: PartialUpdate, Addresses: []AddressDTO{}}, + }, + } + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateAddressLocations(gomock.Any(), "rg", "sgw", gomock.Any()).Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.updateNRPSGWAddressLocations(context.Background(), "sgw", locationsDTO)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateAddressLocations(gomock.Any(), "rg", "sgw", gomock.Any()).Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.updateNRPSGWAddressLocations(context.Background(), "sgw", locationsDTO)) + }) +} + +func TestDisassociateNatGatewayFromServiceGateway_Mock(t *testing.T) { + // Simplest reconcile path: no matching SGW service to clear, and the NAT + // gateway is already gone (404) -> method returns nil. + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + + mockSGW.EXPECT().GetServices(gomock.Any(), "rg", "sgw").Return([]*armnetwork.ServiceGatewayService{}, nil) + mockNAT.EXPECT().Get(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, notFoundError()) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.disassociateNatGatewayFromServiceGateway(context.Background(), "sgw", "svc")) +} + +func TestConvertServicesUpdateActionToARM(t *testing.T) { + assert.Equal(t, armnetwork.ServiceUpdateActionPartialUpdate, *convertServicesUpdateActionToARM(PartialUpdate)) + assert.Equal(t, armnetwork.ServiceUpdateActionFullUpdate, *convertServicesUpdateActionToARM(FullUpdate)) + // Unknown defaults to PartialUpdate. + assert.Equal(t, armnetwork.ServiceUpdateActionPartialUpdate, *convertServicesUpdateActionToARM(UnknownUpdateAction)) +} + +func TestConvertLocationsUpdateActionToARM(t *testing.T) { + assert.Equal(t, armnetwork.UpdateActionPartialUpdate, *convertLocationsUpdateActionToARM(PartialUpdate)) + assert.Equal(t, armnetwork.UpdateActionFullUpdate, *convertLocationsUpdateActionToARM(FullUpdate)) + // Unknown defaults to PartialUpdate. + assert.Equal(t, armnetwork.UpdateActionPartialUpdate, *convertLocationsUpdateActionToARM(UnknownUpdateAction)) +} + +func TestConvertLocationDTOsToAddressLocations(t *testing.T) { + t.Run("drained node keeps non-nil empty Addresses", func(t *testing.T) { + locs := convertLocationDTOsToAddressLocations([]LocationDTO{ + {Location: "node1", AddressUpdateAction: FullUpdate, Addresses: []AddressDTO{}}, + }) + assert.Len(t, locs, 1) + assert.NotNil(t, locs[0].Addresses) + assert.Empty(t, locs[0].Addresses) + assert.Equal(t, armnetwork.AddressUpdateActionFullUpdate, *locs[0].AddressUpdateAction) + }) + + t.Run("address with empty ServiceNames keeps non-nil empty Services", func(t *testing.T) { + locs := convertLocationDTOsToAddressLocations([]LocationDTO{ + {Location: "node1", AddressUpdateAction: PartialUpdate, Addresses: []AddressDTO{ + {Address: "10.0.0.5", ServiceNames: nil}, + }}, + }) + assert.Len(t, locs, 1) + assert.Equal(t, armnetwork.AddressUpdateActionPartialUpdate, *locs[0].AddressUpdateAction) + assert.Len(t, locs[0].Addresses, 1) + assert.NotNil(t, locs[0].Addresses[0].Services) + assert.Empty(t, locs[0].Addresses[0].Services) + assert.Equal(t, "10.0.0.5", *locs[0].Addresses[0].Address) + }) +} diff --git a/pkg/provider/difftracker/resource_helpers_test.go b/pkg/provider/difftracker/resource_helpers_test.go index 2c5ef31166..2170e40a96 100644 --- a/pkg/provider/difftracker/resource_helpers_test.go +++ b/pkg/provider/difftracker/resource_helpers_test.go @@ -534,3 +534,40 @@ func TestBuildInboundServiceResources_IdleTimeoutOutOfRangeErrors(t *testing.T) assert.Error(t, err) assert.Contains(t, err.Error(), "idle timeout") } + +func TestBuildInboundResourceNames(t *testing.T) { + lbName, pipName, backendPoolName := buildInboundResourceNames("uid") + assert.Equal(t, "uid", lbName) + assert.Equal(t, "uid-pip", pipName) + assert.Equal(t, "uid", backendPoolName) +} + +func TestBuildOutboundResourceNames(t *testing.T) { + natGatewayName, pipName := buildOutboundResourceNames("uid") + assert.Equal(t, "uid", natGatewayName) + assert.Equal(t, "uid-pip", pipName) +} + +func TestBuildServiceGatewayRemovalDTO(t *testing.T) { + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg"} + + t.Run("inbound removal", func(t *testing.T) { + dto := buildServiceGatewayRemovalDTO("uid", true, dtConfig) + assert.Equal(t, PartialUpdate, dto.Action) + if assert.Len(t, dto.Services, 1) { + assert.Equal(t, "uid", dto.Services[0].Service) + assert.Equal(t, Inbound, dto.Services[0].ServiceType) + assert.True(t, dto.Services[0].IsDelete) + } + }) + + t.Run("outbound removal", func(t *testing.T) { + dto := buildServiceGatewayRemovalDTO("uid", false, dtConfig) + assert.Equal(t, PartialUpdate, dto.Action) + if assert.Len(t, dto.Services, 1) { + assert.Equal(t, "uid", dto.Services[0].Service) + assert.Equal(t, Outbound, dto.Services[0].ServiceType) + assert.True(t, dto.Services[0].IsDelete) + } + }) +} From f5c527e2ced077049b1db192510d3ca9baba0745 Mon Sep 17 00:00:00 2001 From: George Edward Nechitoaia <58257818+georgeedward2000@users.noreply.github.com> Date: Wed, 24 Jun 2026 20:28:58 +0000 Subject: [PATCH 18/18] Ensure unknown AddressUpdateAction defaults to PartialUpdate and add corresponding test --- pkg/provider/difftracker/azure_operations.go | 7 ++++++- .../difftracker/azure_operations_clients_test.go | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pkg/provider/difftracker/azure_operations.go b/pkg/provider/difftracker/azure_operations.go index d12144c085..1d0baeef79 100644 --- a/pkg/provider/difftracker/azure_operations.go +++ b/pkg/provider/difftracker/azure_operations.go @@ -463,12 +463,17 @@ func convertLocationDTOsToAddressLocations(locations []LocationDTO) []*armnetwor AddressLocation: ptr.To(loc.Location), } - // Set address update action + // Set address update action. Mirror the service/location action converters by + // defaulting an unknown (e.g. unset/zero) value to PartialUpdate instead of + // leaving it nil, so NRP always receives an explicit action. switch loc.AddressUpdateAction { case PartialUpdate: armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionPartialUpdate) case FullUpdate: armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionFullUpdate) + default: + klog.Warningf("convertLocationDTOsToAddressLocations: unknown AddressUpdateAction %v for location %q, defaulting to PartialUpdate", loc.AddressUpdateAction, loc.Location) + armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionPartialUpdate) } // Convert addresses - always initialize the slice to avoid null in JSON diff --git a/pkg/provider/difftracker/azure_operations_clients_test.go b/pkg/provider/difftracker/azure_operations_clients_test.go index bbdf729221..1089c79cbe 100644 --- a/pkg/provider/difftracker/azure_operations_clients_test.go +++ b/pkg/provider/difftracker/azure_operations_clients_test.go @@ -375,4 +375,16 @@ func TestConvertLocationDTOsToAddressLocations(t *testing.T) { assert.Empty(t, locs[0].Addresses[0].Services) assert.Equal(t, "10.0.0.5", *locs[0].Addresses[0].Address) }) + + t.Run("unknown AddressUpdateAction defaults to PartialUpdate", func(t *testing.T) { + // A LocationDTO whose AddressUpdateAction is left unset (zero value + // UnknownUpdateAction) must still produce an explicit action, matching the + // service/location action converters, rather than a nil AddressUpdateAction. + locs := convertLocationDTOsToAddressLocations([]LocationDTO{ + {Location: "node1", Addresses: []AddressDTO{}}, + }) + assert.Len(t, locs, 1) + assert.NotNil(t, locs[0].AddressUpdateAction) + assert.Equal(t, armnetwork.AddressUpdateActionPartialUpdate, *locs[0].AddressUpdateAction) + }) }