diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index e2305842be..523639fc89 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -243,6 +243,9 @@ const ( LoadBalancerSKUStandard = "standard" // LoadBalancerSKUService is the load balancer service SKU LoadBalancerSKUService = "service" + // LoadBalancerSKUNameService is the case-sensitive ARM SKU name for the + // ServiceGateway inbound load balancer; it must be sent as "Service". + LoadBalancerSKUNameService = "Service" // ServiceAnnotationLoadBalancerInternal is the annotation used on the service ServiceAnnotationLoadBalancerInternal = "service.beta.kubernetes.io/azure-load-balancer-internal" @@ -367,6 +370,8 @@ const ( FrontendIPConfigIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/frontendIPConfigurations/%s" // BackendPoolIDTemplate is the template of the backend pool BackendPoolIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s" + // NatGatewayIDTemplate is the template of the nat gateway + NatGatewayIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/natGateways/%s" // LoadBalancerProbeIDTemplate is the template of the load balancer probe LoadBalancerProbeIDTemplate = "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/probes/%s" diff --git a/pkg/provider/difftracker/azure_operations.go b/pkg/provider/difftracker/azure_operations.go new file mode 100644 index 0000000000..1d0baeef79 --- /dev/null +++ b/pkg/provider/difftracker/azure_operations.go @@ -0,0 +1,498 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package difftracker provides state tracking and synchronization between Kubernetes +// resources and Azure Network Resource Provider (NRP) resources. +// +// Architecture: +// - DiffTracker: Maintains the desired state (K8s) and actual state (NRP) +// - ServiceUpdater: Handles service creation/deletion (LoadBalancers, NAT Gateways) +// - LocationsUpdater: Handles address/location synchronization +// +// Azure Operations: +// All Azure SDK operations in azure_operations.go attempt their action once and return +// errors immediately. Retry logic is the responsibility of the callers (ServiceUpdater, +// LocationsUpdater) which implement appropriate retry strategies based on their use cases. +// +// Error Handling: +// Transient errors (throttling, timeouts) should be retried by callers. +// Permanent errors (authentication, resource not found) should not be retried. +package difftracker + +import ( + "context" + "fmt" + "net" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + v1 "k8s.io/api/core/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/util/retry" + servicehelper "k8s.io/cloud-provider/service/helpers" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" + + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" +) + +// createOrUpdatePIP creates or updates a public IP address +func (dt *DiffTracker) createOrUpdatePIP(ctx context.Context, pipResourceGroup string, pip *armnetwork.PublicIPAddress) error { + _, err := dt.createOrUpdatePIPWithResponse(ctx, pipResourceGroup, pip) + return err +} + +// createOrUpdatePIPWithResponse creates or updates a public IP address and returns the response +// containing the allocated IP address. Use this when you need the IP address after PIP creation. +func (dt *DiffTracker) createOrUpdatePIPWithResponse(ctx context.Context, pipResourceGroup string, pip *armnetwork.PublicIPAddress) (*armnetwork.PublicIPAddress, error) { + if pip == nil { + return nil, fmt.Errorf("createOrUpdatePIPWithResponse: pip is nil") + } + pipName := ptr.Deref(pip.Name, "") + if pipName == "" { + return nil, fmt.Errorf("createOrUpdatePIPWithResponse: pip name is empty") + } + klog.V(4).Infof("createOrUpdatePIPWithResponse(%s): start", pipName) + + response, err := dt.networkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(ctx, pipResourceGroup, pipName, *pip) + klog.V(4).Infof("PublicIPAddressClient.CreateOrUpdate(%s, %s): end", pipResourceGroup, pipName) + if err != nil { + klog.Warningf("PublicIPAddressClient.CreateOrUpdate(%s, %s) failed: %s", pipResourceGroup, pipName, err.Error()) + return nil, err + } + + return response, nil +} + +// deletePublicIP deletes a public IP address +func (dt *DiffTracker) deletePublicIP(ctx context.Context, pipResourceGroup string, pipName string) error { + if pipName == "" { + return fmt.Errorf("deletePublicIP: pipName is empty") + } + klog.V(4).Infof("deletePublicIP(%s): start", pipName) + + err := dt.networkClientFactory.GetPublicIPAddressClient().Delete(ctx, pipResourceGroup, pipName) + klog.V(4).Infof("deletePublicIP(%s): end, error: %v", pipName, err) + + if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { + klog.Warningf("deletePublicIP(%s) failed: %s", pipName, err.Error()) + return err + } + + return nil +} + +// createOrUpdateLB creates or updates a load balancer +func (dt *DiffTracker) createOrUpdateLB(ctx context.Context, lb armnetwork.LoadBalancer) error { + lbName := ptr.Deref(lb.Name, "") + if lbName == "" { + return fmt.Errorf("createOrUpdateLB: load balancer name is empty") + } + klog.V(4).Infof("createOrUpdateLB(%s): start", lbName) + + _, err := dt.networkClientFactory.GetLoadBalancerClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, lbName, lb) + klog.V(4).Infof("LoadBalancerClient.CreateOrUpdate(%s): end", lbName) + if err != nil { + klog.Warningf("LoadBalancerClient.CreateOrUpdate(%s) failed: %v", lbName, err) + return err + } + + return nil +} + +// deleteLB deletes a load balancer by service UID +func (dt *DiffTracker) deleteLB(ctx context.Context, uid string) error { + uid = strings.ToLower(uid) + klog.V(4).Infof("deleteLB: deleting LoadBalancer %s", uid) + + err := dt.networkClientFactory.GetLoadBalancerClient().Delete(ctx, dt.config.ResourceGroup, uid) + if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { + return fmt.Errorf("deleteLB: failed to delete LoadBalancer (uid=%s): %w", uid, err) + } + + return nil +} + +// createOrUpdateNatGateway creates or updates a NAT gateway +func (dt *DiffTracker) createOrUpdateNatGateway(ctx context.Context, natGatewayResourceGroup string, natGateway armnetwork.NatGateway) error { + natGatewayName := ptr.Deref(natGateway.Name, "") + if natGatewayName == "" { + return fmt.Errorf("createOrUpdateNatGateway: NAT gateway name is empty") + } + klog.V(4).Infof("createOrUpdateNatGateway(%s): start", natGatewayName) + + _, err := dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, natGatewayResourceGroup, natGatewayName, natGateway) + if err != nil { + klog.Warningf("NatGatewayClient.CreateOrUpdate(%s) failed: %v", natGatewayName, err) + return err + } + + klog.V(4).Infof("NatGatewayClient.CreateOrUpdate(%s): success", natGatewayName) + klog.V(4).Infof("createOrUpdateNatGateway(%s): end, error: nil", natGatewayName) + return nil +} + +// deleteNatGateway deletes a NAT gateway +func (dt *DiffTracker) deleteNatGateway(ctx context.Context, natGatewayResourceGroup string, natGatewayName string) error { + if natGatewayName == "" { + return fmt.Errorf("deleteNatGateway: NAT gateway name is empty") + } + klog.V(4).Infof("deleteNatGateway(%s) in resource group %s: start", natGatewayName, natGatewayResourceGroup) + + err := dt.networkClientFactory.GetNatGatewayClient().Delete(ctx, natGatewayResourceGroup, natGatewayName) + if _, err := errutils.CheckResourceExistsFromAzcoreError(err); err != nil { + klog.Errorf("NatGatewayClient.Delete(%s) in resource group %s failed: %v", natGatewayName, natGatewayResourceGroup, err) + return err + } + + klog.V(4).Infof("NatGatewayClient.Delete(%s) in resource group %s: success", natGatewayName, natGatewayResourceGroup) + klog.V(4).Infof("deleteNatGateway(%s) in resource group %s: end, error: nil", natGatewayName, natGatewayResourceGroup) + return nil +} + +// disassociateNatGatewayFromServiceGateway removes the NAT gateway association from the Service Gateway +// This should be called before deleting the NAT gateway to properly clean up the references +func (dt *DiffTracker) disassociateNatGatewayFromServiceGateway(ctx context.Context, serviceGatewayName string, natGatewayName string) error { + klog.V(2).Infof("disassociateNatGatewayFromServiceGateway: Disassociating NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) + + // Step 1: clear the ServiceGateway-side reference if it still has one. + services, err := dt.networkClientFactory.GetServiceGatewayClient().GetServices(ctx, dt.config.ResourceGroup, serviceGatewayName) + if err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to get Service Gateway %s: %v", serviceGatewayName, err) + return fmt.Errorf("failed to get Service Gateway services: %w", err) + } + + var serviceToBeUpdated *armnetwork.ServiceGatewayService + for _, service := range services { + if service.Name != nil && *service.Name == natGatewayName { + serviceToBeUpdated = service + break + } + } + + if serviceToBeUpdated != nil && serviceToBeUpdated.Properties != nil && serviceToBeUpdated.Properties.PublicNatGatewayID != nil { + serviceToBeUpdated.Properties.PublicNatGatewayID = nil + updateServicesRequest := armnetwork.ServiceGatewayUpdateServicesRequest{ + Action: ptr.To(armnetwork.ServiceUpdateActionPartialUpdate), + ServiceRequests: []*armnetwork.ServiceGatewayServiceRequest{ + { + IsDelete: ptr.To(false), + Service: serviceToBeUpdated, + }, + }, + } + if err := dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, updateServicesRequest); err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update Service Gateway %s to disassociate NAT Gateway %s: %v", serviceGatewayName, natGatewayName, err) + return fmt.Errorf("failed to update Service Gateway: %w", err) + } + } + + // Step 2: clear the NAT-gateway-side reference if it still points back at the + // Service Gateway. This runs independently of Step 1 so a retry after a + // partial failure still reconciles the NAT gateway. + natGateway, err := dt.networkClientFactory.GetNatGatewayClient().Get(ctx, dt.config.ResourceGroup, natGatewayName, nil) + if err != nil { + if exists, cerr := errutils.CheckResourceExistsFromAzcoreError(err); cerr == nil && !exists { + return nil + } + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to get NAT Gateway %s: %v", natGatewayName, err) + return fmt.Errorf("failed to get NAT Gateway: %w", err) + } + + if natGateway.Properties != nil && natGateway.Properties.ServiceGateway != nil { + natGateway.Properties.ServiceGateway = nil + if _, err := dt.networkClientFactory.GetNatGatewayClient().CreateOrUpdate(ctx, dt.config.ResourceGroup, natGatewayName, *natGateway); err != nil { + klog.Errorf("disassociateNatGatewayFromServiceGateway: Failed to update NAT Gateway %s to remove Service Gateway reference: %v", natGatewayName, err) + return fmt.Errorf("failed to update NAT Gateway: %w", err) + } + } + + klog.V(2).Infof("disassociateNatGatewayFromServiceGateway: Successfully disassociated NAT Gateway %s from Service Gateway %s in resource group %s", natGatewayName, serviceGatewayName, dt.config.ResourceGroup) + return nil +} + +// updateNRPSGWServices updates services in the Service Gateway +func (dt *DiffTracker) updateNRPSGWServices(ctx context.Context, serviceGatewayName string, updateServicesRequestDTO ServicesDataDTO) error { + if len(updateServicesRequestDTO.Services) == 0 && updateServicesRequestDTO.Action != FullUpdate { + klog.V(4).Infof("updateNRPSGWServices(%s): no services to update", serviceGatewayName) + return nil + } + + klog.V(4).Infof("updateNRPSGWServices(%s): start", serviceGatewayName) + + // Convert DTO to ARM SDK request + serviceRequests, err := convertServiceDTOsToServiceRequests(updateServicesRequestDTO.Services, dt.config) + if err != nil { + return fmt.Errorf("updateNRPSGWServices(%s): %w", serviceGatewayName, err) + } + req := armnetwork.ServiceGatewayUpdateServicesRequest{ + Action: convertServicesUpdateActionToARM(updateServicesRequestDTO.Action), + ServiceRequests: serviceRequests, + } + + err = dt.networkClientFactory.GetServiceGatewayClient().UpdateServices(ctx, dt.config.ResourceGroup, serviceGatewayName, req) + if err != nil { + klog.Warningf("ServiceGatewayClient.UpdateServices(%s) failed: %v", serviceGatewayName, err) + return err + } + + klog.V(4).Infof("ServiceGatewayClient.UpdateServices(%s): success", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWServices(%s): end, error: nil", serviceGatewayName) + return nil +} + +// updateNRPSGWAddressLocations updates address locations in the Service Gateway +func (dt *DiffTracker) updateNRPSGWAddressLocations(ctx context.Context, serviceGatewayName string, locationsDTO LocationsDataDTO) error { + klog.V(4).Infof("updateNRPSGWAddressLocations(%s): start", serviceGatewayName) + + // Convert DTO to ARM SDK request + req := armnetwork.ServiceGatewayUpdateAddressLocationsRequest{ + Action: convertLocationsUpdateActionToARM(locationsDTO.Action), + AddressLocations: convertLocationDTOsToAddressLocations(locationsDTO.Locations), + } + + err := dt.networkClientFactory.GetServiceGatewayClient().UpdateAddressLocations(ctx, dt.config.ResourceGroup, serviceGatewayName, req) + if err != nil { + klog.Warningf("ServiceGatewayClient.UpdateAddressLocations(%s) failed: %v", serviceGatewayName, err) + return err + } + + klog.V(4).Infof("ServiceGatewayClient.UpdateAddressLocations(%s): success", serviceGatewayName) + klog.V(4).Infof("updateNRPSGWAddressLocations(%s): end, error: nil", serviceGatewayName) + return nil +} + +// getServiceByUID returns the Service whose UID matches the given uid +func (dt *DiffTracker) getServiceByUID(ctx context.Context, uid string) (*v1.Service, error) { + // This lists all Services and scans for a UID match, which is expensive. It is + // acceptable for now since it runs once per service operation (plus conflict retries), + // not in a hot path. + // TODO: maintain a UID -> namespace/name map from the Service informer events so this + // becomes an O(1) lookup plus a direct Services(ns).Get(name) (or lister-cache read) + // instead of a NamespaceAll list. + svcList, err := dt.kubeClient.CoreV1().Services(v1.NamespaceAll).List(ctx, metav1.ListOptions{}) + if err != nil { + return nil, fmt.Errorf("getServiceByUID: list failed: %w", err) + } + for _, svc := range svcList.Items { + if string(svc.UID) == uid { + return svc.DeepCopy(), nil + } + } + return nil, fmt.Errorf("service with uid %s not found", uid) +} + +// updateServiceLoadBalancerStatus updates the K8s Service status with the LoadBalancer IP address. +// EnsureLoadBalancer returns an empty LoadBalancer status immediately while the difftracker +// engine provisions the PIP, LB and ServiceGateway registration asynchronously in the +// background. This function backfills Service.Status.LoadBalancer.Ingress once the PIP is +// created, since it would otherwise stay empty. +func (dt *DiffTracker) updateServiceLoadBalancerStatus(ctx context.Context, serviceUID string, ip string) error { + if ip == "" { + return fmt.Errorf("updateServiceLoadBalancerStatus: ip is empty") + } + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: invalid ip %q", ip) + } + newIsIPv4 := parsedIP.To4() != nil + + return retry.RetryOnConflict(retry.DefaultRetry, func() error { + svc, err := dt.getServiceByUID(ctx, serviceUID) + if err != nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: failed to get service: %w", err) + } + + desired := make([]v1.LoadBalancerIngress, 0, len(svc.Status.LoadBalancer.Ingress)+1) + newPresent := false + // Rebuild the ingress list, keeping it dual-stack safe: preserve non-IP (hostname-only) + // entries and entries of the other IP family untouched, while replacing any stale + // same-family IP with the new one. If the new IP is already present we keep it and + // avoid appending a duplicate below. + for _, ingress := range svc.Status.LoadBalancer.Ingress { + if ingress.IP == "" { + desired = append(desired, ingress) + continue + } + ingressIP := net.ParseIP(ingress.IP) + if ingressIP != nil && (ingressIP.To4() != nil) == newIsIPv4 { + if ingress.IP == ip { + newPresent = true + desired = append(desired, ingress) + } + continue + } + desired = append(desired, ingress) + } + if !newPresent { + desired = append(desired, v1.LoadBalancerIngress{IP: ip}) + } + + if apiequality.Semantic.DeepEqual(svc.Status.LoadBalancer.Ingress, desired) { + klog.V(3).Infof("updateServiceLoadBalancerStatus: service %s/%s already has IP %s", svc.Namespace, svc.Name, ip) + return nil + } + + updated := svc.DeepCopy() + updated.Status.LoadBalancer.Ingress = desired + if _, err := servicehelper.PatchService(dt.kubeClient.CoreV1(), svc, updated); err != nil { + return fmt.Errorf("updateServiceLoadBalancerStatus: failed to patch service: %w", err) + } + + klog.V(2).Infof("updateServiceLoadBalancerStatus: updated service %s/%s with LoadBalancer IP %s", svc.Namespace, svc.Name, ip) + return nil + }) +} + +// Helper functions to convert DTOs to ARM SDK types + +func convertServicesUpdateActionToARM(action UpdateAction) *armnetwork.ServiceUpdateAction { + switch action { + case PartialUpdate: + return ptr.To(armnetwork.ServiceUpdateActionPartialUpdate) + case FullUpdate: + return ptr.To(armnetwork.ServiceUpdateActionFullUpdate) + default: + klog.Warningf("convertServicesUpdateActionToARM: unknown UpdateAction %q, defaulting to PartialUpdate", action) + return ptr.To(armnetwork.ServiceUpdateActionPartialUpdate) + } +} + +func convertLocationsUpdateActionToARM(action UpdateAction) *armnetwork.UpdateAction { + switch action { + case PartialUpdate: + return ptr.To(armnetwork.UpdateActionPartialUpdate) + case FullUpdate: + return ptr.To(armnetwork.UpdateActionFullUpdate) + default: + klog.Warningf("convertLocationsUpdateActionToARM: unknown UpdateAction %q, defaulting to PartialUpdate", action) + return ptr.To(armnetwork.UpdateActionPartialUpdate) + } +} + +// extractResourceChildName extracts a child resource name from an Azure resource ID. +// Returns "" if the requested segment is not present. +func extractResourceChildName(id, segment string) string { + if id == "" { + return "" + } + parts := strings.Split(id, "/") + for i := 0; i < len(parts)-1; i++ { + if strings.EqualFold(parts[i], segment) && parts[i+1] != "" { + return parts[i+1] + } + } + return "" +} + +func convertServiceDTOsToServiceRequests(services []ServiceDTO, config Config) ([]*armnetwork.ServiceGatewayServiceRequest, error) { + serviceRequests := make([]*armnetwork.ServiceGatewayServiceRequest, 0, len(services)) + + // Construct VNet resource ID once for all backend pools + vnetID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/virtualNetworks/%s", + config.SubscriptionID, config.ResourceGroup, config.VNetName) + + for _, svc := range services { + // Build backend pools with full details + loadBalancerBackendPools := make([]*armnetwork.BackendAddressPool, len(svc.LoadBalancerBackendPools)) + for i := range svc.LoadBalancerBackendPools { + backendPoolResourceID := svc.LoadBalancerBackendPools[i].Id + backendPoolName := extractResourceChildName(backendPoolResourceID, "backendAddressPools") + loadBalancerBackendPools[i] = &armnetwork.BackendAddressPool{ + ID: &backendPoolResourceID, + Name: &backendPoolName, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ + ID: &vnetID, + }, + }, + } + } + + // Set service type and NAT gateway based on service type + var serviceType armnetwork.ServiceType + var publicNatGatewayID *string + switch svc.ServiceType { + case Inbound: + serviceType = armnetwork.ServiceTypeInbound + publicNatGatewayID = nil + case Outbound: + serviceType = armnetwork.ServiceTypeOutbound + if !svc.IsDelete && svc.PublicNatGateway.Id != "" { + natID := svc.PublicNatGateway.Id + publicNatGatewayID = &natID + } + default: + return nil, fmt.Errorf("convertServiceDTOsToServiceRequests: unknown ServiceType %q for service %q", svc.ServiceType, svc.Service) + } + + // Build and append the service request + serviceRequests = append(serviceRequests, &armnetwork.ServiceGatewayServiceRequest{ + IsDelete: &svc.IsDelete, + Service: &armnetwork.ServiceGatewayService{ + Name: &svc.Service, + Properties: &armnetwork.ServiceGatewayServicePropertiesFormat{ + LoadBalancerBackendPools: loadBalancerBackendPools, + PublicNatGatewayID: publicNatGatewayID, + ServiceType: &serviceType, + }, + }, + }) + } + return serviceRequests, nil +} + +func convertLocationDTOsToAddressLocations(locations []LocationDTO) []*armnetwork.ServiceGatewayAddressLocation { + armLocations := make([]*armnetwork.ServiceGatewayAddressLocation, 0, len(locations)) + for _, loc := range locations { + armLoc := &armnetwork.ServiceGatewayAddressLocation{ + AddressLocation: ptr.To(loc.Location), + } + + // Set address update action. Mirror the service/location action converters by + // defaulting an unknown (e.g. unset/zero) value to PartialUpdate instead of + // leaving it nil, so NRP always receives an explicit action. + switch loc.AddressUpdateAction { + case PartialUpdate: + armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionPartialUpdate) + case FullUpdate: + armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionFullUpdate) + default: + klog.Warningf("convertLocationDTOsToAddressLocations: unknown AddressUpdateAction %v for location %q, defaulting to PartialUpdate", loc.AddressUpdateAction, loc.Location) + armLoc.AddressUpdateAction = ptr.To(armnetwork.AddressUpdateActionPartialUpdate) + } + + // Convert addresses - always initialize the slice to avoid null in JSON + armLoc.Addresses = make([]*armnetwork.ServiceGatewayAddress, 0, len(loc.Addresses)) + for _, addr := range loc.Addresses { + armAddr := &armnetwork.ServiceGatewayAddress{ + Address: ptr.To(addr.Address), + } + + // Convert service names - always initialize the slice to avoid null in JSON + armAddr.Services = make([]*string, 0, addr.ServiceNames.Len()) + for _, svcName := range addr.ServiceNames.UnsortedList() { + armAddr.Services = append(armAddr.Services, ptr.To(svcName)) + } + + armLoc.Addresses = append(armLoc.Addresses, armAddr) + } + + armLocations = append(armLocations, armLoc) + } + return armLocations +} diff --git a/pkg/provider/difftracker/azure_operations_clients_test.go b/pkg/provider/difftracker/azure_operations_clients_test.go new file mode 100644 index 0000000000..1089c79cbe --- /dev/null +++ b/pkg/provider/difftracker/azure_operations_clients_test.go @@ -0,0 +1,390 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/utils/ptr" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/natgatewayclient/mock_natgatewayclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/servicegatewayclient/mock_servicegatewayclient" +) + +func testConfig() Config { + return Config{ + SubscriptionID: "sub", + ResourceGroup: "rg", + Location: "eastus", + VNetName: "vnet", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/serviceGateways/sgw", + } +} + +func notFoundError() error { + return &azcore.ResponseError{StatusCode: http.StatusNotFound} +} + +func TestCreateOrUpdatePIP_Mock(t *testing.T) { + pip := &armnetwork.PublicIPAddress{Name: ptr.To("svc-pip")} + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc-pip", gomock.Any()).Return(pip, nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.createOrUpdatePIP(context.Background(), "rg", pip)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc-pip", gomock.Any()).Return(nil, errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.createOrUpdatePIP(context.Background(), "rg", pip)) + }) +} + +func TestDeletePublicIP_Mock(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().Delete(gomock.Any(), "rg", "svc-pip").Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deletePublicIP(context.Background(), "rg", "svc-pip")) + }) + + t.Run("not-found is success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().Delete(gomock.Any(), "rg", "svc-pip").Return(notFoundError()) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deletePublicIP(context.Background(), "rg", "svc-pip")) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockPIP := mock_publicipaddressclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetPublicIPAddressClient().Return(mockPIP).AnyTimes() + mockPIP.EXPECT().Delete(gomock.Any(), "rg", "svc-pip").Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.deletePublicIP(context.Background(), "rg", "svc-pip")) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.deletePublicIP(context.Background(), "rg", "")) + }) +} + +func TestCreateOrUpdateLB_Mock(t *testing.T) { + lb := armnetwork.LoadBalancer{Name: ptr.To("svc")} + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.createOrUpdateLB(context.Background(), lb)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.createOrUpdateLB(context.Background(), lb)) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.createOrUpdateLB(context.Background(), armnetwork.LoadBalancer{})) + }) +} + +func TestDeleteLB_Mock(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().Delete(gomock.Any(), "rg", "uid").Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deleteLB(context.Background(), "uid")) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockLB := mock_loadbalancerclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetLoadBalancerClient().Return(mockLB).AnyTimes() + mockLB.EXPECT().Delete(gomock.Any(), "rg", "uid").Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.deleteLB(context.Background(), "uid")) + }) +} + +func TestCreateOrUpdateNatGateway_Mock(t *testing.T) { + natGW := armnetwork.NatGateway{Name: ptr.To("svc")} + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.createOrUpdateNatGateway(context.Background(), "rg", natGW)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.createOrUpdateNatGateway(context.Background(), "rg", natGW)) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.createOrUpdateNatGateway(context.Background(), "rg", armnetwork.NatGateway{})) + }) +} + +func TestDeleteNatGateway_Mock(t *testing.T) { + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().Delete(gomock.Any(), "rg", "svc").Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.deleteNatGateway(context.Background(), "rg", "svc")) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + mockNAT.EXPECT().Delete(gomock.Any(), "rg", "svc").Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.deleteNatGateway(context.Background(), "rg", "svc")) + }) + + t.Run("empty name", func(t *testing.T) { + dt := &DiffTracker{config: testConfig()} + assert.Error(t, dt.deleteNatGateway(context.Background(), "rg", "")) + }) +} + +func TestUpdateNRPSGWServices_Mock(t *testing.T) { + servicesDTO := ServicesDataDTO{ + Action: PartialUpdate, + Services: []ServiceDTO{ + {Service: "svc", ServiceType: Inbound, IsDelete: true}, + }, + } + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateServices(gomock.Any(), "rg", "sgw", gomock.Any()).Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.updateNRPSGWServices(context.Background(), "sgw", servicesDTO)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateServices(gomock.Any(), "rg", "sgw", gomock.Any()).Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.updateNRPSGWServices(context.Background(), "sgw", servicesDTO)) + }) + + t.Run("no-op when empty and not full update", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + // No client call expected. + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.updateNRPSGWServices(context.Background(), "sgw", ServicesDataDTO{Action: PartialUpdate})) + }) +} + +func TestUpdateNRPSGWAddressLocations_Mock(t *testing.T) { + locationsDTO := LocationsDataDTO{ + Action: PartialUpdate, + Locations: []LocationDTO{ + {Location: "node1", AddressUpdateAction: PartialUpdate, Addresses: []AddressDTO{}}, + }, + } + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateAddressLocations(gomock.Any(), "rg", "sgw", gomock.Any()).Return(nil) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.updateNRPSGWAddressLocations(context.Background(), "sgw", locationsDTO)) + }) + + t.Run("error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockSGW.EXPECT().UpdateAddressLocations(gomock.Any(), "rg", "sgw", gomock.Any()).Return(errors.New("boom")) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.Error(t, dt.updateNRPSGWAddressLocations(context.Background(), "sgw", locationsDTO)) + }) +} + +func TestDisassociateNatGatewayFromServiceGateway_Mock(t *testing.T) { + // Simplest reconcile path: no matching SGW service to clear, and the NAT + // gateway is already gone (404) -> method returns nil. + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockSGW := mock_servicegatewayclient.NewMockInterface(ctrl) + mockNAT := mock_natgatewayclient.NewMockInterface(ctrl) + mockFactory.EXPECT().GetServiceGatewayClient().Return(mockSGW).AnyTimes() + mockFactory.EXPECT().GetNatGatewayClient().Return(mockNAT).AnyTimes() + + mockSGW.EXPECT().GetServices(gomock.Any(), "rg", "sgw").Return([]*armnetwork.ServiceGatewayService{}, nil) + mockNAT.EXPECT().Get(gomock.Any(), "rg", "svc", gomock.Any()).Return(nil, notFoundError()) + + dt := &DiffTracker{networkClientFactory: mockFactory, config: testConfig()} + assert.NoError(t, dt.disassociateNatGatewayFromServiceGateway(context.Background(), "sgw", "svc")) +} + +func TestConvertServicesUpdateActionToARM(t *testing.T) { + assert.Equal(t, armnetwork.ServiceUpdateActionPartialUpdate, *convertServicesUpdateActionToARM(PartialUpdate)) + assert.Equal(t, armnetwork.ServiceUpdateActionFullUpdate, *convertServicesUpdateActionToARM(FullUpdate)) + // Unknown defaults to PartialUpdate. + assert.Equal(t, armnetwork.ServiceUpdateActionPartialUpdate, *convertServicesUpdateActionToARM(UnknownUpdateAction)) +} + +func TestConvertLocationsUpdateActionToARM(t *testing.T) { + assert.Equal(t, armnetwork.UpdateActionPartialUpdate, *convertLocationsUpdateActionToARM(PartialUpdate)) + assert.Equal(t, armnetwork.UpdateActionFullUpdate, *convertLocationsUpdateActionToARM(FullUpdate)) + // Unknown defaults to PartialUpdate. + assert.Equal(t, armnetwork.UpdateActionPartialUpdate, *convertLocationsUpdateActionToARM(UnknownUpdateAction)) +} + +func TestConvertLocationDTOsToAddressLocations(t *testing.T) { + t.Run("drained node keeps non-nil empty Addresses", func(t *testing.T) { + locs := convertLocationDTOsToAddressLocations([]LocationDTO{ + {Location: "node1", AddressUpdateAction: FullUpdate, Addresses: []AddressDTO{}}, + }) + assert.Len(t, locs, 1) + assert.NotNil(t, locs[0].Addresses) + assert.Empty(t, locs[0].Addresses) + assert.Equal(t, armnetwork.AddressUpdateActionFullUpdate, *locs[0].AddressUpdateAction) + }) + + t.Run("address with empty ServiceNames keeps non-nil empty Services", func(t *testing.T) { + locs := convertLocationDTOsToAddressLocations([]LocationDTO{ + {Location: "node1", AddressUpdateAction: PartialUpdate, Addresses: []AddressDTO{ + {Address: "10.0.0.5", ServiceNames: nil}, + }}, + }) + assert.Len(t, locs, 1) + assert.Equal(t, armnetwork.AddressUpdateActionPartialUpdate, *locs[0].AddressUpdateAction) + assert.Len(t, locs[0].Addresses, 1) + assert.NotNil(t, locs[0].Addresses[0].Services) + assert.Empty(t, locs[0].Addresses[0].Services) + assert.Equal(t, "10.0.0.5", *locs[0].Addresses[0].Address) + }) + + t.Run("unknown AddressUpdateAction defaults to PartialUpdate", func(t *testing.T) { + // A LocationDTO whose AddressUpdateAction is left unset (zero value + // UnknownUpdateAction) must still produce an explicit action, matching the + // service/location action converters, rather than a nil AddressUpdateAction. + locs := convertLocationDTOsToAddressLocations([]LocationDTO{ + {Location: "node1", Addresses: []AddressDTO{}}, + }) + assert.Len(t, locs, 1) + assert.NotNil(t, locs[0].AddressUpdateAction) + assert.Equal(t, armnetwork.AddressUpdateActionPartialUpdate, *locs[0].AddressUpdateAction) + }) +} diff --git a/pkg/provider/difftracker/azure_operations_test.go b/pkg/provider/difftracker/azure_operations_test.go new file mode 100644 index 0000000000..94355f6163 --- /dev/null +++ b/pkg/provider/difftracker/azure_operations_test.go @@ -0,0 +1,128 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestConvertServiceDTOsToServiceRequests_OutboundRemovalHasNoNatGatewayID(t *testing.T) { + reqs, err := convertServiceDTOsToServiceRequests([]ServiceDTO{ + {Service: "egr1", ServiceType: Outbound, IsDelete: true}, + }, Config{SubscriptionID: "sub", ResourceGroup: "rg", VNetName: "vnet"}) + assert.NoError(t, err) + assert.Len(t, reqs, 1) + assert.Nil(t, reqs[0].Service.Properties.PublicNatGatewayID) +} + +func TestConvertServiceDTOsToServiceRequests_OutboundAddHasNatGatewayID(t *testing.T) { + natID := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/natGateways/egr1" + reqs, err := convertServiceDTOsToServiceRequests([]ServiceDTO{ + {Service: "egr1", ServiceType: Outbound, PublicNatGateway: NatGatewayDTO{Id: natID}}, + }, Config{SubscriptionID: "sub", ResourceGroup: "rg", VNetName: "vnet"}) + assert.NoError(t, err) + assert.Len(t, reqs, 1) + if assert.NotNil(t, reqs[0].Service.Properties.PublicNatGatewayID) { + assert.Equal(t, natID, *reqs[0].Service.Properties.PublicNatGatewayID) + } +} + +func TestConvertServiceDTOsToServiceRequests_UnknownServiceTypeErrors(t *testing.T) { + _, err := convertServiceDTOsToServiceRequests([]ServiceDTO{ + {Service: "x", ServiceType: ServiceType("")}, + }, Config{SubscriptionID: "sub", ResourceGroup: "rg", VNetName: "vnet"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown ServiceType") +} + +func TestExtractResourceChildName(t *testing.T) { + id := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb1/backendAddressPools/pool1" + assert.Equal(t, "pool1", extractResourceChildName(id, "backendAddressPools")) + assert.Equal(t, "", extractResourceChildName("/subscriptions/sub/loadBalancers/lb1", "backendAddressPools")) + assert.Equal(t, "", extractResourceChildName("", "backendAddressPools")) +} + +func TestCreateOrUpdatePIPWithResponse_NilPip(t *testing.T) { + dt := &DiffTracker{} + _, err := dt.createOrUpdatePIPWithResponse(context.Background(), "rg", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "pip is nil") +} + +func TestUpdateServiceLoadBalancerStatus_PreservesDualStackAndHostname(t *testing.T) { + svc := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "svc", Namespace: "ns", UID: "uid-1"}, + Status: v1.ServiceStatus{ + LoadBalancer: v1.LoadBalancerStatus{ + Ingress: []v1.LoadBalancerIngress{ + {IP: "2001:db8::1"}, + {Hostname: "example.com"}, + }, + }, + }, + } + kubeClient := fake.NewSimpleClientset(svc) + dt := &DiffTracker{kubeClient: kubeClient} + + err := dt.updateServiceLoadBalancerStatus(context.Background(), "uid-1", "10.0.0.1") + assert.NoError(t, err) + + got, err := kubeClient.CoreV1().Services("ns").Get(context.Background(), "svc", metav1.GetOptions{}) + assert.NoError(t, err) + var ips, hosts []string + for _, ing := range got.Status.LoadBalancer.Ingress { + if ing.IP != "" { + ips = append(ips, ing.IP) + } + if ing.Hostname != "" { + hosts = append(hosts, ing.Hostname) + } + } + assert.Contains(t, ips, "10.0.0.1") + assert.Contains(t, ips, "2001:db8::1") + assert.Contains(t, hosts, "example.com") +} + +func TestUpdateServiceLoadBalancerStatus_ReplacesStaleSameFamily(t *testing.T) { + svc := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "svc", Namespace: "ns", UID: "uid-2"}, + Status: v1.ServiceStatus{ + LoadBalancer: v1.LoadBalancerStatus{ + Ingress: []v1.LoadBalancerIngress{{IP: "10.0.0.9"}}, + }, + }, + } + kubeClient := fake.NewSimpleClientset(svc) + dt := &DiffTracker{kubeClient: kubeClient} + + err := dt.updateServiceLoadBalancerStatus(context.Background(), "uid-2", "10.0.0.1") + assert.NoError(t, err) + + got, err := kubeClient.CoreV1().Services("ns").Get(context.Background(), "svc", metav1.GetOptions{}) + assert.NoError(t, err) + var ips []string + for _, ing := range got.Status.LoadBalancer.Ingress { + ips = append(ips, ing.IP) + } + assert.Equal(t, []string{"10.0.0.1"}, ips) +} diff --git a/pkg/provider/difftracker/config.go b/pkg/provider/difftracker/config.go new file mode 100644 index 0000000000..b277a4cdbc --- /dev/null +++ b/pkg/provider/difftracker/config.go @@ -0,0 +1,66 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import "fmt" + +// Config holds the configuration values needed by DiffTracker +// to perform Azure operations without depending on the entire AzureCloud struct +// This allows DiffTracker to be more modular and testable +type Config struct { + // Azure subscription ID + SubscriptionID string + + // Azure resource group name + ResourceGroup string + + // Azure location/region (e.g. "eastus2"). Distinct from the difftracker + // Location type, which identifies a node by IP. + Location string + + // Service Gateway resource name + ServiceGatewayResourceName string + + // Full Service Gateway resource ID + ServiceGatewayID string + + // Virtual Network name (required for backend pool configuration) + VNetName string +} + +// Validate checks if the configuration has all required fields +func (c *Config) Validate() error { + if c.SubscriptionID == "" { + return fmt.Errorf("config validation failed: SubscriptionID is required") + } + if c.ResourceGroup == "" { + return fmt.Errorf("config validation failed: ResourceGroup is required") + } + if c.Location == "" { + return fmt.Errorf("config validation failed: Location is required") + } + if c.ServiceGatewayResourceName == "" { + return fmt.Errorf("config validation failed: ServiceGatewayResourceName is required") + } + if c.ServiceGatewayID == "" { + return fmt.Errorf("config validation failed: ServiceGatewayID is required") + } + if c.VNetName == "" { + return fmt.Errorf("config validation failed: VNetName is required") + } + return nil +} diff --git a/pkg/provider/difftracker/difftracker.go b/pkg/provider/difftracker/difftracker.go new file mode 100644 index 0000000000..3066d690bb --- /dev/null +++ b/pkg/provider/difftracker/difftracker.go @@ -0,0 +1,86 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "fmt" + + "k8s.io/client-go/kubernetes" + "k8s.io/klog/v2" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" +) + +// New creates and initializes a new DiffTracker with the given state and configuration. +// It validates the configuration and ensures all required dependencies are present. +// Returns an error if the configuration is invalid or if any required dependency is nil. +func New(k8s K8sState, nrp NRPState, config Config, networkClientFactory azclient.ClientFactory, kubeClient kubernetes.Interface) (*DiffTracker, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("difftracker.New: %w", err) + } + + if networkClientFactory == nil { + return nil, fmt.Errorf("difftracker.New: networkClientFactory must not be nil") + } + if kubeClient == nil { + return nil, fmt.Errorf("difftracker.New: kubeClient must not be nil") + } + + klog.V(2).Infof("difftracker.New: initializing with config: subscription=%s, resourceGroup=%s, location=%s, serviceGatewayResourceName=%s, serviceGatewayID=%s, vNetName=%s", + config.SubscriptionID, config.ResourceGroup, config.Location, config.ServiceGatewayResourceName, config.ServiceGatewayID, config.VNetName) + + // The caller is expected to pass fully initialized state structs. A nil + // field is unexpected and indicates a programming error, so error out. + if k8s.Services == nil { + return nil, fmt.Errorf("difftracker.New: k8s.Services must not be nil") + } + if k8s.Egresses == nil { + return nil, fmt.Errorf("difftracker.New: k8s.Egresses must not be nil") + } + if k8s.Nodes == nil { + return nil, fmt.Errorf("difftracker.New: k8s.Nodes must not be nil") + } + if nrp.LoadBalancers == nil { + return nil, fmt.Errorf("difftracker.New: nrp.LoadBalancers must not be nil") + } + if nrp.NATGateways == nil { + return nil, fmt.Errorf("difftracker.New: nrp.NATGateways must not be nil") + } + if nrp.Locations == nil { + return nil, fmt.Errorf("difftracker.New: nrp.Locations must not be nil") + } + + diffTracker := &DiffTracker{ + K8sResources: k8s, + NRPResources: nrp, + + // Configuration and clients + config: config, + networkClientFactory: networkClientFactory, + kubeClient: kubeClient, + } + + // Seed the outbound ref-counter from egress pods already in the initial state + // so a later REMOVE can drive the counter to zero. + for _, node := range k8s.Nodes { + for _, pod := range node.Pods { + diffTracker.incrementOutboundRefCount(pod.PublicOutboundIdentity) + } + } + + return diffTracker, nil +} diff --git a/pkg/provider/difftracker/difftracker_test.go b/pkg/provider/difftracker/difftracker_test.go new file mode 100644 index 0000000000..274c739b7d --- /dev/null +++ b/pkg/provider/difftracker/difftracker_test.go @@ -0,0 +1,1190 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "k8s.io/client-go/kubernetes/fake" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func TestDiffTracker_DeepEqual(t *testing.T) { + tests := []struct { + name string + dt *DiffTracker + expected bool + }{ + { + name: "equal empty states", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "equal states with services", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1", "service2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1", "service2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "services not equal", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1", "service2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.dt.deepEqualLocked() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestEnqueueK8sServiceOperation(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + }, + } + + // Test Add operation + err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{ + Operation: Add, + ID: "service1", + }) + assert.NoError(t, err) + assert.True(t, dt.K8sResources.Services.Has("service1")) + + // Test Remove operation + err = dt.EnqueueK8sServiceOperation(UpdateK8sResource{ + Operation: Remove, + ID: "service1", + }) + assert.NoError(t, err) + assert.False(t, dt.K8sResources.Services.Has("service1")) + + // Test invalid operation + err = dt.EnqueueK8sServiceOperation(UpdateK8sResource{ + Operation: Update, + ID: "service1", + }) + assert.Error(t, err) +} + +func TestGetSyncLoadBalancerServices(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1", "service2", "service3"), + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service2", "service3", "service4"), + }, + } + + result := dt.GetSyncLoadBalancerServices() + + assert.True(t, result.Additions.Has("service1")) + assert.Equal(t, 1, result.Additions.Len()) + + assert.True(t, result.Removals.Has("service4")) + assert.Equal(t, 1, result.Removals.Len()) +} + +func TestUpdateK8sEndpoints(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{}, + }, + } + + // Test adding new endpoint + input := UpdateK8sEndpointsInputType{ + InboundIdentity: "service1", + OldAddresses: map[string]string{}, + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + } + + errs := dt.UpdateK8sEndpoints(input) + assert.Empty(t, errs) + assert.Contains(t, dt.K8sResources.Nodes, "node1") + assert.Contains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("service1")) + + // Test removing an endpoint + input = UpdateK8sEndpointsInputType{ + InboundIdentity: "service1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + NewAddresses: map[string]string{}, + } + + errs = dt.UpdateK8sEndpoints(input) + assert.Empty(t, errs) + assert.NotContains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") +} + +func TestUpdateK8sPod(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{}, + }, + } + + // Test adding new egress assignment + input := UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + } + + err := dt.UpdateK8sPod(input) + assert.NoError(t, err) + assert.Contains(t, dt.K8sResources.Nodes, "node1") + assert.Contains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") + assert.Equal(t, "public1", dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].PublicOutboundIdentity) + + // Test removing egress assignment + input = UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + } + + err = dt.UpdateK8sPod(input) + assert.NoError(t, err) + assert.NotContains(t, dt.K8sResources.Nodes["node1"].Pods, "10.0.0.1") +} + +// TestUpdateK8sPodIdentityChangeDecrementsOld verifies that when a pod's egress +// identity changes (X -> Y), the old identity's ref-counter is decremented so it +// doesn't leak, while the new identity is counted. +func TestUpdateK8sPodIdentityChangeDecrementsOld(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + // Pod starts using egress "X". + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "X", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("x") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + // Pod's egress changes to "Y": X must be released, Y counted. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Update, + PublicOutboundIdentity: "Y", + Location: "node1", + Address: "10.0.0.1", + })) + _, ok = dt.outboundIdentityPodRefCount.Load("x") + assert.False(t, ok, "old identity X must be decremented on identity change, not leaked") + val, ok = dt.outboundIdentityPodRefCount.Load("y") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) +} + +// TestUpdateK8sPodCaseInsensitiveReAdd verifies that re-adding the same pod with a +// different-cased identity is treated as the same identity (no double counting), +// since the counter is keyed case-insensitively. +func TestUpdateK8sPodCaseInsensitiveReAdd(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "Svc1", + Location: "node1", + Address: "10.0.0.1", + })) + // Informer resync delivers the same identity with different casing. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "svc1", + Location: "node1", + Address: "10.0.0.1", + })) + + val, ok := dt.outboundIdentityPodRefCount.Load("svc1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int), "case-only re-ADD must not double-count the identity") +} + +// TestUpdateK8sPodRemoveDecrementsCounter verifies that removing a pod decrements +// the outbound-identity ref-counter for the identity passed in the remove input +// (matching the engine's DeletePod, which always passes the service UID), clearing +// the entry when the last pod for that identity is removed. +func TestUpdateK8sPodRemoveDecrementsCounter(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Nodes: map[string]Node{}}, + } + + // Add a pod with outbound identity "public1". + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + // Remove passing the same identity; the counter must be cleared. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + })) + + _, ok = dt.outboundIdentityPodRefCount.Load("public1") + assert.False(t, ok, "counter for public1 should be removed after the last pod is removed") + _, ok = dt.outboundIdentityPodRefCount.Load("") + assert.False(t, ok, "no counter should be created for an empty identity") +} + +func TestGetSyncLocationsAddresses(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": { + Pods: map[string]Pod{ + "10.0.0.1": { + InboundIdentities: sets.NewString("service1"), + PublicOutboundIdentity: "public1", + }, + }, + }, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("public1"), + Locations: map[string]NRPLocation{}, + }, + } + + result := dt.GetSyncLocationsAddresses() + + assert.Equal(t, PartialUpdate, result.Action) + assert.Len(t, result.Locations, 1) + + location := result.Locations["node1"] + assert.NotNil(t, location) + assert.Equal(t, FullUpdate, location.AddressUpdateAction) + assert.Len(t, location.Addresses, 1) + + var address string + for addr := range location.Addresses { + address = addr + break + } + + assert.Equal(t, "10.0.0.1", address) + assert.True(t, location.Addresses[address].ServiceRef.Has("service1")) + assert.True(t, location.Addresses[address].ServiceRef.Has("public1")) +} + +func TestOperation_String(t *testing.T) { + assert.Equal(t, "Add", Add.String()) + assert.Equal(t, "Remove", Remove.String()) + assert.Equal(t, "Update", Update.String()) +} + +func TestUpdateNRPLoadBalancers(t *testing.T) { + tests := []struct { + name string + initialState *DiffTracker + expectedNRP *sets.IgnoreCaseSet + }{ + { + name: "add services from K8s to NRP", + initialState: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1", "service2", "service3"), + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1"), + }, + }, + expectedNRP: sets.NewString("service1", "service2", "service3"), + }, + { + name: "no changes needed when K8s and NRP are in sync", + initialState: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1", "service2"), + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1", "service2"), + }, + }, + expectedNRP: sets.NewString("service1", "service2"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + syncServices := tt.initialState.GetSyncLoadBalancerServices() + tt.initialState.UpdateNRPLoadBalancers(syncServices) + + assert.True(t, tt.expectedNRP.Equals(tt.initialState.NRPResources.LoadBalancers), + "Expected NRP LoadBalancers %v, but got %v", + tt.expectedNRP.UnsortedList(), + tt.initialState.NRPResources.LoadBalancers.UnsortedList()) + }) + } +} + +func TestEnqueueK8sEgressOperation(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Egresses: sets.NewString(), + }, + } + + err := dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Add, ID: "egress1"}) + assert.NoError(t, err) + assert.True(t, dt.K8sResources.Egresses.Has("egress1")) + + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Remove, ID: "egress1"}) + assert.NoError(t, err) + assert.False(t, dt.K8sResources.Egresses.Has("egress1")) + + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Update, ID: "egress1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "error - ResourceType=Egress, Operation=Update and ID=egress1") +} + +func TestGetSyncNRPNATGateways(t *testing.T) { + tests := []struct { + name string + k8sEgresses []string + nrpNATGateways []string + expectedAdditions []string + expectedRemovals []string + }{ + { + name: "empty states", + k8sEgresses: []string{}, + nrpNATGateways: []string{}, + expectedAdditions: []string{}, + expectedRemovals: []string{}, + }, + { + name: "mixed state with additions and removals", + k8sEgresses: []string{"egress1", "egress3", "egress5"}, + nrpNATGateways: []string{"egress1", "egress2", "egress4"}, + expectedAdditions: []string{"egress3", "egress5"}, + expectedRemovals: []string{"egress2", "egress4"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Egresses: sets.NewString(tt.k8sEgresses...), + }, + NRPResources: NRPState{ + NATGateways: sets.NewString(tt.nrpNATGateways...), + }, + } + + result := dt.GetSyncNRPNATGateways() + + assert.Equal(t, len(tt.expectedAdditions), result.Additions.Len()) + for _, addition := range tt.expectedAdditions { + assert.True(t, result.Additions.Has(addition)) + } + + assert.Equal(t, len(tt.expectedRemovals), result.Removals.Len()) + for _, removal := range tt.expectedRemovals { + assert.True(t, result.Removals.Has(removal)) + } + }) + } +} + +func TestUpdateNRPNATGateways(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Egresses: sets.NewString("egress1", "egress2", "egress4"), + }, + NRPResources: NRPState{ + NATGateways: sets.NewString("egress1", "egress3", "egress5"), + }, + } + + syncServices := dt.GetSyncNRPNATGateways() + dt.UpdateNRPNATGateways(syncServices) + + expectedNRP := sets.NewString("egress1", "egress2", "egress4") + assert.True(t, expectedNRP.Equals(dt.NRPResources.NATGateways), + "Expected NRP NATGateways %v, but got %v", + expectedNRP.UnsortedList(), + dt.NRPResources.NATGateways.UnsortedList()) +} + +func TestUpdateLocationsAddresses(t *testing.T) { + tests := []struct { + name string + initialState *DiffTracker + expectedNRP map[string]map[string][]string + }{ + { + name: "sync empty states", + initialState: &DiffTracker{ + K8sResources: K8sState{Nodes: map[string]Node{}}, + NRPResources: NRPState{Locations: map[string]NRPLocation{}}, + }, + expectedNRP: map[string]map[string][]string{}, + }, + { + name: "add new location and address", + initialState: &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": { + Pods: map[string]Pod{ + "10.0.0.1": { + InboundIdentities: sets.NewString("service1"), + PublicOutboundIdentity: "public1", + }, + }, + }, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("public1"), + Locations: map[string]NRPLocation{}, + }, + }, + expectedNRP: map[string]map[string][]string{ + "node1": {"10.0.0.1": {"service1", "public1"}}, + }, + }, + { + name: "complex case with multiple operations", + initialState: &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": { + Pods: map[string]Pod{ + "10.0.0.1": { + InboundIdentities: sets.NewString("service1", "service3"), + PublicOutboundIdentity: "public1", + }, + }, + }, + "node3": { + Pods: map[string]Pod{ + "10.0.0.5": {InboundIdentities: sets.NewString("service5")}, + }, + }, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1", "service2", "service3", "service4", "service5"), + NATGateways: sets.NewString("public1"), + Locations: map[string]NRPLocation{ + "node1": { + Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1", "service2")}, + "10.0.0.2": {Services: sets.NewString("service4")}, + }, + }, + "node2": { + Addresses: map[string]NRPAddress{ + "10.0.0.3": {Services: sets.NewString("service3")}, + }, + }, + }, + }, + }, + expectedNRP: map[string]map[string][]string{ + "node1": {"10.0.0.1": {"service1", "service3", "public1"}}, + "node3": {"10.0.0.5": {"service5"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + locationData := tt.initialState.GetSyncLocationsAddresses() + tt.initialState.UpdateLocationsAddresses(locationData) + + assert.Equal(t, len(tt.expectedNRP), len(tt.initialState.NRPResources.Locations)) + + for locName, expectedAddressMap := range tt.expectedNRP { + nrpLoc, exists := tt.initialState.NRPResources.Locations[locName] + assert.True(t, exists, "Expected location %s not found", locName) + assert.Equal(t, len(expectedAddressMap), len(nrpLoc.Addresses)) + + for addr, expectedServices := range expectedAddressMap { + nrpAddr, exists := nrpLoc.Addresses[addr] + assert.True(t, exists, "Expected address %s not found in %s", addr, locName) + assert.Equal(t, len(expectedServices), nrpAddr.Services.Len()) + for _, svc := range expectedServices { + assert.True(t, nrpAddr.Services.Has(svc)) + } + } + } + }) + } +} + +func TestGetSyncOperations(t *testing.T) { + tests := []struct { + name string + initialState *DiffTracker + expectedSyncStatus SyncStatus + }{ + { + name: "states already in sync", + initialState: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("service1")}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("egress1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1")}, + }}, + }, + }, + }, + expectedSyncStatus: AlreadyInSync, + }, + { + name: "services out of sync", + initialState: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("service1", "service2"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("service1")}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1"), + NATGateways: sets.NewString("egress1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1")}, + }}, + }, + }, + }, + expectedSyncStatus: Success, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.initialState.GetSyncOperations() + assert.Equal(t, tt.expectedSyncStatus, result.SyncStatus) + }) + } +} + +// Real Scenario: CloudProvider is down and K8s Cluster is subject to continuous updates. +// This test verifies if the DiffTracker is able to sync K8s Cluster and NRP correctly +// when there is a huge discrepancy between K8s Cluster and NRP. +func TestNew(t *testing.T) { + K8sResources := K8sState{ + Services: sets.NewString("Service0", "Service1", "Service2"), + Egresses: sets.NewString("Egress0", "Egress1", "Egress2"), + Nodes: map[string]Node{ + "Node1": { + Pods: map[string]Pod{ + "Pod34": {InboundIdentities: sets.NewString("Service0"), PublicOutboundIdentity: ""}, + "Pod0": {InboundIdentities: sets.NewString("Service0"), PublicOutboundIdentity: "Egress0"}, + "Pod1": {InboundIdentities: sets.NewString("Service1", "Service2"), PublicOutboundIdentity: "Egress1"}, + "Pod3": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "Egress2"}, + }, + }, + "Node2": { + Pods: map[string]Pod{ + "Pod2": {InboundIdentities: sets.NewString("Service1"), PublicOutboundIdentity: "Egress2"}, + }, + }, + }, + } + + NRPResources := NRPState{ + LoadBalancers: sets.NewString("Service0", "Service6", "Service5"), + NATGateways: sets.NewString("Egress0", "Egress6", "Egress5"), + Locations: map[string]NRPLocation{ + "Node1": { + Addresses: map[string]NRPAddress{ + "Pod34": {Services: sets.NewString("Service0", "Service5")}, + "Pod00": {Services: sets.NewString("Service6", "Egress5")}, + "Pod0": {Services: sets.NewString("Service0", "Egress0")}, + }, + }, + "Node3": { + Addresses: map[string]NRPAddress{ + "Pod4": {Services: sets.NewString("Service6", "Eggres6")}, + "Pod5": {Services: sets.NewString("Egress5")}, + }, + }, + }, + } + + config := Config{ + SubscriptionID: "test-subscription", + ResourceGroup: "test-rg", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + diffTracker, err := New(K8sResources, NRPResources, config, mockFactory, mockKubeClient) + assert.NoError(t, err) + syncOperations := diffTracker.GetSyncOperations() + + diffTracker.UpdateNRPLoadBalancers(syncOperations.LoadBalancerUpdates) + diffTracker.UpdateNRPNATGateways(syncOperations.NATGatewayUpdates) + diffTracker.UpdateLocationsAddresses(syncOperations.LocationData) + + assert.Equal(t, Success, syncOperations.SyncStatus) + + expectedDiffTracker := &DiffTracker{ + K8sResources: K8sResources, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("Service0", "Service1", "Service2"), + NATGateways: sets.NewString("Egress0", "Egress1", "Egress2"), + Locations: map[string]NRPLocation{ + "Node1": { + Addresses: map[string]NRPAddress{ + "Pod34": {Services: sets.NewString("Service0")}, + "Pod0": {Services: sets.NewString("Service0", "Egress0")}, + }, + }, + }, + }, + } + + assert.True(t, diffTracker.Equals(expectedDiffTracker), + "DiffTracker does not match expected state") +} + +func validTestConfig() Config { + return Config{ + SubscriptionID: "test-subscription", + ResourceGroup: "test-rg", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } +} + +func emptyK8sState() K8sState { + return K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: make(map[string]Node), + } +} + +func emptyNRPState() NRPState { + return NRPState{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString(), + Locations: make(map[string]NRPLocation), + } +} + +// TestNewErrorPaths covers the validation/error branches of +// New and the successful initialization path. +func TestNewErrorPaths(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + + // Invalid config (empty) -> validation error. + _, err := New(K8sState{}, NRPState{}, Config{}, mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "New") + + // Nil networkClientFactory. + _, err = New(K8sState{}, NRPState{}, validTestConfig(), nil, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "networkClientFactory must not be nil") + + // Nil kubeClient. + _, err = New(K8sState{}, NRPState{}, validTestConfig(), mockFactory, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "kubeClient must not be nil") + + // Uninitialized state field -> error out instead of silently initializing. + k8sMissingServices := emptyK8sState() + k8sMissingServices.Services = nil + _, err = New(k8sMissingServices, emptyNRPState(), validTestConfig(), mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "k8s.Services must not be nil") + + nrpMissingLocations := emptyNRPState() + nrpMissingLocations.Locations = nil + _, err = New(emptyK8sState(), nrpMissingLocations, validTestConfig(), mockFactory, mockKubeClient) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nrp.Locations must not be nil") + + // Valid call with fully initialized (empty) states. + dt, err := New(emptyK8sState(), emptyNRPState(), validTestConfig(), mockFactory, mockKubeClient) + assert.NoError(t, err) + assert.NotNil(t, dt) + assert.NotNil(t, dt.K8sResources.Services) + assert.NotNil(t, dt.K8sResources.Egresses) + assert.NotNil(t, dt.K8sResources.Nodes) + assert.NotNil(t, dt.NRPResources.LoadBalancers) + assert.NotNil(t, dt.NRPResources.NATGateways) + assert.NotNil(t, dt.NRPResources.Locations) +} + +// TestNewSeedsOutboundRefCount verifies New seeds the outbound ref-counter from +// the egress pods already present in the initial state, so the counter is +// non-zero for identities that have backing pods at construction time. +func TestNewSeedsOutboundRefCount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := mock_azclient.NewMockClientFactory(ctrl) + mockKubeClient := fake.NewSimpleClientset() + + k8s := emptyK8sState() + k8s.Nodes["node1"] = Node{Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + "10.0.0.2": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + "10.0.0.3": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: ""}, + }} + k8s.Nodes["node2"] = Node{Pods: map[string]Pod{ + "10.0.1.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "Egress2"}, + }} + + dt, err := New(k8s, emptyNRPState(), validTestConfig(), mockFactory, mockKubeClient) + assert.NoError(t, err) + + val, ok := dt.outboundIdentityPodRefCount.Load("egress1") + assert.True(t, ok) + assert.Equal(t, 2, val.(int)) + + val, ok = dt.outboundIdentityPodRefCount.Load("egress2") + assert.True(t, ok, "identity key is lowercased") + assert.Equal(t, 1, val.(int)) + + _, ok = dt.outboundIdentityPodRefCount.Load("") + assert.False(t, ok, "pods without an egress identity are not counted") +} + +// TestEnqueueK8sResourceOperationErrors covers the empty-ID and invalid-operation +// branches of enqueueK8sResourceOperation (via the public wrappers). +func TestEnqueueK8sResourceOperationErrors(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Services: sets.NewString(), Egresses: sets.NewString()}, + } + + // Empty ID. + err := dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: Add, ID: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty ID") + + // Invalid operation (UPDATE is not valid for the resource set). + err = dt.EnqueueK8sEgressOperation(UpdateK8sResource{Operation: Update, ID: "egress1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Operation=Update") + + // Successful Add then Remove. + assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: Add, ID: "svc1"})) + assert.True(t, dt.K8sResources.Services.Has("svc1")) + assert.NoError(t, dt.EnqueueK8sServiceOperation(UpdateK8sResource{Operation: Remove, ID: "svc1"})) + assert.False(t, dt.K8sResources.Services.Has("svc1")) +} + +// TestUpdateK8sPodInvalidOperation covers the default branch of updateK8sPodLocked. +func TestUpdateK8sPodInvalidOperation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + err := dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Operation(99), + Location: "node1", + Address: "10.0.0.1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid pod operation") +} + +// TestUpdateK8sPodRemoveNonExistent covers the duplicate-removal branch. +func TestUpdateK8sPodRemoveNonExistent(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + // Removing a pod that was never added is a no-op (no error, no counter change). + err := dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + }) + assert.NoError(t, err) + _, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.False(t, ok) +} + +// TestUpdateK8sEndpointsMissingLocation covers the error branch where an address +// has no associated node location. +func TestUpdateK8sEndpointsMissingLocation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": ""}, + }) + assert.NotEmpty(t, errs) + assert.Contains(t, errs[0].Error(), "does not have a node associated") +} + +// TestRemoveServiceFromK8sStateLocked covers removeServiceFromK8sStateLocked for +// both inbound and outbound identities, including empty-pod/node cleanup. +func TestRemoveServiceFromK8sStateLocked(t *testing.T) { + t.Run("inbound removal cleans up empty pods and nodes", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + // Pod backs only svc1 inbound -> becomes empty and is removed. + "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, + // Pod backs svc1 and svc2 -> keeps svc2. + "10.0.0.2": {InboundIdentities: sets.NewString("svc1", "svc2")}, + }}, + }, + }, + } + dt.removeServiceFromK8sStateLocked("svc1", true) + + // 10.0.0.1 had only svc1 -> removed. + _, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.False(t, ok) + // 10.0.0.2 still has svc2. + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.2"] + assert.True(t, ok) + assert.True(t, pod.InboundIdentities.Has("svc2")) + assert.False(t, pod.InboundIdentities.Has("svc1")) + }) + + t.Run("outbound removal clears identity and removes empty node", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString(), PublicOutboundIdentity: "egress1"}, + }}, + }, + }, + } + dt.removeServiceFromK8sStateLocked("egress1", false) + + // Pod had only the outbound identity -> pod and node removed. + _, nodeOK := dt.K8sResources.Nodes["node1"] + assert.False(t, nodeOK) + }) + + t.Run("public wrapper acquires lock and removes inbound identity", func(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{ + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1")}, + }}, + }, + }, + } + dt.RemoveServiceFromK8sState("svc1", true) + + _, nodeOK := dt.K8sResources.Nodes["node1"] + assert.False(t, nodeOK) + }) + + t.Run("outbound service removal decrements ref-counter", func(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + // Add an egress pod so the counter is seeded. + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "egress1", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("egress1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + // Removing the egress service must release the counter, not leak it. + dt.RemoveServiceFromK8sState("egress1", false) + _, ok = dt.outboundIdentityPodRefCount.Load("egress1") + assert.False(t, ok, "service removal must decrement the outbound ref-counter") + }) +} + +// TestUpdateK8sPodAddNoOutboundIdentity verifies that adding a pod with no +// outbound access (empty PublicOutboundIdentity) tracks the pod in state but +// does not create a bogus ref-counter entry under the empty-string key. This +// keeps the Add path symmetric with the Remove path, which skips the decrement +// for an empty identity. +func TestUpdateK8sPodAddNoOutboundIdentity(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + in := UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "", + Location: "node1", + Address: "10.0.0.1", + } + assert.NoError(t, dt.UpdateK8sPod(in)) + + // The pod must be tracked in state. + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.True(t, ok) + assert.Equal(t, "", pod.PublicOutboundIdentity) + + // No counter entry should be created for the empty identity. + _, ok = dt.outboundIdentityPodRefCount.Load("") + assert.False(t, ok, "no counter should be created for an empty outbound identity") +} + +// TestUpdateK8sPodAddIdempotent covers the alreadyExists branch of updateK8sPodLocked: +// a repeated Add for the same pod+identity must not double-count the counter. +func TestUpdateK8sPodAddIdempotent(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + in := UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "public1", + Location: "node1", + Address: "10.0.0.1", + } + assert.NoError(t, dt.UpdateK8sPod(in)) + // Second identical Add (e.g. informer resync) must be a no-op for the counter. + assert.NoError(t, dt.UpdateK8sPod(in)) + + val, ok := dt.outboundIdentityPodRefCount.Load("public1") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) +} + +// TestUpdateK8sEndpointsAddThenRemove covers the OldAddresses removal path of +// updateK8sEndpointsLocked, including empty-pod/node cleanup. +func TestUpdateK8sEndpointsAddThenRemove(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + // Add endpoint for svc1 at pod 10.0.0.1 on node1. + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("svc1")) + + // Remove the same endpoint (now in OldAddresses, absent from NewAddresses). + errs = dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + // Pod had only svc1 -> pod and node cleaned up. + _, ok := dt.K8sResources.Nodes["node1"] + assert.False(t, ok) +} + +// TestUpdateK8sEndpointsRelocation covers the case where the same pod IP appears +// in both OldAddresses and NewAddresses but on a different node (relocation): the +// pod must be removed from the old node and added to the new one. +func TestUpdateK8sEndpointsRelocation(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.True(t, dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"].InboundIdentities.Has("svc1")) + + // Same pod IP moves from node1 to node2. + errs = dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "svc1", + OldAddresses: map[string]string{"10.0.0.1": "node1"}, + NewAddresses: map[string]string{"10.0.0.1": "node2"}, + }) + assert.Empty(t, errs) + + // Old node is gone, new node holds the pod with svc1. + _, ok := dt.K8sResources.Nodes["node1"] + assert.False(t, ok, "pod must be removed from the old node") + pod, ok := dt.K8sResources.Nodes["node2"].Pods["10.0.0.1"] + assert.True(t, ok, "pod must be added to the new node") + assert.True(t, pod.InboundIdentities.Has("svc1")) +} + +func TestUpdateK8sPodRejectsEmptyLocationOrAddress(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + err := dt.UpdateK8sPod(UpdatePodInputType{PodOperation: Add, Location: "", Address: "10.0.0.1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must not be empty") + + err = dt.UpdateK8sPod(UpdatePodInputType{PodOperation: Add, Location: "node1", Address: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must not be empty") +} + +// TestUpdateK8sPodRemovePreservesInboundIdentities verifies that removing a pod's +// egress assignment clears only the outbound identity and keeps the pod while it +// still backs an inbound (LoadBalancer) service, whereas an egress-only pod is +// removed entirely and its ref-counter released. +func TestUpdateK8sPodRemovePreservesInboundIdentities(t *testing.T) { + dt := &DiffTracker{K8sResources: K8sState{Nodes: map[string]Node{}}} + + errs := dt.UpdateK8sEndpoints(UpdateK8sEndpointsInputType{ + InboundIdentity: "lb1", + NewAddresses: map[string]string{"10.0.0.1": "node1"}, + }) + assert.Empty(t, errs) + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "egressA", + Location: "node1", + Address: "10.0.0.1", + })) + val, ok := dt.outboundIdentityPodRefCount.Load("egressa") + assert.True(t, ok) + assert.Equal(t, 1, val.(int)) + + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "egressA", + Location: "node1", + Address: "10.0.0.1", + })) + pod, ok := dt.K8sResources.Nodes["node1"].Pods["10.0.0.1"] + assert.True(t, ok, "pod backing an inbound service must be kept") + assert.True(t, pod.InboundIdentities.Has("lb1")) + assert.Equal(t, "", pod.PublicOutboundIdentity) + _, ok = dt.outboundIdentityPodRefCount.Load("egressa") + assert.False(t, ok, "egress ref-counter must be released") + + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Add, + PublicOutboundIdentity: "egressB", + Location: "node2", + Address: "10.0.0.2", + })) + assert.NoError(t, dt.UpdateK8sPod(UpdatePodInputType{ + PodOperation: Remove, + PublicOutboundIdentity: "egressB", + Location: "node2", + Address: "10.0.0.2", + })) + _, nodeOK := dt.K8sResources.Nodes["node2"] + assert.False(t, nodeOK, "egress-only pod and its empty node must be removed") + _, ok = dt.outboundIdentityPodRefCount.Load("egressb") + assert.False(t, ok) +} + +// TestGetSyncLocationsAddressesRemovesGoneNode verifies that when a node is gone +// from K8s but still present in NRP, the location is emitted with an empty +// Addresses map (a PartialUpdate that deletes the whole location on the Service +// Gateway), and applying the result drops the location locally. +func TestGetSyncLocationsAddressesRemovesGoneNode(t *testing.T) { + dt := &DiffTracker{ + K8sResources: K8sState{Nodes: map[string]Node{}}, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("service1", "service2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{ + "node1": { + Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("service1")}, + "10.0.0.2": {Services: sets.NewString("service2")}, + }, + }, + }, + }, + } + + result := dt.GetSyncLocationsAddresses() + + loc, ok := result.Locations["node1"] + assert.True(t, ok) + assert.Equal(t, PartialUpdate, loc.AddressUpdateAction) + assert.Empty(t, loc.Addresses) + + dt.UpdateLocationsAddresses(result) + _, ok = dt.NRPResources.Locations["node1"] + assert.False(t, ok, "gone node's location must be removed locally") +} diff --git a/pkg/provider/difftracker/dto_mappers.go b/pkg/provider/difftracker/dto_mappers.go new file mode 100644 index 0000000000..dc85795a11 --- /dev/null +++ b/pkg/provider/difftracker/dto_mappers.go @@ -0,0 +1,79 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "fmt" + + "sigs.k8s.io/cloud-provider-azure/pkg/consts" +) + +func MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(loadBalancerUpdates SyncServicesReturnType, natGatewayUpdates SyncServicesReturnType, subscriptionID string, resourceGroup string) ServicesDataDTO { + var ServicesDataDTO ServicesDataDTO + ServicesDataDTO.Action = PartialUpdate + ServicesDataDTO.Services = []ServiceDTO{} + for _, service := range loadBalancerUpdates.Additions.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + ServiceType: Inbound, + LoadBalancerBackendPools: []LoadBalancerBackendPoolDTO{ + { + Id: fmt.Sprintf( + consts.BackendPoolIDTemplate, + subscriptionID, + resourceGroup, + service, + service, + ), + }, + }, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + for _, service := range loadBalancerUpdates.Removals.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + IsDelete: true, + ServiceType: Inbound, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + for _, service := range natGatewayUpdates.Additions.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + ServiceType: Outbound, + PublicNatGateway: NatGatewayDTO{ + Id: fmt.Sprintf( + consts.NatGatewayIDTemplate, + subscriptionID, + resourceGroup, + service, + ), + }, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + for _, service := range natGatewayUpdates.Removals.UnsortedList() { + serviceDTO := ServiceDTO{ + Service: service, + IsDelete: true, + ServiceType: Outbound, + } + ServicesDataDTO.Services = append(ServicesDataDTO.Services, serviceDTO) + } + return ServicesDataDTO +} diff --git a/pkg/provider/difftracker/dto_mappers_test.go b/pkg/provider/difftracker/dto_mappers_test.go new file mode 100644 index 0000000000..e6dce7807d --- /dev/null +++ b/pkg/provider/difftracker/dto_mappers_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func makeInboundConfig(frontendPorts ...int32) *InboundConfig { + cfg := &InboundConfig{} + for _, p := range frontendPorts { + cfg.FrontendPorts = append(cfg.FrontendPorts, PortMapping{Port: p, Protocol: "TCP"}) + cfg.BackendPorts = append(cfg.BackendPorts, PortMapping{Port: p, Protocol: "TCP"}) + } + return cfg +} + +func TestInboundConfig_Equals(t *testing.T) { + a := makeInboundConfig(80, 443) + b := makeInboundConfig(80, 443) + assert.True(t, a.Equals(b), "identical configs should be equal") + + c := makeInboundConfig(80, 8080) + assert.False(t, a.Equals(c), "different ports should be unequal") + + d := makeInboundConfig(443, 80) + assert.False(t, a.Equals(d), "ordered comparison: reversed ports must be unequal") + + assert.True(t, (*InboundConfig)(nil).Equals(nil), "nil-nil equal") + assert.False(t, a.Equals(nil), "non-nil vs nil unequal") +} + +func TestMapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(t *testing.T) { + tests := []struct { + name string + lbUpdates SyncServicesReturnType + natUpdates SyncServicesReturnType + expectedLen int + }{ + { + name: "only inbound additions", + lbUpdates: SyncServicesReturnType{ + Additions: sets.NewString("svc1", "svc2"), + }, + natUpdates: SyncServicesReturnType{}, + expectedLen: 2, + }, + { + name: "only outbound additions", + lbUpdates: SyncServicesReturnType{}, + natUpdates: SyncServicesReturnType{ + Additions: sets.NewString("egress1"), + }, + expectedLen: 1, + }, + { + name: "mixed additions", + lbUpdates: SyncServicesReturnType{ + Additions: sets.NewString("svc1"), + }, + natUpdates: SyncServicesReturnType{ + Additions: sets.NewString("egress1"), + }, + expectedLen: 2, + }, + { + name: "removals only", + lbUpdates: SyncServicesReturnType{ + Removals: sets.NewString("svc1"), + }, + natUpdates: SyncServicesReturnType{ + Removals: sets.NewString("egress1"), + }, + expectedLen: 2, + }, + { + name: "empty updates", + lbUpdates: SyncServicesReturnType{}, + natUpdates: SyncServicesReturnType{}, + expectedLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO(tt.lbUpdates, tt.natUpdates, "sub1", "rg1") + assert.Equal(t, tt.expectedLen, len(result.Services)) + }) + } +} diff --git a/pkg/provider/difftracker/k8s_state_updates.go b/pkg/provider/difftracker/k8s_state_updates.go new file mode 100644 index 0000000000..46fa116bd3 --- /dev/null +++ b/pkg/provider/difftracker/k8s_state_updates.go @@ -0,0 +1,341 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "fmt" + "strings" + + "k8s.io/klog/v2" + + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +const ( + ResourceTypeService = "Service" + ResourceTypeEgress = "Egress" +) + +// enqueueK8sResourceOperation applies the requested operation (Add/Remove) to the +// in-memory K8s resource set. It does not perform any Azure update calls; it only +// mutates the local desired-state model that will later be reconciled with NRP. +func (dt *DiffTracker) enqueueK8sResourceOperation(input UpdateK8sResource, set *utilsets.IgnoreCaseSet, resourceType string) error { + if input.ID == "" { + return fmt.Errorf("%s: empty ID not allowed", resourceType) + } + + switch input.Operation { + case Add: + set.Insert(input.ID) + klog.V(2).Infof("enqueueK8sResourceOperation: Added %s %s to K8s state", resourceType, input.ID) + case Remove: + set.Delete(input.ID) + klog.V(2).Infof("enqueueK8sResourceOperation: Removed %s %s from K8s state", resourceType, input.ID) + default: + return fmt.Errorf("error - ResourceType=%s, Operation=%s and ID=%s", resourceType, input.Operation, input.ID) + } + return nil +} + +// EnqueueK8sServiceOperation records a service Add/Remove in the local K8s state set. +// The change is reconciled with NRP later by the sync operations; this method itself +// performs no Azure calls. +func (dt *DiffTracker) EnqueueK8sServiceOperation(input UpdateK8sResource) error { + dt.mu.Lock() + defer dt.mu.Unlock() + + return dt.enqueueK8sResourceOperation(input, dt.K8sResources.Services, ResourceTypeService) +} + +// EnqueueK8sEgressOperation records an egress Add/Remove in the local K8s state set. +// The change is reconciled with NRP later by the sync operations; this method itself +// performs no Azure calls. +func (dt *DiffTracker) EnqueueK8sEgressOperation(input UpdateK8sResource) error { + dt.mu.Lock() + defer dt.mu.Unlock() + + return dt.enqueueK8sResourceOperation(input, dt.K8sResources.Egresses, ResourceTypeEgress) +} + +// updateK8sEndpointsLocked updates K8s endpoints state. Assumes lock is already held. +// Terminology: an "endpoint" here is a Kubernetes EndpointSlice endpoint, i.e. a pod IP +// (the map key "address"), and its "location" is the IP of the node/VM hosting that pod +// (the map value). Both NewAddresses and OldAddresses are address(podIP) -> location(nodeIP). +func (dt *DiffTracker) updateK8sEndpointsLocked(input UpdateK8sEndpointsInputType) []error { + var errs []error + for address, location := range input.NewAddresses { + if location == "" { + errs = append(errs, fmt.Errorf("error UpdateK8sEndpoints, address=%s does not have a node associated", address)) + continue + } + + if oldLocation, exists := input.OldAddresses[address]; exists && oldLocation == location { + continue + } + + nodeState, exists := dt.K8sResources.Nodes[location] + if !exists { + nodeState = newNode() + dt.K8sResources.Nodes[location] = nodeState + } + + pod, exists := nodeState.Pods[address] + if !exists { + pod = newPod() + nodeState.Pods[address] = pod + } + pod.InboundIdentities.Insert(input.InboundIdentity) + klog.V(2).Infof("updateK8sEndpointsLocked: Added inbound identity %s to pod %s on node %s", input.InboundIdentity, address, location) + } + + for address, location := range input.OldAddresses { + if newLocation, exists := input.NewAddresses[address]; exists && newLocation == location { + continue + } + + if location == "" { + errs = append(errs, fmt.Errorf("error UpdateK8sEndpoints, address=%s does not have a node associated", address)) + continue + } + + node, nodeExists := dt.K8sResources.Nodes[location] + if !nodeExists { + continue + } + + pod, podExists := node.Pods[address] + if !podExists { + continue + } + + pod.InboundIdentities.Delete(input.InboundIdentity) + klog.V(2).Infof("updateK8sEndpointsLocked: Removed inbound identity %s from pod %s on node %s", input.InboundIdentity, address, location) + + if !pod.HasIdentities() { + delete(node.Pods, address) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, location) + } + } + } + + return errs +} + +// UpdateK8sEndpoints is the public, lock-acquiring entry point for applying an +// EndpointSlice change to K8s state. It is called by the EndpointSlice informer +// handlers (Add/Update/Delete). Use this rather than updateK8sEndpointsLocked +// unless the caller already holds dt.mu. +func (dt *DiffTracker) UpdateK8sEndpoints(input UpdateK8sEndpointsInputType) []error { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.updateK8sEndpointsLocked(input) +} + +func (dt *DiffTracker) addOrUpdatePod(input UpdatePodInputType) { + node, exists := dt.K8sResources.Nodes[input.Location] + if !exists { + node = newNode() + dt.K8sResources.Nodes[input.Location] = node + } + + pod, exists := node.Pods[input.Address] + if !exists { + pod = newPod() + } + + pod.PublicOutboundIdentity = input.PublicOutboundIdentity + node.Pods[input.Address] = pod + klog.V(2).Infof("addOrUpdatePod: Set outbound identity %q for pod %s on node %s", input.PublicOutboundIdentity, input.Address, input.Location) +} + +// removePod clears the pod's outbound identity when it matches the input and +// deletes the pod only once it has no identities left, so an egress removal never +// drops a pod's inbound (LoadBalancer) backing. It returns whether the pod existed +// and any error from decrementing the ref-counter. +func (dt *DiffTracker) removePod(input UpdatePodInputType) (existed bool, err error) { + node, nodeExists := dt.K8sResources.Nodes[input.Location] + if !nodeExists { + return false, nil + } + + pod, podExists := node.Pods[input.Address] + if !podExists { + return false, nil + } + + if pod.PublicOutboundIdentity != "" && strings.EqualFold(pod.PublicOutboundIdentity, input.PublicOutboundIdentity) { + pod.PublicOutboundIdentity = "" + node.Pods[input.Address] = pod + err = dt.decrementOutboundRefCount(input.PublicOutboundIdentity) + } + + if !pod.HasIdentities() { + delete(node.Pods, input.Address) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, input.Location) + } + klog.V(2).Infof("removePod: Removed pod %s from node %s", input.Address, input.Location) + } + + return true, err +} + +// incrementOutboundRefCount increments the ref-counter for a pod's outbound +// (egress) identity. Empty identities are not counted. +func (dt *DiffTracker) incrementOutboundRefCount(identity string) { + if identity == "" { + return + } + key := strings.ToLower(identity) + counter := 0 + if val, ok := dt.outboundIdentityPodRefCount.Load(key); ok { + counter = val.(int) + } + dt.outboundIdentityPodRefCount.Store(key, counter+1) +} + +// decrementOutboundRefCount decrements the ref-counter for an outbound (egress) +// identity, deleting the entry when it reaches zero. Empty or unknown identities +// are a no-op. Returns an error if the counter is already non-positive. +func (dt *DiffTracker) decrementOutboundRefCount(identity string) error { + if identity == "" { + return nil + } + key := strings.ToLower(identity) + val, ok := dt.outboundIdentityPodRefCount.Load(key) + if !ok { + return nil + } + counter := val.(int) + if counter <= 0 { + return fmt.Errorf("error - PublicOutboundIdentity %s has a non-positive count: %d", identity, counter) + } + if counter == 1 { + dt.outboundIdentityPodRefCount.Delete(key) + } else { + dt.outboundIdentityPodRefCount.Store(key, counter-1) + } + return nil +} + +// updateK8sPodLocked updates K8s pod state. Assumes lock is already held. +func (dt *DiffTracker) updateK8sPodLocked(input UpdatePodInputType) error { + if input.Location == "" || input.Address == "" { + return fmt.Errorf("updateK8sPodLocked: Location and Address must not be empty (location=%q, address=%q)", input.Location, input.Address) + } + switch input.PodOperation { + case Add, Update: + // Determine the pod's current (old) outbound identity, if any, and whether + // this exact pod+identity is already counted (idempotent re-ADD, e.g. an + // informer resync). The comparison is case-insensitive because the counter + // is keyed on the lowercased identity. + oldIdentity := "" + alreadyExists := false + if node, nodeExists := dt.K8sResources.Nodes[input.Location]; nodeExists { + if pod, podExists := node.Pods[input.Address]; podExists { + oldIdentity = pod.PublicOutboundIdentity + if strings.EqualFold(oldIdentity, input.PublicOutboundIdentity) { + alreadyExists = true + klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s already exists for service %s, skipping counter update", + input.Location, input.Address, input.PublicOutboundIdentity) + } + } + } + + if !alreadyExists { + // The pod's egress identity is changing (old -> new). Release the old + // identity's count before counting the new one, otherwise the old + // identity's counter would leak (a later remove decrements the new + // identity, never the old). + if !strings.EqualFold(oldIdentity, input.PublicOutboundIdentity) { + if err := dt.decrementOutboundRefCount(oldIdentity); err != nil { + return err + } + } + dt.incrementOutboundRefCount(input.PublicOutboundIdentity) + } + dt.addOrUpdatePod(input) + return nil + case Remove: + existed, err := dt.removePod(input) + if !existed { + klog.V(4).Infof("updateK8sPodLocked: Pod at %s:%s was already removed (duplicate delete), skipping counter decrement", + input.Location, input.Address) + return nil + } + return err + default: + return fmt.Errorf("invalid pod operation: %s for pod at %s:%s", + input.PodOperation, input.Location, input.Address) + } +} + +// UpdateK8sPod is the public, lock-acquiring entry point for applying a pod +// egress-assignment change to K8s state. It is called by the pod informer +// handlers (Add/Update/Delete). Use this rather than updateK8sPodLocked unless +// the caller already holds dt.mu. +func (dt *DiffTracker) UpdateK8sPod(input UpdatePodInputType) error { + dt.mu.Lock() + defer dt.mu.Unlock() + return dt.updateK8sPodLocked(input) +} + +// removeServiceFromK8sStateLocked removes a service from all pod identities in K8s state. +// This is used during service deletion to proactively clear location/address references +// so the LocationsUpdater can sync the removal to NRP. +// Assumes lock is already held. +func (dt *DiffTracker) removeServiceFromK8sStateLocked(serviceUID string, isInbound bool) { + for nodeIP, node := range dt.K8sResources.Nodes { + for podIP, pod := range node.Pods { + if isInbound { + // Remove from inbound identities + if pod.InboundIdentities != nil && pod.InboundIdentities.Has(serviceUID) { + pod.InboundIdentities.Delete(serviceUID) + } + } else { + // Clear outbound identity if it matches, releasing its ref-count so + // the counter doesn't leak when a service is deleted before its pods + // are removed. + if strings.EqualFold(pod.PublicOutboundIdentity, serviceUID) { + pod.PublicOutboundIdentity = "" + node.Pods[podIP] = pod + if err := dt.decrementOutboundRefCount(serviceUID); err != nil { + klog.Warningf("removeServiceFromK8sStateLocked: %v", err) + } + } + } + + // Clean up empty pods and nodes + if !pod.HasIdentities() { + delete(node.Pods, podIP) + if !node.HasPods() { + delete(dt.K8sResources.Nodes, nodeIP) + } + } + } + } +} + +// RemoveServiceFromK8sState is the public, lock-acquiring entry point for +// removeServiceFromK8sStateLocked. Use it to clear a deleted service's references +// from pod identities when the caller does not already hold dt.mu. +func (dt *DiffTracker) RemoveServiceFromK8sState(serviceUID string, isInbound bool) { + dt.mu.Lock() + defer dt.mu.Unlock() + dt.removeServiceFromK8sStateLocked(serviceUID, isInbound) +} diff --git a/pkg/provider/difftracker/nrp_state_updates.go b/pkg/provider/difftracker/nrp_state_updates.go new file mode 100644 index 0000000000..94d427493c --- /dev/null +++ b/pkg/provider/difftracker/nrp_state_updates.go @@ -0,0 +1,108 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "k8s.io/klog/v2" +) + +func (dt *DiffTracker) UpdateNRPLoadBalancers(syncServicesReturnType SyncServicesReturnType) { + dt.mu.Lock() + defer dt.mu.Unlock() + + for _, service := range syncServicesReturnType.Additions.UnsortedList() { + dt.NRPResources.LoadBalancers.Insert(service) + klog.V(2).Infof("UpdateNRPLoadBalancers: Added service %s to NRP LoadBalancers", service) + } + + for _, service := range syncServicesReturnType.Removals.UnsortedList() { + dt.NRPResources.LoadBalancers.Delete(service) + klog.V(2).Infof("UpdateNRPLoadBalancers: Removed service %s from NRP LoadBalancers", service) + } +} + +func (dt *DiffTracker) UpdateNRPNATGateways(syncServicesReturnType SyncServicesReturnType) { + dt.mu.Lock() + defer dt.mu.Unlock() + + for _, service := range syncServicesReturnType.Additions.UnsortedList() { + dt.NRPResources.NATGateways.Insert(service) + klog.V(2).Infof("UpdateNRPNATGateways: Added service %s to NRP NATGateways", service) + } + + for _, service := range syncServicesReturnType.Removals.UnsortedList() { + dt.NRPResources.NATGateways.Delete(service) + klog.V(2).Infof("UpdateNRPNATGateways: Removed service %s from NRP NATGateways", service) + } +} + +func (dt *DiffTracker) UpdateLocationsAddresses(locationData LocationData) { + dt.mu.Lock() + defer dt.mu.Unlock() + + for locationKey, locationValue := range locationData.Locations { + // Remove empty locations + if len(locationValue.Addresses) == 0 { + delete(dt.NRPResources.Locations, locationKey) + continue + } + + // Get or create location + nrpLocation, exists := dt.NRPResources.Locations[locationKey] + isFullUpdate := !exists || locationValue.AddressUpdateAction == FullUpdate + + // For full updates, start with a fresh location + if isFullUpdate { + nrpLocation = NRPLocation{ + Addresses: make(map[string]NRPAddress), + } + } + + // Process address updates + for addressKey, addressValue := range locationValue.Addresses { + // Remove empty addresses + if addressValue.ServiceRef.Len() == 0 { + delete(nrpLocation.Addresses, addressKey) + continue + } + + // Create new service references set + serviceRefs := createServiceRefsFromAddress(addressValue) + + // For full update or when address doesn't exist, add new address + if isFullUpdate || !addressExists(nrpLocation, addressKey) { + nrpLocation.Addresses[addressKey] = NRPAddress{Services: serviceRefs} + continue + } + + // For partial updates with existing address + existingAddress := nrpLocation.Addresses[addressKey] + if !serviceRefs.Equals(existingAddress.Services) { + nrpLocation.Addresses[addressKey] = NRPAddress{ + Services: serviceRefs, + } + } + } + + // Save location if it has addresses, otherwise delete it + if len(nrpLocation.Addresses) > 0 { + dt.NRPResources.Locations[locationKey] = nrpLocation + } else { + delete(dt.NRPResources.Locations, locationKey) + } + } +} diff --git a/pkg/provider/difftracker/resource_helpers.go b/pkg/provider/difftracker/resource_helpers.go new file mode 100644 index 0000000000..b9b40bfd8c --- /dev/null +++ b/pkg/provider/difftracker/resource_helpers.go @@ -0,0 +1,332 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/klog/v2" + + "sigs.k8s.io/cloud-provider-azure/pkg/consts" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// buildInboundServiceResources constructs the PIP, LoadBalancer, and ServicesDTO for an inbound service +// Returns the resources ready to be created via createOrUpdatePIP/createOrUpdateLB/updateNRPSGWServices +func buildInboundServiceResources(serviceUID string, config *InboundConfig, dtConfig Config) ( + pip armnetwork.PublicIPAddress, + lb armnetwork.LoadBalancer, + servicesDTO ServicesDataDTO, + err error, +) { + pipName := fmt.Sprintf("%s-pip", serviceUID) + + // Build Public IP + pip = armnetwork.PublicIPAddress{ + Name: to.Ptr(pipName), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandardV2), + }, + Location: to.Ptr(dtConfig.Location), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + }, + } + + // Build LoadBalancer with backend pool and rules + backendPoolName := serviceUID + frontendIPConfigID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/frontendIPConfigurations/frontend", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID) + backendPoolID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID, backendPoolName) + + // Build backend pool + backendPools := []*armnetwork.BackendAddressPool{ + { + Name: to.Ptr(backendPoolName), + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + // Backend pool will be populated by ServiceGateway with pod IPs + }, + }, + } + + // Build LB rules from config. No health probes are created: PodIP backend pools + // don't support them, so the LB's Probes field is left nil. + var lbRules []*armnetwork.LoadBalancingRule + + if config != nil && len(config.FrontendPorts) > 0 { + idleTimeout := int32(4) + if config.IdleTimeoutMinutes != nil { + idleTimeout = *config.IdleTimeoutMinutes + if idleTimeout < 4 || idleTimeout > 30 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: idle timeout %d out of range (4-30) for service %s", idleTimeout, serviceUID) + } + } + for i, frontendPort := range config.FrontendPorts { + backendPort := frontendPort.Port + if i < len(config.BackendPorts) { + backendPort = config.BackendPorts[i].Port + } + + if frontendPort.Port < 1 || frontendPort.Port > 65534 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: frontend port %d out of range (1-65534) for service %s", frontendPort.Port, serviceUID) + } + if backendPort < 1 || backendPort > 65535 { + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: backend port %d out of range (1-65535) for service %s", backendPort, serviceUID) + } + + var protocol armnetwork.TransportProtocol + switch { + case strings.EqualFold(frontendPort.Protocol, "TCP"): + protocol = armnetwork.TransportProtocolTCP + case strings.EqualFold(frontendPort.Protocol, "UDP"): + protocol = armnetwork.TransportProtocolUDP + default: + return pip, lb, servicesDTO, fmt.Errorf("buildInboundServiceResources: unsupported protocol %q for service %s", frontendPort.Protocol, serviceUID) + } + + ruleName := fmt.Sprintf("rule-%s-%d", strings.ToLower(frontendPort.Protocol), frontendPort.Port) + + ruleProps := &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(protocol), + FrontendPort: to.Ptr(frontendPort.Port), + BackendPort: to.Ptr(backendPort), + EnableFloatingIP: to.Ptr(false), // Disabled for PodIP backend + IdleTimeoutInMinutes: to.Ptr(idleTimeout), + FrontendIPConfiguration: &armnetwork.SubResource{ + ID: to.Ptr(frontendIPConfigID), + }, + BackendAddressPool: &armnetwork.SubResource{ + ID: to.Ptr(backendPoolID), + }, + // No probe for PodIP backend pools + } + if protocol == armnetwork.TransportProtocolTCP { + ruleProps.EnableTCPReset = to.Ptr(true) + } + + lbRules = append(lbRules, &armnetwork.LoadBalancingRule{ + Name: to.Ptr(ruleName), + Properties: ruleProps, + }) + + klog.V(4).Infof("buildInboundServiceResources: created LB rule %s: frontend=%d backend=%d protocol=%s for service %s", + ruleName, frontendPort.Port, backendPort, frontendPort.Protocol, serviceUID) + } + } else { + klog.V(2).Infof("buildInboundServiceResources: no port configuration provided for service %s, creating LB without rules", serviceUID) + } + + lb = armnetwork.LoadBalancer{ + Name: to.Ptr(serviceUID), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s", dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID)), + Location: to.Ptr(dtConfig.Location), + SKU: &armnetwork.LoadBalancerSKU{ + Name: to.Ptr(armnetwork.LoadBalancerSKUName(consts.LoadBalancerSKUNameService)), + }, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Scope: to.Ptr(armnetwork.LoadBalancerScopePublic), + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ + { + Name: to.Ptr("frontend"), + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + }, + }, + }, + }, + BackendAddressPools: backendPools, + LoadBalancingRules: lbRules, + }, + } + + // Build ServicesDTO for ServiceGateway registration + servicesDTO = MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: utilsets.NewString(serviceUID), + Removals: nil, + }, + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) + + return pip, lb, servicesDTO, nil +} + +// buildOutboundServiceResources constructs the PIP, NAT Gateway, and ServicesDTO for an outbound service +// Returns the resources ready to be created via createOrUpdatePIP/createOrUpdateNatGateway/updateNRPSGWServices +func buildOutboundServiceResources(serviceUID string, config *OutboundConfig, dtConfig Config) ( + pip armnetwork.PublicIPAddress, + natGateway armnetwork.NatGateway, + servicesDTO ServicesDataDTO, +) { + pipName := fmt.Sprintf("%s-pip", serviceUID) + + // Build Public IP + pip = armnetwork.PublicIPAddress{ + Name: to.Ptr(pipName), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandardV2), + }, + Location: to.Ptr(dtConfig.Location), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + }, + } + + // Build NAT Gateway + natGateway = armnetwork.NatGateway{ + Name: to.Ptr(serviceUID), + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/natGateways/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, serviceUID)), + SKU: &armnetwork.NatGatewaySKU{ + Name: to.Ptr(armnetwork.NatGatewaySKUNameStandardV2), + }, + Location: to.Ptr(dtConfig.Location), + Properties: &armnetwork.NatGatewayPropertiesFormat{ + ServiceGateway: &armnetwork.SubResource{ + ID: to.Ptr(dtConfig.ServiceGatewayID), + }, + PublicIPAddresses: []*armnetwork.SubResource{ + { + ID: to.Ptr(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + dtConfig.SubscriptionID, dtConfig.ResourceGroup, pipName)), + }, + }, + }, + } + + // Build ServicesDTO for ServiceGateway registration + servicesDTO = MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + SyncServicesReturnType{ + Additions: utilsets.NewString(serviceUID), + Removals: nil, + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) + + return pip, natGateway, servicesDTO +} + +// ExtractInboundConfigFromService creates InboundConfig from a Kubernetes Service +// This is shared between initialization and the provider layer +func ExtractInboundConfigFromService(service *v1.Service) *InboundConfig { + if service == nil || len(service.Spec.Ports) == 0 { + return nil + } + + config := &InboundConfig{ + FrontendPorts: make([]PortMapping, 0, len(service.Spec.Ports)), + BackendPorts: make([]PortMapping, 0, len(service.Spec.Ports)), + } + + // Extract port mappings from service + for _, port := range service.Spec.Ports { + protocol := string(port.Protocol) + if protocol == "" { + protocol = "TCP" + } + + // Frontend port (service port) + config.FrontendPorts = append(config.FrontendPorts, PortMapping{ + Port: port.Port, + Protocol: protocol, + }) + + // Backend port (target port) + // For SLB with PodIP backend, we use TargetPort + // If TargetPort is not specified, default to Port + backendPort := port.Port + switch port.TargetPort.Type { + case intstr.Int: + if port.TargetPort.IntVal > 0 { + backendPort = port.TargetPort.IntVal + } + case intstr.String: + klog.Warningf("ExtractInboundConfigFromService: named targetPort %q is not supported for service %s/%s; falling back to port %d", + port.TargetPort.StrVal, service.Namespace, service.Name, port.Port) + backendPort = port.Port + } + + config.BackendPorts = append(config.BackendPorts, PortMapping{ + Port: backendPort, + Protocol: protocol, + }) + } + + return config +} + +// buildInboundResourceNames returns the resource names for an inbound service +func buildInboundResourceNames(serviceUID string) (lbName string, pipName string, backendPoolName string) { + return serviceUID, fmt.Sprintf("%s-pip", serviceUID), serviceUID +} + +// buildOutboundResourceNames returns the resource names for an outbound service +func buildOutboundResourceNames(serviceUID string) (natGatewayName string, pipName string) { + return serviceUID, fmt.Sprintf("%s-pip", serviceUID) +} + +// buildServiceGatewayRemovalDTO creates a ServicesDTO for removing a service from ServiceGateway +func buildServiceGatewayRemovalDTO(serviceUID string, isInbound bool, dtConfig Config) ServicesDataDTO { + if isInbound { + return MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: nil, + Removals: utilsets.NewString(serviceUID), + }, + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) + } + return MapLoadBalancerAndNATGatewayUpdatesToServicesDataDTO( + SyncServicesReturnType{ + Additions: nil, + Removals: nil, + }, + SyncServicesReturnType{ + Additions: nil, + Removals: utilsets.NewString(serviceUID), + }, + dtConfig.SubscriptionID, + dtConfig.ResourceGroup, + ) +} diff --git a/pkg/provider/difftracker/resource_helpers_test.go b/pkg/provider/difftracker/resource_helpers_test.go new file mode 100644 index 0000000000..2170e40a96 --- /dev/null +++ b/pkg/provider/difftracker/resource_helpers_test.go @@ -0,0 +1,573 @@ +package difftracker + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" +) + +func TestExtractInboundConfigFromService_NilService(t *testing.T) { + config := ExtractInboundConfigFromService(nil) + assert.Nil(t, config) +} + +func TestExtractInboundConfigFromService_EmptyPorts(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{}, + }, + } + config := ExtractInboundConfigFromService(service) + assert.Nil(t, config) +} + +func TestExtractInboundConfigFromService_SingleTCPPort(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt(8080), + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + assert.Len(t, config.FrontendPorts, 1) + assert.Len(t, config.BackendPorts, 1) + + // Check frontend port + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, "TCP", config.FrontendPorts[0].Protocol) + + // Check backend port (should be TargetPort) + assert.Equal(t, int32(8080), config.BackendPorts[0].Port) + assert.Equal(t, "TCP", config.BackendPorts[0].Protocol) +} + +func TestExtractInboundConfigFromService_MultiplePortsWithUDP(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromInt(8080), + }, + { + Name: "dns", + Protocol: v1.ProtocolUDP, + Port: 53, + TargetPort: intstr.FromInt(5353), + }, + { + Name: "https", + Protocol: v1.ProtocolTCP, + Port: 443, + TargetPort: intstr.FromInt(8443), + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + assert.Len(t, config.FrontendPorts, 3) + assert.Len(t, config.BackendPorts, 3) + + // Verify HTTP + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, "TCP", config.FrontendPorts[0].Protocol) + assert.Equal(t, int32(8080), config.BackendPorts[0].Port) + + // Verify DNS (UDP) + assert.Equal(t, int32(53), config.FrontendPorts[1].Port) + assert.Equal(t, "UDP", config.FrontendPorts[1].Protocol) + assert.Equal(t, int32(5353), config.BackendPorts[1].Port) + + // Verify HTTPS + assert.Equal(t, int32(443), config.FrontendPorts[2].Port) + assert.Equal(t, "TCP", config.FrontendPorts[2].Protocol) + assert.Equal(t, int32(8443), config.BackendPorts[2].Port) +} + +func TestExtractInboundConfigFromService_NoTargetPort(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + // TargetPort not specified + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + + // When TargetPort is not specified, backend port should equal frontend port + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, int32(80), config.BackendPorts[0].Port) +} + +func TestExtractInboundConfigFromService_NamedTargetPort(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 80, + TargetPort: intstr.FromString("http-port"), // Named port + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + assert.Len(t, config.FrontendPorts, 1) + assert.Len(t, config.BackendPorts, 1) + + assert.Equal(t, int32(80), config.FrontendPorts[0].Port) + assert.Equal(t, config.FrontendPorts[0].Port, config.BackendPorts[0].Port) +} + +func TestExtractInboundConfigFromService_EmptyProtocol(t *testing.T) { + service := &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Name: "http", + Port: 80, + // Protocol not specified + }, + }, + }, + } + + config := ExtractInboundConfigFromService(service) + assert.NotNil(t, config) + + // Default protocol should be TCP + assert.Equal(t, "TCP", config.FrontendPorts[0].Protocol) + assert.Equal(t, "TCP", config.BackendPorts[0].Protocol) +} + +func TestBuildInboundServiceResources_WithConfig(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{ + {Port: 80, Protocol: "TCP"}, + {Port: 443, Protocol: "TCP"}, + }, + BackendPorts: []PortMapping{ + {Port: 8080, Protocol: "TCP"}, + {Port: 8443, Protocol: "TCP"}, + }, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + pip, lb, servicesDTO, err := buildInboundServiceResources("service-uid-123", config, dtConfig) + assert.NoError(t, err) + + // Verify PIP + assert.NotNil(t, pip.Name) + assert.Equal(t, "service-uid-123-pip", *pip.Name) + assert.Equal(t, armnetwork.PublicIPAddressSKUNameStandardV2, *pip.SKU.Name) + assert.Equal(t, "eastus", *pip.Location) + + // Verify LoadBalancer + assert.NotNil(t, lb.Name) + assert.Equal(t, "service-uid-123", *lb.Name) + assert.Equal(t, "Service", string(*lb.SKU.Name)) + assert.Equal(t, "eastus", *lb.Location) + + // Verify backend pool + assert.Len(t, lb.Properties.BackendAddressPools, 1) + assert.Equal(t, "service-uid-123", *lb.Properties.BackendAddressPools[0].Name) + + // Verify LB rules + assert.Len(t, lb.Properties.LoadBalancingRules, 2) + + // Rule 1: port 80 -> 8080 + rule1 := lb.Properties.LoadBalancingRules[0] + assert.Equal(t, "rule-tcp-80", *rule1.Name) + assert.Equal(t, armnetwork.TransportProtocolTCP, *rule1.Properties.Protocol) + assert.Equal(t, int32(80), *rule1.Properties.FrontendPort) + assert.Equal(t, int32(8080), *rule1.Properties.BackendPort) + assert.False(t, *rule1.Properties.EnableFloatingIP) + + // Rule 2: port 443 -> 8443 + rule2 := lb.Properties.LoadBalancingRules[1] + assert.Equal(t, "rule-tcp-443", *rule2.Name) + assert.Equal(t, int32(443), *rule2.Properties.FrontendPort) + assert.Equal(t, int32(8443), *rule2.Properties.BackendPort) + + // Verify ServicesDTO + assert.Len(t, servicesDTO.Services, 1) + assert.Contains(t, servicesDTO.Services[0].Service, "service-uid-123") + assert.Equal(t, Inbound, servicesDTO.Services[0].ServiceType) +} + +func TestBuildInboundServiceResources_NilConfig(t *testing.T) { + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + pip, lb, servicesDTO, err := buildInboundServiceResources("service-uid-123", nil, dtConfig) + assert.NoError(t, err) + + // Should still create LB, just without rules + assert.NotNil(t, lb.Name) + assert.Equal(t, "service-uid-123", *lb.Name) + + // Should have backend pool but no rules + assert.Len(t, lb.Properties.BackendAddressPools, 1) + assert.Empty(t, lb.Properties.LoadBalancingRules) + + // PIP should still be created + assert.NotNil(t, pip.Name) + + // ServicesDTO should still be valid + assert.Len(t, servicesDTO.Services, 1) + assert.Equal(t, Inbound, servicesDTO.Services[0].ServiceType) +} + +func TestBuildInboundServiceResources_UDPProtocol(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{ + {Port: 53, Protocol: "UDP"}, + }, + BackendPorts: []PortMapping{ + {Port: 5353, Protocol: "UDP"}, + }, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "westus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + _, lb, _, err := buildInboundServiceResources("service-uid-udp", config, dtConfig) + assert.NoError(t, err) + + // Verify UDP rule + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + rule := lb.Properties.LoadBalancingRules[0] + assert.Equal(t, "rule-udp-53", *rule.Name) + assert.Equal(t, armnetwork.TransportProtocolUDP, *rule.Properties.Protocol) + assert.Equal(t, int32(53), *rule.Properties.FrontendPort) + assert.Nil(t, rule.Properties.EnableTCPReset) + assert.Equal(t, int32(5353), *rule.Properties.BackendPort) +} + +func TestBuildOutboundServiceResources_Basic(t *testing.T) { + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "centralus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + pip, natGw, servicesDTO := buildOutboundServiceResources("egress-uid-456", nil, dtConfig) + + // Verify PIP + assert.NotNil(t, pip.Name) + assert.Equal(t, "egress-uid-456-pip", *pip.Name) + assert.Equal(t, armnetwork.PublicIPAddressSKUNameStandardV2, *pip.SKU.Name) + assert.Equal(t, "centralus", *pip.Location) + + // Verify NAT Gateway + assert.NotNil(t, natGw.Name) + assert.Equal(t, "egress-uid-456", *natGw.Name) + assert.Equal(t, armnetwork.NatGatewaySKUNameStandardV2, *natGw.SKU.Name) + assert.Equal(t, "centralus", *natGw.Location) + + // Verify NAT Gateway has ServiceGateway reference + assert.NotNil(t, natGw.Properties.ServiceGateway) + assert.Equal(t, dtConfig.ServiceGatewayID, *natGw.Properties.ServiceGateway.ID) + + // Verify NAT Gateway has PIP reference + assert.Len(t, natGw.Properties.PublicIPAddresses, 1) + assert.Contains(t, *natGw.Properties.PublicIPAddresses[0].ID, "egress-uid-456-pip") + + // Verify ServicesDTO + assert.Len(t, servicesDTO.Services, 1) + assert.Contains(t, servicesDTO.Services[0].Service, "egress-uid-456") + assert.Equal(t, Outbound, servicesDTO.Services[0].ServiceType) +} + +func TestBuildInboundServiceResources_BackendPoolNaming(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + _, lb, _, err := buildInboundServiceResources("my-service-uid", config, dtConfig) + assert.NoError(t, err) + + // Backend pool name must match serviceUID for SLB mode + assert.Len(t, lb.Properties.BackendAddressPools, 1) + backendPool := lb.Properties.BackendAddressPools[0] + assert.Equal(t, "my-service-uid", *backendPool.Name) + + // LB rule should reference the correct backend pool + rule := lb.Properties.LoadBalancingRules[0] + assert.Contains(t, *rule.Properties.BackendAddressPool.ID, "my-service-uid") +} + +func TestBuildInboundServiceResources_NoProbesForPodIPBackend(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + + dtConfig := Config{ + SubscriptionID: "test-sub", + ResourceGroup: "test-rg", + Location: "eastus", + ServiceGatewayResourceName: "test-sgw", + ServiceGatewayID: "/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/serviceGateways/test-sgw", + } + + _, lb, _, err := buildInboundServiceResources("service-uid", config, dtConfig) + assert.NoError(t, err) + + // For PodIP backend pools, no health probes should be created + assert.Empty(t, lb.Properties.Probes) + + // LB rules should have no probe reference + rule := lb.Properties.LoadBalancingRules[0] + assert.Nil(t, rule.Properties.Probe) +} + +func TestBuildInboundServiceResources_ResourceIDs(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + + dtConfig := Config{ + SubscriptionID: "sub-123", + ResourceGroup: "rg-456", + Location: "eastus", + ServiceGatewayResourceName: "sgw-789", + ServiceGatewayID: "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/serviceGateways/sgw-789", + } + + pip, lb, _, err := buildInboundServiceResources("svc-abc", config, dtConfig) + assert.NoError(t, err) + + // Verify PIP ID format + expectedPIPID := "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/publicIPAddresses/svc-abc-pip" + assert.Equal(t, expectedPIPID, *pip.ID) + + // Verify LB references PIP correctly + frontendConfig := lb.Properties.FrontendIPConfigurations[0] + assert.Equal(t, expectedPIPID, *frontendConfig.Properties.PublicIPAddress.ID) + + // Verify backend pool ID reference in rule + rule := lb.Properties.LoadBalancingRules[0] + expectedBackendPoolID := "/subscriptions/sub-123/resourceGroups/rg-456/providers/Microsoft.Network/loadBalancers/svc-abc/backendAddressPools/svc-abc" + assert.Equal(t, expectedBackendPoolID, *rule.Properties.BackendAddressPool.ID) +} + +func TestBuildInboundServiceResources_LowercaseUDP(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 53, Protocol: "udp"}}, + BackendPorts: []PortMapping{{Port: 5353, Protocol: "udp"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + assert.Equal(t, armnetwork.TransportProtocolUDP, *lb.Properties.LoadBalancingRules[0].Properties.Protocol) +} + +func TestBuildInboundServiceResources_UnsupportedProtocolErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 53, Protocol: "SCTP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported protocol") +} + +func TestBuildInboundServiceResources_PortOutOfRangeErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 65535, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "out of range") +} + +func TestBuildInboundServiceResources_TCPHasResetEnabled(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + if assert.NotNil(t, lb.Properties.LoadBalancingRules[0].Properties.EnableTCPReset) { + assert.True(t, *lb.Properties.LoadBalancingRules[0].Properties.EnableTCPReset) + } +} + +func TestBuildInboundServiceResources_BackendPortMaxIsValid(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 65535, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + if assert.Len(t, lb.Properties.LoadBalancingRules, 1) { + assert.Equal(t, int32(65535), *lb.Properties.LoadBalancingRules[0].Properties.BackendPort) + } +} + +func TestBuildInboundServiceResources_BackendPortOutOfRangeErrors(t *testing.T) { + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 65536, Protocol: "TCP"}}, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "backend port") +} + +func TestBuildInboundServiceResources_AppliesIdleTimeout(t *testing.T) { + idle := int32(30) + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + BackendPorts: []PortMapping{{Port: 8080, Protocol: "TCP"}}, + IdleTimeoutMinutes: &idle, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, lb, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.NoError(t, err) + assert.Len(t, lb.Properties.LoadBalancingRules, 1) + assert.Equal(t, int32(30), *lb.Properties.LoadBalancingRules[0].Properties.IdleTimeoutInMinutes) +} + +func TestBuildInboundServiceResources_IdleTimeoutOutOfRangeErrors(t *testing.T) { + idle := int32(99) + config := &InboundConfig{ + FrontendPorts: []PortMapping{{Port: 80, Protocol: "TCP"}}, + IdleTimeoutMinutes: &idle, + } + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg", Location: "westus"} + + _, _, _, err := buildInboundServiceResources("svc", config, dtConfig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "idle timeout") +} + +func TestBuildInboundResourceNames(t *testing.T) { + lbName, pipName, backendPoolName := buildInboundResourceNames("uid") + assert.Equal(t, "uid", lbName) + assert.Equal(t, "uid-pip", pipName) + assert.Equal(t, "uid", backendPoolName) +} + +func TestBuildOutboundResourceNames(t *testing.T) { + natGatewayName, pipName := buildOutboundResourceNames("uid") + assert.Equal(t, "uid", natGatewayName) + assert.Equal(t, "uid-pip", pipName) +} + +func TestBuildServiceGatewayRemovalDTO(t *testing.T) { + dtConfig := Config{SubscriptionID: "sub", ResourceGroup: "rg"} + + t.Run("inbound removal", func(t *testing.T) { + dto := buildServiceGatewayRemovalDTO("uid", true, dtConfig) + assert.Equal(t, PartialUpdate, dto.Action) + if assert.Len(t, dto.Services, 1) { + assert.Equal(t, "uid", dto.Services[0].Service) + assert.Equal(t, Inbound, dto.Services[0].ServiceType) + assert.True(t, dto.Services[0].IsDelete) + } + }) + + t.Run("outbound removal", func(t *testing.T) { + dto := buildServiceGatewayRemovalDTO("uid", false, dtConfig) + assert.Equal(t, PartialUpdate, dto.Action) + if assert.Len(t, dto.Services, 1) { + assert.Equal(t, "uid", dto.Services[0].Service) + assert.Equal(t, Outbound, dto.Services[0].ServiceType) + assert.True(t, dto.Services[0].IsDelete) + } + }) +} diff --git a/pkg/provider/difftracker/sync_operations.go b/pkg/provider/difftracker/sync_operations.go new file mode 100644 index 0000000000..0923400cb2 --- /dev/null +++ b/pkg/provider/difftracker/sync_operations.go @@ -0,0 +1,214 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "k8s.io/klog/v2" + + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// GetServicesToSync handles the synchronization of services between K8s and NRP +func GetServicesToSync(k8sServices, nrpServices *utilsets.IgnoreCaseSet) SyncServicesReturnType { + klog.V(2).Infof("GetServicesToSync: K8s services (%d): %v", k8sServices.Len(), k8sServices.UnsortedList()) + klog.V(2).Infof("GetServicesToSync: NRP services (%d): %v", nrpServices.Len(), nrpServices.UnsortedList()) + + syncServices := SyncServicesReturnType{ + // Additions are in K8s but not yet in NRP; removals are in NRP but no + // longer in K8s. + Additions: k8sServices.Difference(nrpServices), + Removals: nrpServices.Difference(k8sServices), + } + klog.V(4).Infof("GetServicesToSync: additions=%v, removals=%v", + syncServices.Additions.UnsortedList(), syncServices.Removals.UnsortedList()) + + klog.V(2).Infof("GetServicesToSync: Result - Additions: %d, Removals: %d", syncServices.Additions.Len(), syncServices.Removals.Len()) + return syncServices +} + +func (dt *DiffTracker) GetSyncLoadBalancerServices() SyncServicesReturnType { + dt.mu.Lock() + defer dt.mu.Unlock() + + return dt.getSyncLoadBalancerServicesLocked() +} + +// getSyncLoadBalancerServicesLocked is the lock-free body. Callers must hold dt.mu. +func (dt *DiffTracker) getSyncLoadBalancerServicesLocked() SyncServicesReturnType { + return GetServicesToSync(dt.K8sResources.Services, dt.NRPResources.LoadBalancers) +} + +func (dt *DiffTracker) GetSyncNRPNATGateways() SyncServicesReturnType { + dt.mu.Lock() + defer dt.mu.Unlock() + + return dt.getSyncNRPNATGatewaysLocked() +} + +// getSyncNRPNATGatewaysLocked is the lock-free body. Callers must hold dt.mu. +func (dt *DiffTracker) getSyncNRPNATGatewaysLocked() SyncServicesReturnType { + return GetServicesToSync(dt.K8sResources.Egresses, dt.NRPResources.NATGateways) +} + +func (dt *DiffTracker) GetSyncLocationsAddresses() LocationData { + dt.mu.Lock() + defer dt.mu.Unlock() + + return dt.getSyncLocationsAddressesLocked() +} + +// getSyncLocationsAddressesLocked is the lock-free body. Callers must hold dt.mu. +func (dt *DiffTracker) getSyncLocationsAddressesLocked() LocationData { + result := LocationData{ + Action: PartialUpdate, + Locations: make(map[string]Location), + } + + // Iterate over all nodes in the K8s state + for nodeIP, node := range dt.K8sResources.Nodes { + nrpLocation, locationExists := dt.NRPResources.Locations[nodeIP] + location := initializeLocation(locationExists) + locationUpdated := false + + for address, pod := range node.Pods { + // Filter services: only include services that exist in NRP + serviceRef := dt.createServiceRefFiltered(pod) + + // Check if address exists in NRP and if service list changed + nrpAddressData, nrpAddrExists := nrpLocation.Addresses[address] + + // Skip this address if: + // 1. No ready services AND address doesn't exist in NRP (nothing to sync) + // 2. ServiceRef matches what's already in NRP (no change) + if serviceRef.Len() == 0 && !nrpAddrExists { + continue + } + + if nrpAddrExists && serviceRef.Equals(nrpAddressData.Services) { + continue + } + + // ServiceRef changed (or address is new) - need to sync + addressData := Address{ServiceRef: serviceRef} + location.Addresses[address] = addressData + locationUpdated = true + } + if locationUpdated { + result.Locations[nodeIP] = location + } + } + + // Iterate over all locations in the NRP state + for location, nrpLocation := range dt.NRPResources.Locations { + node, exists := dt.K8sResources.Nodes[location] + if !exists { + result.Locations[location] = Location{ + AddressUpdateAction: PartialUpdate, + Addresses: make(map[string]Address), + } + } else { + locationData := findLocationData(result, location) + if locationData == nil { + locationData = &Location{ + AddressUpdateAction: PartialUpdate, + Addresses: make(map[string]Address), + } + } + for address := range nrpLocation.Addresses { + if _, exists := node.Pods[address]; !exists { + addressData := Address{ServiceRef: utilsets.NewString()} + locationData.Addresses[address] = addressData + result.Locations[location] = *locationData + } + } + } + } + return result +} + +// Helper function to initialize Location based on existence in NRP +func initializeLocation(exists bool) Location { + if !exists { + return Location{ + AddressUpdateAction: FullUpdate, + Addresses: make(map[string]Address), + } + } + return Location{ + AddressUpdateAction: PartialUpdate, + Addresses: make(map[string]Address), + } +} + +// createServiceRefFiltered creates ServiceRef but only includes services that exist in NRP. +// Must be called with dt.mu held. +func (dt *DiffTracker) createServiceRefFiltered(pod Pod) *utilsets.IgnoreCaseSet { + serviceRef := utilsets.NewString() + + // Check inbound services (LoadBalancers) + for _, serviceUID := range pod.InboundIdentities.UnsortedList() { + if dt.isServiceReadyToSync(serviceUID, true) { + serviceRef.Insert(serviceUID) + } + } + + // Check outbound service (NAT Gateway) + if pod.PublicOutboundIdentity != "" { + if dt.isServiceReadyToSync(pod.PublicOutboundIdentity, false) { + serviceRef.Insert(pod.PublicOutboundIdentity) + } + } + + return serviceRef +} + +// isServiceReadyToSync reports whether a service is ready to be synced to the +// Service Gateway, i.e. its NRP resource exists. Must be called with dt.mu held. +func (dt *DiffTracker) isServiceReadyToSync(serviceUID string, isInbound bool) bool { + if isInbound { + return dt.NRPResources.LoadBalancers.Has(serviceUID) + } + return dt.NRPResources.NATGateways.Has(serviceUID) +} + +// findLocationData returns a pointer to the Location stored under the given key +// in data, or nil if no such location exists. +func findLocationData(data LocationData, location string) *Location { + if loc, ok := data.Locations[location]; ok { + return &loc + } + return nil +} + +func (dt *DiffTracker) GetSyncOperations() *SyncDiffTrackerReturnType { + // Take the lock once so DeepEqual and all three sync computations observe a + // single consistent snapshot of the state (avoids a data race with mutating + // methods and inconsistency between the individual GetSync* results). + dt.mu.Lock() + defer dt.mu.Unlock() + + if dt.deepEqualLocked() { + return &SyncDiffTrackerReturnType{SyncStatus: AlreadyInSync} + } + + return &SyncDiffTrackerReturnType{ + SyncStatus: Success, + LoadBalancerUpdates: dt.getSyncLoadBalancerServicesLocked(), + NATGatewayUpdates: dt.getSyncNRPNATGatewaysLocked(), + LocationData: dt.getSyncLocationsAddressesLocked(), + } +} diff --git a/pkg/provider/difftracker/types.go b/pkg/provider/difftracker/types.go new file mode 100644 index 0000000000..3ac5611801 --- /dev/null +++ b/pkg/provider/difftracker/types.go @@ -0,0 +1,359 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "sync" + + "k8s.io/client-go/kubernetes" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// ================================================================================================ +// ENUMS +// ================================================================================================ +type Operation int + +const ( + UnknownOperation Operation = iota + Add + Remove + Update +) + +type UpdateAction int + +const ( + UnknownUpdateAction UpdateAction = iota + PartialUpdate + FullUpdate +) + +type SyncStatus int + +const ( + UnknownSyncStatus SyncStatus = iota + AlreadyInSync + Success +) + +// -------------------------------------------------------------------------------- +// DiffTracker keeps track of the state of the K8s cluster and NRP +// -------------------------------------------------------------------------------- +// NRPAddress holds the NRP-side state for a single pod address (pod IP). +type NRPAddress struct { + // Services holds the SGW service identities (LBs for inbound, NATGWs for + // outbound) currently associated with this address on the NRP side. + // These are SGW service identities, not Kubernetes Service names. + Services *utilsets.IgnoreCaseSet +} + +// NRPLocation holds the NRP-side state for a single node/VM and groups the +// pod addresses running on it. +type NRPLocation struct { + // Addresses is keyed by pod IP. Each pod IP is added to the ServiceGateway + // as an address under this location once the pod is created. + Addresses map[string]NRPAddress +} + +type NRPState struct { + // LoadBalancers holds the UIDs of inbound services that have a LoadBalancer + // registered on the NRP side. These are SGW service identities, not Azure + // LoadBalancer resource names. + LoadBalancers *utilsets.IgnoreCaseSet + // NATGateways holds the UIDs of outbound/egress services that have a NAT + // Gateway registered on the NRP side (SGW service identities, not Azure + // resource names). + NATGateways *utilsets.IgnoreCaseSet + // Locations is keyed by node/VM IP (e.g. "10.0.0.1"). "Location" here is + // an SGW concept identifying a node, not an Azure region (e.g. "eastus2"). + Locations map[string]NRPLocation +} + +type Pod struct { + // InboundIdentities holds the UIDs of the inbound ServiceGateway services + // (LoadBalancers) this pod backs. A pod may back several, hence a set. + InboundIdentities *utilsets.IgnoreCaseSet + // PublicOutboundIdentity is the UID of the single outbound/egress ServiceGateway + // service (NAT Gateway) this pod uses for egress; empty if the pod has no egress. + PublicOutboundIdentity string +} + +// newPod returns a Pod with its InboundIdentities set initialized. +func newPod() Pod { + return Pod{InboundIdentities: utilsets.NewString()} +} + +type Node struct { + Pods map[string]Pod +} + +// newNode returns a Node with its Pods map initialized. +func newNode() Node { + return Node{Pods: make(map[string]Pod)} +} + +type K8sState struct { + Services *utilsets.IgnoreCaseSet + Egresses *utilsets.IgnoreCaseSet + Nodes map[string]Node +} + +// DiffTracker is the main struct that contains the state of the K8s and NRP services +type DiffTracker struct { + mu sync.Mutex // Protects concurrent access to DiffTracker + + K8sResources K8sState + NRPResources NRPState + + // outboundIdentityPodRefCount counts how many pods reference each outbound + // (egress) identity, keyed by lowercased PublicOutboundIdentity. It lets the + // engine delete a NAT Gateway when its last egress pod is removed. Inbound + // (LoadBalancer) services are not tracked here; their lifecycle follows the + // Kubernetes Service object. + outboundIdentityPodRefCount sync.Map + + // Configuration and clients + config Config + networkClientFactory azclient.ClientFactory + kubeClient kubernetes.Interface +} + +// -------------------------------------------------------------------------------- +// Types that are used while events are received and processed in order to update K8s state +// -------------------------------------------------------------------------------- + +// UpdateK8sResource represents input for K8s service or egress updates +type UpdateK8sResource struct { + Operation Operation + ID string +} + +// UpdateK8sEndpointsInputType represents input for K8s endpoints updates +type UpdateK8sEndpointsInputType struct { + InboundIdentity string + OldAddresses map[string]string // address -> location + NewAddresses map[string]string // address -> location +} + +// UpdatePodInputType represents input for K8s pod updates (egress assignments) +type UpdatePodInputType struct { + PodOperation Operation + PublicOutboundIdentity string + Location string + Address string +} + +// -------------------------------------------------------------------------------- +// Types that are used while syncing NRP state to K8s state +// -------------------------------------------------------------------------------- +type Address struct { + ServiceRef *utilsets.IgnoreCaseSet +} + +// Location uses a map for Addresses +type Location struct { + AddressUpdateAction UpdateAction + Addresses map[string]Address // key is the pod IP +} + +type LocationData struct { + Action UpdateAction + Locations map[string]Location // key is the node IP +} + +type SyncServicesReturnType struct { + Additions *utilsets.IgnoreCaseSet + Removals *utilsets.IgnoreCaseSet +} + +type SyncDiffTrackerReturnType struct { + SyncStatus SyncStatus + LoadBalancerUpdates SyncServicesReturnType + NATGatewayUpdates SyncServicesReturnType + LocationData LocationData +} + +// ================================================================================================ +// CP2: Service configuration types + ServiceGateway DTOs (deferred from CP1) +// ================================================================================================ + +// PortMapping represents a port mapping configuration +type PortMapping struct { + Port int32 + Protocol string // TCP or UDP +} + +// HealthProbeConfig represents health probe configuration +type HealthProbeConfig struct { + Protocol string // TCP, HTTP, or HTTPS + Port int32 + IntervalInSeconds int32 + NumberOfProbes int32 + RequestPath *string // For HTTP/HTTPS probes +} + +// InboundConfig contains Load Balancer configuration for inbound services +type InboundConfig struct { + FrontendPorts []PortMapping // nullable for future use + BackendPorts []PortMapping // nullable for future use + Protocol *string // TCP/UDP, nullable + IdleTimeoutMinutes *int32 // nullable + SessionPersistence *string // nullable + HealthProbe *HealthProbeConfig // nullable +} + +// OutboundConfig contains NAT Gateway configuration for outbound services +type OutboundConfig struct { + // Placeholder for future NAT Gateway options +} + +// Equals returns true if two InboundConfigs describe the same desired LB shape. +// Used by UpdateService to short-circuit no-op reconciles. +// Comparison is order-sensitive for FrontendPorts/BackendPorts because the +// position of a port determines its pairing with a backend port in +// buildInboundServiceResources. +func (c *InboundConfig) Equals(other *InboundConfig) bool { + if c == nil && other == nil { + return true + } + if c == nil || other == nil { + return false + } + if !portMappingsEqual(c.FrontendPorts, other.FrontendPorts) { + return false + } + if !portMappingsEqual(c.BackendPorts, other.BackendPorts) { + return false + } + if !strPtrEqual(c.Protocol, other.Protocol) { + return false + } + if !int32PtrEqual(c.IdleTimeoutMinutes, other.IdleTimeoutMinutes) { + return false + } + if !strPtrEqual(c.SessionPersistence, other.SessionPersistence) { + return false + } + if !healthProbeEqual(c.HealthProbe, other.HealthProbe) { + return false + } + return true +} + +func portMappingsEqual(a, b []PortMapping) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func strPtrEqual(a, b *string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func int32PtrEqual(a, b *int32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func healthProbeEqual(a, b *HealthProbeConfig) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + if a.Protocol != b.Protocol || a.Port != b.Port || + a.IntervalInSeconds != b.IntervalInSeconds || a.NumberOfProbes != b.NumberOfProbes { + return false + } + return strPtrEqual(a.RequestPath, b.RequestPath) +} + +// -------------------------------------------------------------------------------- +// Data Transfer Objects (DTOs) for LocationData (following the ServiceGateway API documentation) +// -------------------------------------------------------------------------------- + +// AddressDTO represents the DTO for Address +type AddressDTO struct { + Address string `json:"Address"` + ServiceNames *utilsets.IgnoreCaseSet `json:"ServiceNames"` +} + +// LocationDTO represents the DTO for Location +type LocationDTO struct { + Location string `json:"Location"` + AddressUpdateAction UpdateAction `json:"AddressUpdateAction"` + Addresses []AddressDTO `json:"Addresses"` +} + +// LocationsDataDTO represents the DTO for LocationData +type LocationsDataDTO struct { + Action UpdateAction `json:"Action"` + Locations []LocationDTO `json:"Locations"` +} + +// ================================================================================================ +// Data Transfer Objects (DTOs) for ServiceData (following the ServiceGateway API documentation) +// ================================================================================================ + +type ServiceType string + +const ( + Inbound ServiceType = "Inbound" + Outbound ServiceType = "Outbound" +) + +type LoadBalancerBackendPoolDTO struct { + Id string `json:"Id"` +} + +type NatGatewayDTO struct { + Id string `json:"Id"` +} + +type ServiceDTO struct { + Service string `json:"Service"` + ServiceType ServiceType `json:"ServiceType"` + LoadBalancerBackendPools []LoadBalancerBackendPoolDTO `json:"LoadBalancerBackendPools"` + PublicNatGateway NatGatewayDTO `json:"PublicNatGateway"` + IsDelete bool +} + +type ServicesDataDTO struct { + Action UpdateAction `json:"Action"` + Services []ServiceDTO `json:"Services"` +} diff --git a/pkg/provider/difftracker/util.go b/pkg/provider/difftracker/util.go new file mode 100644 index 0000000000..4406285e51 --- /dev/null +++ b/pkg/provider/difftracker/util.go @@ -0,0 +1,341 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "k8s.io/klog/v2" + + utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +func (operation Operation) String() string { + switch operation { + case UnknownOperation: + return "UnknownOperation" + case Add: + return "Add" + case Remove: + return "Remove" + case Update: + return "Update" + default: + return fmt.Sprintf("Operation(%d)", int(operation)) + } +} + +func (operation Operation) MarshalJSON() ([]byte, error) { + return json.Marshal(operation.String()) +} + +func (updateAction UpdateAction) String() string { + switch updateAction { + case UnknownUpdateAction: + return "UnknownUpdateAction" + case PartialUpdate: + return "PartialUpdate" + case FullUpdate: + return "FullUpdate" + default: + return fmt.Sprintf("UpdateAction(%d)", int(updateAction)) + } +} + +func (updateAction UpdateAction) MarshalJSON() ([]byte, error) { + return json.Marshal(updateAction.String()) +} + +func (updateAction *UpdateAction) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + switch s { + case "PartialUpdate": + *updateAction = PartialUpdate + case "FullUpdate": + *updateAction = FullUpdate + default: + return fmt.Errorf("unknown UpdateAction: %q", s) + } + + return nil +} + +func (syncStatus SyncStatus) String() string { + switch syncStatus { + case UnknownSyncStatus: + return "UnknownSyncStatus" + case AlreadyInSync: + return "AlreadyInSync" + case Success: + return "Success" + default: + return fmt.Sprintf("SyncStatus(%d)", int(syncStatus)) + } +} + +func (syncStatus SyncStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(syncStatus.String()) +} + +func (node *Node) HasPods() bool { return len(node.Pods) > 0 } + +func (pod *Pod) HasIdentities() bool { + return pod.InboundIdentities.Len() > 0 || pod.PublicOutboundIdentity != "" +} + +// deepEqualLocked compares the K8s and NRP states to check if they are in sync. +// Callers must hold dt.mu. +func (dt *DiffTracker) deepEqualLocked() bool { + klog.V(4).Infof("DeepEqual: Checking equality - K8s Services=%d, NRP LoadBalancers=%d, K8s Egresses=%d, NRP NATGateways=%d, K8s Nodes=%d, NRP Locations=%d", + dt.K8sResources.Services.Len(), dt.NRPResources.LoadBalancers.Len(), + dt.K8sResources.Egresses.Len(), dt.NRPResources.NATGateways.Len(), + len(dt.K8sResources.Nodes), len(dt.NRPResources.Locations)) + + // Compare Services with LoadBalancers and Egresses with NATGateways. + if !dt.K8sResources.Services.Equals(dt.NRPResources.LoadBalancers) { + klog.V(4).Infof("DeepEqual: Services and LoadBalancers mismatch") + return false + } + if !dt.K8sResources.Egresses.Equals(dt.NRPResources.NATGateways) { + klog.V(4).Infof("DeepEqual: Egresses and NATGateways mismatch") + return false + } + + // Compare Nodes with Locations. + if len(dt.K8sResources.Nodes) != len(dt.NRPResources.Locations) { + klog.V(4).Infof("DeepEqual: Nodes and Locations length mismatch") + return false + } + for nodeKey, node := range dt.K8sResources.Nodes { + nrpLocation, exists := dt.NRPResources.Locations[nodeKey] + if !exists { + klog.V(4).Infof("DeepEqual: Node %s not found in Locations", nodeKey) + return false + } + + // Compare Pods with Addresses. + if len(node.Pods) != len(nrpLocation.Addresses) { + klog.V(4).Infof("DeepEqual: Pods and Addresses length mismatch for node %s", nodeKey) + return false + } + for podKey, pod := range node.Pods { + nrpAddress, exists := nrpLocation.Addresses[podKey] + if !exists { + klog.V(4).Infof("DeepEqual: Pod %s not found in Addresses for node %s", podKey, nodeKey) + return false + } + + // Compare [...InboundIdentities, PublicOutboundIdentity] with Services. + combinedIdentities := utilsets.NewString(pod.InboundIdentities.UnsortedList()...) + if pod.PublicOutboundIdentity != "" { + combinedIdentities.Insert(pod.PublicOutboundIdentity) + } + if !combinedIdentities.Equals(nrpAddress.Services) { + klog.V(4).Infof("DeepEqual: Identities and Services mismatch for pod %s in node %s", podKey, nodeKey) + return false + } + } + } + + return true +} + +func (s *SyncServicesReturnType) Equals(other *SyncServicesReturnType) bool { + return s.Additions.Equals(other.Additions) && s.Removals.Equals(other.Removals) +} + +// Equals compares two LocationData objects for equality +func (ld *LocationData) Equals(other *LocationData) bool { + if ld.Action != other.Action { + return false + } + + if len(ld.Locations) != len(other.Locations) { + return false + } + + for locName, location := range ld.Locations { + otherLocation, exists := other.Locations[locName] + if !exists { + return false + } + + if location.AddressUpdateAction != otherLocation.AddressUpdateAction { + return false + } + + if len(location.Addresses) != len(otherLocation.Addresses) { + return false + } + + for addrName, address := range location.Addresses { + otherAddress, exists := otherLocation.Addresses[addrName] + if !exists { + return false + } + + if !address.ServiceRef.Equals(otherAddress.ServiceRef) { + return false + } + } + } + + return true +} + +// Equals compares two SyncDiffTrackerReturnType objects for equality +func (s *SyncDiffTrackerReturnType) Equals(other *SyncDiffTrackerReturnType) bool { + if s.SyncStatus != other.SyncStatus { + return false + } + + if !s.LoadBalancerUpdates.Additions.Equals(other.LoadBalancerUpdates.Additions) { + return false + } + + if !s.LoadBalancerUpdates.Removals.Equals(other.LoadBalancerUpdates.Removals) { + return false + } + + if !s.NATGatewayUpdates.Additions.Equals(other.NATGatewayUpdates.Additions) { + return false + } + + if !s.NATGatewayUpdates.Removals.Equals(other.NATGatewayUpdates.Removals) { + return false + } + + if !s.LocationData.Equals(&other.LocationData) { + return false + } + + return true +} + +// Equals compares two DiffTracker objects for equality +func (dt *DiffTracker) Equals(other *DiffTracker) bool { + // Lock both trackers in a consistent order (by pointer address) so that + // concurrent dt.Equals(other) and other.Equals(dt) calls can't deadlock. + // If both refer to the same object, lock only once to avoid self-deadlock. + first, second := dt, other + if reflect.ValueOf(first).Pointer() > reflect.ValueOf(second).Pointer() { + first, second = second, first + } + first.mu.Lock() + defer first.mu.Unlock() + if second != first { + second.mu.Lock() + defer second.mu.Unlock() + } + + if !dt.K8sResources.Services.Equals(other.K8sResources.Services) { + return false + } + + if !dt.K8sResources.Egresses.Equals(other.K8sResources.Egresses) { + return false + } + + if len(dt.K8sResources.Nodes) != len(other.K8sResources.Nodes) { + return false + } + + for nodeKey, node := range dt.K8sResources.Nodes { + otherNode, exists := other.K8sResources.Nodes[nodeKey] + if !exists { + return false + } + + if len(node.Pods) != len(otherNode.Pods) { + return false + } + + for podKey, pod := range node.Pods { + otherPod, exists := otherNode.Pods[podKey] + if !exists { + return false + } + + if !pod.InboundIdentities.Equals(otherPod.InboundIdentities) { + return false + } + + if !strings.EqualFold(pod.PublicOutboundIdentity, otherPod.PublicOutboundIdentity) { + return false + } + } + } + + // Compare NRP state + if !dt.NRPResources.LoadBalancers.Equals(other.NRPResources.LoadBalancers) { + return false + } + + if !dt.NRPResources.NATGateways.Equals(other.NRPResources.NATGateways) { + return false + } + + if len(dt.NRPResources.Locations) != len(other.NRPResources.Locations) { + return false + } + + for location, nrpLocation := range dt.NRPResources.Locations { + otherNrpLocation, exists := other.NRPResources.Locations[location] + if !exists { + return false + } + + if len(nrpLocation.Addresses) != len(otherNrpLocation.Addresses) { + return false + } + + for address, nrpAddress := range nrpLocation.Addresses { + otherNrpAddress, exists := otherNrpLocation.Addresses[address] + if !exists { + return false + } + + if !nrpAddress.Services.Equals(otherNrpAddress.Services) { + return false + } + } + } + + return true +} + +// addressExists reports whether an address with the given key exists in a location. +func addressExists(location NRPLocation, addressKey string) bool { + _, exists := location.Addresses[addressKey] + return exists +} + +// createServiceRefsFromAddress returns a copy of the address's service references. +func createServiceRefsFromAddress(addressValue Address) *utilsets.IgnoreCaseSet { + serviceRefs := utilsets.NewString() + for _, service := range addressValue.ServiceRef.UnsortedList() { + serviceRefs.Insert(service) + } + return serviceRefs +} diff --git a/pkg/provider/difftracker/util_test.go b/pkg/provider/difftracker/util_test.go new file mode 100644 index 0000000000..dbd7d18e99 --- /dev/null +++ b/pkg/provider/difftracker/util_test.go @@ -0,0 +1,967 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package difftracker + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" +) + +// TestOperationStringAndJSON tests Operation String() and MarshalJSON() +func TestOperationStringAndJSON(t *testing.T) { + tests := []struct { + name string + op Operation + expected string + }{ + {"Add operation", Add, "Add"}, + {"Remove operation", Remove, "Remove"}, + {"Update operation", Update, "Update"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() + assert.Equal(t, tt.expected, tt.op.String()) + + // Test MarshalJSON() + data, err := tt.op.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"`+tt.expected+`"`, string(data)) + }) + } +} + +// TestUpdateActionStringAndJSON tests UpdateAction String(), MarshalJSON(), and UnmarshalJSON() +func TestUpdateActionStringAndJSON(t *testing.T) { + tests := []struct { + name string + action UpdateAction + expected string + }{ + {"PartialUpdate", PartialUpdate, "PartialUpdate"}, + {"FullUpdate", FullUpdate, "FullUpdate"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() + assert.Equal(t, tt.expected, tt.action.String()) + + // Test MarshalJSON() + data, err := tt.action.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"`+tt.expected+`"`, string(data)) + + // Test UnmarshalJSON() round-trip + var unmarshaled UpdateAction + err = unmarshaled.UnmarshalJSON(data) + assert.NoError(t, err) + assert.Equal(t, tt.action, unmarshaled) + }) + } +} + +// TestEnumStringOutOfRange verifies String() does not panic for out-of-range enum +// values and returns a descriptive fallback instead. +func TestEnumStringOutOfRange(t *testing.T) { + assert.Equal(t, "Operation(99)", Operation(99).String()) + assert.Equal(t, "Operation(-1)", Operation(-1).String()) + assert.Equal(t, "UpdateAction(99)", UpdateAction(99).String()) + assert.Equal(t, "SyncStatus(99)", SyncStatus(99).String()) +} + +// TestUpdateActionUnmarshalJSON_InvalidValue tests error handling +func TestUpdateActionUnmarshalJSON_InvalidValue(t *testing.T) { + var action UpdateAction + err := action.UnmarshalJSON([]byte(`"InvalidAction"`)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown UpdateAction") +} + +// TestUpdateActionUnmarshalJSON_InvalidJSON tests JSON parsing errors +func TestUpdateActionUnmarshalJSON_InvalidJSON(t *testing.T) { + var action UpdateAction + err := action.UnmarshalJSON([]byte(`{invalid json`)) + assert.Error(t, err) +} + +// TestSyncStatusStringAndJSON tests SyncStatus String() and MarshalJSON() +func TestSyncStatusStringAndJSON(t *testing.T) { + tests := []struct { + name string + status SyncStatus + expected string + }{ + {"AlreadyInSync", AlreadyInSync, "AlreadyInSync"}, + {"Success", Success, "Success"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() + assert.Equal(t, tt.expected, tt.status.String()) + + // Test MarshalJSON() + data, err := tt.status.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"`+tt.expected+`"`, string(data)) + }) + } +} + +// TestNodeHasPods tests Node.HasPods() +func TestNodeHasPods(t *testing.T) { + tests := []struct { + name string + node Node + expected bool + }{ + { + name: "node with pods", + node: Node{Pods: map[string]Pod{"pod1": {}}}, + expected: true, + }, + { + name: "node without pods", + node: Node{Pods: map[string]Pod{}}, + expected: false, + }, + { + name: "node with nil pods map", + node: Node{Pods: nil}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.node.HasPods()) + }) + } +} + +// TestPodHasIdentities tests Pod.HasIdentities() +func TestPodHasIdentities(t *testing.T) { + tests := []struct { + name string + pod Pod + expected bool + }{ + { + name: "pod with inbound identities", + pod: Pod{InboundIdentities: sets.NewString("id1", "id2")}, + expected: true, + }, + { + name: "pod with public outbound identity", + pod: Pod{PublicOutboundIdentity: "outbound-id"}, + expected: true, + }, + { + name: "pod with both identities", + pod: Pod{InboundIdentities: sets.NewString("id1"), PublicOutboundIdentity: "outbound-id"}, + expected: true, + }, + { + name: "pod without identities", + pod: Pod{InboundIdentities: sets.NewString(), PublicOutboundIdentity: ""}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.pod.HasIdentities()) + }) + } +} + +// TestDeepEqual tests DiffTracker.deepEqualLocked() +func TestDeepEqual(t *testing.T) { + tests := []struct { + name string + dt *DiffTracker + expected bool + }{ + { + name: "in sync - matching services and load balancers", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1", "svc2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1", "svc2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "not in sync - service count mismatch", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1", "svc2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + { + name: "not in sync - service name mismatch", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc2"), + NATGateways: sets.NewString(), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + { + name: "in sync - matching egresses and NAT gateways", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString("egress1", "egress2"), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString("egress1", "egress2"), + Locations: map[string]NRPLocation{}, + }, + }, + expected: true, + }, + { + name: "not in sync - egress count mismatch", + dt: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString(), + NATGateways: sets.NewString("egress1", "egress2"), + Locations: map[string]NRPLocation{}, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.dt.deepEqualLocked()) + }) + } +} + +// TestLocationDataEquals tests LocationData.Equals() +func TestLocationDataEquals(t *testing.T) { + tests := []struct { + name string + ld1 LocationData + ld2 LocationData + expected bool + }{ + { + name: "equal location data - empty locations", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{}, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{}, + }, + expected: true, + }, + { + name: "equal location data - with addresses", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{ + "addr1": {ServiceRef: sets.NewString("svc1")}, + }, + }, + }, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{ + "addr1": {ServiceRef: sets.NewString("svc1")}, + }, + }, + }, + }, + expected: true, + }, + { + name: "different action", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{}, + }, + ld2: LocationData{ + Action: FullUpdate, + Locations: map[string]Location{}, + }, + expected: false, + }, + { + name: "different location count", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + "loc2": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + }, + expected: false, + }, + { + name: "different address update action", + ld1: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + }, + ld2: LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "loc1": {AddressUpdateAction: FullUpdate, Addresses: map[string]Address{}}, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.ld1.Equals(&tt.ld2)) + }) + } +} + +// TestDiffTrackerEquals tests DiffTracker.Equals() +func TestDiffTrackerEquals(t *testing.T) { + tests := []struct { + name string + dt1 *DiffTracker + dt2 *DiffTracker + expected bool + }{ + { + name: "equal diff trackers", + dt1: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + }, + expected: true, + }, + { + name: "different services", + dt1: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc2"), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + }, + expected: false, + }, + { + name: "different egresses", + dt1: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString("egress1"), + Nodes: map[string]Node{}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString("egress2"), + Nodes: map[string]Node{}, + }, + }, + expected: false, + }, + { + name: "different node count", + dt1: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{"node1": {}}, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{}, + }, + }, + expected: false, + }, + { + name: "different pod count in node", + dt1: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{"pod1": {}}}, + }, + }, + }, + dt2: &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString(), + Egresses: sets.NewString(), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{}}, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.dt1.Equals(tt.dt2)) + }) + } +} + +// TestSyncServicesReturnTypeEquals tests SyncServicesReturnType.Equals() +func TestSyncServicesReturnTypeEquals(t *testing.T) { + tests := []struct { + name string + s1 SyncServicesReturnType + s2 SyncServicesReturnType + expected bool + }{ + { + name: "equal - both empty", + s1: SyncServicesReturnType{ + Additions: sets.NewString(), + Removals: sets.NewString(), + }, + s2: SyncServicesReturnType{ + Additions: sets.NewString(), + Removals: sets.NewString(), + }, + expected: true, + }, + { + name: "equal - same additions", + s1: SyncServicesReturnType{ + Additions: sets.NewString("svc1", "svc2"), + Removals: sets.NewString(), + }, + s2: SyncServicesReturnType{ + Additions: sets.NewString("svc1", "svc2"), + Removals: sets.NewString(), + }, + expected: true, + }, + { + name: "not equal - different additions", + s1: SyncServicesReturnType{ + Additions: sets.NewString("svc1"), + Removals: sets.NewString(), + }, + s2: SyncServicesReturnType{ + Additions: sets.NewString("svc2"), + Removals: sets.NewString(), + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.s1.Equals(&tt.s2)) + }) + } +} + +// TestConfigValidate tests Config.Validate() +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config Config + shouldError bool + errorMsg string + }{ + { + name: "valid config", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + Location: "eastus", + VNetName: "test-vnet", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/subscriptions/sub1/resourceGroups/rg1/providers/Microsoft.Network/serviceGateways/sgw", + }, + shouldError: false, + }, + { + name: "missing subscription ID", + config: Config{ + ResourceGroup: "rg1", + Location: "eastus", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "SubscriptionID is required", + }, + { + name: "missing resource group", + config: Config{ + SubscriptionID: "sub1", + Location: "eastus", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "ResourceGroup is required", + }, + { + name: "missing location", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + ServiceGatewayResourceName: "sgw", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "Location is required", + }, + { + name: "missing ServiceGatewayResourceName", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + Location: "eastus", + ServiceGatewayID: "/id", + }, + shouldError: true, + errorMsg: "ServiceGatewayResourceName is required", + }, + { + name: "missing ServiceGatewayID", + config: Config{ + SubscriptionID: "sub1", + ResourceGroup: "rg1", + Location: "eastus", + ServiceGatewayResourceName: "sgw", + }, + shouldError: true, + errorMsg: "ServiceGatewayID is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.shouldError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestJSONRoundTrip tests JSON marshaling/unmarshaling for various types +func TestJSONRoundTrip(t *testing.T) { + t.Run("UpdateAction round trip", func(t *testing.T) { + original := PartialUpdate + data, err := json.Marshal(original) + assert.NoError(t, err) + + var unmarshaled UpdateAction + err = json.Unmarshal(data, &unmarshaled) + assert.NoError(t, err) + assert.Equal(t, original, unmarshaled) + }) + + t.Run("Operation round trip", func(t *testing.T) { + original := Add + data, err := json.Marshal(original) + assert.NoError(t, err) + assert.Equal(t, `"Add"`, string(data)) + }) + + t.Run("SyncStatus round trip", func(t *testing.T) { + original := Success + data, err := json.Marshal(original) + assert.NoError(t, err) + assert.Equal(t, `"Success"`, string(data)) + }) +} + +// TestLocationDataEqualsMoreCases covers additional LocationData.Equals branches. +func TestLocationDataEqualsMoreCases(t *testing.T) { + base := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, + }, + }, + } + equal := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": { + AddressUpdateAction: PartialUpdate, + Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc1")}}, + }, + }, + } + assert.True(t, base.Equals(&equal)) + + // Different top-level Action. + diffAction := equal + diffAction.Action = FullUpdate + assert.False(t, base.Equals(&diffAction)) + + // Different number of locations. + diffLen := LocationData{Action: PartialUpdate, Locations: map[string]Location{}} + assert.False(t, base.Equals(&diffLen)) + + // Missing location name. + diffName := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{"node2": base.Locations["node1"]}, + } + assert.False(t, base.Equals(&diffName)) + + // Different AddressUpdateAction. + diffAUA := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: FullUpdate, Addresses: base.Locations["node1"].Addresses}, + }, + } + assert.False(t, base.Equals(&diffAUA)) + + // Different addresses length. + diffAddrLen := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{}}, + }, + } + assert.False(t, base.Equals(&diffAddrLen)) + + // Missing address name. + diffAddrName := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.2": {ServiceRef: sets.NewString("svc1")}}}, + }, + } + assert.False(t, base.Equals(&diffAddrName)) + + // Different ServiceRef. + diffRef := LocationData{ + Action: PartialUpdate, + Locations: map[string]Location{ + "node1": {AddressUpdateAction: PartialUpdate, Addresses: map[string]Address{"10.0.0.1": {ServiceRef: sets.NewString("svc2")}}}, + }, + } + assert.False(t, base.Equals(&diffRef)) +} + +// TestSyncDiffTrackerReturnTypeEquals covers SyncDiffTrackerReturnType.Equals branches. +func TestSyncDiffTrackerReturnTypeEquals(t *testing.T) { + mk := func() SyncDiffTrackerReturnType { + return SyncDiffTrackerReturnType{ + SyncStatus: Success, + LoadBalancerUpdates: SyncServicesReturnType{Additions: sets.NewString("a"), Removals: sets.NewString("b")}, + NATGatewayUpdates: SyncServicesReturnType{Additions: sets.NewString("c"), Removals: sets.NewString("d")}, + LocationData: LocationData{Action: PartialUpdate, Locations: map[string]Location{}}, + } + } + + base := mk() + equal := mk() + assert.True(t, base.Equals(&equal)) + + // Different SyncStatus. + diffStatus := mk() + diffStatus.SyncStatus = AlreadyInSync + assert.False(t, base.Equals(&diffStatus)) + + // Different LB additions. + diffLBAdd := mk() + diffLBAdd.LoadBalancerUpdates.Additions = sets.NewString("x") + assert.False(t, base.Equals(&diffLBAdd)) + + // Different LB removals. + diffLBRem := mk() + diffLBRem.LoadBalancerUpdates.Removals = sets.NewString("x") + assert.False(t, base.Equals(&diffLBRem)) + + // Different NATGW additions. + diffNGAdd := mk() + diffNGAdd.NATGatewayUpdates.Additions = sets.NewString("x") + assert.False(t, base.Equals(&diffNGAdd)) + + // Different NATGW removals. + diffNGRem := mk() + diffNGRem.NATGatewayUpdates.Removals = sets.NewString("x") + assert.False(t, base.Equals(&diffNGRem)) + + // Different LocationData. + diffLoc := mk() + diffLoc.LocationData.Action = FullUpdate + assert.False(t, base.Equals(&diffLoc)) +} + +// TestDiffTrackerEqualsMoreCases covers additional DiffTracker.Equals branches. +func TestDiffTrackerEqualsMoreCases(t *testing.T) { + mk := func() *DiffTracker { + return &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egr1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1"), + NATGateways: sets.NewString("egr1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }}, + }, + }, + } + } + + assert.True(t, mk().Equals(mk())) + + // Different Services. + a := mk() + b := mk() + b.K8sResources.Services = sets.NewString("other") + assert.False(t, a.Equals(b)) + + // Different Egresses. + b = mk() + b.K8sResources.Egresses = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different node count. + b = mk() + b.K8sResources.Nodes["node2"] = Node{Pods: map[string]Pod{}} + assert.False(t, mk().Equals(b)) + + // Missing node name. + b = mk() + delete(b.K8sResources.Nodes, "node1") + b.K8sResources.Nodes["nodeX"] = Node{Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }} + assert.False(t, mk().Equals(b)) + + // Different pod count. + b = mk() + n := b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.2"] = Pod{InboundIdentities: sets.NewString()} + assert.False(t, mk().Equals(b)) + + // Missing pod address. + b = mk() + n = b.K8sResources.Nodes["node1"] + delete(n.Pods, "10.0.0.1") + n.Pods["10.0.0.9"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"} + assert.False(t, mk().Equals(b)) + + // Different InboundIdentities. + b = mk() + n = b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("other"), PublicOutboundIdentity: "egr1"} + assert.False(t, mk().Equals(b)) + + // Different PublicOutboundIdentity. + b = mk() + n = b.K8sResources.Nodes["node1"] + n.Pods["10.0.0.1"] = Pod{InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "other"} + assert.False(t, mk().Equals(b)) + + // Different NRP LoadBalancers. + b = mk() + b.NRPResources.LoadBalancers = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different NRP NATGateways. + b = mk() + b.NRPResources.NATGateways = sets.NewString("other") + assert.False(t, mk().Equals(b)) + + // Different NRP location count. + b = mk() + b.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} + assert.False(t, mk().Equals(b)) +} + +// TestDeepEqualMoreCases covers the node/pod/identity mismatch branches of DeepEqual. +func TestDeepEqualMoreCases(t *testing.T) { + // Helper to build a DiffTracker that is in-sync by construction. + inSync := func() *DiffTracker { + return &DiffTracker{ + K8sResources: K8sState{ + Services: sets.NewString("svc1"), + Egresses: sets.NewString("egr1"), + Nodes: map[string]Node{ + "node1": {Pods: map[string]Pod{ + "10.0.0.1": {InboundIdentities: sets.NewString("svc1"), PublicOutboundIdentity: "egr1"}, + }}, + }, + }, + NRPResources: NRPState{ + LoadBalancers: sets.NewString("svc1"), + NATGateways: sets.NewString("egr1"), + Locations: map[string]NRPLocation{ + "node1": {Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }}, + }, + }, + } + } + + assert.True(t, inSync().deepEqualLocked()) + + // LoadBalancer present in NRP but not in K8s Services (reverse-direction check). + d := inSync() + d.NRPResources.LoadBalancers = sets.NewString("svc1", "extra") + d.K8sResources.Services = sets.NewString("svc1", "different") + assert.False(t, d.deepEqualLocked()) + + // Egress name mismatch (reverse direction): lengths equal (1==1) but names + // differ -> mismatch. + d = inSync() + d.K8sResources.Egresses = sets.NewString("egr2") + d.NRPResources.NATGateways = sets.NewString("egr2x") + assert.False(t, d.deepEqualLocked()) + + // Nodes vs Locations length mismatch. + d = inSync() + d.NRPResources.Locations["node2"] = NRPLocation{Addresses: map[string]NRPAddress{}} + assert.False(t, d.deepEqualLocked()) + + // Node missing in Locations (same count, different key). + d = inSync() + delete(d.NRPResources.Locations, "node1") + d.NRPResources.Locations["nodeX"] = NRPLocation{Addresses: map[string]NRPAddress{ + "10.0.0.1": {Services: sets.NewString("svc1", "egr1")}, + }} + assert.False(t, d.deepEqualLocked()) + + // Pods vs Addresses length mismatch. + d = inSync() + loc := d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.2"] = NRPAddress{Services: sets.NewString("svc1")} + assert.False(t, d.deepEqualLocked()) + + // Pod missing in Addresses (same count, different key). + d = inSync() + loc = d.NRPResources.Locations["node1"] + delete(loc.Addresses, "10.0.0.1") + loc.Addresses["10.0.0.9"] = NRPAddress{Services: sets.NewString("svc1", "egr1")} + assert.False(t, d.deepEqualLocked()) + + // Combined identities length mismatch. + d = inSync() + loc = d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1")} + assert.False(t, d.deepEqualLocked()) + + // Identity not found in Services (same count, different identity). + d = inSync() + loc = d.NRPResources.Locations["node1"] + loc.Addresses["10.0.0.1"] = NRPAddress{Services: sets.NewString("svc1", "egrX")} + assert.False(t, d.deepEqualLocked()) +} diff --git a/pkg/util/sets/string.go b/pkg/util/sets/string.go index 2562fcf6ea..83f881ca85 100644 --- a/pkg/util/sets/string.go +++ b/pkg/util/sets/string.go @@ -17,6 +17,7 @@ limitations under the License. package sets import ( + "encoding/json" "strings" "k8s.io/apimachinery/pkg/util/sets" @@ -37,14 +38,13 @@ func NewString(items ...string) *IgnoreCaseSet { return &IgnoreCaseSet{set: set} } -// Insert adds the given items to the set. It only works if the set is initialized. +// Insert adds the given items to the set, initializing the underlying set if needed. func (s *IgnoreCaseSet) Insert(items ...string) { - var lowerItems []string - for _, item := range items { - lowerItems = append(lowerItems, strings.ToLower(item)) + if s.set == nil { + s.set = sets.New[string]() } - for _, item := range lowerItems { - s.set.Insert(item) + for _, item := range items { + s.set.Insert(strings.ToLower(item)) } } @@ -98,3 +98,42 @@ func (s *IgnoreCaseSet) Len() int { } return s.set.Len() } + +// MarshalJSON marshals the set to JSON as an array of strings. +func (s *IgnoreCaseSet) MarshalJSON() ([]byte, error) { + if s == nil { + return []byte("null"), nil + } + if s.Len() == 0 { + return []byte("[]"), nil + } + return json.Marshal(s.UnsortedList()) +} + +// Equals returns true if the two sets are equal. +func (s *IgnoreCaseSet) Equals(other *IgnoreCaseSet) bool { + // Two sets of equal size are equal iff every item in one is contained in the + // other, so a single containment check in one direction is sufficient. + if s.Len() != other.Len() { + return false + } + for _, item := range s.UnsortedList() { + if !other.Has(item) { + return false + } + } + return true +} + +// Difference returns a new IgnoreCaseSet containing the items in s that are not +// present in other (i.e. the set difference s \ other). It is safe to call on +// nil or uninitialized sets. +func (s *IgnoreCaseSet) Difference(other *IgnoreCaseSet) *IgnoreCaseSet { + result := NewString() + for _, item := range s.UnsortedList() { + if !other.Has(item) { + result.Insert(item) + } + } + return result +} diff --git a/pkg/util/sets/string_test.go b/pkg/util/sets/string_test.go index b0286b7d3d..79d640d0a2 100644 --- a/pkg/util/sets/string_test.go +++ b/pkg/util/sets/string_test.go @@ -367,3 +367,147 @@ func TestLen(t *testing.T) { }) } } +func TestEquals(t *testing.T) { + tests := []struct { + name string + s1 *IgnoreCaseSet + s2 *IgnoreCaseSet + want bool + }{ + { + name: "both nil", + s1: nil, + s2: nil, + want: true, + }, + { + name: "first nil", + s1: nil, + s2: NewString("foo"), + want: false, + }, + { + name: "second nil", + s1: NewString("foo"), + s2: nil, + want: false, + }, + { + name: "empty sets", + s1: NewString(), + s2: NewString(), + want: true, + }, + { + name: "same elements", + s1: NewString("foo", "bar"), + s2: NewString("foo", "bar"), + want: true, + }, + { + name: "same elements with different case", + s1: NewString("foo", "bar"), + s2: NewString("FOO", "BAR"), + want: true, + }, + { + name: "different sizes", + s1: NewString("foo", "bar"), + s2: NewString("foo"), + want: false, + }, + { + name: "same size but different elements", + s1: NewString("foo", "bar"), + s2: NewString("foo", "baz"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Equals is nil-safe on both the receiver and the argument, so call it + // directly for every case (including the nil ones) to exercise the + // production code path rather than asserting an expectation against itself. + if got := tt.s1.Equals(tt.s2); got != tt.want { + t.Errorf("Equals() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDifference(t *testing.T) { + tests := []struct { + name string + s *IgnoreCaseSet + other *IgnoreCaseSet + want *IgnoreCaseSet + }{ + { + name: "disjoint sets", + s: NewString("foo", "bar"), + other: NewString("baz"), + want: NewString("foo", "bar"), + }, + { + name: "partial overlap", + s: NewString("foo", "bar", "baz"), + other: NewString("bar"), + want: NewString("foo", "baz"), + }, + { + name: "all removed", + s: NewString("foo", "bar"), + other: NewString("foo", "bar"), + want: NewString(), + }, + { + name: "case-insensitive", + s: NewString("Foo", "BAR"), + other: NewString("foo"), + want: NewString("bar"), + }, + { + name: "other empty", + s: NewString("foo", "bar"), + other: NewString(), + want: NewString("foo", "bar"), + }, + { + name: "s empty", + s: NewString(), + other: NewString("foo"), + want: NewString(), + }, + { + name: "other nil", + s: NewString("foo"), + other: nil, + want: NewString("foo"), + }, + { + name: "s nil", + s: nil, + other: NewString("foo"), + want: NewString(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.s.Difference(tt.other) + if !got.Equals(tt.want) { + t.Errorf("Difference() = %v, want %v", got.UnsortedList(), tt.want.UnsortedList()) + } + }) + } +} + +func TestInsertOnUninitializedSet(t *testing.T) { + s := &IgnoreCaseSet{} + s.Insert("Foo", "BAR") + if !s.Has("foo") || !s.Has("bar") { + t.Errorf("Insert on uninitialized set should lazily initialize and add items, got %v", s.UnsortedList()) + } + if s.Len() != 2 { + t.Errorf("expected len 2, got %d", s.Len()) + } +}