From e7651b34089002fa83ada00ce7ad75e7bf4532a0 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Wed, 16 Apr 2025 16:57:05 +0100 Subject: [PATCH 01/24] Add Pod Sharing --- Makefile | 8 + ...si.aws.com_mountpoints3podattachments.yaml | 101 +++ .../serviceaccount-csi-controller.yaml | 5 +- .../templates/serviceaccount-csi-node.yaml | 5 +- .../csicontroller/expectations.go | 71 ++ .../csicontroller/reconciler.go | 349 ++++++-- cmd/aws-s3-csi-controller/main.go | 49 +- hack/boilerplate.go.txt | 0 pkg/api/v1/groupversion_info.go | 20 + pkg/api/v1/mountpoints3podattachment_types.go | 79 ++ pkg/api/v1/zz_generated.deepcopy.go | 98 +++ pkg/driver/driver.go | 84 +- .../node/credentialprovider/provider.go | 14 +- .../node/credentialprovider/provider_pod.go | 5 + pkg/driver/node/mounter/fake_cache.go | 43 + pkg/driver/node/mounter/fake_mounter.go | 2 +- pkg/driver/node/mounter/mocks/mock_mount.go | 8 +- pkg/driver/node/mounter/mounter.go | 70 +- pkg/driver/node/mounter/mppod_lock.go | 55 ++ pkg/driver/node/mounter/pod_mounter.go | 317 ++++--- pkg/driver/node/mounter/pod_mounter_darwin.go | 4 + pkg/driver/node/mounter/pod_mounter_linux.go | 8 + pkg/driver/node/mounter/pod_mounter_test.go | 69 +- pkg/driver/node/mounter/pod_unmounter.go | 197 +++++ pkg/driver/node/mounter/systemd_mounter.go | 2 +- .../node/mounter/systemd_mounter_test.go | 5 +- pkg/driver/node/node.go | 16 +- pkg/driver/node/node_test.go | 75 +- pkg/podmounter/mppod/creator.go | 22 +- pkg/podmounter/mppod/creator_test.go | 26 +- pkg/podmounter/mppod/mppod.go | 17 - pkg/podmounter/mppod/mppod_test.go | 56 -- pkg/podmounter/mppod/watcher/watcher.go | 24 +- pkg/podmounter/mppod/watcher/watcher_test.go | 2 +- tests/controller/controller_test.go | 791 ++++++++++++++---- tests/controller/suite_test.go | 52 +- tests/e2e-kubernetes/e2e_test.go | 1 + tests/e2e-kubernetes/go.mod | 1 + tests/e2e-kubernetes/go.sum | 2 + .../e2e-kubernetes/testsuites/pod_sharing.go | 304 +++++++ 40 files changed, 2550 insertions(+), 507 deletions(-) create mode 100644 charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml create mode 100644 cmd/aws-s3-csi-controller/csicontroller/expectations.go create mode 100644 hack/boilerplate.go.txt create mode 100644 pkg/api/v1/groupversion_info.go create mode 100644 pkg/api/v1/mountpoints3podattachment_types.go create mode 100644 pkg/api/v1/zz_generated.deepcopy.go create mode 100644 pkg/driver/node/mounter/fake_cache.go create mode 100644 pkg/driver/node/mounter/mppod_lock.go create mode 100644 pkg/driver/node/mounter/pod_unmounter.go delete mode 100644 pkg/podmounter/mppod/mppod.go delete mode 100644 pkg/podmounter/mppod/mppod_test.go create mode 100644 tests/e2e-kubernetes/testsuites/pod_sharing.go diff --git a/Makefile b/Makefile index 255129a1..5ffcb209 100644 --- a/Makefile +++ b/Makefile @@ -174,6 +174,14 @@ check_style: clean: rm -rf bin/ && docker system prune +# Generate files for Custom Resources (`zz_generated.deepcopy.go` and CustomResourceDefinition YAML file). +# TODO: Wrap CRD YAML file with experimental.podMounter=true Helm flag +# POD_ATTACHMENT_CRD_FILE ?= "./charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml" +.PHONY: generate +generate: + controller-gen object:headerFile="hack/boilerplate.go.txt" paths="./pkg/api/..." + controller-gen crd paths="./pkg/api/..." output:crd:dir=./charts/aws-mountpoint-s3-csi-driver/templates + ## Binaries used in tests. TESTBIN ?= $(shell pwd)/tests/bin diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml new file mode 100644 index 00000000..16621492 --- /dev/null +++ b/charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml @@ -0,0 +1,101 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.17.3 + name: mountpoints3podattachments.s3.csi.aws.com +spec: + group: s3.csi.aws.com + names: + kind: MountpointS3PodAttachment + listKind: MountpointS3PodAttachmentList + plural: mountpoints3podattachments + shortNames: + - s3pa + singular: mountpoints3podattachment + scope: Cluster + versions: + - name: v1 + schema: + openAPIV3Schema: + description: MountpointS3PodAttachment is the Schema for the mountpoints3podattachments + API. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: MountpointS3PodAttachmentSpec defines the desired state of + MountpointS3PodAttachment. + properties: + authenticationSource: + description: Authentication source taken from volume attribute field + `authenticationSource`. + type: string + mountOptions: + description: Comma separated mount options taken from volume. + type: string + mountpointS3PodToWorkloadPodUIDs: + additionalProperties: + items: + type: string + type: array + description: Maps each Mountpoint S3 pod name to the list of workload + pod UIDs it is attached to. + type: object + nodeName: + description: Name of the node. + type: string + persistentVolumeName: + description: Name of the Persistent Volume. + type: string + volumeID: + description: Volume ID. + type: string + workloadFSGroup: + description: Workload pod's `fsGroup` from pod security context + type: string + workloadNamespace: + description: 'Workload pod''s namespace. Exists only if `authenticationSource: + pod`.' + type: string + workloadServiceAccountIAMRoleARN: + description: 'EKS IAM Role ARN from workload pod''s service account + annotation (IRSA). Exists only if `authenticationSource: pod` and + service account has `eks.amazonaws.com/role-arn` annotation.' + type: string + workloadServiceAccountName: + description: 'Workload pod''s service account name. Exists only if + `authenticationSource: pod`.' + type: string + required: + - authenticationSource + - mountOptions + - mountpointS3PodToWorkloadPodUIDs + - nodeName + - persistentVolumeName + - volumeID + - workloadFSGroup + type: object + type: object + selectableFields: + - jsonPath: .spec.nodeName + served: true + storage: true + subresources: + status: {} diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml index c5469692..529ee41e 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml @@ -49,8 +49,11 @@ metadata: {{- include "aws-mountpoint-s3-csi-driver.labels" . | nindent 4 }} rules: - apiGroups: [""] - resources: ["pods", "persistentvolumeclaims", "persistentvolumes"] + resources: ["pods", "persistentvolumeclaims", "persistentvolumes", "serviceaccounts"] verbs: ["get", "watch", "list"] + - apiGroups: ["s3.csi.aws.com"] + resources: ["mountpoints3podattachments"] + verbs: ["create", "delete", "update", "get", "watch", "list"] --- kind: ClusterRoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml index 9e6baa8e..96606310 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml @@ -23,8 +23,11 @@ metadata: app.kubernetes.io/name: aws-mountpoint-s3-csi-driver rules: - apiGroups: [""] - resources: ["serviceaccounts"] + resources: ["serviceaccounts"] # TODO: Remove once we stop supporting systemd mounts. verbs: ["get"] + - apiGroups: ["s3.csi.aws.com"] + resources: ["mountpoints3podattachments"] + verbs: ["get", "list", "watch"] --- kind: ClusterRoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/cmd/aws-s3-csi-controller/csicontroller/expectations.go b/cmd/aws-s3-csi-controller/csicontroller/expectations.go new file mode 100644 index 00000000..1fbb8967 --- /dev/null +++ b/cmd/aws-s3-csi-controller/csicontroller/expectations.go @@ -0,0 +1,71 @@ +package csicontroller + +import ( + "sort" + "strings" + "sync" + + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// Expectations is a structure that manages pending expectations for Kubernetes resources. +// It uses field filters as keys to track resources that are expected to be created +// helping to reduce unnecessary processing and API server load. +type Expectations struct { + pending sync.Map +} + +// NewExpectations creates and returns a new Expectations instance. +func NewExpectations() *Expectations { + return &Expectations{} +} + +// SetPending marks a resource as pending based on the given field filters. +// This is typically used when a create operation is initiated. +func (e *Expectations) SetPending(fieldFilters client.MatchingFields) { + key := deriveExpectationKeyFromFilters(fieldFilters) + e.pending.Store(key, struct{}{}) +} + +// IsPending checks if a resource is marked as pending based on the given field filters. +// Returns true if the resource is pending, false otherwise. +func (e *Expectations) IsPending(fieldFilters client.MatchingFields) bool { + key := deriveExpectationKeyFromFilters(fieldFilters) + _, ok := e.pending.Load(key) + return ok +} + +// Clear removes the pending mark for a resource based on the given field filters. +// This is typically called when an expected operation has been confirmed as completed. +func (e *Expectations) Clear(fieldFilters client.MatchingFields) { + key := deriveExpectationKeyFromFilters(fieldFilters) + e.pending.Delete(key) +} + +// deriveExpectationKeyFromFilters generates a deterministic string key from a map of field filters. +// It creates a consistent string representation of the filters by: +// 1. Sorting the filter keys alphabetically +// 2. Concatenating each key-value pair in the format "key=value;" +// +// For example, given filters {"foo": "bar", "baz": "qux"}, it will produce "baz=qux;foo=bar;" +// +// Parameters: +// - fieldFilters: A map of field names to their values used for filtering Kubernetes resources +// +// Returns: +// - A string that uniquely represents the combination of field filters +func deriveExpectationKeyFromFilters(fieldFilters client.MatchingFields) string { + keys := make([]string, 0, len(fieldFilters)) + for k := range fieldFilters { + keys = append(keys, k) + } + sort.Strings(keys) + var sb strings.Builder + for _, k := range keys { + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(fieldFilters[k]) + sb.WriteString(";") + } + return sb.String() +} diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index 8059e54e..7483abe1 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -4,26 +4,39 @@ import ( "context" "errors" "fmt" + "strconv" + "strings" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/go-logr/logr" ) const debugLevel = 4 const mountpointCSIDriverName = "s3.csi.aws.com" +const defaultServiceAccount = "default" + +const ( + AnnotationServiceAccountRole = "eks.amazonaws.com/role-arn" + LabelCSIDriverVersion = "s3.csi.aws.com/created-by-csi-driver-version" +) // A Reconciler reconciles Mountpoint Pods by watching other workload Pods thats using S3 CSI Driver. type Reconciler struct { mountpointPodConfig mppod.Config mountpointPodCreator *mppod.Creator + s3paExpectations *Expectations client.Client } @@ -31,7 +44,7 @@ type Reconciler struct { // NewReconciler returns a new reconciler created from `client` and `podConfig`. func NewReconciler(client client.Client, podConfig mppod.Config) *Reconciler { creator := mppod.NewCreator(podConfig) - return &Reconciler{Client: client, mountpointPodConfig: podConfig, mountpointPodCreator: creator} + return &Reconciler{Client: client, mountpointPodConfig: podConfig, mountpointPodCreator: creator, s3paExpectations: NewExpectations()} } // SetupWithManager configures reconciler to run with given `mgr`. @@ -139,7 +152,8 @@ func (r *Reconciler) reconcileWorkloadPod(ctx context.Context, pod *corev1.Pod) log.V(debugLevel).Info("Found bound PV for PVC", "pvc", pvc.Name, "volumeName", pv.Name) - err = r.spawnOrDeleteMountpointPodIfNeeded(ctx, pod, pvc, pv, csiSpec) + needsRequeue, err := r.spawnOrDeleteMountpointPodIfNeeded(ctx, pod, pvc, pv, csiSpec) + requeue = requeue || needsRequeue if err != nil { errs = append(errs, err) continue @@ -163,56 +177,269 @@ func (r *Reconciler) spawnOrDeleteMountpointPodIfNeeded( pvc *corev1.PersistentVolumeClaim, pv *corev1.PersistentVolume, csiSpec *corev1.CSIPersistentVolumeSource, -) error { - mpPodName := mppod.MountpointPodNameFor(string(workloadPod.UID), pvc.Spec.VolumeName) +) (bool, error) { + workloadUID := string(workloadPod.UID) + roleArn, err := r.findIRSAServiceAccountRole(ctx, workloadPod) + if err != nil { + return false, err + } + fieldFilters := r.buildFieldFilters(workloadPod, pv, roleArn) + log := r.setupLogger(ctx, workloadPod, pvc, pv, workloadUID, fieldFilters) + + s3paList, err := r.getExistingS3PodAttachments(ctx, fieldFilters) + if err != nil { + return false, err + } + + if !isPodActive(workloadPod) { + return r.handleInactivePod(ctx, workloadPod, s3paList, workloadUID, log) + } + + if len(s3paList.Items) == 1 { + return r.handleExistingS3PodAttachment(ctx, s3paList, workloadUID, fieldFilters, log) + } + + return r.handleNewS3PodAttachment(ctx, workloadPod, pv, fieldFilters, log) +} - log := logf.FromContext(ctx).WithValues( +func (r *Reconciler) setupLogger(ctx context.Context, workloadPod *corev1.Pod, pvc *corev1.PersistentVolumeClaim, pv *corev1.PersistentVolume, workloadUID string, fieldFilters client.MatchingFields) logr.Logger { + logger := logf.FromContext(ctx).WithValues( "workloadPod", types.NamespacedName{Namespace: workloadPod.Namespace, Name: workloadPod.Name}, - "mountpointPod", mpPodName, - "pvc", pvc.Name, "volumeName", pv.Name) + "pvc", pvc.Name, + "workloadUID", workloadUID, + ) - mpPod := &corev1.Pod{} - err := r.Get(ctx, types.NamespacedName{Namespace: r.mountpointPodConfig.Namespace, Name: mpPodName}, mpPod) - if err != nil && !apierrors.IsNotFound(err) { - log.Error(err, "Failed to get Mountpoint Pod") - return err + var keyValues []interface{} + for k, v := range fieldFilters { + keyValues = append(keyValues, k, v) } - isMountpointPodExists := err == nil + if len(keyValues) > 0 { + logger = logger.WithValues(keyValues...) + } - // `workloadPod` is not active, its either terminated (i.e., `phase == Succeeded or phase == Failed`) or - // its scheduled for termination (i.e., `DeletionTimestamp != nil`) - if !isPodActive(workloadPod) { - // if its scheduled for termination and its still in `Pending` phase, - // delete if there is an existing Mountpoint Pod as otherwise this - // Mountpoint Pod might take some time to terminate on its own. - if isMountpointPodExists && workloadPod.Status.Phase == corev1.PodPending { - log.Info("Deleting scheduled Mountpoint Pod") - err := r.deleteMountpointPod(ctx, mpPod) - if err != nil { - log.Error(err, "Failed to delete scheduled Mountpoint Pod") - return err + return logger +} + +func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.PersistentVolume, roleArn string) client.MatchingFields { + authSource := r.getAuthSource(pv) + fsGroup := r.getFSGroup(workloadPod) + + fieldFilters := client.MatchingFields{ + crdv1.FieldNodeName: workloadPod.Spec.NodeName, + crdv1.FieldPersistentVolumeName: pv.Name, + crdv1.FieldVolumeID: pv.Spec.CSI.VolumeHandle, + crdv1.FieldMountOptions: strings.Join(pv.Spec.MountOptions, ","), + crdv1.FieldWorkloadFSGroup: fsGroup, + crdv1.FieldAuthenticationSource: authSource, + } + + if authSource == "pod" { + fieldFilters[crdv1.FieldWorkloadNamespace] = workloadPod.Namespace + fieldFilters[crdv1.FieldWorkloadServiceAccountName] = getServiceAccountName(workloadPod) + fieldFilters[crdv1.FieldWorkloadServiceAccountIAMRoleARN] = roleArn + } + + return fieldFilters +} + +func (r *Reconciler) getAuthSource(pv *corev1.PersistentVolume) string { + volumeAttributes := mppod.ExtractVolumeAttributes(pv) + authSource := volumeAttributes[volumecontext.AuthenticationSource] + if authSource == "" { + return "driver" + } + return authSource +} + +func (r *Reconciler) getFSGroup(workloadPod *corev1.Pod) string { + if workloadPod.Spec.SecurityContext.FSGroup != nil { + return strconv.FormatInt(*workloadPod.Spec.SecurityContext.FSGroup, 10) + } + return "" +} + +func (r *Reconciler) getExistingS3PodAttachments(ctx context.Context, fieldFilters client.MatchingFields) (*crdv1.MountpointS3PodAttachmentList, error) { + s3paList := &crdv1.MountpointS3PodAttachmentList{} + if err := r.List(ctx, s3paList, fieldFilters); err != nil { + return nil, err + } + + if len(s3paList.Items) > 1 { + return nil, fmt.Errorf("found %d MountpointS3PodAttachments instead of 1", len(s3paList.Items)) + } + + return s3paList, nil +} + +func (r *Reconciler) handleInactivePod(ctx context.Context, workloadPod *corev1.Pod, s3paList *crdv1.MountpointS3PodAttachmentList, workloadUID string, log logr.Logger) (bool, error) { + if len(s3paList.Items) != 1 { + log.Info("Workload pod is not active. Did not find any MountpointS3PodAttachments.") + return false, nil + } + + return r.removeWorkloadFromS3PodAttachment(ctx, &s3paList.Items[0], workloadUID, log) +} + +func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3paList *crdv1.MountpointS3PodAttachmentList, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { + s3pa := &s3paList.Items[0] + + if r.s3paExpectations.IsPending(fieldFilters) { + log.Info("MountpointS3PodAttachment creation is pending, removing from pending") + r.s3paExpectations.Clear(fieldFilters) + } + + if s3paContainsWorkload(s3pa, workloadUID) { + log.Info("MountpointS3PodAttachment already has this workload UID") + return false, nil + } + + return r.addWorkloadToS3PodAttachment(ctx, s3pa, workloadUID, log) +} + +func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { + log.Info("Adding workload UID to MountpointS3PodAttachment", "workloadUID", workloadUID) + + for key := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[key] = append(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[key], workloadUID) + break + } + + err := r.Update(ctx, s3pa) + if apierrors.IsConflict(err) { + log.Info("Failed to update MountpointS3PodAttachment - resource conflict - requeue", "workloadUID", workloadUID) + return true, nil + } + + return false, nil +} + +func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { + // Remove workload UID from mountpoint pods + for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + filteredUIDs := []string{} + found := false + for _, uid := range uids { + if uid == workloadUID { + found = true + continue } + filteredUIDs = append(filteredUIDs, uid) + } + if found { + s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[mpPodName] = filteredUIDs + err := r.Update(ctx, s3pa) + if apierrors.IsConflict(err) { + log.Info("Failed to remove workload pod UID from existing MountpointS3PodAttachment due to resource conflict, requeueing") + return true, nil + } + log.Info("Successfully removed workload pod UID from MountpointS3PodAttachment") + break + } + } - log.Info("Scheduled Mountpoint Pod deleted") - return err + // Remove Mountpoint pods with zero workloads + for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + if len(uids) == 0 { + log.Info("Mountpoint pod has zero workload UIDs. Will remove it from MountpointS3PodAttachment", + "mountpointPodName", mpPodName) + delete(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs, mpPodName) + err := r.Update(ctx, s3pa) + if apierrors.IsConflict(err) { + log.Info("Failed to remove Mountpoint pod from MountpointS3PodAttachment due to resource conflict, requeueing", + "mountpointPodName", mpPodName) + return true, nil + } } + } - // No need to do anything - either there was no Mountpoint Pod for `pod` or it was in `Running` state, - // so a clean unmount operation will be performed and Mountpoint Pod will cleany exit (and get deleted by `reconcileMountpointPod`). - return nil + // Delete MountpointS3PodAttachment if map is empty + if len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs) == 0 { + log.Info("MountpointS3PodAttachment has zero Mountpoint Pods. Will delete it") + err := r.Delete(ctx, s3pa) + if apierrors.IsConflict(err) { + log.Info("Failed to delete MountpointS3PodAttachment due to resource conflict, requeueing") + return true, nil + } } - if isMountpointPodExists { - log.V(debugLevel).Info("Mountpoint Pod already exists - ignoring") - return nil + return false, nil +} + +func (r *Reconciler) handleNewS3PodAttachment( + ctx context.Context, + workloadPod *corev1.Pod, + pv *corev1.PersistentVolume, + fieldFilters client.MatchingFields, + log logr.Logger, +) (bool, error) { + if r.s3paExpectations.IsPending(fieldFilters) { + log.Info("MountpointS3PodAttachment creation is pending, requeueing") + return true, nil } - if err := r.spawnMountpointPod(ctx, workloadPod, pvc, pv, csiSpec, mpPodName); err != nil { + if err := r.createS3PodAttachmentWithMPPod(ctx, workloadPod, pv, log); err != nil { + return false, err + } + + r.s3paExpectations.SetPending(fieldFilters) + return true, nil +} + +func (r *Reconciler) createS3PodAttachmentWithMPPod( + ctx context.Context, + workloadPod *corev1.Pod, + pv *corev1.PersistentVolume, + log logr.Logger, +) error { + authSource := r.getAuthSource(pv) + mpPodName, err := r.spawnMountpointPod(ctx, workloadPod, pv, log) + if err != nil { log.Error(err, "Failed to spawn Mountpoint Pod") return err } + fsGroup := "" + if workloadPod.Spec.SecurityContext.FSGroup != nil { + fsGroup = strconv.FormatInt(*workloadPod.Spec.SecurityContext.FSGroup, 10) + } + s3pa := &crdv1.MountpointS3PodAttachment{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "s3pa-", + Labels: map[string]string{ + LabelCSIDriverVersion: r.mountpointPodConfig.CSIDriverVersion, + }, + }, + Spec: crdv1.MountpointS3PodAttachmentSpec{ + NodeName: workloadPod.Spec.NodeName, + PersistentVolumeName: pv.Name, + VolumeID: pv.Spec.CSI.VolumeHandle, + MountOptions: strings.Join(pv.Spec.MountOptions, ","), + WorkloadFSGroup: fsGroup, + AuthenticationSource: authSource, + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + mpPodName: {string(workloadPod.UID)}, + }, + }, + } + if authSource == "pod" { + s3pa.Spec.WorkloadNamespace = workloadPod.Namespace + s3pa.Spec.WorkloadServiceAccountName = getServiceAccountName(workloadPod) + + roleARN, err := r.findIRSAServiceAccountRole(ctx, workloadPod) + if err != nil { + return err + } + s3pa.Spec.WorkloadServiceAccountIAMRoleARN = roleARN + } + + err = r.Create(ctx, s3pa) + if err != nil { + log.Error(err, "Failed to create MountpointS3PodAttachment") + return err + } + + log.Info("MountpointS3PodAttachment is created", "s3paName", s3pa.Name) return nil } @@ -222,33 +449,20 @@ func (r *Reconciler) spawnOrDeleteMountpointPodIfNeeded( func (r *Reconciler) spawnMountpointPod( ctx context.Context, workloadPod *corev1.Pod, - pvc *corev1.PersistentVolumeClaim, pv *corev1.PersistentVolume, - _ *corev1.CSIPersistentVolumeSource, - name string, -) error { - log := logf.FromContext(ctx).WithValues( - "workloadPod", types.NamespacedName{Namespace: workloadPod.Namespace, Name: workloadPod.Name}, - "mountpointPod", name, - "pvc", pvc.Name, "volumeName", pv.Name) - + log logr.Logger, +) (string, error) { log.Info("Spawning Mountpoint Pod") - mpPod := r.mountpointPodCreator.Create(workloadPod, pv) - if mpPod.Name != name { - err := fmt.Errorf("Mountpoint Pod name mismatch %s vs %s", mpPod.Name, name) - log.Error(err, "Name mismatch on Mountpoint Pod") - return err - } + mpPod := r.mountpointPodCreator.Create(workloadPod.Spec.NodeName, pv) err := r.Create(ctx, mpPod) if err != nil { - log.Error(err, "Failed to create Mountpoint Pod") - return err + return "", err } - log.Info("Mountpoint Pod spawned", "mountpointPodUID", mpPod.UID) - return nil + log.Info("Mountpoint Pod spawned", "mountpointPodName", mpPod.Name) + return mpPod.Name, nil } // deleteMountpointPod deletes given `mountpointPod`. @@ -314,6 +528,16 @@ func (r *Reconciler) getBoundPVForPodClaim( return pvc, pv, nil } +func (r *Reconciler) findIRSAServiceAccountRole(ctx context.Context, pod *corev1.Pod) (string, error) { + sa := &corev1.ServiceAccount{} + err := r.Get(ctx, types.NamespacedName{Namespace: pod.Namespace, Name: getServiceAccountName(pod)}, sa) + if err != nil { + return "", fmt.Errorf("Failed to find workload pod's service account %s", getServiceAccountName(pod)) + } + + return sa.Annotations[AnnotationServiceAccountRole], nil +} + // isMountpointPod returns whether given `pod` is a Mountpoint Pod. // It currently checks namespace of `pod`. func (r *Reconciler) isMountpointPod(pod *corev1.Pod) bool { @@ -338,3 +562,22 @@ func isPodActive(p *corev1.Pod) bool { corev1.PodFailed != p.Status.Phase && p.DeletionTimestamp == nil } + +func s3paContainsWorkload(s3pa *crdv1.MountpointS3PodAttachment, workloadUID string) bool { + for _, workloads := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for _, workload := range workloads { + if workload == workloadUID { + return true + } + } + } + return false +} + +// getServiceAccountName returns the pod's service account name or "default" if not specified +func getServiceAccountName(pod *corev1.Pod) string { + if pod.Spec.ServiceAccountName != "" { + return pod.Spec.ServiceAccountName + } + return defaultServiceAccount +} diff --git a/cmd/aws-s3-csi-controller/main.go b/cmd/aws-s3-csi-controller/main.go index 31c46f00..28ab23bc 100644 --- a/cmd/aws-s3-csi-controller/main.go +++ b/cmd/aws-s3-csi-controller/main.go @@ -7,10 +7,16 @@ package main import ( + "context" "flag" + "fmt" "os" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/config" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -18,9 +24,11 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager/signals" "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/go-logr/logr" ) var mountpointNamespace = flag.String("mountpoint-namespace", os.Getenv("MOUNTPOINT_NAMESPACE"), "Namespace to spawn Mountpoint Pods in.") @@ -30,20 +38,33 @@ var mountpointImage = flag.String("mountpoint-image", os.Getenv("MOUNTPOINT_IMAG var mountpointImagePullPolicy = flag.String("mountpoint-image-pull-policy", os.Getenv("MOUNTPOINT_IMAGE_PULL_POLICY"), "Pull policy of Mountpoint images.") var mountpointContainerCommand = flag.String("mountpoint-container-command", "/bin/aws-s3-csi-mounter", "Entrypoint command of the Mountpoint Pods.") +var ( + scheme = runtime.NewScheme() +) + +func init() { + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(crdv1.AddToScheme(scheme)) +} + func main() { flag.Parse() logf.SetLogger(zap.New()) log := logf.Log.WithName(csicontroller.Name) - client := config.GetConfigOrDie() + conf := config.GetConfigOrDie() - mgr, err := manager.New(client, manager.Options{}) + mgr, err := manager.New(conf, manager.Options{ + Scheme: scheme, + }) if err != nil { log.Error(err, "Failed to create a new manager") os.Exit(1) } + IndexMountpointS3PodAttachmentFields(log, mgr) + err = csicontroller.NewReconciler(mgr.GetClient(), mppod.Config{ Namespace: *mountpointNamespace, MountpointVersion: *mountpointVersion, @@ -54,7 +75,7 @@ func main() { ImagePullPolicy: corev1.PullPolicy(*mountpointImagePullPolicy), }, CSIDriverVersion: version.GetVersion().DriverVersion, - ClusterVariant: cluster.DetectVariant(client, log), + ClusterVariant: cluster.DetectVariant(conf, log), }).SetupWithManager(mgr) if err != nil { log.Error(err, "Failed to create controller") @@ -66,3 +87,25 @@ func main() { os.Exit(1) } } + +func IndexMountpointS3PodAttachmentFields(log logr.Logger, mgr manager.Manager) { + indexField(log, mgr, crdv1.FieldNodeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) + indexField(log, mgr, crdv1.FieldPersistentVolumeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) + indexField(log, mgr, crdv1.FieldVolumeID, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) + indexField(log, mgr, crdv1.FieldMountOptions, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) + indexField(log, mgr, crdv1.FieldAuthenticationSource, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) + indexField(log, mgr, crdv1.FieldWorkloadFSGroup, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) + indexField(log, mgr, crdv1.FieldWorkloadServiceAccountName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) + indexField(log, mgr, crdv1.FieldWorkloadNamespace, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) + indexField(log, mgr, crdv1.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) +} + +func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1.MountpointS3PodAttachment) string) { + err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { + return []string{extractor(obj.(*crdv1.MountpointS3PodAttachment))} + }) + if err != nil { + log.Error(err, fmt.Sprintf("Failed to create a %s field indexer", field)) + os.Exit(1) + } +} diff --git a/hack/boilerplate.go.txt b/hack/boilerplate.go.txt new file mode 100644 index 00000000..e69de29b diff --git a/pkg/api/v1/groupversion_info.go b/pkg/api/v1/groupversion_info.go new file mode 100644 index 00000000..8c6d8b72 --- /dev/null +++ b/pkg/api/v1/groupversion_info.go @@ -0,0 +1,20 @@ +// Package v1 contains API Schema definitions for the s3.csi.aws.com v1 API group. +// +kubebuilder:object:generate=true +// +groupName=s3.csi.aws.com +package v1 + +import ( + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/scheme" +) + +var ( + // GroupVersion is group version used to register these objects. + GroupVersion = schema.GroupVersion{Group: "s3.csi.aws.com", Version: "v1"} + + // SchemeBuilder is used to add go types to the GroupVersionKind scheme. + SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion} + + // AddToScheme adds the types in this group-version to the given scheme. + AddToScheme = SchemeBuilder.AddToScheme +) diff --git a/pkg/api/v1/mountpoints3podattachment_types.go b/pkg/api/v1/mountpoints3podattachment_types.go new file mode 100644 index 00000000..cb623538 --- /dev/null +++ b/pkg/api/v1/mountpoints3podattachment_types.go @@ -0,0 +1,79 @@ +package v1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// The following fields are used as matching criteria to determine if a mountpoint s3 pod can be shared by having the same MountpointS3PodAttachment resource: +const ( + FieldNodeName = "spec.nodeName" + FieldPersistentVolumeName = "spec.persistentVolumeName" + FieldVolumeID = "spec.volumeID" + FieldMountOptions = "spec.mountOptions" + FieldAuthenticationSource = "spec.authenticationSource" + FieldWorkloadFSGroup = "spec.workloadFSGroup" + FieldWorkloadServiceAccountName = "spec.workloadServiceAccountName" + FieldWorkloadNamespace = "spec.workloadNamespace" + FieldWorkloadServiceAccountIAMRoleARN = "spec.workloadServiceAccountIAMRoleARN" +) + +// MountpointS3PodAttachmentSpec defines the desired state of MountpointS3PodAttachment. +type MountpointS3PodAttachmentSpec struct { + // Important: Run "make generate" to regenerate code after modifying this file + + // Name of the node. + NodeName string `json:"nodeName"` + + // Name of the Persistent Volume. + PersistentVolumeName string `json:"persistentVolumeName"` + + // Volume ID. + VolumeID string `json:"volumeID"` + + // Comma separated mount options taken from volume. + MountOptions string `json:"mountOptions"` + + // Authentication source taken from volume attribute field `authenticationSource`. + AuthenticationSource string `json:"authenticationSource"` + + // Workload pod's `fsGroup` from pod security context + WorkloadFSGroup string `json:"workloadFSGroup"` + + // Workload pod's service account name. Exists only if `authenticationSource: pod`. + WorkloadServiceAccountName string `json:"workloadServiceAccountName,omitempty"` + + // Workload pod's namespace. Exists only if `authenticationSource: pod`. + WorkloadNamespace string `json:"workloadNamespace,omitempty"` + + // EKS IAM Role ARN from workload pod's service account annotation (IRSA). Exists only if `authenticationSource: pod` and service account has `eks.amazonaws.com/role-arn` annotation. + WorkloadServiceAccountIAMRoleARN string `json:"workloadServiceAccountIAMRoleARN,omitempty"` + + // Maps each Mountpoint S3 pod name to the list of workload pod UIDs it is attached to. + MountpointS3PodToWorkloadPodUIDs map[string][]string `json:"mountpointS3PodToWorkloadPodUIDs"` +} + +// +kubebuilder:object:root=true +// +kubebuilder:subresource:status +// +kubebuilder:resource:scope=Cluster,shortName=s3pa +// +kubebuilder:selectablefield:JSONPath=`.spec.nodeName` + +// MountpointS3PodAttachment is the Schema for the mountpoints3podattachments API. +type MountpointS3PodAttachment struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec MountpointS3PodAttachmentSpec `json:"spec,omitempty"` +} + +// +kubebuilder:object:root=true + +// MountpointS3PodAttachmentList contains a list of MountpointS3PodAttachment. +type MountpointS3PodAttachmentList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []MountpointS3PodAttachment `json:"items"` +} + +func init() { + SchemeBuilder.Register(&MountpointS3PodAttachment{}, &MountpointS3PodAttachmentList{}) +} diff --git a/pkg/api/v1/zz_generated.deepcopy.go b/pkg/api/v1/zz_generated.deepcopy.go new file mode 100644 index 00000000..32041d4a --- /dev/null +++ b/pkg/api/v1/zz_generated.deepcopy.go @@ -0,0 +1,98 @@ +//go:build !ignore_autogenerated + +// Code generated by controller-gen. DO NOT EDIT. + +package v1 + +import ( + runtime "k8s.io/apimachinery/pkg/runtime" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MountpointS3PodAttachment) DeepCopyInto(out *MountpointS3PodAttachment) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + in.Spec.DeepCopyInto(&out.Spec) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MountpointS3PodAttachment. +func (in *MountpointS3PodAttachment) DeepCopy() *MountpointS3PodAttachment { + if in == nil { + return nil + } + out := new(MountpointS3PodAttachment) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MountpointS3PodAttachment) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MountpointS3PodAttachmentList) DeepCopyInto(out *MountpointS3PodAttachmentList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]MountpointS3PodAttachment, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MountpointS3PodAttachmentList. +func (in *MountpointS3PodAttachmentList) DeepCopy() *MountpointS3PodAttachmentList { + if in == nil { + return nil + } + out := new(MountpointS3PodAttachmentList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MountpointS3PodAttachmentList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MountpointS3PodAttachmentSpec) DeepCopyInto(out *MountpointS3PodAttachmentSpec) { + *out = *in + if in.MountpointS3PodToWorkloadPodUIDs != nil { + in, out := &in.MountpointS3PodToWorkloadPodUIDs, &out.MountpointS3PodToWorkloadPodUIDs + *out = make(map[string][]string, len(*in)) + for key, val := range *in { + var outVal []string + if val == nil { + (*out)[key] = nil + } else { + inVal := (*in)[key] + in, out := &inVal, &outVal + *out = make([]string, len(*in)) + copy(*out, *in) + } + (*out)[key] = outVal + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MountpointS3PodAttachmentSpec. +func (in *MountpointS3PodAttachmentSpec) DeepCopy() *MountpointS3PodAttachmentSpec { + if in == nil { + return nil + } + out := new(MountpointS3PodAttachmentSpec) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 3b20d2d2..5d0c776f 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -23,6 +23,7 @@ import ( "os" "time" + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" @@ -31,10 +32,18 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/util" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/kubernetes" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" "k8s.io/mount-utils" + ctrlcache "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager/signals" ) const ( @@ -43,11 +52,18 @@ const ( grpcServerMaxReceiveMessageSize = 1024 * 1024 * 2 // 2MB unixSocketPerm = os.FileMode(0700) // only owner can write and read. +) +var ( + mountpointPodNamespace = os.Getenv("MOUNTPOINT_NAMESPACE") podWatcherResyncPeriod = time.Minute + scheme = runtime.NewScheme() ) -var mountpointPodNamespace = os.Getenv("MOUNTPOINT_NAMESPACE") +func init() { + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(crdv1.AddToScheme(scheme)) +} type Driver struct { Endpoint string @@ -91,13 +107,53 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error var mounterImpl mounter.Mounter if util.UsePodMounter() { - podWatcher := watcher.New(clientset, mountpointPodNamespace, podWatcherResyncPeriod) + mountUtil := mount.New("") + podWatcher := watcher.New(clientset, mountpointPodNamespace, nodeID, podWatcherResyncPeriod) err = podWatcher.Start(stopCh) if err != nil { klog.Fatalf("Failed to start Pod watcher: %v\n", err) } - mounterImpl, err = mounter.NewPodMounter(podWatcher, credProvider, mount.New(""), nil, kubernetesVersion) + s3paCache, err := ctrlcache.New(config, ctrlcache.Options{ + Scheme: scheme, + SyncPeriod: &podWatcherResyncPeriod, + ReaderFailOnMissingInformer: true, + ByObject: map[client.Object]ctrlcache.ByObject{ + &crdv1.MountpointS3PodAttachment{}: { + Field: fields.OneTermEqualSelector("spec.nodeName", nodeID), + }, + }, + }) + if err != nil { + klog.Fatalf("Failed to create cache: %v\n", err) + } + + indexMountpointS3PodAttachmentFields(s3paCache) + + s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1.MountpointS3PodAttachment{}) + if err != nil { + klog.Fatalf("Failed to create informer for MountpointS3PodAttachment: %v\n", err) + } + + go func() { + if err := s3paCache.Start(signals.SetupSignalHandler()); err != nil { + klog.Fatalf("Failed to start cache: %v\n", err) + } + }() + + unmounter := mounter.NewPodUnmounter(nodeID, mountUtil, podWatcher, s3paCache, credProvider) + + s3podAttachmentInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + UpdateFunc: unmounter.HandleS3PodAttachmentUpdate, + }) + + if !cache.WaitForCacheSync(stopCh, s3podAttachmentInformer.HasSynced) { + klog.Fatalf("Failed to sync informer cache within the timeout: %v\n", err) + } + + unmounter.CleanupDanglingMounts() + + mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mountUtil, nil, kubernetesVersion) if err != nil { klog.Fatalln(err) } @@ -185,3 +241,25 @@ func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) { return version.String(), nil } + +// TODO: This is duplicated multiple times +func indexMountpointS3PodAttachmentFields(s3paCache ctrlcache.Cache) { + indexField(s3paCache, crdv1.FieldNodeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) + indexField(s3paCache, crdv1.FieldPersistentVolumeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) + indexField(s3paCache, crdv1.FieldVolumeID, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) + indexField(s3paCache, crdv1.FieldMountOptions, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) + indexField(s3paCache, crdv1.FieldAuthenticationSource, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) + indexField(s3paCache, crdv1.FieldWorkloadFSGroup, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) + indexField(s3paCache, crdv1.FieldWorkloadServiceAccountName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) + indexField(s3paCache, crdv1.FieldWorkloadNamespace, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) + indexField(s3paCache, crdv1.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) +} + +func indexField(cache ctrlcache.Cache, field string, extractor func(*crdv1.MountpointS3PodAttachment) string) { + err := cache.IndexField(context.Background(), &crdv1.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { + return []string{extractor(obj.(*crdv1.MountpointS3PodAttachment))} + }) + if err != nil { + klog.Fatalf("Failed to create a %s field indexer: %v", field, err) + } +} diff --git a/pkg/driver/node/credentialprovider/provider.go b/pkg/driver/node/credentialprovider/provider.go index 06070f84..9fa9797b 100644 --- a/pkg/driver/node/credentialprovider/provider.go +++ b/pkg/driver/node/credentialprovider/provider.go @@ -62,10 +62,11 @@ type ProvideContext struct { VolumeID string // The following values are provided from CSI volume context. - AuthenticationSource AuthenticationSource - PodNamespace string - ServiceAccountTokens string - ServiceAccountName string + AuthenticationSource AuthenticationSource + PodNamespace string + ServiceAccountTokens string + ServiceAccountName string + ServiceAccountEKSRoleARN string // StsRegion is the `stsRegion` parameter passed via volume attribute. StsRegion string // BucketRegion is the `--region` parameter passed via mount options. @@ -78,6 +79,11 @@ func (ctx *ProvideContext) SetWriteAndEnvPath(writePath, envPath string) { ctx.EnvPath = envPath } +// SetServiceAccountEKSRoleARN sets `ServiceAccountEKSRoleARN` for `ctx`. +func (ctx *ProvideContext) SetServiceAccountEKSRoleARN(roleArn string) { + ctx.ServiceAccountEKSRoleARN = roleArn +} + // A CleanupContext contains parameters needed to clean up credentials after volume unmount. type CleanupContext struct { // WritePath is basepath where credentials previously written into. diff --git a/pkg/driver/node/credentialprovider/provider_pod.go b/pkg/driver/node/credentialprovider/provider_pod.go index d011a942..832fc8b8 100644 --- a/pkg/driver/node/credentialprovider/provider_pod.go +++ b/pkg/driver/node/credentialprovider/provider_pod.go @@ -113,6 +113,11 @@ func (c *Provider) cleanupFromPod(cleanupCtx CleanupContext) error { // findPodServiceAccountRole tries to provide associated AWS IAM role for service account specified in the volume context. func (c *Provider) findPodServiceAccountRole(ctx context.Context, provideCtx ProvideContext) (string, error) { + // In PodMounter we get IAM Role ARN from MountpointS3PodAttachment custom resource + if provideCtx.ServiceAccountEKSRoleARN != "" { + return provideCtx.ServiceAccountEKSRoleARN, nil + } + podNamespace := provideCtx.PodNamespace podServiceAccount := provideCtx.ServiceAccountName if podNamespace == "" || podServiceAccount == "" { diff --git a/pkg/driver/node/mounter/fake_cache.go b/pkg/driver/node/mounter/fake_cache.go new file mode 100644 index 00000000..2346be02 --- /dev/null +++ b/pkg/driver/node/mounter/fake_cache.go @@ -0,0 +1,43 @@ +package mounter + +import ( + "context" + + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type FakeCache struct{} + +func (f *FakeCache) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return nil +} + +func (f *FakeCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + return nil +} + +func (f *FakeCache) GetInformer(ctx context.Context, obj client.Object, opts ...cache.InformerGetOption) (cache.Informer, error) { + return nil, nil +} + +func (f *FakeCache) GetInformerForKind(ctx context.Context, gvk schema.GroupVersionKind, opts ...cache.InformerGetOption) (cache.Informer, error) { + return nil, nil +} + +func (f *FakeCache) RemoveInformer(ctx context.Context, obj client.Object) error { + return nil +} + +func (f *FakeCache) IndexField(ctx context.Context, obj client.Object, field string, extractValue client.IndexerFunc) error { + return nil +} + +func (f *FakeCache) Start(ctx context.Context) error { + return nil +} + +func (f *FakeCache) WaitForCacheSync(ctx context.Context) bool { + return true +} diff --git a/pkg/driver/node/mounter/fake_mounter.go b/pkg/driver/node/mounter/fake_mounter.go index 8cfaa603..10f07dd7 100644 --- a/pkg/driver/node/mounter/fake_mounter.go +++ b/pkg/driver/node/mounter/fake_mounter.go @@ -10,7 +10,7 @@ import ( type FakeMounter struct{} func (m *FakeMounter) Mount(ctx context.Context, bucketName string, target string, - credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { + credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error { return nil } diff --git a/pkg/driver/node/mounter/mocks/mock_mount.go b/pkg/driver/node/mounter/mocks/mock_mount.go index 27fc6fbd..c8539a9f 100644 --- a/pkg/driver/node/mounter/mocks/mock_mount.go +++ b/pkg/driver/node/mounter/mocks/mock_mount.go @@ -106,17 +106,17 @@ func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call } // Mount mocks base method. -func (m *MockMounter) Mount(ctx context.Context, bucketName, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (m *MockMounter) Mount(ctx context.Context, bucketName, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", ctx, bucketName, target, credentialCtx, args) + ret := m.ctrl.Call(m, "Mount", ctx, bucketName, target, credentialCtx, args, fsGroup, pvMountOptions) ret0, _ := ret[0].(error) return ret0 } // Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(ctx, bucketName, target, credentialCtx, args interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Mount(ctx, bucketName, target, credentialCtx, args, fsGroup, pvMountOptions interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), ctx, bucketName, target, credentialCtx, args) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), ctx, bucketName, target, credentialCtx, args, fsGroup, pvMountOptions) } // Unmount mocks base method. diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index d9ac4e99..7258f3a2 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -5,6 +5,8 @@ import ( "context" "fmt" "os" + "strings" + "syscall" "k8s.io/klog/v2" "k8s.io/mount-utils" @@ -24,11 +26,13 @@ type ServiceRunner interface { // Mounter is an interface for mount operations type Mounter interface { - Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error + Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error Unmount(ctx context.Context, target string, credentialCtx credentialprovider.CleanupContext) error IsMountPoint(target string) (bool, error) } +// Internal S3 CSI Driver directory for source mount points +const SourceMountDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/mnt/" const MountS3PathEnv = "MOUNT_S3_PATH" const defaultMountS3Path = "/usr/bin/mount-s3" @@ -65,3 +69,67 @@ func isMountPoint(mounter mount.Interface, target string) (bool, error) { } return false, nil } + +// findSourceMountPoint locates the source S3 mount point for a given target path by comparing +// device IDs and inodes with all S3 mount points at driver source directory `SourceMountDir`. +// +// Parameters: +// - mounter: Interface providing mounting operations and mount point listing capabilities +// - target: The target path whose source mount point needs to be found +// +// Returns: +// - string: The path of the source mount point if found +// - error: An error if the operation fails +// +// The function works by: +// 1. Getting the device ID and inode of the target path +// 2. Listing all mount points in the system that has "mountpoint-s3" as device name and prefix `SourceMountDir` +// 3. Finding a mount point that matches both the device ID and inode of the target +func findSourceMountPoint(mounter mount.Interface, target string) (string, error) { + if mounter == nil { + return "", fmt.Errorf("mounter interface cannot be nil") + } + + targetFileInfo, err := os.Stat(target) + if err != nil { + return "", fmt.Errorf("failed to stat %q: %w", target, err) + } + + targetSysInfo, ok := targetFileInfo.Sys().(*syscall.Stat_t) + if !ok { + return "", fmt.Errorf("failed to get system info for target %q", target) + } + + targetDevID := targetSysInfo.Dev + targetInodeID := targetSysInfo.Ino + + mountPoints, err := mounter.List() + if err != nil { + return "", fmt.Errorf("failed to list mount points: %w", err) + } + + for _, mountPoint := range mountPoints { + if mountPoint.Device != mountpointDeviceName || !strings.HasPrefix(mountPoint.Path, SourceMountDir) { + continue + } + + mountPathInfo, err := os.Stat(mountPoint.Path) + if err != nil { + klog.V(4).Infof("Skipping mount point %q: unable to stat %v", mountPoint.Path, err) + continue + } + + mountSysInfo, ok := mountPathInfo.Sys().(*syscall.Stat_t) + if !ok { + klog.V(4).Infof("Skipping mount point %q: unable to get system info", mountPoint.Path) + continue + } + + if targetDevID == mountSysInfo.Dev && targetInodeID == mountSysInfo.Ino { + return mountPoint.Path, nil + } + } + + return "", fmt.Errorf("no source mount point found for path %q (device: %d, inode: %d)", + target, targetDevID, targetInodeID) +} diff --git a/pkg/driver/node/mounter/mppod_lock.go b/pkg/driver/node/mounter/mppod_lock.go new file mode 100644 index 00000000..e9b2f1fe --- /dev/null +++ b/pkg/driver/node/mounter/mppod_lock.go @@ -0,0 +1,55 @@ +package mounter + +import "sync" + +// MPPodLock represents a reference-counted mutex lock for Mountpoint Pod. +// It ensures synchronized access to pod-specific resources. +type MPPodLock struct { + mutex sync.Mutex + refCount int +} + +var ( + // mpPodLocks maps pod UIDs to their corresponding locks. + mpPodLocks = make(map[string]*MPPodLock) + + // mpPodLocksMutex guards access to the mpPodLocks map. + mpPodLocksMutex sync.Mutex +) + +// getMPPodLock retrieves or creates a lock for the specified pod UID. +// It increments the reference count for existing locks. +// The caller is responsible for calling releaseMPPodLock when the lock is no longer needed. +func getMPPodLock(mpPodUID string) *MPPodLock { + mpPodLocksMutex.Lock() + defer mpPodLocksMutex.Unlock() + + lock, exists := mpPodLocks[mpPodUID] + if !exists { + lock = &MPPodLock{refCount: 1} + mpPodLocks[mpPodUID] = lock + } else { + lock.refCount++ + } + return lock +} + +// releaseMPPodLock decrements the reference count for a pod's lock. +// When the reference count reaches zero, the lock is removed from the map. +// If the lock doesn't exist, the function returns silently. +func releaseMPPodLock(mpPodUID string) { + mpPodLocksMutex.Lock() + defer mpPodLocksMutex.Unlock() + + lock, exists := mpPodLocks[mpPodUID] + if !exists { + // Should never happen + return + } + + lock.refCount-- + + if lock.refCount <= 0 { + delete(mpPodLocks, mpPodUID) + } +} diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 5cfd2799..85409466 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -15,6 +15,7 @@ import ( "k8s.io/klog/v2" "k8s.io/mount-utils" + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" @@ -23,11 +24,10 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" "github.com/awslabs/aws-s3-csi-driver/pkg/util" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" ) -const mountpointPodReadinessTimeout = 10 * time.Second -const mountpointPodReadinessCheckInterval = 100 * time.Millisecond - // targetDirPerm is the permission to use while creating target directory if its not exists. const targetDirPerm = fs.FileMode(0755) @@ -35,21 +35,25 @@ const targetDirPerm = fs.FileMode(0755) // It returns mounted FUSE file descriptor as a result. // This is mainly exposed for testing, in production platform-native function (`mountSyscallDefault`) will be used. type mountSyscall func(target string, args mountpoint.Args) (fd int, err error) +type bindMountSyscall func(source, target string) (err error) // A PodMounter is a [Mounter] that mounts Mountpoint on pre-created Kubernetes Pod running in the same node. type PodMounter struct { podWatcher *watcher.Watcher + s3paCache cache.Cache mount mount.Interface kubeletPath string mountSyscall mountSyscall + bindMountSyscall bindMountSyscall kubernetesVersion string credProvider *credentialprovider.Provider } // NewPodMounter creates a new [PodMounter] with given Kubernetes client. -func NewPodMounter(podWatcher *watcher.Watcher, credProvider *credentialprovider.Provider, mount mount.Interface, mountSyscall mountSyscall, kubernetesVersion string) (*PodMounter, error) { +func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, mount mount.Interface, mountSyscall mountSyscall, kubernetesVersion string) (*PodMounter, error) { return &PodMounter{ podWatcher: podWatcher, + s3paCache: s3paCache, credProvider: credProvider, mount: mount, kubeletPath: util.KubeletPath(), @@ -69,36 +73,96 @@ func NewPodMounter(podWatcher *watcher.Watcher, credProvider *credentialprovider // 6. Wait until Mountpoint successfully mounts at `target` // // If Mountpoint is already mounted at `target`, it will return early at step 2 to ensure credentials are up-to-date. -func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup string, pvMountOptions string) error { volumeName, err := pm.volumeNameFromTargetPath(target) if err != nil { return fmt.Errorf("Failed to extract volume name from %q: %w", target, err) } - podID := credentialCtx.PodID - err = pm.verifyOrSetupMountTarget(target) if err != nil { return fmt.Errorf("Failed to verify target path can be used as a mount point %q: %w", target, err) } - isMountPoint, err := pm.IsMountPoint(target) + isTargetMountPoint, err := pm.IsMountPoint(target) if err != nil { return fmt.Errorf("Could not check if %q is already a mount point: %w", target, err) } + if isTargetMountPoint { + klog.V(4).Infof("Target path %q is already mounted. Only refreshing credentials.", target) + source, err := pm.findSourceMountPoint(target) + if err != nil { + klog.Errorf("Failed to find source mount point for %q: %v", target, err) + return fmt.Errorf("Failed to find source mount point for %q: %w", target, err) + } + mpPodUID := filepath.Base(source) + podPath := pm.podPath(mpPodUID) + + mpPodLock := getMPPodLock(mpPodUID) + mpPodLock.mutex.Lock() + defer func() { + mpPodLock.mutex.Unlock() + releaseMPPodLock(mpPodUID) + }() + + podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) + if err != nil { + klog.Errorf("Failed to create credentials directory for %q: %v", target, err) + return fmt.Errorf("Failed to create credentials directory for %q: %w", target, err) + } + + credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) + + _, _, err = pm.credProvider.Provide(ctx, credentialCtx) + if err != nil { + klog.Errorf("Failed to provide credentials for %s: %v\n%s", target, err, "TODO: pm.helpMessageForGettingMountpointLogs(pod)") + return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", target, err, "TODO: pm.helpMessageForGettingMountpointLogs(pod)") + } + + return nil + } + + s3PodAttachment, mpPodName, err := pm.getS3PodAttachmentWithRetry(ctx, volumeName, credentialCtx, fsGroup, pvMountOptions) + if err != nil { + klog.Errorf("Failed to find corresponding MountpointS3PodAttachment custom resource %q: %v", target, err) + return fmt.Errorf("Failed to find corresponding MountpointS3PodAttachment custom resource %q: %w", target, err) + } + + if s3PodAttachment.Spec.WorkloadServiceAccountIAMRoleARN != "" { + credentialCtx.SetServiceAccountEKSRoleARN(s3PodAttachment.Spec.WorkloadServiceAccountIAMRoleARN) + } + // TODO: If `target` is a `systemd`-mounted Mountpoint, this would return an error, // but we should still update the credentials for it by calling `credProvider.Provide`. - pod, podPath, err := pm.waitForMountpointPod(ctx, podID, volumeName) + pod, podPath, err := pm.waitForMountpointPod(ctx, mpPodName) if err != nil { klog.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %v", target, err) return fmt.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %w", target, err) } + mpPodUID := string(pod.UID) + mpPodLock := getMPPodLock(mpPodUID) + mpPodLock.mutex.Lock() + defer func() { + mpPodLock.mutex.Unlock() + releaseMPPodLock(mpPodUID) + }() + + source := filepath.Join(SourceMountDir, mpPodUID) + err = pm.verifyOrSetupMountTarget(source) + if err != nil { + return fmt.Errorf("Failed to verify source path can be used as a mount point %q: %w", source, err) + } + + isSourceMountPoint, err := pm.IsMountPoint(source) + if err != nil { + return fmt.Errorf("Could not check if %q is already a mount point: %w", source, err) + } podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) if err != nil { - klog.Errorf("Failed to create credentials directory for %q: %v", target, err) - return fmt.Errorf("Failed to create credentials directory for %q: %w", target, err) + klog.Errorf("Failed to create credentials directory for %q: %v", source, err) + return fmt.Errorf("Failed to create credentials directory for %q: %w", source, err) } credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) @@ -107,126 +171,98 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin // there is an existing mount point at `target`. credEnv, authenticationSource, err := pm.credProvider.Provide(ctx, credentialCtx) if err != nil { - klog.Errorf("Failed to provide credentials for %s: %v\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) + klog.Errorf("Failed to provide credentials for %s: %v\n%s", source, err, pm.helpMessageForGettingMountpointLogs(pod)) + return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", source, err, pm.helpMessageForGettingMountpointLogs(pod)) } - if isMountPoint { - klog.V(4).Infof("Target path %q is already mounted", target) - return nil - } - - env := envprovider.Default() - env.Merge(credEnv) + if !isSourceMountPoint { + env := envprovider.Default() + env.Merge(credEnv) - // Move `--aws-max-attempts` to env if provided - if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { - env.Set(envprovider.EnvMaxAttempts, maxAttempts) - } + // Move `--aws-max-attempts` to env if provided + if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { + env.Set(envprovider.EnvMaxAttempts, maxAttempts) + } - args.Set(mountpoint.ArgUserAgentPrefix, UserAgent(authenticationSource, pm.kubernetesVersion)) + args.Set(mountpoint.ArgUserAgentPrefix, UserAgent(authenticationSource, pm.kubernetesVersion)) - podMountSockPath := mppod.PathOnHost(podPath, mppod.KnownPathMountSock) - podMountErrorPath := mppod.PathOnHost(podPath, mppod.KnownPathMountError) + podMountSockPath := mppod.PathOnHost(podPath, mppod.KnownPathMountSock) + podMountErrorPath := mppod.PathOnHost(podPath, mppod.KnownPathMountError) - klog.V(4).Infof("Mounting %s for %s", target, pod.Name) + klog.V(4).Infof("Mounting %s for %s", source, pod.Name) - fuseDeviceFD, err := pm.mountSyscallWithDefault(target, args) - if err != nil { - klog.Errorf("Failed to mount %s: %v", target, err) - return fmt.Errorf("Failed to mount %s: %w", target, err) - } - - // Remove the read-only argument from the list as mount-s3 does not support it when using FUSE - // file descriptor (we already pass MS_RDONLY flag during mount syscall in `pod_mounter_linux.go`) - if args.Has(mountpoint.ArgReadOnly) { - args.Remove(mountpoint.ArgReadOnly) - } + fuseDeviceFD, err := pm.mountSyscallWithDefault(source, args) + if err != nil { + klog.Errorf("Failed to mount %s: %v", source, err) + return fmt.Errorf("Failed to mount %s: %w", source, err) + } - // This will set to false in the success condition. This is set to `true` by default to - // ensure we don't leave `target` mounted if Mountpoint is not started to serve requests for it. - unmount := true - defer func() { - if unmount { - if err := pm.unmountTarget(target); err != nil { - klog.V(4).ErrorS(err, "Failed to unmount mounted target %s\n", target) - } else { - klog.V(4).Infof("Target %s unmounted successfully\n", target) - } + // Remove the read-only argument from the list as mount-s3 does not support it when using FUSE + // file descriptor (we already pass MS_RDONLY flag during mount syscall in `pod_mounter_linux.go`) + if args.Has(mountpoint.ArgReadOnly) { + args.Remove(mountpoint.ArgReadOnly) } - }() - // This function can either fail or successfully send mount options to Mountpoint Pod - in which - // Mountpoint Pod will get its own fd referencing the same underlying file description. - // In both case we need to close the fd in this process. - defer pm.closeFUSEDevFD(fuseDeviceFD) + // This will set to false in the success condition. This is set to `true` by default to + // ensure we don't leave `source` mounted if Mountpoint is not started to serve requests for it. + unmount := true + defer func() { + if unmount { + if err := pm.unmountTarget(source); err != nil { + klog.V(4).ErrorS(err, "Failed to unmount mounted source %s\n", source) + } else { + klog.V(4).Infof("Source %s unmounted successfully\n", source) + } + } + }() - // Remove old mount error file if exists - _ = os.Remove(podMountErrorPath) + // This function can either fail or successfully send mount options to Mountpoint Pod - in which + // Mountpoint Pod will get its own fd referencing the same underlying file description. + // In both case we need to close the fd in this process. + defer pm.closeFUSEDevFD(fuseDeviceFD) - klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", pod.Name, podMountSockPath) + // Remove old mount error file if exists + _ = os.Remove(podMountErrorPath) - err = mountoptions.Send(ctx, podMountSockPath, mountoptions.Options{ - Fd: fuseDeviceFD, - BucketName: bucketName, - Args: args.SortedList(), - Env: env.List(), - }) - if err != nil { - klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) - } + klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", pod.Name, podMountSockPath) - err = pm.waitForMount(ctx, target, pod.Name, podMountErrorPath) - if err != nil { - klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) - } + err = mountoptions.Send(ctx, podMountSockPath, mountoptions.Options{ + Fd: fuseDeviceFD, + BucketName: bucketName, + Args: args.SortedList(), + Env: env.List(), + }) + if err != nil { + klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + } - // Mountpoint successfully started, so don't unmount the filesystem - unmount = false - return nil -} + err = pm.waitForMount(ctx, source, pod.Name, podMountErrorPath) + if err != nil { + klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + } -// Unmount unmounts the mount point at `target` and cleans all credentials. -func (pm *PodMounter) Unmount(ctx context.Context, target string, credentialCtx credentialprovider.CleanupContext) error { - volumeName, err := pm.volumeNameFromTargetPath(target) - if err != nil { - return fmt.Errorf("Failed to extract volume name from %q: %w", target, err) + // Mountpoint successfully started, so don't unmount the filesystem + unmount = false } - podID := credentialCtx.PodID - - // TODO: If `target` is a `systemd`-mounted Mountpoint, this would return an error, - // but we should still unmount it and clean the credentials. - pod, podPath, err := pm.waitForMountpointPod(ctx, podID, volumeName) + err = pm.bindMountSyscallWithDefault(source, target) if err != nil { - klog.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %v", target, err) - return fmt.Errorf("Failed to wait for Mountpoint Pod for %q: %w", target, err) + klog.Errorf("Failed to bind mount %s to target %s: %v", source, target, err) + return fmt.Errorf("Failed to bind mount %s to target %s: %w", source, target, err) } - credentialCtx.WritePath = pm.credentialsDir(podPath) - - // Write `mount.exit` file to indicate Mountpoint Pod to cleanly exit. - podMountExitPath := mppod.PathOnHost(podPath, mppod.KnownPathMountExit) - _, err = os.OpenFile(podMountExitPath, os.O_RDONLY|os.O_CREATE, credentialprovider.CredentialFilePerm) - if err != nil { - klog.Errorf("Failed to send a exit message to Mountpoint Pod for %q: %s\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to send a exit message to Mountpoint Pod for %q: %w\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - } + return nil +} - err = pm.unmountTarget(target) +// Unmount unmounts only the bind mount point at `target`. +func (pm *PodMounter) Unmount(_ context.Context, target string, _ credentialprovider.CleanupContext) error { + err := pm.unmountTarget(target) if err != nil { klog.Errorf("Failed to unmount %q: %v", target, err) return fmt.Errorf("Failed to unmount %q: %w", target, err) } - - err = pm.credProvider.Cleanup(credentialCtx) - if err != nil { - klog.Errorf("Failed to clean up credentials for %s: %v\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to clean up credentials for %q: %w\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - } - return nil } @@ -236,11 +272,14 @@ func (pm *PodMounter) IsMountPoint(target string) (bool, error) { return isMountPoint(pm.mount, target) } -// waitForMountpointPod waints until Mountpoint Pod for given `podID` and `volumeName` is in `Running` state. -// It returns found Mountpoint Pod and it's base directory. -func (pm *PodMounter) waitForMountpointPod(ctx context.Context, podID, volumeName string) (*corev1.Pod, string, error) { - podName := mppod.MountpointPodNameFor(podID, volumeName) +// findSourceMountPoint calls `findSourceMountPoint` on `target`. +func (pm *PodMounter) findSourceMountPoint(target string) (string, error) { + return findSourceMountPoint(pm.mount, target) +} +// waitForMountpointPod waits until Mountpoint Pod for given `podName` is in `Running` state. +// It returns found Mountpoint Pod and it's base directory. +func (pm *PodMounter) waitForMountpointPod(ctx context.Context, podName string) (*corev1.Pod, string, error) { pod, err := pm.podWatcher.Wait(ctx, podName) if err != nil { return nil, "", err @@ -248,7 +287,7 @@ func (pm *PodMounter) waitForMountpointPod(ctx context.Context, podID, volumeNam klog.V(4).Infof("Mountpoint Pod %s/%s is running with id %s", pod.Namespace, podName, pod.UID) - return pod, pm.podPath(pod), nil + return pod, pm.podPath(string(pod.UID)), nil } // waitForMount waits until Mountpoint is successfully mounted at `target`. @@ -354,8 +393,8 @@ func (pm *PodMounter) credentialsDir(podPath string) string { } // podPath returns `pod`'s basepath inside kubelet's path. -func (pm *PodMounter) podPath(pod *corev1.Pod) string { - return filepath.Join(pm.kubeletPath, "pods", string(pod.UID)) +func (pm *PodMounter) podPath(podUID string) string { + return filepath.Join(pm.kubeletPath, "pods", podUID) } // mountSyscallWithDefault delegates to `mountSyscall` if set, or fallbacks to platform-native `mountSyscallDefault`. @@ -367,6 +406,15 @@ func (pm *PodMounter) mountSyscallWithDefault(target string, args mountpoint.Arg return pm.mountSyscallDefault(target, args) } +// bindMountWithDefault delegates to `bindMountSyscall` if set, or fallbacks to platform-native `bindMountSyscallDefault`. +func (pm *PodMounter) bindMountSyscallWithDefault(source, target string) error { + if pm.bindMountSyscall != nil { + return pm.bindMountSyscall(source, target) + } + + return pm.bindMountSyscallDefault(source, target) +} + // unmountTarget calls `unmount` syscall on `target`. func (pm *PodMounter) unmountTarget(target string) error { return pm.mount.Unmount(target) @@ -384,3 +432,48 @@ func (pm *PodMounter) volumeNameFromTargetPath(target string) (string, error) { func (pm *PodMounter) helpMessageForGettingMountpointLogs(pod *corev1.Pod) string { return fmt.Sprintf("You can see Mountpoint logs by running: `kubectl logs -n %s %s`. If the Mountpoint Pod already restarted, you can also pass `--previous` to get logs from the previous run.", pod.Namespace, pod.Name) } + +func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeName string, credentialCtx credentialprovider.ProvideContext, fsGroup, pvMountOptions string) (*crdv1.MountpointS3PodAttachment, string, error) { + fieldFilters := client.MatchingFields{ + crdv1.FieldNodeName: os.Getenv("CSI_NODE_NAME"), // TODO + crdv1.FieldPersistentVolumeName: volumeName, + crdv1.FieldVolumeID: credentialCtx.VolumeID, + crdv1.FieldMountOptions: pvMountOptions, + crdv1.FieldWorkloadFSGroup: fsGroup, + crdv1.FieldAuthenticationSource: credentialCtx.AuthenticationSource, + } + if credentialCtx.AuthenticationSource == credentialprovider.AuthenticationSourcePod { + fieldFilters[crdv1.FieldWorkloadNamespace] = credentialCtx.PodNamespace + fieldFilters[crdv1.FieldWorkloadServiceAccountName] = credentialCtx.ServiceAccountName + } + + for { + select { + case <-ctx.Done(): + return nil, "", ctx.Err() + default: + } + + s3paList := &crdv1.MountpointS3PodAttachmentList{} + err := pm.s3paCache.List(ctx, s3paList, fieldFilters) + if err != nil { + klog.Errorf("Failed to list MountpointS3PodAttachments: %v", err) + return nil, "", err + } + for _, s3pa := range s3paList.Items { + for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for _, uid := range uids { + if uid == credentialCtx.PodID { + return &s3pa, mpPodName, nil + } + } + } + } + + select { + case <-ctx.Done(): + return nil, "", ctx.Err() + case <-time.After(100 * time.Millisecond): + } + } +} diff --git a/pkg/driver/node/mounter/pod_mounter_darwin.go b/pkg/driver/node/mounter/pod_mounter_darwin.go index 5ca4f95d..5511f49b 100644 --- a/pkg/driver/node/mounter/pod_mounter_darwin.go +++ b/pkg/driver/node/mounter/pod_mounter_darwin.go @@ -11,6 +11,10 @@ func (pm *PodMounter) mountSyscallDefault(_ string, _ mountpoint.Args) (int, err return 0, errors.New("Only supported on Linux") } +func (pm *PodMounter) bindMountSyscallDefault(source, target string) error { + return errors.New("Only supported on Linux") +} + func verifyMountPointStatx(path string) error { // statx is a Linux-specific syscall, let's simulate with os.Stat _, err := os.Stat(path) diff --git a/pkg/driver/node/mounter/pod_mounter_linux.go b/pkg/driver/node/mounter/pod_mounter_linux.go index eeeccc4a..392f4c9b 100644 --- a/pkg/driver/node/mounter/pod_mounter_linux.go +++ b/pkg/driver/node/mounter/pod_mounter_linux.go @@ -65,6 +65,14 @@ func (pm *PodMounter) mountSyscallDefault(target string, args mountpoint.Args) ( return fd, nil } +// bindMountSyscallDefault performs a bind mount syscall from `source` to `target`. +func (pm *PodMounter) bindMountSyscallDefault(source, target string) error { + if err := unix.Mount(source, target, "", unix.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to bind mount from %s to %s: %v", source, target, err) + } + return nil +} + func verifyMountPointStatx(path string) error { var stat unix.Statx_t if err := unix.Statx(unix.AT_FDCWD, path, unix.AT_STATX_FORCE_SYNC, 0, &stat); err != nil { diff --git a/pkg/driver/node/mounter/pod_mounter_test.go b/pkg/driver/node/mounter/pod_mounter_test.go index bdcae9c8..6f35c0e7 100644 --- a/pkg/driver/node/mounter/pod_mounter_test.go +++ b/pkg/driver/node/mounter/pod_mounter_test.go @@ -42,14 +42,18 @@ type testCtx struct { client *fake.Clientset mount *mount.FakeMounter + s3paCache *mounter.FakeCache mountSyscall func(target string, args mountpoint.Args) (fd int, err error) - bucketName string - kubeletPath string - targetPath string - podUID string - volumeID string - pvName string + bucketName string + kubeletPath string + targetPath string + podUID string + volumeID string + pvName string + nodeName string + fsGroup string + pvMountOptions string } func setup(t *testing.T) *testCtx { @@ -67,6 +71,10 @@ func setup(t *testing.T) *testCtx { podUID := uuid.New().String() volumeID := "s3-csi-driver-volume" pvName := "s3-csi-driver-pv" + nodeName := "test-node" + fsGroup := "1000" + pvMountOptions := "--fake-mountoption" + s3paCache := &mounter.FakeCache{} targetPath := filepath.Join( kubeletPath, fmt.Sprintf("pods/%s/volumes/kubernetes.io~csi/%s/mount", podUID, pvName), @@ -86,16 +94,20 @@ func setup(t *testing.T) *testCtx { mount := mount.NewFakeMounter(nil) testCtx := &testCtx{ - t: t, - ctx: ctx, - client: client, - mount: mount, - bucketName: bucketName, - kubeletPath: kubeletPath, - targetPath: targetPath, - podUID: podUID, - volumeID: volumeID, - pvName: pvName, + t: t, + ctx: ctx, + client: client, + mount: mount, + bucketName: bucketName, + kubeletPath: kubeletPath, + targetPath: targetPath, + podUID: podUID, + volumeID: volumeID, + pvName: pvName, + nodeName: nodeName, + fsGroup: fsGroup, + s3paCache: s3paCache, + pvMountOptions: pvMountOptions, } mountSyscall := func(target string, args mountpoint.Args) (fd int, err error) { @@ -111,7 +123,7 @@ func setup(t *testing.T) *testCtx { return dummyIMDSRegion, nil }) - podWatcher := watcher.New(client, mountpointPodNamespace, 10*time.Second) + podWatcher := watcher.New(client, mountpointPodNamespace, nodeName, 10*time.Second) stopCh := make(chan struct{}) t.Cleanup(func() { close(stopCh) @@ -119,7 +131,7 @@ func setup(t *testing.T) *testCtx { err = podWatcher.Start(stopCh) assert.NoError(t, err) - podMounter, err := mounter.NewPodMounter(podWatcher, credProvider, mount, mountSyscall, testK8sVersion) + podMounter, err := mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mount, mountSyscall, testK8sVersion) assert.NoError(t, err) testCtx.podMounter = podMounter @@ -154,7 +166,7 @@ func TestPodMounter(t *testing.T) { AuthenticationSource: credentialprovider.AuthenticationSourceDriver, VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, args) + }, args, testCtx.fsGroup, testCtx.pvMountOptions) if err != nil { log.Println("Mount failed", err) } @@ -200,7 +212,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) }) @@ -214,7 +226,7 @@ func TestPodMounter(t *testing.T) { AuthenticationSource: credentialprovider.AuthenticationSourceDriver, VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, args) + }, args, testCtx.fsGroup, testCtx.pvMountOptions) if err != nil { log.Println("Mount failed", err) } @@ -254,7 +266,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) } @@ -277,7 +289,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) if err == nil { t.Errorf("mount shouldn't succeeded if Mountpoint does not receive the mount options") } @@ -311,7 +323,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) if err == nil { t.Errorf("mount shouldn't succeeded if Mountpoint fails to start") } @@ -346,7 +358,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) if err == nil { t.Errorf("mount shouldn't succeeded if Mountpoint fails to start") } @@ -379,7 +391,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) @@ -399,7 +411,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) ok, err := testCtx.podMounter.IsMountPoint(testCtx.targetPath) @@ -430,8 +442,7 @@ func createMountpointPod(testCtx *testCtx) *mountpointPod { pod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(uuid.New().String()), - Name: mppod.MountpointPodNameFor(testCtx.podUID, testCtx.pvName), + UID: types.UID(uuid.New().String()), }, } pod, err := testCtx.client.CoreV1().Pods(mountpointPodNamespace).Create(context.TODO(), pod, metav1.CreateOptions{}) diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go new file mode 100644 index 00000000..1b75ee92 --- /dev/null +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -0,0 +1,197 @@ +package mounter + +import ( + "context" + "fmt" + "os" + "path/filepath" + + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" + "github.com/awslabs/aws-s3-csi-driver/pkg/util" + corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" + "k8s.io/mount-utils" + "sigs.k8s.io/controller-runtime/pkg/cache" +) + +type PodUnmounter struct { + nodeID string + mountUtil mount.Interface + podWatcher *watcher.Watcher + s3paCache cache.Cache + credProvider *credentialprovider.Provider +} + +func NewPodUnmounter( + nodeID string, + mountUtil mount.Interface, + podWatcher *watcher.Watcher, + s3paCache cache.Cache, + credProvider *credentialprovider.Provider, +) *PodUnmounter { + return &PodUnmounter{ + nodeID: nodeID, + mountUtil: mountUtil, + podWatcher: podWatcher, + s3paCache: s3paCache, + credProvider: credProvider, + } +} + +func (u *PodUnmounter) HandleS3PodAttachmentUpdate(old, new any) { + s3pa := new.(*crdv1.MountpointS3PodAttachment) + if s3pa.Spec.NodeName != u.nodeID { + return + } + + for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + if len(uids) == 0 { + u.unmountSourceForPod(s3pa, mpPodName) + } + } +} + +func (u *PodUnmounter) unmountSourceForPod(s3pa *crdv1.MountpointS3PodAttachment, mpPodName string) { + klog.Infof("Found Mountpoint pod with zero workload pods, unmounting it - %s", mpPodName) + mpPod, err := u.podWatcher.Get(mpPodName) + if err != nil { + klog.Infof("failed to find mpPodName during update event") + return + } + + mpPodUID := string(mpPod.UID) + podPath := filepath.Join(util.KubeletPath(), "pods", mpPodUID) + source := filepath.Join(SourceMountDir, mpPodUID) + + if err := u.writeExitFile(podPath, mpPod); err != nil { + return + } + + if err := u.unmountAndCleanup(source); err != nil { + return + } + klog.Infof("Successfully unmounted Mountpoint Pod - %s", mpPodName) + + if err := u.cleanupCredentials(s3pa, mpPodUID, podPath, source, mpPod); err != nil { + return + } +} + +func (u *PodUnmounter) writeExitFile(podPath string, mpPod *corev1.Pod) error { + podMountExitPath := mppod.PathOnHost(podPath, mppod.KnownPathMountExit) + _, err := os.OpenFile(podMountExitPath, os.O_RDONLY|os.O_CREATE, credentialprovider.CredentialFilePerm) + if err != nil { + klog.Errorf("Failed to send a exit message to Mountpoint Pod: %s", err) + return err + } + return nil +} + +func (u *PodUnmounter) unmountAndCleanup(source string) error { + if err := u.mountUtil.Unmount(source); err != nil { + klog.Errorf("Failed to unmount source %q: %v", source, err) + return err + } + + if err := os.Remove(source); err != nil { + klog.Errorf("Failed to remove source directory %q: %v", source, err) + return err + } + return nil +} + +func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1.MountpointS3PodAttachment, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { + err := u.credProvider.Cleanup(credentialprovider.CleanupContext{ + VolumeID: s3pa.Spec.VolumeID, + PodID: mpPodUID, + WritePath: filepath.Join(util.KubeletPath(), "pods", mpPodUID), + }) + if err != nil { + klog.Errorf("Failed to clean up credentials for %s: %v", source, err) + return err + } + return nil +} + +func (u *PodUnmounter) CleanupDanglingMounts() { + entries, err := os.ReadDir(SourceMountDir) + if err != nil { + klog.Errorf("Failed to read source mount directory (`%s`): %v", SourceMountDir, err) + return + } + + for _, file := range entries { + if !file.IsDir() { + continue + } + + mpPodUID := file.Name() + source := filepath.Join(SourceMountDir, mpPodUID) + // Try to find corresponding pod + mpPod, err := u.findPodByUID(mpPodUID) + if err != nil { + klog.V(4).Infof("Mountpoint Pod not found for UID %s, will only unmount and delete folder: %v", mpPodUID, err) + if err := u.unmountAndCleanup(source); err != nil { + klog.Errorf("Failed to cleanup dangling mount for Mountpoint Pod %s: %v", mpPod.Name, err) + } + continue + } + + // Check if pod has an S3PodAttachment + hasWorkloads, err := u.checkForWorkloads(mpPod) + if err != nil { + klog.Errorf("Failed to check workloads for Mountpoint Pod %s: %v", mpPod.Name, err) + continue + } + + if !hasWorkloads { + klog.Infof("Found dangling mount for Mountpoint Pod %s (UID: %s), cleaning up", mpPod.Name, mpPodUID) + podPath := filepath.Join(util.KubeletPath(), "pods", mpPodUID) + if err := u.writeExitFile(podPath, mpPod); err != nil { + return + } + + if err := u.unmountAndCleanup(source); err != nil { + klog.Errorf("Failed to cleanup dangling mount for Mountpoint Pod %s: %v", mpPod.Name, err) + continue + } + + // TODO: Skip credential clean up as we do not know volumeID OR delete all files in credential folder? + } + } +} + +func (u *PodUnmounter) findPodByUID(mpPodUID string) (*corev1.Pod, error) { + pods, err := u.podWatcher.List() + if err != nil { + return nil, err + } + + for _, pod := range pods { + if string(pod.UID) == mpPodUID { + return pod, nil + } + } + return nil, fmt.Errorf("Mountpoint Pod not found for UID %s", mpPodUID) +} + +func (u *PodUnmounter) checkForWorkloads(mpPod *corev1.Pod) (bool, error) { + s3paList := &crdv1.MountpointS3PodAttachmentList{} + err := u.s3paCache.List(context.Background(), s3paList) + if err != nil { + return false, err + } + + // Find attachment for this pod and check if it has workloads + for _, s3pa := range s3paList.Items { + for mpPodName, workloadUIDs := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + if mpPodName == mpPod.Name { + return len(workloadUIDs) > 0, nil + } + } + } + return false, nil +} diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index f7552eeb..40c9abe4 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -55,7 +55,7 @@ func (m *SystemdMounter) IsMountPoint(target string) (bool, error) { // // This method will create the target path if it does not exist and if there is an existing corrupt // mount, it will attempt an unmount before attempting the mount. -func (m *SystemdMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (m *SystemdMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, _, _ string) error { if bucketName == "" { return fmt.Errorf("bucket name is empty") } diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index 4a013689..45c2ab08 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -8,6 +8,8 @@ import ( "strings" "testing" + "slices" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" @@ -16,7 +18,6 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" "github.com/golang/mock/gomock" "k8s.io/mount-utils" - "slices" ) type mounterTestEnv struct { @@ -150,7 +151,7 @@ func TestS3MounterMount(t *testing.T) { testCase.before(t, env) } err := env.mounter.Mount(env.ctx, testCase.bucketName, testCase.targetPath, - testCase.provideCtx, mountpoint.ParseArgs(testCase.options)) + testCase.provideCtx, mountpoint.ParseArgs(testCase.options), "", "") env.mockCtl.Finish() if err != nil && !testCase.expectedErr { t.Fatal(err) diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index 121eeaad..946d082f 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -124,8 +124,11 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl args := mountpoint.ParseArgs(mountpointArgs) + fsGroup := "" + pvMountOptions := "" if capMount := volCap.GetMount(); capMount != nil && util.UsePodMounter() { if volumeMountGroup := capMount.GetVolumeMountGroup(); volumeMountGroup != "" { + fsGroup = volumeMountGroup // We need to add the following flags to support fsGroup // If these flags were already set by customer in PV mountOptions then we won't override them args.SetIfAbsent(mountpoint.ArgGid, volumeMountGroup) @@ -133,6 +136,10 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl args.SetIfAbsent(mountpoint.ArgDirMode, filePerm770) args.SetIfAbsent(mountpoint.ArgFileMode, filePerm660) } + + if mountFlags := capMount.GetMountFlags(); mountFlags != nil { + pvMountOptions = strings.Join(mountFlags, ",") + } } if util.UsePodMounter() && !args.Has(mountpoint.ArgAllowOther) { @@ -144,7 +151,7 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl credentialCtx := credentialProvideContextFromPublishRequest(req, args) - if err := ns.Mounter.Mount(ctx, bucket, target, credentialCtx, args); err != nil { + if err := ns.Mounter.Mount(ctx, bucket, target, credentialCtx, args, fsGroup, pvMountOptions); err != nil { os.Remove(target) return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } @@ -257,12 +264,17 @@ func credentialProvideContextFromPublishRequest(req *csi.NodePublishVolumeReques podID, _ = podIDFromTargetPath(req.GetTargetPath()) } + authSource := credentialprovider.AuthenticationSourceDriver + if volumeCtx[volumecontext.AuthenticationSource] != "" { + authSource = volumeCtx[volumecontext.AuthenticationSource] + } + bucketRegion, _ := args.Value(mountpoint.ArgRegion) return credentialprovider.ProvideContext{ PodID: podID, VolumeID: req.GetVolumeId(), - AuthenticationSource: volumeCtx[volumecontext.AuthenticationSource], + AuthenticationSource: authSource, PodNamespace: volumeCtx[volumecontext.CSIPodNamespace], ServiceAccountTokens: volumeCtx[volumecontext.CSIServiceAccountTokens], ServiceAccountName: volumeCtx[volumecontext.CSIServiceAccountName], diff --git a/pkg/driver/node/node_test.go b/pkg/driver/node/node_test.go index ad95fb38..52bee130 100644 --- a/pkg/driver/node/node_test.go +++ b/pkg/driver/node/node_test.go @@ -3,6 +3,7 @@ package node_test import ( "errors" "io/fs" + "strings" "testing" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -69,9 +70,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Any()) + gomock.Any(), + gomock.Eq(""), + gomock.Eq(""), + ) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -104,9 +109,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"}))) + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"})), + gomock.Eq(""), + gomock.Eq(""), + ) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -142,9 +151,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"}))) + gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"})), + gomock.Eq(""), + gomock.Eq(""), + ) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -180,9 +193,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"})), + gomock.Eq(""), + gomock.Eq(""), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -254,9 +271,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"})), + gomock.Eq("123"), + gomock.Eq(""), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -292,9 +313,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"})), + gomock.Eq("123"), + gomock.Eq("--allow-other"), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -330,9 +355,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--allow-root"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--allow-root"})), + gomock.Eq(""), + gomock.Eq(""), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -368,9 +397,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--allow-other"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--allow-other"})), + gomock.Eq(""), + gomock.Eq("--allow-other"), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -407,9 +440,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs(mountFlags))).Return(nil) + gomock.Eq(mountpoint.ParseArgs(mountFlags)), + gomock.Eq("123"), + gomock.Eq(strings.Join(mountFlags, ",")), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -572,7 +609,7 @@ var _ mounter.Mounter = &dummyMounter{} type dummyMounter struct{} -func (d *dummyMounter) Mount(ctx context.Context, bucketName string, target string, provideCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (d *dummyMounter) Mount(ctx context.Context, bucketName string, target string, provideCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error { return nil } diff --git a/pkg/podmounter/mppod/creator.go b/pkg/podmounter/mppod/creator.go index 64daa33c..cc3bef11 100644 --- a/pkg/podmounter/mppod/creator.go +++ b/pkg/podmounter/mppod/creator.go @@ -14,7 +14,6 @@ import ( // Labels populated on spawned Mountpoint Pods. const ( LabelMountpointVersion = "s3.csi.aws.com/mountpoint-version" - LabelPodUID = "s3.csi.aws.com/pod-uid" LabelVolumeName = "s3.csi.aws.com/volume-name" LabelCSIDriverVersion = "s3.csi.aws.com/mounted-by-csi-driver-version" ) @@ -46,21 +45,14 @@ func NewCreator(config Config) *Creator { return &Creator{config: config} } -// Create returns a new Mountpoint Pod spec to schedule for given `pod` and `pv`. -// -// It automatically assigns Mountpoint Pod to `pod`'s node. -// The name of the Mountpoint Pod is consistently generated from `pod` and `pv` using `MountpointPodNameFor` function. -func (c *Creator) Create(pod *corev1.Pod, pv *corev1.PersistentVolume) *corev1.Pod { - node := pod.Spec.NodeName - name := MountpointPodNameFor(string(pod.UID), pv.Name) - +// Create returns a new Mountpoint Pod spec to schedule for given `node` and `pv`. +func (c *Creator) Create(node string, pv *corev1.PersistentVolume) *corev1.Pod { mpPod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: c.config.Namespace, + GenerateName: "mp-", + Namespace: c.config.Namespace, Labels: map[string]string{ LabelMountpointVersion: c.config.MountpointVersion, - LabelPodUID: string(pod.UID), LabelVolumeName: pv.Name, LabelCSIDriverVersion: c.config.CSIDriverVersion, }, @@ -133,7 +125,7 @@ func (c *Creator) Create(pod *corev1.Pod, pv *corev1.PersistentVolume) *corev1.P }, } - volumeAttributes := extractVolumeAttributes(pv) + volumeAttributes := ExtractVolumeAttributes(pv) if saName := volumeAttributes[volumecontext.MountpointPodServiceAccountName]; saName != "" { mpPod.Spec.ServiceAccountName = saName @@ -142,9 +134,9 @@ func (c *Creator) Create(pod *corev1.Pod, pv *corev1.PersistentVolume) *corev1.P return mpPod } -// extractVolumeAttributes extracts volume attributes from given `pv`. +// ExtractVolumeAttributes extracts volume attributes from given `pv`. // It always returns a non-nil map, and it's safe to use even though `pv` doesn't contain any volume attributes. -func extractVolumeAttributes(pv *corev1.PersistentVolume) map[string]string { +func ExtractVolumeAttributes(pv *corev1.PersistentVolume) map[string]string { csiSpec := pv.Spec.CSI if csiSpec == nil { return map[string]string{} diff --git a/pkg/podmounter/mppod/creator_test.go b/pkg/podmounter/mppod/creator_test.go index 62266deb..13fb31ad 100644 --- a/pkg/podmounter/mppod/creator_test.go +++ b/pkg/podmounter/mppod/creator_test.go @@ -5,7 +5,6 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" @@ -45,14 +44,13 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu creator := mppod.NewCreator(createTestConfig(clusterVariant)) verifyDefaultValues := func(mpPod *corev1.Pod) { - // This is a hash of `testPodUID` + `testVolName` - assert.Equals(t, "mp-8ef7856a0c7f1d5706bd6af93fdc4bc90b33cf2ceb6769b4afd62586", mpPod.Name) + assert.Equals(t, "mp-", mpPod.GenerateName) + assert.Equals(t, "", mpPod.Name) assert.Equals(t, namespace, mpPod.Namespace) assert.Equals(t, map[string]string{ mppod.LabelMountpointVersion: mountpointVersion, - mppod.LabelPodUID: testPodUID, - mppod.LabelVolumeName: testVolName, mppod.LabelCSIDriverVersion: csiDriverVersion, + mppod.LabelVolumeName: testVolName, }, mpPod.Labels) assert.Equals(t, priorityClassName, mpPod.Spec.PriorityClassName) @@ -105,14 +103,7 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu } t.Run("Empty PV", func(t *testing.T) { - mpPod := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, @@ -122,14 +113,7 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu }) t.Run("With ServiceAccountName specified in PV", func(t *testing.T) { - mpPod := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, diff --git a/pkg/podmounter/mppod/mppod.go b/pkg/podmounter/mppod/mppod.go deleted file mode 100644 index 05d3b4d2..00000000 --- a/pkg/podmounter/mppod/mppod.go +++ /dev/null @@ -1,17 +0,0 @@ -// Package mppod provides utilities for creating and accessing Mountpoint Pods. -package mppod - -import ( - "crypto/sha256" - "fmt" -) - -// MountpointPodNameFor returns a consistent and unique Pod name for -// Mountpoint Pod for given `podUID` and `volumeName`. -// -// Changing output of this function might cause duplicate Mountpoint Pods to be spawned, -// ideally multiple implementation of this function shouldn't co-exists in the same cluster -// unless there is a clean install of the CSI Driver. -func MountpointPodNameFor(podUID string, volumeName string) string { - return fmt.Sprintf("mp-%x", sha256.Sum224(fmt.Appendf(nil, "%s%s", podUID, volumeName))) -} diff --git a/pkg/podmounter/mppod/mppod_test.go b/pkg/podmounter/mppod/mppod_test.go deleted file mode 100644 index f24a5d7b..00000000 --- a/pkg/podmounter/mppod/mppod_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package mppod_test - -import ( - "testing" - - "github.com/google/uuid" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - - "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" - "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" -) - -func TestGeneratingMountpointPodName(t *testing.T) { - t.Run("Consistency", func(t *testing.T) { - pod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{UID: types.UID(uuid.New().String())}, - } - pvc := &corev1.PersistentVolumeClaim{ - Spec: corev1.PersistentVolumeClaimSpec{VolumeName: "test-vol"}, - } - - // Ensure `MountpointPodNameFor` returns a consistent output for the same Pod and PVC. - assert.Equals(t, - mppod.MountpointPodNameFor(string(pod.UID), pvc.Spec.VolumeName), - mppod.MountpointPodNameFor(string(pod.UID), pvc.Spec.VolumeName)) - }) - - t.Run("Uniqueness", func(t *testing.T) { - pod1 := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{UID: types.UID(uuid.New().String())}, - } - pvc1 := &corev1.PersistentVolumeClaim{ - Spec: corev1.PersistentVolumeClaimSpec{VolumeName: "test-vol-1"}, - } - pod2 := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{UID: types.UID(uuid.New().String())}, - } - pvc2 := &corev1.PersistentVolumeClaim{ - Spec: corev1.PersistentVolumeClaimSpec{VolumeName: "test-vol-2"}, - } - - if mppod.MountpointPodNameFor(string(pod1.UID), pvc1.Spec.VolumeName) == mppod.MountpointPodNameFor(string(pod1.UID), pvc2.Spec.VolumeName) { - t.Error("Different PVCs with same Pod should return a different Mountpoint Pod name") - } - if mppod.MountpointPodNameFor(string(pod1.UID), pvc1.Spec.VolumeName) == mppod.MountpointPodNameFor(string(pod2.UID), pvc1.Spec.VolumeName) { - t.Error("Different Pods with same PVC should return a different Mountpoint Pod name") - } - }) - - t.Run("Snapshot", func(t *testing.T) { - mountpointPodName := mppod.MountpointPodNameFor("a4509011-bd2a-4f37-b1b0-05d715087852", "test-vol") - assert.Equals(t, "mp-55f7d2331f3149f00d62d7af839d4cee895e1c68a2f0d96ffd359f79", mountpointPodName) - }) -} diff --git a/pkg/podmounter/mppod/watcher/watcher.go b/pkg/podmounter/mppod/watcher/watcher.go index c33f5692..5fffc9c2 100644 --- a/pkg/podmounter/mppod/watcher/watcher.go +++ b/pkg/podmounter/mppod/watcher/watcher.go @@ -10,6 +10,9 @@ import ( corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" listerv1 "k8s.io/client-go/listers/core/v1" @@ -33,8 +36,15 @@ type Watcher struct { } // New creates a new [Watcher] with the given Kubernetes client, Mountpoint Pod namespace, and resync duration. -func New(client kubernetes.Interface, namespace string, defaultResync time.Duration) *Watcher { - factory := informers.NewSharedInformerFactoryWithOptions(client, defaultResync, informers.WithNamespace(namespace)) +func New(client kubernetes.Interface, namespace, nodeName string, defaultResync time.Duration) *Watcher { + factory := informers.NewSharedInformerFactoryWithOptions( + client, + defaultResync, + informers.WithNamespace(namespace), + informers.WithTweakListOptions(func(options *metav1.ListOptions) { + options.FieldSelector = fields.OneTermEqualSelector("spec.nodeName", nodeName).String() + }), + ) informer := factory.Core().V1().Pods().Informer() lister := factory.Core().V1().Pods().Lister().Pods(namespace) return &Watcher{informer, lister} @@ -51,6 +61,16 @@ func (w *Watcher) Start(stopCh <-chan struct{}) error { return nil } +// Get returns pod from watcher's cache. +func (w *Watcher) Get(name string) (*corev1.Pod, error) { + return w.lister.Get(name) +} + +// List returns all pods from watcher's cache. +func (w *Watcher) List() ([]*corev1.Pod, error) { + return w.lister.List(labels.Everything()) +} + // Wait blocks until the specified Mountpoint Pod is found and ready, or until the context is cancelled. func (w *Watcher) Wait(ctx context.Context, name string) (*corev1.Pod, error) { // Set a watcher for Pod create & update events diff --git a/pkg/podmounter/mppod/watcher/watcher_test.go b/pkg/podmounter/mppod/watcher/watcher_test.go index 9055c3f9..716a85be 100644 --- a/pkg/podmounter/mppod/watcher/watcher_test.go +++ b/pkg/podmounter/mppod/watcher/watcher_test.go @@ -108,7 +108,7 @@ func TestGettingPodsConcurrently(t *testing.T) { } func createAndStartWatcher(t *testing.T, client kubernetes.Interface) *watcher.Watcher { - mpPodWatcher := watcher.New(client, testMountpointPodNamespace, 10*time.Second) + mpPodWatcher := watcher.New(client, testMountpointPodNamespace, "test-node", 10*time.Second) stopCh := make(chan struct{}) t.Cleanup(func() { diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 64606852..5e3e5853 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -1,6 +1,11 @@ package controller_test import ( + "fmt" + "strconv" + "strings" + + "github.com/google/uuid" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -10,12 +15,21 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" ) var _ = Describe("Mountpoint Controller", func() { + var testNode string + + BeforeEach(func() { + testNode = generateRandomNodeName() + }) + Context("Static Provisioning", func() { Context("Scheduled Pod with pre-bound PV and PVC", func() { It("should schedule a Mountpoint Pod", func() { @@ -23,9 +37,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -35,10 +49,10 @@ var _ = Describe("Mountpoint Controller", func() { vol2.bind() pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - waitAndVerifyMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { @@ -46,9 +60,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -58,10 +72,10 @@ var _ = Describe("Mountpoint Controller", func() { vol2.bind() pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) @@ -70,13 +84,13 @@ var _ = Describe("Mountpoint Controller", func() { vol := createVolume() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) vol.bind() - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -84,32 +98,32 @@ var _ = Describe("Mountpoint Controller", func() { vol2 := createVolume() pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol1.bind() - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol2.bind() - waitAndVerifyMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { vol := createVolume(withCSIDriver(ebsCSIDriver)) pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) vol.bind() - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -117,20 +131,20 @@ var _ = Describe("Mountpoint Controller", func() { vol2 := createVolume(withCSIDriver(ebsCSIDriver)) pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol2.bind() - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol1.bind() - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) @@ -141,11 +155,11 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -156,13 +170,13 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - waitAndVerifyMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { @@ -171,11 +185,11 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -186,13 +200,13 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) @@ -202,15 +216,15 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) vol.bind() - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -219,18 +233,18 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) vol2.bind() - expectNoMountpointPodFor(pod, vol1) - waitAndVerifyMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) vol1.bind() - waitAndVerifyMountpointPodFor(pod, vol1) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { @@ -238,15 +252,15 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) vol.bind() - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -255,61 +269,395 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) vol1.bind() vol2.bind() - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) Context("Multiple Pods using the same PV and PVC", func() { Context("Same Node", func() { Context("Pre-bound PV and PVC", func() { - It("should schedule a Mountpoint Pod per Workload Pod", func() { + It("should schedule single Mountpoint Pod", func() { vol := createVolume() vol.bind() pod1 := createPod(withPVC(vol.pvc)) pod2 := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) + + pod1.schedule(testNode) - pod1.schedule("test-node") + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + pod2.schedule(testNode) - pod2.schedule("test-node") + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod1, s3pa.ResourceVersion) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod2, s3pa.ResourceVersion) - waitAndVerifyMountpointPodFor(pod2, vol) + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") }) }) Context("Late PV and PVC binding", func() { - It("should schedule a Mountpoint Pod per Workload Pod", func() { + It("should schedule single Mountpoint Pod", func() { + vol := createVolume() + + pod1 := createPod(withPVC(vol.pvc)) + pod2 := createPod(withPVC(vol.pvc)) + pod1.schedule(testNode) + + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) + + vol.bind() + + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) + + pod2.schedule(testNode) + + waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod1, s3pa.ResourceVersion) + waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod2, s3pa.ResourceVersion) + }) + }) + + Context("MountOptions", func() { + It("should schedule different Mountpoint Pods if mountOptions were modified", func() { vol := createVolume() + vol.bind() + pv := vol.pv pod1 := createPod(withPVC(vol.pvc)) + pod1.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) + + pv.Spec.MountOptions = []string{"--allow-delete"} + + Expect(k8sClient.Update(ctx, pv)).To(Succeed()) + pod2 := createPod(withPVC(vol.pvc)) - pod1.schedule("test-node") + pod2.schedule(testNode) + + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod2) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + }) + }) + + Context("FSGroup", func() { + It("should schedule single Mountpoint Pod if workload pods have the same FSGroup", func() { + vol := createVolume() + vol.bind() + + pod1 := createPod(withPVC(vol.pvc), withFSGroup(1111)) + pod2 := createPod(withPVC(vol.pvc), withFSGroup(1111)) + pod1.schedule(testNode) + pod2.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["WorkloadFSGroup"] = "1111" + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") + }) + + It("should schedule different Mountpoint Pods if workload pods have different FSGroup", func() { + vol := createVolume() + vol.bind() + + pod1 := createPod(withPVC(vol.pvc)) // no fsGroup + pod2 := createPod(withPVC(vol.pvc), withFSGroup(1111)) + pod3 := createPod(withPVC(vol.pvc), withFSGroup(2222)) + pod1.schedule(testNode) + pod2.schedule(testNode) + pod3.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadFSGroup"] = "1111" + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadFSGroup"] = "2222" + s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(s3pa1.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + Expect(s3pa2.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod2.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + }) + }) + + Context("authenticationSource=pod", func() { + It("should schedule single Mountpoint Pod if workload pods have the same namespace and service account", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() + + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod1.schedule(testNode) + pod2.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadServiceAccountName"] = sa.Name + expectedFields["WorkloadNamespace"] = defaultNamespace + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") + }) + + It("should schedule single Mountpoint Pod if workload pods have the same namespace, service account and IRSA role annotation", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + Annotations: map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role"}, + }, + } + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod1.schedule(testNode) + pod2.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadServiceAccountName"] = sa.Name + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role" + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") + }) + It("should schedule different Mountpoint Pods if workload pods have different IRSA role annotations for the same service account", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) vol.bind() - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) // no IRSA annotation + pod1.schedule(testNode) + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountName"] = sa.Name + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "" + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + + sa.Annotations = map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role-1"} + Expect(k8sClient.Update(ctx, sa)).To(Succeed()) + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod2.schedule(testNode) + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role-1" + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + + sa.Annotations = map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role-2"} + Expect(k8sClient.Update(ctx, sa)).To(Succeed()) + pod3 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod3.schedule(testNode) + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role-2" + s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(s3pa1.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + Expect(s3pa2.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod2.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + }) + + It("should schedule different Mountpoint Pods if workload pods have different service account names in the same namespace", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() - pod2.schedule("test-node") + sa1 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + sa2 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa1)).To(Succeed()) + Expect(k8sClient.Create(ctx, sa2)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc)) // default service account + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa1.Name)) + pod3 := createPod(withPVC(vol.pvc), withServiceAccount(sa2.Name)) + pod1.schedule(testNode) + pod2.schedule(testNode) + pod3.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountName"] = "default" + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadServiceAccountName"] = sa1.Name + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadServiceAccountName"] = sa2.Name + s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(s3pa1.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + Expect(s3pa2.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod2.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + }) + + It("should schedule different Mountpoint Pods if workload pods have same service account names in the different namespace", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() - waitAndVerifyMountpointPodFor(pod2, vol) + ns := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "ns-", + }, + } + Expect(k8sClient.Create(ctx, ns)).To(Succeed()) + sa1 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa1)).To(Succeed()) + sa2 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: sa1.Name, + Namespace: ns.Name, + }, + } + Expect(k8sClient.Create(ctx, sa2)).To(Succeed()) + _, err := controllerutil.CreateOrUpdate(ctx, k8sClient, sa2, func() error { return nil }) + Expect(err).To(Succeed()) + pvc2 := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-pvc", + Namespace: ns.Name, + }, + Spec: corev1.PersistentVolumeClaimSpec{ + StorageClassName: vol.pvc.Spec.StorageClassName, + AccessModes: vol.pvc.Spec.AccessModes, + Resources: vol.pvc.Spec.Resources, + }, + } + Expect(k8sClient.Create(ctx, pvc2)).To(Succeed()) + vol2 := testVolume{pv: vol.pv, pvc: pvc2} + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa1.Name)) + pod1.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountName"] = sa1.Name + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + + pod2 := createPod(withPVC(pvc2), withServiceAccount(sa1.Name), withNamespace(ns.Name)) + vol2.bind() + pod2.schedule(testNode) + + expectedFields["WorkloadNamespace"] = ns.Name + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") }) }) }) @@ -323,17 +671,16 @@ var _ = Describe("Mountpoint Controller", func() { pod1 := createPod(withPVC(vol.pvc)) pod2 := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) pod1.schedule("test-node1") - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node1", vol, pod1) + expectNoS3PodAttachmentWithFields(defaultExpectedFields("test-node2", vol.pv)) pod2.schedule("test-node2") - waitAndVerifyMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node2", vol, pod2) }) }) @@ -345,17 +692,16 @@ var _ = Describe("Mountpoint Controller", func() { pod2 := createPod(withPVC(vol.pvc)) pod1.schedule("test-node1") - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) vol.bind() - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node1", vol, pod1) + expectNoS3PodAttachmentWithFields(defaultExpectedFields("test-node2", vol.pv)) pod2.schedule("test-node2") - waitAndVerifyMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node2", vol, pod2) }) }) }) @@ -367,9 +713,9 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withVolume("empty-dir", corev1.VolumeSource{ EmptyDir: &corev1.EmptyDirVolumeSource{}, })) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodForWorkloadPod(pod) + expectNoS3PodAttachmentWithFields(map[string]string{"NodeName": testNode}) }) It("should not schedule a Mountpoint Pod if the Pod only uses a hostPath volume", func() { @@ -379,9 +725,9 @@ var _ = Describe("Mountpoint Controller", func() { Type: ptr.To(corev1.HostPathDirectoryOrCreate), }, })) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodForWorkloadPod(pod) + expectNoS3PodAttachmentWithFields(map[string]string{"NodeName": testNode}) }) It("should not schedule a Mountpoint Pod if the Pod only different volume-types/CSI-drivers", func() { @@ -400,9 +746,9 @@ var _ = Describe("Mountpoint Controller", func() { EmptyDir: &corev1.EmptyDirVolumeSource{}, }), ) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodForWorkloadPod(pod) + expectNoS3PodAttachmentWithFields(map[string]string{"NodeName": testNode}) }) }) @@ -412,10 +758,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + _, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) mountpointPod.succeed() @@ -427,11 +772,10 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) // `pod` got a `mountpointPod` - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + s3pa, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) // `mountpointPod` got terminated mountpointPod.succeed() @@ -439,30 +783,28 @@ var _ = Describe("Mountpoint Controller", func() { // `pod` got terminated pod.terminate() + waitForObjectToDisappear(s3pa) // Since `pod` was in `Pending` state, termination of Pod will still keep that in // `Pending` state but will populate `DeletionTimestamp` to indicate this Pod is terminating. // In this case, there shouldn't be a new Mountpoint Pod spawned for it. - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) - It("should delete Mountpoint Pod if the Workload Pod is terminated", func() { + It("should delete S3 Pod Attachment if the Workload Pod is terminated", func() { vol := createVolume() vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) // `pod` got a `mountpointPod` - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) // `pod` got terminated pod.terminate() waitForObjectToDisappear(pod.Pod) - - // `mountpointPod` scheduled for `pod` should also get terminated - waitForObjectToDisappear(mountpointPod.Pod) + waitForObjectToDisappear(s3pa) }) }) @@ -481,10 +823,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + _, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) Expect(mountpointPod.Spec.ServiceAccountName).To(Equal(sa.Name)) @@ -550,6 +891,30 @@ func withVolume(name string, vol corev1.VolumeSource) podModifier { } } +// withFSGroup returns a `podModifier` that sets fsGroup in the Pod's security context. +func withFSGroup(fsGroup int64) podModifier { + return func(pod *corev1.Pod) { + if pod.Spec.SecurityContext == nil { + pod.Spec.SecurityContext = &corev1.PodSecurityContext{} + } + pod.Spec.SecurityContext.FSGroup = &fsGroup + } +} + +// withServiceAccount returns a `podModifier` that sets ServiceAccountName. +func withServiceAccount(saName string) podModifier { + return func(pod *corev1.Pod) { + pod.Spec.ServiceAccountName = saName + } +} + +// withNamespace returns a `podModifier` that sets Namespace. +func withNamespace(namespace string) podModifier { + return func(pod *corev1.Pod) { + pod.Namespace = namespace + } +} + // A podModifier is a function for modifying Pod to be created. type podModifier func(*corev1.Pod) @@ -674,54 +1039,101 @@ func createVolume(modifiers ...volumeModifier) *testVolume { return testVolume } -// waitForMountpointPodFor waits and returns the Mountpoint Pod scheduled for given `pod` and `vol`. -func waitForMountpointPodFor(pod *testPod, vol *testVolume) *testPod { - mountpointPodKey := mountpointPodNameFor(pod, vol) +// waitForMountpointPodWithName waits and returns the Mountpoint Pod scheduled for given `mpPodName` +func waitForMountpointPodWithName(mpPodName string) *testPod { mountpointPod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: mountpointPodKey.Name, - Namespace: mountpointPodKey.Namespace, + Name: mpPodName, + Namespace: mountpointNamespace, }, } waitForObject(mountpointPod) return &testPod{Pod: mountpointPod} } -// expectNoMountpointPodFor verifies that there is no Mountpoint Pod scheduled for given `pod` and `vol`. -func expectNoMountpointPodFor(pod *testPod, vol *testVolume) { - mountpointPodKey := mountpointPodNameFor(pod, vol) - expectNoObject(&corev1.Pod{ObjectMeta: metav1.ObjectMeta{ - Name: mountpointPodKey.Name, - Namespace: mountpointPodKey.Namespace, - }}) +// expectNoS3PodAttachmentWithFields verifies that no MountpointS3PodAttachment matching specified fields exists within a time period +func expectNoS3PodAttachmentWithFields(expectedFields map[string]string) { + Consistently(func(g Gomega) { + list := &crdv1.MountpointS3PodAttachmentList{} + g.Expect(k8sClient.List(ctx, list)).To(Succeed()) + + for i := range list.Items { + cr := &list.Items[i] + if matchesSpec(cr.Spec, expectedFields) { + g.Expect(false).To(BeTrue(), "Found matching MountpointS3PodAttachment when none was expected: %#v", cr) + } + } + }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) } -// expectNoMountpointPodForWorkloadPod verifies that there is no Mountpoint Pod scheduled for given `pod`. -// `expectNoMountpointPodFor` is preferable to this method if the `vol` is known as this performs a slower list operation. -func expectNoMountpointPodForWorkloadPod(pod *testPod) { - Consistently(func(g Gomega) { - podList := &corev1.PodList{} - g.Expect(k8sClient.List(ctx, podList, - client.InNamespace(mountpointNamespace), client.MatchingLabels{ - mppod.LabelPodUID: string(pod.UID), - }, - )).To(Succeed()) +// expectNoPodUIDInS3PodAttachment validates that pod UID does not exist in MountpointS3PodToWorkloadPodUIDs map +func expectNoPodUIDInS3PodAttachment(s3pa *crdv1.MountpointS3PodAttachment, podUID string) { + for _, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for _, uid := range uids { + if uid == podUID { + Expect(false).To(BeTrue(), "Found pod UID %s in S3PodAttachment when none was expected: %#v", podUID, s3pa) + } + } + } +} - g.Expect(podList.Items).To(BeEmpty(), "Expected empty list but got: %#v", podList) - g.Expect(podList.Continue).To(BeEmpty(), "Continue token on list must be empty but got: %s", podList.Continue) - }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) +// waitAndVerifyS3PodAttachmentAndMountpointPod waits and verifies that MountpointS3PodAttachment and Mountpoint Pod +// are created for given `node`, `vol` and `pod` +func waitAndVerifyS3PodAttachmentAndMountpointPod( + node string, + vol *testVolume, + pod *testPod, +) (*crdv1.MountpointS3PodAttachment, *testPod) { + s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(node, vol.pv), "") + Expect(len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) + return s3pa, mpPod } -// waitAndVerifyMountpointPodFor waits and verifies Mountpoint Pod scheduled for given `pod` and `vol.` -func waitAndVerifyMountpointPodFor(pod *testPod, vol *testVolume) { - mountpointPod := waitForMountpointPodFor(pod, vol) +// waitAndVerifyS3PodAttachmentAndMountpointPod waits and verifies that MountpointS3PodAttachment with `minVersion` and Mountpoint Pod +// are created for given `node`, `vol` and `pod` +func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion( + testNode string, + vol *testVolume, + pod *testPod, + minVersion string, +) (*crdv1.MountpointS3PodAttachment, *testPod) { + s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv), minVersion) + Expect(len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) + return s3pa, mpPod +} + +// waitAndVerifyMountpointPodFromPodAttachment waits and verifies Mountpoint Pod scheduled for given `s3pa`, `pod` and `vol.` +func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1.MountpointS3PodAttachment, pod *testPod, vol *testVolume) *testPod { + // Find the mpPodName where pod.UID exists in the value slice + var mpPodName string + podUID := string(pod.UID) + + for k, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for _, uid := range uids { + if uid == podUID { + mpPodName = k + break + } + } + if mpPodName != "" { + break + } + } + + Expect(mpPodName).NotTo(BeEmpty(), "No Mountpoint Pod found for pod UID %s in MountpointS3PodAttachment: %#v", podUID, s3pa) + Expect(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[mpPodName]).To(ContainElement(podUID)) + + mountpointPod := waitForMountpointPodWithName(mpPodName) verifyMountpointPodFor(pod, vol, mountpointPod) + + return mountpointPod } // verifyMountpointPodFor verifies given `mountpointPod` for given `pod` and `vol`. func verifyMountpointPodFor(pod *testPod, vol *testVolume, mountpointPod *testPod) { Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelMountpointVersion, mountpointVersion)) - Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelPodUID, string(pod.UID))) Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelVolumeName, vol.pvc.Spec.VolumeName)) Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelCSIDriverVersion, version.GetVersion().DriverVersion)) @@ -759,6 +1171,82 @@ func waitForObject[Obj client.Object](obj Obj, verifiers ...func(Gomega, Obj)) { }, defaultWaitTimeout, defaultWaitRetryPeriod).Should(Succeed()) } +// waitForS3PodAttachmentWithFields waits until a MountpointS3PodAttachment matching specified node and pv appears in the cluster +func waitForS3PodAttachmentWithFields( + expectedFields map[string]string, + minResourceVersion string, + verifiers ...func(Gomega, *crdv1.MountpointS3PodAttachment), +) *crdv1.MountpointS3PodAttachment { + var matchedCR *crdv1.MountpointS3PodAttachment + + Eventually(func(g Gomega) { + list := &crdv1.MountpointS3PodAttachmentList{} + g.Expect(k8sClient.List(ctx, list)).To(Succeed()) + + for i := range list.Items { + cr := &list.Items[i] + if matchesSpec(cr.Spec, expectedFields) { + // Skip if the resource version isn't newer than the minimum + if minResourceVersion != "" { + minVersion, err := strconv.ParseInt(minResourceVersion, 10, 64) + g.Expect(err).NotTo(HaveOccurred()) + + currentVersion, err := strconv.ParseInt(cr.ResourceVersion, 10, 64) + g.Expect(err).NotTo(HaveOccurred()) + + if currentVersion <= minVersion { + continue + } + } + + for _, verifier := range verifiers { + verifier(g, cr) + } + matchedCR = cr + return + } + } + + g.Expect(false).To(BeTrue(), "No matching MountpointS3PodAttachment found") + }, defaultWaitTimeout, defaultWaitRetryPeriod).Should(Succeed()) + + return matchedCR +} + +// matchesSpec checks whether MountpointS3PodAttachmentSpec matches `expected` fields +func matchesSpec(spec crdv1.MountpointS3PodAttachmentSpec, expected map[string]string) bool { + specValues := map[string]string{ + "NodeName": spec.NodeName, + "PersistentVolumeName": spec.PersistentVolumeName, + "VolumeID": spec.VolumeID, + "MountOptions": spec.MountOptions, + "AuthenticationSource": spec.AuthenticationSource, + "WorkloadFSGroup": spec.WorkloadFSGroup, + "WorkloadServiceAccountName": spec.WorkloadServiceAccountName, + "WorkloadNamespace": spec.WorkloadNamespace, + "WorkloadServiceAccountIAMRoleARN": spec.WorkloadServiceAccountIAMRoleARN, + } + + for k, v := range expected { + if specValues[k] != v { + return false + } + } + return true +} + +// defaultExpectedFields return default test expected fields for MountpointS3PodAttachmentSpec matching +func defaultExpectedFields(nodeName string, pv *corev1.PersistentVolume) map[string]string { + return map[string]string{ + "NodeName": nodeName, + "PersistentVolumeName": pv.Name, + "VolumeID": pv.Spec.CSI.VolumeHandle, + "MountOptions": strings.Join(pv.Spec.MountOptions, ","), + "AuthenticationSource": "driver", + "WorkloadFSGroup": "", + } +} + // waitForObjectToDisappear waits until `obj` disappears in the control plane. func waitForObjectToDisappear(obj client.Object) { key := types.NamespacedName{Name: obj.GetName(), Namespace: obj.GetNamespace()} @@ -773,20 +1261,7 @@ func waitForObjectToDisappear(obj client.Object) { }, defaultWaitTimeout, defaultWaitRetryPeriod).Should(Succeed()) } -// expectNoObject verifies object with given key does not exists within a time period. -func expectNoObject(obj client.Object) { - key := types.NamespacedName{Name: obj.GetName(), Namespace: obj.GetNamespace()} - Consistently(func(g Gomega) { - err := k8sClient.Get(ctx, key, obj) - g.Expect(err).ToNot(BeNil(), "The object expected not to exists but its found: %#v", obj) - g.Expect(apierrors.IsNotFound(err)).To(BeTrue(), "Expected not found error but fond: %v", err) - }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) -} - -// mountpointPodNameFor returns namespaced name of Mountpoint Pod for given `pod` and `vol`. -func mountpointPodNameFor(pod *testPod, vol *testVolume) types.NamespacedName { - return types.NamespacedName{ - Name: mppod.MountpointPodNameFor(string(pod.Pod.UID), vol.pvc.Spec.VolumeName), - Namespace: mountpointNamespace, - } +// generateRandomNodeName generates random node name +func generateRandomNodeName() string { + return fmt.Sprintf("test-node-%s", uuid.New().String()[:8]) } diff --git a/tests/controller/suite_test.go b/tests/controller/suite_test.go index 3ae0381c..4b090f2d 100644 --- a/tests/controller/suite_test.go +++ b/tests/controller/suite_test.go @@ -3,9 +3,12 @@ package controller_test import ( "context" "fmt" + "os" "testing" "time" + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + "github.com/go-logr/logr" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -18,6 +21,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + "sigs.k8s.io/controller-runtime/pkg/manager" "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" @@ -65,7 +69,14 @@ var _ = BeforeSuite(func() { ctx, cancel = context.WithCancel(context.TODO()) By("Bootstrapping test environment") - testEnv = &envtest.Environment{} + + crdv1.AddToScheme(scheme.Scheme) + testEnv = &envtest.Environment{ + CRDInstallOptions: envtest.CRDInstallOptions{ + Paths: []string{"../../charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml"}, + }, + ErrorIfCRDPathMissing: true, + } var err error cfg, err = testEnv.Start() @@ -79,6 +90,8 @@ var _ = BeforeSuite(func() { k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{Scheme: scheme.Scheme}) Expect(err).ToNot(HaveOccurred()) + IndexMountpointS3PodAttachmentFields(logf.Log.WithName("controller-test"), k8sManager) + err = csicontroller.NewReconciler(k8sManager.GetClient(), mppod.Config{ Namespace: mountpointNamespace, MountpointVersion: mountpointVersion, @@ -99,6 +112,7 @@ var _ = BeforeSuite(func() { }() createMountpointNamespace() + createDefaultServiceAccount() createMountpointPriorityClass() }) @@ -117,6 +131,20 @@ func createMountpointNamespace() { waitForObject(namespace) } +// createDefaultServiceAccount creates default service account in the control plane. +func createDefaultServiceAccount() { + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "default", + Namespace: defaultNamespace, + }, + } + + By(fmt.Sprintf("Creating default service account in %q", mountpointNamespace)) + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + waitForObject(sa) +} + // createMountpointPriorityClass creates priority class for Mountpoint Pods. func createMountpointPriorityClass() { By(fmt.Sprintf("Creating priority class %q for Mountpoint Pods", mountpointPriorityClassName)) @@ -127,3 +155,25 @@ func createMountpointPriorityClass() { Expect(k8sClient.Create(ctx, priorityClass)).To(Succeed()) waitForObject(priorityClass) } + +func IndexMountpointS3PodAttachmentFields(log logr.Logger, mgr manager.Manager) { + indexField(log, mgr, crdv1.FieldNodeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) + indexField(log, mgr, crdv1.FieldPersistentVolumeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) + indexField(log, mgr, crdv1.FieldVolumeID, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) + indexField(log, mgr, crdv1.FieldMountOptions, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) + indexField(log, mgr, crdv1.FieldAuthenticationSource, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) + indexField(log, mgr, crdv1.FieldWorkloadFSGroup, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) + indexField(log, mgr, crdv1.FieldWorkloadServiceAccountName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) + indexField(log, mgr, crdv1.FieldWorkloadNamespace, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) + indexField(log, mgr, crdv1.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) +} + +func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1.MountpointS3PodAttachment) string) { + err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { + return []string{extractor(obj.(*crdv1.MountpointS3PodAttachment))} + }) + if err != nil { + log.Error(err, fmt.Sprintf("Failed to create a %s field indexer", field)) + os.Exit(1) + } +} diff --git a/tests/e2e-kubernetes/e2e_test.go b/tests/e2e-kubernetes/e2e_test.go index afe2bbc0..9ee0f99e 100644 --- a/tests/e2e-kubernetes/e2e_test.go +++ b/tests/e2e-kubernetes/e2e_test.go @@ -61,6 +61,7 @@ var CSITestSuites = []func() framework.TestSuite{ custom_testsuites.InitS3MountOptionsTestSuite, custom_testsuites.InitS3CSICredentialsTestSuite, custom_testsuites.InitS3CSICacheTestSuite, + custom_testsuites.InitS3CSIPodSharingTestSuite, } // This executes testSuites for csi volumes. diff --git a/tests/e2e-kubernetes/go.mod b/tests/e2e-kubernetes/go.mod index 1625d5db..cc84b6f6 100644 --- a/tests/e2e-kubernetes/go.mod +++ b/tests/e2e-kubernetes/go.mod @@ -145,6 +145,7 @@ require ( k8s.io/kubelet v0.31.3 // indirect k8s.io/mount-utils v0.31.3 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3 // indirect + sigs.k8s.io/controller-runtime v0.19.2 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect sigs.k8s.io/yaml v1.4.0 // indirect diff --git a/tests/e2e-kubernetes/go.sum b/tests/e2e-kubernetes/go.sum index 0f84289c..464fa63f 100644 --- a/tests/e2e-kubernetes/go.sum +++ b/tests/e2e-kubernetes/go.sum @@ -438,6 +438,8 @@ k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1 k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3 h1:2770sDpzrjjsAtVhSeUFseziht227YAWYHLGNM8QPwY= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= +sigs.k8s.io/controller-runtime v0.19.2 h1:3sPrF58XQEPzbE8T81TN6selQIMGbtYwuaJ6eDssDF8= +sigs.k8s.io/controller-runtime v0.19.2/go.mod h1:iRmWllt8IlaLjvTTDLhRBXIEtkCK6hwVBJJsYS9Ajf4= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go new file mode 100644 index 00000000..bb82adc8 --- /dev/null +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -0,0 +1,304 @@ +package custom_testsuites + +import ( + "context" + "fmt" + "path/filepath" + "strings" + "time" + + crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/kubernetes/test/e2e/framework" + e2epod "k8s.io/kubernetes/test/e2e/framework/pod" + storageframework "k8s.io/kubernetes/test/e2e/storage/framework" + admissionapi "k8s.io/pod-security-admission/api" + "k8s.io/utils/ptr" +) + +var s3paGVR = schema.GroupVersionResource{Group: "s3.csi.aws.com", Version: "v1", Resource: "mountpoints3podattachments"} + +type s3CSIPodSharingTestSuite struct { + tsInfo storageframework.TestSuiteInfo +} + +func InitS3CSIPodSharingTestSuite() storageframework.TestSuite { + return &s3CSIPodSharingTestSuite{ + tsInfo: storageframework.TestSuiteInfo{ + Name: "multivolume", + TestPatterns: []storageframework.TestPattern{ + storageframework.DefaultFsPreprovisionedPV, + }, + }, + } +} + +func (t *s3CSIPodSharingTestSuite) GetTestSuiteInfo() storageframework.TestSuiteInfo { + return t.tsInfo +} + +func (t *s3CSIPodSharingTestSuite) SkipUnsupportedTests(_ storageframework.TestDriver, _ storageframework.TestPattern) { +} + +func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDriver, pattern storageframework.TestPattern) { + type local struct { + resources []*storageframework.VolumeResource + config *storageframework.PerTestConfig + } + var ( + l local + ) + + f := framework.NewFrameworkWithCustomTimeouts(NamespacePrefix+"multivolume", storageframework.GetDriverTimeouts(driver)) + f.NamespacePodSecurityLevel = admissionapi.LevelBaseline + + cleanup := func(ctx context.Context) { + var errs []error + for _, resource := range l.resources { + errs = append(errs, resource.CleanupResource(ctx)) + } + framework.ExpectNoError(errors.NewAggregate(errs), "while cleanup resource") + } + ginkgo.BeforeEach(func(ctx context.Context) { + l = local{} + l.config = driver.PrepareTest(ctx, f) + ginkgo.DeferCleanup(cleanup) + }) + + ginkgo.It("should concurrently access the single volume from pods on the same node using the same Mountpoint Pod", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var pods []*v1.Pod + node := l.config.ClientNodeSelection + // Create each pod with pvc + for i := 0; i < 2; i++ { + index := i + 1 + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume on %+v", index, node)) + pod, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, nil, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + framework.ExpectNoError(err) + // The pod must get deleted before this function returns because the caller may try to + // delete volumes as part of the tests. Keeping the pod running would block that. + // If the test times out, then the namespace deletion will take care of it. + defer func() { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + }() + pods = append(pods, pod) + e2epod.SetAffinity(&node, pod.Spec.NodeName) + } + + verifyPodsShareMountpointPod(ctx, f, pods, resource.Pv) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should concurrently access the single volume from pods on the same node using different Mountpoint Pods if fsGroup is different", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var pods []*v1.Pod + var targetNode string + for i := 0; i < 2; i++ { + index := i + 1 + podConfig := &e2epod.Config{ + NS: f.Namespace.Name, + PVCs: []*v1.PersistentVolumeClaim{resource.Pvc}, + SecurityLevel: admissionapi.LevelBaseline, + FsGroup: ptr.To(int64(1000 + i)), + } + + // For the first pod, let it schedule anywhere + // For subsequent pods, force them to the same node as the first pod + if i > 0 && targetNode != "" { + podConfig.NodeSelection = e2epod.NodeSelection{ + Name: targetNode, + } + } + + ginkgo.By(fmt.Sprintf("Creating pod%d", index)) + pod, err := e2epod.CreateSecPod(ctx, f.ClientSet, podConfig, 10*time.Second) + framework.ExpectNoError(err) + + // Store the node name from the first pod + if i == 0 { + targetNode = pod.Spec.NodeName + } + + defer func() { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + }() + pods = append(pods, pod) + } + + verifyPodsHaveDifferentMountpointPods(ctx, f, pods, resource.Pv, func(pod *v1.Pod) map[string]string { + expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) + expectedFields["WorkloadFSGroup"] = fmt.Sprintf("%d", pod.Spec.SecurityContext.FSGroup) + return expectedFields + }) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + // TODO: Add more test cases +} + +func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume) { + var s3paList *crdv1.MountpointS3PodAttachmentList + framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) + if err != nil { + return false, err + } + s3paList, err = convertToCustomResourceList(list) + if err != nil { + return false, err + } + for _, s3pa := range s3paList.Items { + if matchesSpec(s3pa.Spec, defaultExpectedFields(pods[0].Spec.NodeName, pv)) { + allUIDs := make(map[string]bool) + for _, uids := range s3paList.Items[0].Spec.MountpointS3PodToWorkloadPodUIDs { + for _, uid := range uids { + allUIDs[uid] = true + } + } + for _, pod := range pods { + podUID := string(pod.UID) + if _, exists := allUIDs[podUID]; !exists { + return false, fmt.Errorf("pod UID %s not found in MountpointS3PodAttachment", podUID) + } + } + + return true, nil + } + } + + return false, err + })).WithTimeout(10 * time.Second).WithPolling(1 * time.Second).Should(gomega.BeTrue()) + +} + +func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume, expectedFieldsFunc func(pod *v1.Pod) map[string]string) { + var s3paList *crdv1.MountpointS3PodAttachmentList + framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) + if err != nil { + return false, fmt.Errorf("failed to list S3PodAttachments: %w", err) + } + s3paList, err = convertToCustomResourceList(list) + if err != nil { + return false, fmt.Errorf("failed to convert to custom resource list: %w", err) + } + + matchCount := 0 + for _, s3pa := range s3paList.Items { + for _, pod := range pods { + if matchesSpec(s3pa.Spec, expectedFieldsFunc(pod)) { + matchCount++ + break + } + } + } + + return matchCount == len(pods), nil + })).WithTimeout(10 * time.Second).WithPolling(1 * time.Second).Should(gomega.BeTrue()) + + podToMountpointPod := make(map[string]string) + for _, s3pa := range s3paList.Items { + for mpPodName, workloadPodUIDs := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for _, uid := range workloadPodUIDs { + podToMountpointPod[uid] = mpPodName + } + } + } + + seenMountpointPods := make(map[string]bool) + for _, pod := range pods { + podUID := string(pod.UID) + mpPodName, exists := podToMountpointPod[podUID] + + framework.Gomega().Expect(exists).To(gomega.BeTrue()) + + _, alreadySeen := seenMountpointPods[mpPodName] + framework.Gomega().Expect(alreadySeen).To(gomega.BeFalse()) + + seenMountpointPods[mpPodName] = true + } + + framework.Gomega().Expect(len(seenMountpointPods)).To(gomega.Equal(len(pods))) +} + +// Convert UnstructuredList to MountpointS3PodAttachmentList +func convertToCustomResourceList(list *unstructured.UnstructuredList) (*crdv1.MountpointS3PodAttachmentList, error) { + crList := &crdv1.MountpointS3PodAttachmentList{ + Items: make([]crdv1.MountpointS3PodAttachment, 0, len(list.Items)), + } + + for _, item := range list.Items { + cr := &crdv1.MountpointS3PodAttachment{} + err := runtime.DefaultUnstructuredConverter.FromUnstructured(item.Object, cr) + if err != nil { + return nil, fmt.Errorf("failed to convert item to MountpointS3PodAttachment: %v", err) + } + crList.Items = append(crList.Items, *cr) + } + + return crList, nil +} + +// matchesSpec checks whether MountpointS3PodAttachmentSpec matches `expected` fields +func matchesSpec(spec crdv1.MountpointS3PodAttachmentSpec, expected map[string]string) bool { + specValues := map[string]string{ + "NodeName": spec.NodeName, + "PersistentVolumeName": spec.PersistentVolumeName, + "VolumeID": spec.VolumeID, + "MountOptions": spec.MountOptions, + "AuthenticationSource": spec.AuthenticationSource, + "WorkloadFSGroup": spec.WorkloadFSGroup, + "WorkloadServiceAccountName": spec.WorkloadServiceAccountName, + "WorkloadNamespace": spec.WorkloadNamespace, + "WorkloadServiceAccountIAMRoleARN": spec.WorkloadServiceAccountIAMRoleARN, + } + + for k, v := range expected { + if specValues[k] != v { + return false + } + } + return true +} + +// defaultExpectedFields return default test expected fields for MountpointS3PodAttachmentSpec matching +func defaultExpectedFields(nodeName string, pv *v1.PersistentVolume) map[string]string { + return map[string]string{ + "NodeName": nodeName, + "PersistentVolumeName": pv.Name, + "VolumeID": pv.Spec.CSI.VolumeHandle, + "MountOptions": strings.Join(pv.Spec.MountOptions, ","), + "AuthenticationSource": "driver", + "WorkloadFSGroup": "", + } +} + +func checkCrossReadWrite(f *framework.Framework, pod1, pod2 *v1.Pod) { + toWrite := 1024 // 1KB + path := "/mnt/volume1" + + // Check write from pod1 and read from pod2 + checkPodWriteAndOtherPodRead(f, pod1, pod2, path, "file1.txt", toWrite) + + // Check write from pod2 and read from pod1 + checkPodWriteAndOtherPodRead(f, pod2, pod1, path, "file2.txt", toWrite) +} + +func checkPodWriteAndOtherPodRead(f *framework.Framework, writerPod, readerPod *v1.Pod, basePath, filename string, size int) { + filePath := filepath.Join(basePath, filename) + seed := time.Now().UTC().UnixNano() + + checkWriteToPath(f, writerPod, filePath, size, seed) + checkReadFromPath(f, readerPod, filePath, size, seed) +} From 0ba8d12b0a5b4c78ec17bb76b80c23367203900f Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:44:41 +0100 Subject: [PATCH 02/24] Add `{{- if .Values.experimental.podMounter -}}` wrapper to CRD --- Makefile | 15 ++- .../mountpoints3podattachments-crd.yaml | 104 ++++++++++++++++++ tests/controller/suite_test.go | 2 +- .../crd/mountpoints3podattachments-crd.yaml | 1 + 4 files changed, 118 insertions(+), 4 deletions(-) create mode 100644 charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml rename charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml => tests/crd/mountpoints3podattachments-crd.yaml (98%) diff --git a/Makefile b/Makefile index 5ffcb209..b08d75a1 100644 --- a/Makefile +++ b/Makefile @@ -175,12 +175,21 @@ clean: rm -rf bin/ && docker system prune # Generate files for Custom Resources (`zz_generated.deepcopy.go` and CustomResourceDefinition YAML file). -# TODO: Wrap CRD YAML file with experimental.podMounter=true Helm flag -# POD_ATTACHMENT_CRD_FILE ?= "./charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml" +TMP_POD_ATTACHMENT_CRD_FILE ?= "./hack/s3.csi.aws.com_mountpoints3podattachments.yaml" +# Helm CRD file needs extra `{{- if .Values.experimental.podMounter -}}` wrapper, while we don't need it for tests. +HELM_POD_ATTACHMENT_CRD_FILE ?= "./charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml" +TEST_POD_ATTACHMENT_CRD_FILE ?= "./tests/crd/mountpoints3podattachments-crd.yaml" .PHONY: generate generate: controller-gen object:headerFile="hack/boilerplate.go.txt" paths="./pkg/api/..." - controller-gen crd paths="./pkg/api/..." output:crd:dir=./charts/aws-mountpoint-s3-csi-driver/templates + controller-gen crd paths="./pkg/api/..." output:crd:dir=./hack/ + echo '# Auto-generated file via `make generate`. Do not edit.' > $(HELM_POD_ATTACHMENT_CRD_FILE) + echo '{{- if .Values.experimental.podMounter -}}' >> $(HELM_POD_ATTACHMENT_CRD_FILE) + cat $(TMP_POD_ATTACHMENT_CRD_FILE) >> $(HELM_POD_ATTACHMENT_CRD_FILE) + echo '{{- end -}}' >> $(HELM_POD_ATTACHMENT_CRD_FILE) + echo '# Auto-generated file via `make generate`. Do not edit.' > $(TEST_POD_ATTACHMENT_CRD_FILE) + cat $(TMP_POD_ATTACHMENT_CRD_FILE) >> $(TEST_POD_ATTACHMENT_CRD_FILE) + rm $(TMP_POD_ATTACHMENT_CRD_FILE) ## Binaries used in tests. diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml new file mode 100644 index 00000000..b5de87e8 --- /dev/null +++ b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml @@ -0,0 +1,104 @@ +# Auto-generated file via `make generate`. Do not edit. +{{- if .Values.experimental.podMounter -}} +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.17.3 + name: mountpoints3podattachments.s3.csi.aws.com +spec: + group: s3.csi.aws.com + names: + kind: MountpointS3PodAttachment + listKind: MountpointS3PodAttachmentList + plural: mountpoints3podattachments + shortNames: + - s3pa + singular: mountpoints3podattachment + scope: Cluster + versions: + - name: v1 + schema: + openAPIV3Schema: + description: MountpointS3PodAttachment is the Schema for the mountpoints3podattachments + API. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: MountpointS3PodAttachmentSpec defines the desired state of + MountpointS3PodAttachment. + properties: + authenticationSource: + description: Authentication source taken from volume attribute field + `authenticationSource`. + type: string + mountOptions: + description: Comma separated mount options taken from volume. + type: string + mountpointS3PodToWorkloadPodUIDs: + additionalProperties: + items: + type: string + type: array + description: Maps each Mountpoint S3 pod name to the list of workload + pod UIDs it is attached to. + type: object + nodeName: + description: Name of the node. + type: string + persistentVolumeName: + description: Name of the Persistent Volume. + type: string + volumeID: + description: Volume ID. + type: string + workloadFSGroup: + description: Workload pod's `fsGroup` from pod security context + type: string + workloadNamespace: + description: 'Workload pod''s namespace. Exists only if `authenticationSource: + pod`.' + type: string + workloadServiceAccountIAMRoleARN: + description: 'EKS IAM Role ARN from workload pod''s service account + annotation (IRSA). Exists only if `authenticationSource: pod` and + service account has `eks.amazonaws.com/role-arn` annotation.' + type: string + workloadServiceAccountName: + description: 'Workload pod''s service account name. Exists only if + `authenticationSource: pod`.' + type: string + required: + - authenticationSource + - mountOptions + - mountpointS3PodToWorkloadPodUIDs + - nodeName + - persistentVolumeName + - volumeID + - workloadFSGroup + type: object + type: object + selectableFields: + - jsonPath: .spec.nodeName + served: true + storage: true + subresources: + status: {} +{{- end -}} diff --git a/tests/controller/suite_test.go b/tests/controller/suite_test.go index 4b090f2d..f130ea35 100644 --- a/tests/controller/suite_test.go +++ b/tests/controller/suite_test.go @@ -73,7 +73,7 @@ var _ = BeforeSuite(func() { crdv1.AddToScheme(scheme.Scheme) testEnv = &envtest.Environment{ CRDInstallOptions: envtest.CRDInstallOptions{ - Paths: []string{"../../charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml"}, + Paths: []string{"../crd/mountpoints3podattachments-crd.yaml"}, }, ErrorIfCRDPathMissing: true, } diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml b/tests/crd/mountpoints3podattachments-crd.yaml similarity index 98% rename from charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml rename to tests/crd/mountpoints3podattachments-crd.yaml index 16621492..c7f35fb9 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/s3.csi.aws.com_mountpoints3podattachments.yaml +++ b/tests/crd/mountpoints3podattachments-crd.yaml @@ -1,3 +1,4 @@ +# Auto-generated file via `make generate`. Do not edit. --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition From 1e1e761bab8d120a9ed2def590935ac299c4bcae Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:51:04 +0100 Subject: [PATCH 03/24] Make Expectations private, fix comments --- .../templates/serviceaccount-csi-node.yaml | 2 +- .../csicontroller/expectations.go | 24 +++++++++---------- .../csicontroller/reconciler.go | 12 +++++----- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml index 96606310..5c68a202 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml @@ -23,7 +23,7 @@ metadata: app.kubernetes.io/name: aws-mountpoint-s3-csi-driver rules: - apiGroups: [""] - resources: ["serviceaccounts"] # TODO: Remove once we stop supporting systemd mounts. + resources: ["serviceaccounts"] # TODO: Remove once we stop supporting systemd mounts because in PodMounter we get IRSA Role ARN from MountpointS3PodAttachment verbs: ["get"] - apiGroups: ["s3.csi.aws.com"] resources: ["mountpoints3podattachments"] diff --git a/cmd/aws-s3-csi-controller/csicontroller/expectations.go b/cmd/aws-s3-csi-controller/csicontroller/expectations.go index 1fbb8967..a444613a 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/expectations.go +++ b/cmd/aws-s3-csi-controller/csicontroller/expectations.go @@ -8,36 +8,36 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -// Expectations is a structure that manages pending expectations for Kubernetes resources. +// expectations is a structure that manages pending expectations for Kubernetes resources. // It uses field filters as keys to track resources that are expected to be created -// helping to reduce unnecessary processing and API server load. -type Expectations struct { +// helping with eventual consistency, reducing unnecessary processing and API server load. +type expectations struct { pending sync.Map } -// NewExpectations creates and returns a new Expectations instance. -func NewExpectations() *Expectations { - return &Expectations{} +// newExpectations creates and returns a new Expectations instance. +func newExpectations() *expectations { + return &expectations{} } -// SetPending marks a resource as pending based on the given field filters. +// setPending marks a resource as pending based on the given field filters. // This is typically used when a create operation is initiated. -func (e *Expectations) SetPending(fieldFilters client.MatchingFields) { +func (e *expectations) setPending(fieldFilters client.MatchingFields) { key := deriveExpectationKeyFromFilters(fieldFilters) e.pending.Store(key, struct{}{}) } -// IsPending checks if a resource is marked as pending based on the given field filters. +// isPending checks if a resource is marked as pending based on the given field filters. // Returns true if the resource is pending, false otherwise. -func (e *Expectations) IsPending(fieldFilters client.MatchingFields) bool { +func (e *expectations) isPending(fieldFilters client.MatchingFields) bool { key := deriveExpectationKeyFromFilters(fieldFilters) _, ok := e.pending.Load(key) return ok } -// Clear removes the pending mark for a resource based on the given field filters. +// clear removes the pending mark for a resource based on the given field filters. // This is typically called when an expected operation has been confirmed as completed. -func (e *Expectations) Clear(fieldFilters client.MatchingFields) { +func (e *expectations) clear(fieldFilters client.MatchingFields) { key := deriveExpectationKeyFromFilters(fieldFilters) e.pending.Delete(key) } diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index 7483abe1..bc849e3a 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -36,7 +36,7 @@ const ( type Reconciler struct { mountpointPodConfig mppod.Config mountpointPodCreator *mppod.Creator - s3paExpectations *Expectations + s3paExpectations *expectations client.Client } @@ -44,7 +44,7 @@ type Reconciler struct { // NewReconciler returns a new reconciler created from `client` and `podConfig`. func NewReconciler(client client.Client, podConfig mppod.Config) *Reconciler { creator := mppod.NewCreator(podConfig) - return &Reconciler{Client: client, mountpointPodConfig: podConfig, mountpointPodCreator: creator, s3paExpectations: NewExpectations()} + return &Reconciler{Client: client, mountpointPodConfig: podConfig, mountpointPodCreator: creator, s3paExpectations: newExpectations()} } // SetupWithManager configures reconciler to run with given `mgr`. @@ -284,9 +284,9 @@ func (r *Reconciler) handleInactivePod(ctx context.Context, workloadPod *corev1. func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3paList *crdv1.MountpointS3PodAttachmentList, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { s3pa := &s3paList.Items[0] - if r.s3paExpectations.IsPending(fieldFilters) { + if r.s3paExpectations.isPending(fieldFilters) { log.Info("MountpointS3PodAttachment creation is pending, removing from pending") - r.s3paExpectations.Clear(fieldFilters) + r.s3paExpectations.clear(fieldFilters) } if s3paContainsWorkload(s3pa, workloadUID) { @@ -373,7 +373,7 @@ func (r *Reconciler) handleNewS3PodAttachment( fieldFilters client.MatchingFields, log logr.Logger, ) (bool, error) { - if r.s3paExpectations.IsPending(fieldFilters) { + if r.s3paExpectations.isPending(fieldFilters) { log.Info("MountpointS3PodAttachment creation is pending, requeueing") return true, nil } @@ -382,7 +382,7 @@ func (r *Reconciler) handleNewS3PodAttachment( return false, err } - r.s3paExpectations.SetPending(fieldFilters) + r.s3paExpectations.setPending(fieldFilters) return true, nil } From 8a53b098d422a24bd9d6e06e09fac96ab5b32395 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 18:13:43 +0100 Subject: [PATCH 04/24] Reconcile fixes --- .../csicontroller/reconciler.go | 185 ++++++++++++------ 1 file changed, 122 insertions(+), 63 deletions(-) diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index bc849e3a..e23e3da0 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -32,6 +32,11 @@ const ( LabelCSIDriverVersion = "s3.csi.aws.com/created-by-csi-driver-version" ) +const ( + Requeue = true + DontRequeue = false +) + // A Reconciler reconciles Mountpoint Pods by watching other workload Pods thats using S3 CSI Driver. type Reconciler struct { mountpointPodConfig mppod.Config @@ -181,34 +186,46 @@ func (r *Reconciler) spawnOrDeleteMountpointPodIfNeeded( workloadUID := string(workloadPod.UID) roleArn, err := r.findIRSAServiceAccountRole(ctx, workloadPod) if err != nil { - return false, err + return Requeue, err } fieldFilters := r.buildFieldFilters(workloadPod, pv, roleArn) - log := r.setupLogger(ctx, workloadPod, pvc, pv, workloadUID, fieldFilters) - - s3paList, err := r.getExistingS3PodAttachments(ctx, fieldFilters) + s3pa, err := r.getExistingS3PodAttachment(ctx, fieldFilters) if err != nil { - return false, err + return Requeue, err } + log := r.setupLogger(ctx, workloadPod, pvc, workloadUID, fieldFilters, s3pa) if !isPodActive(workloadPod) { - return r.handleInactivePod(ctx, workloadPod, s3paList, workloadUID, log) + return r.handleInactivePod(ctx, s3pa, workloadUID, log) } - if len(s3paList.Items) == 1 { - return r.handleExistingS3PodAttachment(ctx, s3paList, workloadUID, fieldFilters, log) + if s3pa != nil { + return r.handleExistingS3PodAttachment(ctx, s3pa, workloadUID, fieldFilters, log) + } else { + return r.handleNewS3PodAttachment(ctx, workloadPod, pv, roleArn, fieldFilters, log) } - - return r.handleNewS3PodAttachment(ctx, workloadPod, pv, fieldFilters, log) } -func (r *Reconciler) setupLogger(ctx context.Context, workloadPod *corev1.Pod, pvc *corev1.PersistentVolumeClaim, pv *corev1.PersistentVolume, workloadUID string, fieldFilters client.MatchingFields) logr.Logger { +// setupLogger creates and configures logger that includes pod namespace/name, PVC name, and workload UID fields. +// If an S3PodAttachment is provided, its name is added. All fieldFilters are appended as additional key-value pairs. +func (r *Reconciler) setupLogger( + ctx context.Context, + workloadPod *corev1.Pod, + pvc *corev1.PersistentVolumeClaim, + workloadUID string, + fieldFilters client.MatchingFields, + s3pa *crdv1.MountpointS3PodAttachment, +) logr.Logger { logger := logf.FromContext(ctx).WithValues( "workloadPod", types.NamespacedName{Namespace: workloadPod.Namespace, Name: workloadPod.Name}, "pvc", pvc.Name, "workloadUID", workloadUID, ) + if s3pa != nil { + logger = logger.WithValues("s3pa", s3pa.Name) + } + var keyValues []interface{} for k, v := range fieldFilters { keyValues = append(keyValues, k, v) @@ -221,6 +238,7 @@ func (r *Reconciler) setupLogger(ctx context.Context, workloadPod *corev1.Pod, p return logger } +// buildFieldFilters build appropriate matching field filters for List operation on MountpointS3PodAttachments func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.PersistentVolume, roleArn string) client.MatchingFields { authSource := r.getAuthSource(pv) fsGroup := r.getFSGroup(workloadPod) @@ -243,6 +261,8 @@ func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.Persi return fieldFilters } +// getAuthSource returns authentication source from given PV. +// Defaults to `driver` if `authenticationSource` is not found in volume attributes. func (r *Reconciler) getAuthSource(pv *corev1.PersistentVolume) string { volumeAttributes := mppod.ExtractVolumeAttributes(pv) authSource := volumeAttributes[volumecontext.AuthenticationSource] @@ -252,6 +272,8 @@ func (r *Reconciler) getAuthSource(pv *corev1.PersistentVolume) string { return authSource } +// getFSGroup returns the FSGroup value from the pod's security context as a string. +// If FSGroup is not set, it returns an empty string. func (r *Reconciler) getFSGroup(workloadPod *corev1.Pod) string { if workloadPod.Spec.SecurityContext.FSGroup != nil { return strconv.FormatInt(*workloadPod.Spec.SecurityContext.FSGroup, 10) @@ -259,31 +281,39 @@ func (r *Reconciler) getFSGroup(workloadPod *corev1.Pod) string { return "" } -func (r *Reconciler) getExistingS3PodAttachments(ctx context.Context, fieldFilters client.MatchingFields) (*crdv1.MountpointS3PodAttachmentList, error) { +// getExistingS3PodAttachment retrieves a MountpointS3PodAttachment resource that matches the provided field filters. +// It returns: +// - The matching MountpointS3PodAttachment if exactly one is found +// - nil if no matching resource is found +// - An error if multiple matching resources are found or if the list operation fails +func (r *Reconciler) getExistingS3PodAttachment(ctx context.Context, fieldFilters client.MatchingFields) (*crdv1.MountpointS3PodAttachment, error) { s3paList := &crdv1.MountpointS3PodAttachmentList{} if err := r.List(ctx, s3paList, fieldFilters); err != nil { - return nil, err + return nil, fmt.Errorf("failed to list MountpointS3PodAttachments: %w", err) } - if len(s3paList.Items) > 1 { - return nil, fmt.Errorf("found %d MountpointS3PodAttachments instead of 1", len(s3paList.Items)) + switch len(s3paList.Items) { + case 0: + return nil, nil + case 1: + return &s3paList.Items[0], nil + default: + return nil, fmt.Errorf("found %d MountpointS3PodAttachments when expecting 0 or 1", len(s3paList.Items)) } - - return s3paList, nil } -func (r *Reconciler) handleInactivePod(ctx context.Context, workloadPod *corev1.Pod, s3paList *crdv1.MountpointS3PodAttachmentList, workloadUID string, log logr.Logger) (bool, error) { - if len(s3paList.Items) != 1 { +// handleInactivePod handles inactive workload pod. +func (r *Reconciler) handleInactivePod(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { + if s3pa == nil { log.Info("Workload pod is not active. Did not find any MountpointS3PodAttachments.") - return false, nil + return DontRequeue, nil } - return r.removeWorkloadFromS3PodAttachment(ctx, &s3paList.Items[0], workloadUID, log) + return r.removeWorkloadFromS3PodAttachment(ctx, s3pa, workloadUID, log) } -func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3paList *crdv1.MountpointS3PodAttachmentList, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { - s3pa := &s3paList.Items[0] - +// handleExistingS3PodAttachment handles existing S3 Pod Attachment. +func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { if r.s3paExpectations.isPending(fieldFilters) { log.Info("MountpointS3PodAttachment creation is pending, removing from pending") r.s3paExpectations.clear(fieldFilters) @@ -291,14 +321,16 @@ func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3paList if s3paContainsWorkload(s3pa, workloadUID) { log.Info("MountpointS3PodAttachment already has this workload UID") - return false, nil + return DontRequeue, nil } return r.addWorkloadToS3PodAttachment(ctx, s3pa, workloadUID, log) } +// addWorkloadToS3PodAttachment adds workload UID to the first Mountpoint Pod in the map +// TODO: We will later add extra logic for selecting/creating MPPod if existing MP Pods are using old CSI Driver version or have some "no-new-attachments" annotation func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { - log.Info("Adding workload UID to MountpointS3PodAttachment", "workloadUID", workloadUID) + log.Info("Adding workload UID to MountpointS3PodAttachment") for key := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[key] = append(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[key], workloadUID) @@ -306,14 +338,20 @@ func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crd } err := r.Update(ctx, s3pa) - if apierrors.IsConflict(err) { - log.Info("Failed to update MountpointS3PodAttachment - resource conflict - requeue", "workloadUID", workloadUID) - return true, nil + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to update MountpointS3PodAttachment - resource conflict - requeue") + return Requeue, nil + } + log.Error(err, "Failed to update MountpointS3PodAttachment") + return Requeue, err } - return false, nil + return DontRequeue, nil } +// removeWorkloadFromS3PodAttachment removes workload UID from MountpointS3PodAttachment map. +// It will delete MountpointS3PodAttachment if map becomes empty. func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { // Remove workload UID from mountpoint pods for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { @@ -329,9 +367,13 @@ func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa if found { s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[mpPodName] = filteredUIDs err := r.Update(ctx, s3pa) - if apierrors.IsConflict(err) { - log.Info("Failed to remove workload pod UID from existing MountpointS3PodAttachment due to resource conflict, requeueing") - return true, nil + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to remove workload pod UID from existing MountpointS3PodAttachment due to resource conflict, requeueing") + return Requeue, nil + } + log.Error(err, "Failed to update MountpointS3PodAttachment") + return Requeue, err } log.Info("Successfully removed workload pod UID from MountpointS3PodAttachment") break @@ -345,10 +387,14 @@ func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa "mountpointPodName", mpPodName) delete(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs, mpPodName) err := r.Update(ctx, s3pa) - if apierrors.IsConflict(err) { - log.Info("Failed to remove Mountpoint pod from MountpointS3PodAttachment due to resource conflict, requeueing", - "mountpointPodName", mpPodName) - return true, nil + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to remove Mountpoint pod from MountpointS3PodAttachment due to resource conflict, requeueing", + "mountpointPodName", mpPodName) + return Requeue, nil + } + log.Error(err, "Failed to update MountpointS3PodAttachment") + return Requeue, err } } } @@ -357,52 +403,55 @@ func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa if len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs) == 0 { log.Info("MountpointS3PodAttachment has zero Mountpoint Pods. Will delete it") err := r.Delete(ctx, s3pa) - if apierrors.IsConflict(err) { - log.Info("Failed to delete MountpointS3PodAttachment due to resource conflict, requeueing") - return true, nil + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to delete MountpointS3PodAttachment due to resource conflict, requeueing") + return Requeue, nil + } + log.Error(err, "Failed to delete MountpointS3PodAttachment") + return Requeue, err } } - return false, nil + return DontRequeue, nil } +// handleNewS3PodAttachment handles new S3 pod attachment in case none were found. func (r *Reconciler) handleNewS3PodAttachment( ctx context.Context, workloadPod *corev1.Pod, pv *corev1.PersistentVolume, + roleArn string, fieldFilters client.MatchingFields, log logr.Logger, ) (bool, error) { if r.s3paExpectations.isPending(fieldFilters) { log.Info("MountpointS3PodAttachment creation is pending, requeueing") - return true, nil + return Requeue, nil } - if err := r.createS3PodAttachmentWithMPPod(ctx, workloadPod, pv, log); err != nil { - return false, err + if err := r.createS3PodAttachmentWithMPPod(ctx, workloadPod, pv, roleArn, log); err != nil { + return Requeue, err } r.s3paExpectations.setPending(fieldFilters) - return true, nil + return Requeue, nil } +// createS3PodAttachmentWithMPPod creates new MountpointS3PodAttachment resource and Mountpoint Pod for given workload and PV. func (r *Reconciler) createS3PodAttachmentWithMPPod( ctx context.Context, workloadPod *corev1.Pod, pv *corev1.PersistentVolume, + roleArn string, log logr.Logger, ) error { authSource := r.getAuthSource(pv) - mpPodName, err := r.spawnMountpointPod(ctx, workloadPod, pv, log) + mpPod, err := r.spawnMountpointPod(ctx, workloadPod, pv, log) if err != nil { log.Error(err, "Failed to spawn Mountpoint Pod") return err } - - fsGroup := "" - if workloadPod.Spec.SecurityContext.FSGroup != nil { - fsGroup = strconv.FormatInt(*workloadPod.Spec.SecurityContext.FSGroup, 10) - } s3pa := &crdv1.MountpointS3PodAttachment{ ObjectMeta: metav1.ObjectMeta{ GenerateName: "s3pa-", @@ -415,31 +464,31 @@ func (r *Reconciler) createS3PodAttachmentWithMPPod( PersistentVolumeName: pv.Name, VolumeID: pv.Spec.CSI.VolumeHandle, MountOptions: strings.Join(pv.Spec.MountOptions, ","), - WorkloadFSGroup: fsGroup, + WorkloadFSGroup: r.getFSGroup(workloadPod), AuthenticationSource: authSource, MountpointS3PodToWorkloadPodUIDs: map[string][]string{ - mpPodName: {string(workloadPod.UID)}, + mpPod.Name: {string(workloadPod.UID)}, }, }, } if authSource == "pod" { s3pa.Spec.WorkloadNamespace = workloadPod.Namespace s3pa.Spec.WorkloadServiceAccountName = getServiceAccountName(workloadPod) - - roleARN, err := r.findIRSAServiceAccountRole(ctx, workloadPod) - if err != nil { - return err - } - s3pa.Spec.WorkloadServiceAccountIAMRoleARN = roleARN + s3pa.Spec.WorkloadServiceAccountIAMRoleARN = roleArn } err = r.Create(ctx, s3pa) if err != nil { log.Error(err, "Failed to create MountpointS3PodAttachment") + if deleteErr := r.Delete(ctx, mpPod); deleteErr != nil { + log.Error(deleteErr, "Failed to cleanup Mountpoint Pod after MountpointS3PodAttachment creation failure", "mountpointPodName", mpPod.Name) + } else { + log.Info("Successfully cleaned up Mountpoint Pod after S3PodAttachment creation failure", "mountpointPodName", mpPod.Name) + } return err } - log.Info("MountpointS3PodAttachment is created", "s3paName", s3pa.Name) + log.Info("MountpointS3PodAttachment is created", "s3pa", s3pa.Name) return nil } @@ -451,18 +500,18 @@ func (r *Reconciler) spawnMountpointPod( workloadPod *corev1.Pod, pv *corev1.PersistentVolume, log logr.Logger, -) (string, error) { +) (*corev1.Pod, error) { log.Info("Spawning Mountpoint Pod") mpPod := r.mountpointPodCreator.Create(workloadPod.Spec.NodeName, pv) err := r.Create(ctx, mpPod) if err != nil { - return "", err + return nil, err } log.Info("Mountpoint Pod spawned", "mountpointPodName", mpPod.Name) - return mpPod.Name, nil + return mpPod, nil } // deleteMountpointPod deletes given `mountpointPod`. @@ -528,6 +577,16 @@ func (r *Reconciler) getBoundPVForPodClaim( return pvc, pv, nil } +// findIRSAServiceAccountRole retrieves the IAM role ARN associated with a pod's service account +// through IRSA (IAM Roles for Service Accounts) annotation ("eks.amazonaws.com/role-arn"). +// +// Parameters: +// - ctx: Context for the request +// - pod: The Kubernetes pod whose service account role should be retrieved +// +// Returns: +// - string: The IAM role ARN from the service account's annotation, empty string if not found +// - error: Error if the service account cannot be retrieved func (r *Reconciler) findIRSAServiceAccountRole(ctx context.Context, pod *corev1.Pod) (string, error) { sa := &corev1.ServiceAccount{} err := r.Get(ctx, types.NamespacedName{Namespace: pod.Namespace, Name: getServiceAccountName(pod)}, sa) From 75ec928572cd5431a38c912c9bf6ca85dfc67f79 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 18:26:20 +0100 Subject: [PATCH 05/24] Rename CRD version from v1 to v1beta --- .../mountpoints3podattachments-crd.yaml | 2 +- .../csicontroller/reconciler.go | 40 +++++++++---------- cmd/aws-s3-csi-controller/main.go | 31 +++++++------- pkg/api/{v1 => v1beta}/groupversion_info.go | 6 +-- .../mountpoints3podattachment_types.go | 2 +- .../{v1 => v1beta}/zz_generated.deepcopy.go | 2 +- pkg/driver/driver.go | 32 +++++++-------- pkg/driver/node/mounter/pod_mounter.go | 22 +++++----- pkg/driver/node/mounter/pod_unmounter.go | 10 ++--- tests/controller/controller_test.go | 22 +++++----- tests/controller/suite_test.go | 28 ++++++------- tests/crd/mountpoints3podattachments-crd.yaml | 2 +- .../e2e-kubernetes/testsuites/pod_sharing.go | 16 ++++---- 13 files changed, 109 insertions(+), 106 deletions(-) rename pkg/api/{v1 => v1beta}/groupversion_info.go (82%) rename pkg/api/{v1 => v1beta}/mountpoints3podattachment_types.go (99%) rename pkg/api/{v1 => v1beta}/zz_generated.deepcopy.go (99%) diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml index b5de87e8..31780b01 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml @@ -18,7 +18,7 @@ spec: singular: mountpoints3podattachment scope: Cluster versions: - - name: v1 + - name: v1beta schema: openAPIV3Schema: description: MountpointS3PodAttachment is the Schema for the mountpoints3podattachments diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index e23e3da0..7faaae45 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -16,7 +16,7 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" "github.com/go-logr/logr" @@ -214,7 +214,7 @@ func (r *Reconciler) setupLogger( pvc *corev1.PersistentVolumeClaim, workloadUID string, fieldFilters client.MatchingFields, - s3pa *crdv1.MountpointS3PodAttachment, + s3pa *crdv1beta.MountpointS3PodAttachment, ) logr.Logger { logger := logf.FromContext(ctx).WithValues( "workloadPod", types.NamespacedName{Namespace: workloadPod.Namespace, Name: workloadPod.Name}, @@ -244,18 +244,18 @@ func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.Persi fsGroup := r.getFSGroup(workloadPod) fieldFilters := client.MatchingFields{ - crdv1.FieldNodeName: workloadPod.Spec.NodeName, - crdv1.FieldPersistentVolumeName: pv.Name, - crdv1.FieldVolumeID: pv.Spec.CSI.VolumeHandle, - crdv1.FieldMountOptions: strings.Join(pv.Spec.MountOptions, ","), - crdv1.FieldWorkloadFSGroup: fsGroup, - crdv1.FieldAuthenticationSource: authSource, + crdv1beta.FieldNodeName: workloadPod.Spec.NodeName, + crdv1beta.FieldPersistentVolumeName: pv.Name, + crdv1beta.FieldVolumeID: pv.Spec.CSI.VolumeHandle, + crdv1beta.FieldMountOptions: strings.Join(pv.Spec.MountOptions, ","), + crdv1beta.FieldWorkloadFSGroup: fsGroup, + crdv1beta.FieldAuthenticationSource: authSource, } if authSource == "pod" { - fieldFilters[crdv1.FieldWorkloadNamespace] = workloadPod.Namespace - fieldFilters[crdv1.FieldWorkloadServiceAccountName] = getServiceAccountName(workloadPod) - fieldFilters[crdv1.FieldWorkloadServiceAccountIAMRoleARN] = roleArn + fieldFilters[crdv1beta.FieldWorkloadNamespace] = workloadPod.Namespace + fieldFilters[crdv1beta.FieldWorkloadServiceAccountName] = getServiceAccountName(workloadPod) + fieldFilters[crdv1beta.FieldWorkloadServiceAccountIAMRoleARN] = roleArn } return fieldFilters @@ -286,8 +286,8 @@ func (r *Reconciler) getFSGroup(workloadPod *corev1.Pod) string { // - The matching MountpointS3PodAttachment if exactly one is found // - nil if no matching resource is found // - An error if multiple matching resources are found or if the list operation fails -func (r *Reconciler) getExistingS3PodAttachment(ctx context.Context, fieldFilters client.MatchingFields) (*crdv1.MountpointS3PodAttachment, error) { - s3paList := &crdv1.MountpointS3PodAttachmentList{} +func (r *Reconciler) getExistingS3PodAttachment(ctx context.Context, fieldFilters client.MatchingFields) (*crdv1beta.MountpointS3PodAttachment, error) { + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} if err := r.List(ctx, s3paList, fieldFilters); err != nil { return nil, fmt.Errorf("failed to list MountpointS3PodAttachments: %w", err) } @@ -303,7 +303,7 @@ func (r *Reconciler) getExistingS3PodAttachment(ctx context.Context, fieldFilter } // handleInactivePod handles inactive workload pod. -func (r *Reconciler) handleInactivePod(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { +func (r *Reconciler) handleInactivePod(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { if s3pa == nil { log.Info("Workload pod is not active. Did not find any MountpointS3PodAttachments.") return DontRequeue, nil @@ -313,7 +313,7 @@ func (r *Reconciler) handleInactivePod(ctx context.Context, s3pa *crdv1.Mountpoi } // handleExistingS3PodAttachment handles existing S3 Pod Attachment. -func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { +func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { if r.s3paExpectations.isPending(fieldFilters) { log.Info("MountpointS3PodAttachment creation is pending, removing from pending") r.s3paExpectations.clear(fieldFilters) @@ -329,7 +329,7 @@ func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *cr // addWorkloadToS3PodAttachment adds workload UID to the first Mountpoint Pod in the map // TODO: We will later add extra logic for selecting/creating MPPod if existing MP Pods are using old CSI Driver version or have some "no-new-attachments" annotation -func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { +func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { log.Info("Adding workload UID to MountpointS3PodAttachment") for key := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { @@ -352,7 +352,7 @@ func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crd // removeWorkloadFromS3PodAttachment removes workload UID from MountpointS3PodAttachment map. // It will delete MountpointS3PodAttachment if map becomes empty. -func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa *crdv1.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { +func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { // Remove workload UID from mountpoint pods for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { filteredUIDs := []string{} @@ -452,14 +452,14 @@ func (r *Reconciler) createS3PodAttachmentWithMPPod( log.Error(err, "Failed to spawn Mountpoint Pod") return err } - s3pa := &crdv1.MountpointS3PodAttachment{ + s3pa := &crdv1beta.MountpointS3PodAttachment{ ObjectMeta: metav1.ObjectMeta{ GenerateName: "s3pa-", Labels: map[string]string{ LabelCSIDriverVersion: r.mountpointPodConfig.CSIDriverVersion, }, }, - Spec: crdv1.MountpointS3PodAttachmentSpec{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ NodeName: workloadPod.Spec.NodeName, PersistentVolumeName: pv.Name, VolumeID: pv.Spec.CSI.VolumeHandle, @@ -622,7 +622,7 @@ func isPodActive(p *corev1.Pod) bool { p.DeletionTimestamp == nil } -func s3paContainsWorkload(s3pa *crdv1.MountpointS3PodAttachment, workloadUID string) bool { +func s3paContainsWorkload(s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string) bool { for _, workloads := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { for _, workload := range workloads { if workload == workloadUID { diff --git a/cmd/aws-s3-csi-controller/main.go b/cmd/aws-s3-csi-controller/main.go index 28ab23bc..eb0c5c3c 100644 --- a/cmd/aws-s3-csi-controller/main.go +++ b/cmd/aws-s3-csi-controller/main.go @@ -24,7 +24,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager/signals" "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" @@ -44,7 +44,7 @@ var ( func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(crdv1.AddToScheme(scheme)) + utilruntime.Must(crdv1beta.AddToScheme(scheme)) } func main() { @@ -88,21 +88,24 @@ func main() { } } +// IndexMountpointS3PodAttachmentFields adds internal index on fields for our custom resource. +// This is needed for `List()` method to work with field filters. func IndexMountpointS3PodAttachmentFields(log logr.Logger, mgr manager.Manager) { - indexField(log, mgr, crdv1.FieldNodeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) - indexField(log, mgr, crdv1.FieldPersistentVolumeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) - indexField(log, mgr, crdv1.FieldVolumeID, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) - indexField(log, mgr, crdv1.FieldMountOptions, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) - indexField(log, mgr, crdv1.FieldAuthenticationSource, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) - indexField(log, mgr, crdv1.FieldWorkloadFSGroup, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) - indexField(log, mgr, crdv1.FieldWorkloadServiceAccountName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) - indexField(log, mgr, crdv1.FieldWorkloadNamespace, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) - indexField(log, mgr, crdv1.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) + indexField(log, mgr, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) + indexField(log, mgr, crdv1beta.FieldPersistentVolumeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) + indexField(log, mgr, crdv1beta.FieldVolumeID, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) + indexField(log, mgr, crdv1beta.FieldMountOptions, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) + indexField(log, mgr, crdv1beta.FieldAuthenticationSource, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) + indexField(log, mgr, crdv1beta.FieldWorkloadFSGroup, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) + indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) + indexField(log, mgr, crdv1beta.FieldWorkloadNamespace, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) + indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) } -func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1.MountpointS3PodAttachment) string) { - err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { - return []string{extractor(obj.(*crdv1.MountpointS3PodAttachment))} +// indexField adds index on a field. +func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1beta.MountpointS3PodAttachment) string) { + err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1beta.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { + return []string{extractor(obj.(*crdv1beta.MountpointS3PodAttachment))} }) if err != nil { log.Error(err, fmt.Sprintf("Failed to create a %s field indexer", field)) diff --git a/pkg/api/v1/groupversion_info.go b/pkg/api/v1beta/groupversion_info.go similarity index 82% rename from pkg/api/v1/groupversion_info.go rename to pkg/api/v1beta/groupversion_info.go index 8c6d8b72..1efe5dfc 100644 --- a/pkg/api/v1/groupversion_info.go +++ b/pkg/api/v1beta/groupversion_info.go @@ -1,7 +1,7 @@ -// Package v1 contains API Schema definitions for the s3.csi.aws.com v1 API group. +// Package v1beta contains API Schema definitions for the s3.csi.aws.com v1beta API group. // +kubebuilder:object:generate=true // +groupName=s3.csi.aws.com -package v1 +package v1beta import ( "k8s.io/apimachinery/pkg/runtime/schema" @@ -10,7 +10,7 @@ import ( var ( // GroupVersion is group version used to register these objects. - GroupVersion = schema.GroupVersion{Group: "s3.csi.aws.com", Version: "v1"} + GroupVersion = schema.GroupVersion{Group: "s3.csi.aws.com", Version: "v1beta"} // SchemeBuilder is used to add go types to the GroupVersionKind scheme. SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion} diff --git a/pkg/api/v1/mountpoints3podattachment_types.go b/pkg/api/v1beta/mountpoints3podattachment_types.go similarity index 99% rename from pkg/api/v1/mountpoints3podattachment_types.go rename to pkg/api/v1beta/mountpoints3podattachment_types.go index cb623538..83ed93c1 100644 --- a/pkg/api/v1/mountpoints3podattachment_types.go +++ b/pkg/api/v1beta/mountpoints3podattachment_types.go @@ -1,4 +1,4 @@ -package v1 +package v1beta import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/pkg/api/v1/zz_generated.deepcopy.go b/pkg/api/v1beta/zz_generated.deepcopy.go similarity index 99% rename from pkg/api/v1/zz_generated.deepcopy.go rename to pkg/api/v1beta/zz_generated.deepcopy.go index 32041d4a..93c5d922 100644 --- a/pkg/api/v1/zz_generated.deepcopy.go +++ b/pkg/api/v1beta/zz_generated.deepcopy.go @@ -2,7 +2,7 @@ // Code generated by controller-gen. DO NOT EDIT. -package v1 +package v1beta import ( runtime "k8s.io/apimachinery/pkg/runtime" diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 5d0c776f..21be5bed 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -23,7 +23,7 @@ import ( "os" "time" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" @@ -62,7 +62,7 @@ var ( func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(crdv1.AddToScheme(scheme)) + utilruntime.Must(crdv1beta.AddToScheme(scheme)) } type Driver struct { @@ -119,7 +119,7 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error SyncPeriod: &podWatcherResyncPeriod, ReaderFailOnMissingInformer: true, ByObject: map[client.Object]ctrlcache.ByObject{ - &crdv1.MountpointS3PodAttachment{}: { + &crdv1beta.MountpointS3PodAttachment{}: { Field: fields.OneTermEqualSelector("spec.nodeName", nodeID), }, }, @@ -130,7 +130,7 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error indexMountpointS3PodAttachmentFields(s3paCache) - s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1.MountpointS3PodAttachment{}) + s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1beta.MountpointS3PodAttachment{}) if err != nil { klog.Fatalf("Failed to create informer for MountpointS3PodAttachment: %v\n", err) } @@ -244,20 +244,20 @@ func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) { // TODO: This is duplicated multiple times func indexMountpointS3PodAttachmentFields(s3paCache ctrlcache.Cache) { - indexField(s3paCache, crdv1.FieldNodeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) - indexField(s3paCache, crdv1.FieldPersistentVolumeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) - indexField(s3paCache, crdv1.FieldVolumeID, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) - indexField(s3paCache, crdv1.FieldMountOptions, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) - indexField(s3paCache, crdv1.FieldAuthenticationSource, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) - indexField(s3paCache, crdv1.FieldWorkloadFSGroup, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) - indexField(s3paCache, crdv1.FieldWorkloadServiceAccountName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) - indexField(s3paCache, crdv1.FieldWorkloadNamespace, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) - indexField(s3paCache, crdv1.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) + indexField(s3paCache, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) + indexField(s3paCache, crdv1beta.FieldPersistentVolumeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) + indexField(s3paCache, crdv1beta.FieldVolumeID, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) + indexField(s3paCache, crdv1beta.FieldMountOptions, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) + indexField(s3paCache, crdv1beta.FieldAuthenticationSource, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) + indexField(s3paCache, crdv1beta.FieldWorkloadFSGroup, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) + indexField(s3paCache, crdv1beta.FieldWorkloadServiceAccountName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) + indexField(s3paCache, crdv1beta.FieldWorkloadNamespace, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) + indexField(s3paCache, crdv1beta.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) } -func indexField(cache ctrlcache.Cache, field string, extractor func(*crdv1.MountpointS3PodAttachment) string) { - err := cache.IndexField(context.Background(), &crdv1.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { - return []string{extractor(obj.(*crdv1.MountpointS3PodAttachment))} +func indexField(cache ctrlcache.Cache, field string, extractor func(*crdv1beta.MountpointS3PodAttachment) string) { + err := cache.IndexField(context.Background(), &crdv1beta.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { + return []string{extractor(obj.(*crdv1beta.MountpointS3PodAttachment))} }) if err != nil { klog.Fatalf("Failed to create a %s field indexer: %v", field, err) diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 85409466..9d546204 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -15,7 +15,7 @@ import ( "k8s.io/klog/v2" "k8s.io/mount-utils" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" @@ -433,18 +433,18 @@ func (pm *PodMounter) helpMessageForGettingMountpointLogs(pod *corev1.Pod) strin return fmt.Sprintf("You can see Mountpoint logs by running: `kubectl logs -n %s %s`. If the Mountpoint Pod already restarted, you can also pass `--previous` to get logs from the previous run.", pod.Namespace, pod.Name) } -func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeName string, credentialCtx credentialprovider.ProvideContext, fsGroup, pvMountOptions string) (*crdv1.MountpointS3PodAttachment, string, error) { +func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeName string, credentialCtx credentialprovider.ProvideContext, fsGroup, pvMountOptions string) (*crdv1beta.MountpointS3PodAttachment, string, error) { fieldFilters := client.MatchingFields{ - crdv1.FieldNodeName: os.Getenv("CSI_NODE_NAME"), // TODO - crdv1.FieldPersistentVolumeName: volumeName, - crdv1.FieldVolumeID: credentialCtx.VolumeID, - crdv1.FieldMountOptions: pvMountOptions, - crdv1.FieldWorkloadFSGroup: fsGroup, - crdv1.FieldAuthenticationSource: credentialCtx.AuthenticationSource, + crdv1beta.FieldNodeName: os.Getenv("CSI_NODE_NAME"), // TODO + crdv1beta.FieldPersistentVolumeName: volumeName, + crdv1beta.FieldVolumeID: credentialCtx.VolumeID, + crdv1beta.FieldMountOptions: pvMountOptions, + crdv1beta.FieldWorkloadFSGroup: fsGroup, + crdv1beta.FieldAuthenticationSource: credentialCtx.AuthenticationSource, } if credentialCtx.AuthenticationSource == credentialprovider.AuthenticationSourcePod { - fieldFilters[crdv1.FieldWorkloadNamespace] = credentialCtx.PodNamespace - fieldFilters[crdv1.FieldWorkloadServiceAccountName] = credentialCtx.ServiceAccountName + fieldFilters[crdv1beta.FieldWorkloadNamespace] = credentialCtx.PodNamespace + fieldFilters[crdv1beta.FieldWorkloadServiceAccountName] = credentialCtx.ServiceAccountName } for { @@ -454,7 +454,7 @@ func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeNam default: } - s3paList := &crdv1.MountpointS3PodAttachmentList{} + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} err := pm.s3paCache.List(ctx, s3paList, fieldFilters) if err != nil { klog.Errorf("Failed to list MountpointS3PodAttachments: %v", err) diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go index 1b75ee92..6ca86762 100644 --- a/pkg/driver/node/mounter/pod_unmounter.go +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -6,7 +6,7 @@ import ( "os" "path/filepath" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" @@ -42,7 +42,7 @@ func NewPodUnmounter( } func (u *PodUnmounter) HandleS3PodAttachmentUpdate(old, new any) { - s3pa := new.(*crdv1.MountpointS3PodAttachment) + s3pa := new.(*crdv1beta.MountpointS3PodAttachment) if s3pa.Spec.NodeName != u.nodeID { return } @@ -54,7 +54,7 @@ func (u *PodUnmounter) HandleS3PodAttachmentUpdate(old, new any) { } } -func (u *PodUnmounter) unmountSourceForPod(s3pa *crdv1.MountpointS3PodAttachment, mpPodName string) { +func (u *PodUnmounter) unmountSourceForPod(s3pa *crdv1beta.MountpointS3PodAttachment, mpPodName string) { klog.Infof("Found Mountpoint pod with zero workload pods, unmounting it - %s", mpPodName) mpPod, err := u.podWatcher.Get(mpPodName) if err != nil { @@ -103,7 +103,7 @@ func (u *PodUnmounter) unmountAndCleanup(source string) error { return nil } -func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1.MountpointS3PodAttachment, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { +func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1beta.MountpointS3PodAttachment, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { err := u.credProvider.Cleanup(credentialprovider.CleanupContext{ VolumeID: s3pa.Spec.VolumeID, PodID: mpPodUID, @@ -179,7 +179,7 @@ func (u *PodUnmounter) findPodByUID(mpPodUID string) (*corev1.Pod, error) { } func (u *PodUnmounter) checkForWorkloads(mpPod *corev1.Pod) (bool, error) { - s3paList := &crdv1.MountpointS3PodAttachmentList{} + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} err := u.s3paCache.List(context.Background(), s3paList) if err != nil { return false, err diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 5e3e5853..2b28f89b 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -18,7 +18,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" ) @@ -1054,7 +1054,7 @@ func waitForMountpointPodWithName(mpPodName string) *testPod { // expectNoS3PodAttachmentWithFields verifies that no MountpointS3PodAttachment matching specified fields exists within a time period func expectNoS3PodAttachmentWithFields(expectedFields map[string]string) { Consistently(func(g Gomega) { - list := &crdv1.MountpointS3PodAttachmentList{} + list := &crdv1beta.MountpointS3PodAttachmentList{} g.Expect(k8sClient.List(ctx, list)).To(Succeed()) for i := range list.Items { @@ -1067,7 +1067,7 @@ func expectNoS3PodAttachmentWithFields(expectedFields map[string]string) { } // expectNoPodUIDInS3PodAttachment validates that pod UID does not exist in MountpointS3PodToWorkloadPodUIDs map -func expectNoPodUIDInS3PodAttachment(s3pa *crdv1.MountpointS3PodAttachment, podUID string) { +func expectNoPodUIDInS3PodAttachment(s3pa *crdv1beta.MountpointS3PodAttachment, podUID string) { for _, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { for _, uid := range uids { if uid == podUID { @@ -1083,7 +1083,7 @@ func waitAndVerifyS3PodAttachmentAndMountpointPod( node string, vol *testVolume, pod *testPod, -) (*crdv1.MountpointS3PodAttachment, *testPod) { +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(node, vol.pv), "") Expect(len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) @@ -1097,7 +1097,7 @@ func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion( vol *testVolume, pod *testPod, minVersion string, -) (*crdv1.MountpointS3PodAttachment, *testPod) { +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv), minVersion) Expect(len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) @@ -1105,7 +1105,7 @@ func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion( } // waitAndVerifyMountpointPodFromPodAttachment waits and verifies Mountpoint Pod scheduled for given `s3pa`, `pod` and `vol.` -func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1.MountpointS3PodAttachment, pod *testPod, vol *testVolume) *testPod { +func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1beta.MountpointS3PodAttachment, pod *testPod, vol *testVolume) *testPod { // Find the mpPodName where pod.UID exists in the value slice var mpPodName string podUID := string(pod.UID) @@ -1175,12 +1175,12 @@ func waitForObject[Obj client.Object](obj Obj, verifiers ...func(Gomega, Obj)) { func waitForS3PodAttachmentWithFields( expectedFields map[string]string, minResourceVersion string, - verifiers ...func(Gomega, *crdv1.MountpointS3PodAttachment), -) *crdv1.MountpointS3PodAttachment { - var matchedCR *crdv1.MountpointS3PodAttachment + verifiers ...func(Gomega, *crdv1beta.MountpointS3PodAttachment), +) *crdv1beta.MountpointS3PodAttachment { + var matchedCR *crdv1beta.MountpointS3PodAttachment Eventually(func(g Gomega) { - list := &crdv1.MountpointS3PodAttachmentList{} + list := &crdv1beta.MountpointS3PodAttachmentList{} g.Expect(k8sClient.List(ctx, list)).To(Succeed()) for i := range list.Items { @@ -1214,7 +1214,7 @@ func waitForS3PodAttachmentWithFields( } // matchesSpec checks whether MountpointS3PodAttachmentSpec matches `expected` fields -func matchesSpec(spec crdv1.MountpointS3PodAttachmentSpec, expected map[string]string) bool { +func matchesSpec(spec crdv1beta.MountpointS3PodAttachmentSpec, expected map[string]string) bool { specValues := map[string]string{ "NodeName": spec.NodeName, "PersistentVolumeName": spec.PersistentVolumeName, diff --git a/tests/controller/suite_test.go b/tests/controller/suite_test.go index f130ea35..160a936f 100644 --- a/tests/controller/suite_test.go +++ b/tests/controller/suite_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/go-logr/logr" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -70,7 +70,7 @@ var _ = BeforeSuite(func() { By("Bootstrapping test environment") - crdv1.AddToScheme(scheme.Scheme) + crdv1beta.AddToScheme(scheme.Scheme) testEnv = &envtest.Environment{ CRDInstallOptions: envtest.CRDInstallOptions{ Paths: []string{"../crd/mountpoints3podattachments-crd.yaml"}, @@ -157,20 +157,20 @@ func createMountpointPriorityClass() { } func IndexMountpointS3PodAttachmentFields(log logr.Logger, mgr manager.Manager) { - indexField(log, mgr, crdv1.FieldNodeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) - indexField(log, mgr, crdv1.FieldPersistentVolumeName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) - indexField(log, mgr, crdv1.FieldVolumeID, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) - indexField(log, mgr, crdv1.FieldMountOptions, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) - indexField(log, mgr, crdv1.FieldAuthenticationSource, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) - indexField(log, mgr, crdv1.FieldWorkloadFSGroup, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) - indexField(log, mgr, crdv1.FieldWorkloadServiceAccountName, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) - indexField(log, mgr, crdv1.FieldWorkloadNamespace, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) - indexField(log, mgr, crdv1.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) + indexField(log, mgr, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) + indexField(log, mgr, crdv1beta.FieldPersistentVolumeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) + indexField(log, mgr, crdv1beta.FieldVolumeID, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) + indexField(log, mgr, crdv1beta.FieldMountOptions, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) + indexField(log, mgr, crdv1beta.FieldAuthenticationSource, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) + indexField(log, mgr, crdv1beta.FieldWorkloadFSGroup, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) + indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) + indexField(log, mgr, crdv1beta.FieldWorkloadNamespace, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) + indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) } -func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1.MountpointS3PodAttachment) string) { - err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { - return []string{extractor(obj.(*crdv1.MountpointS3PodAttachment))} +func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1beta.MountpointS3PodAttachment) string) { + err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1beta.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { + return []string{extractor(obj.(*crdv1beta.MountpointS3PodAttachment))} }) if err != nil { log.Error(err, fmt.Sprintf("Failed to create a %s field indexer", field)) diff --git a/tests/crd/mountpoints3podattachments-crd.yaml b/tests/crd/mountpoints3podattachments-crd.yaml index c7f35fb9..91fcde59 100644 --- a/tests/crd/mountpoints3podattachments-crd.yaml +++ b/tests/crd/mountpoints3podattachments-crd.yaml @@ -17,7 +17,7 @@ spec: singular: mountpoints3podattachment scope: Cluster versions: - - name: v1 + - name: v1beta schema: openAPIV3Schema: description: MountpointS3PodAttachment is the Schema for the mountpoints3podattachments diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go index bb82adc8..25e09a52 100644 --- a/tests/e2e-kubernetes/testsuites/pod_sharing.go +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -7,7 +7,7 @@ import ( "strings" "time" - crdv1 "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" v1 "k8s.io/api/core/v1" @@ -148,7 +148,7 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive } func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume) { - var s3paList *crdv1.MountpointS3PodAttachmentList + var s3paList *crdv1beta.MountpointS3PodAttachmentList framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) if err != nil { @@ -183,7 +183,7 @@ func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, p } func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume, expectedFieldsFunc func(pod *v1.Pod) map[string]string) { - var s3paList *crdv1.MountpointS3PodAttachmentList + var s3paList *crdv1beta.MountpointS3PodAttachmentList framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) if err != nil { @@ -233,13 +233,13 @@ func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Fra } // Convert UnstructuredList to MountpointS3PodAttachmentList -func convertToCustomResourceList(list *unstructured.UnstructuredList) (*crdv1.MountpointS3PodAttachmentList, error) { - crList := &crdv1.MountpointS3PodAttachmentList{ - Items: make([]crdv1.MountpointS3PodAttachment, 0, len(list.Items)), +func convertToCustomResourceList(list *unstructured.UnstructuredList) (*crdv1beta.MountpointS3PodAttachmentList, error) { + crList := &crdv1beta.MountpointS3PodAttachmentList{ + Items: make([]crdv1beta.MountpointS3PodAttachment, 0, len(list.Items)), } for _, item := range list.Items { - cr := &crdv1.MountpointS3PodAttachment{} + cr := &crdv1beta.MountpointS3PodAttachment{} err := runtime.DefaultUnstructuredConverter.FromUnstructured(item.Object, cr) if err != nil { return nil, fmt.Errorf("failed to convert item to MountpointS3PodAttachment: %v", err) @@ -251,7 +251,7 @@ func convertToCustomResourceList(list *unstructured.UnstructuredList) (*crdv1.Mo } // matchesSpec checks whether MountpointS3PodAttachmentSpec matches `expected` fields -func matchesSpec(spec crdv1.MountpointS3PodAttachmentSpec, expected map[string]string) bool { +func matchesSpec(spec crdv1beta.MountpointS3PodAttachmentSpec, expected map[string]string) bool { specValues := map[string]string{ "NodeName": spec.NodeName, "PersistentVolumeName": spec.PersistentVolumeName, From 19ebfa19c39512984831b17b9bbff0515dd0384a Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 19:20:39 +0100 Subject: [PATCH 06/24] PodMounter fixes --- pkg/driver/driver.go | 2 +- pkg/driver/node/mounter/mppod_lock.go | 7 +- pkg/driver/node/mounter/pod_mounter.go | 173 ++++++++++---------- pkg/driver/node/mounter/pod_mounter_test.go | 2 +- 4 files changed, 99 insertions(+), 85 deletions(-) diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 21be5bed..9947acf6 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -153,7 +153,7 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error unmounter.CleanupDanglingMounts() - mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mountUtil, nil, kubernetesVersion) + mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mountUtil, nil, kubernetesVersion, nodeID) if err != nil { klog.Fatalln(err) } diff --git a/pkg/driver/node/mounter/mppod_lock.go b/pkg/driver/node/mounter/mppod_lock.go index e9b2f1fe..5bb6e5b9 100644 --- a/pkg/driver/node/mounter/mppod_lock.go +++ b/pkg/driver/node/mounter/mppod_lock.go @@ -1,6 +1,10 @@ package mounter -import "sync" +import ( + "sync" + + "k8s.io/klog/v2" +) // MPPodLock represents a reference-counted mutex lock for Mountpoint Pod. // It ensures synchronized access to pod-specific resources. @@ -44,6 +48,7 @@ func releaseMPPodLock(mpPodUID string) { lock, exists := mpPodLocks[mpPodUID] if !exists { // Should never happen + klog.Errorf("Attempted to release non-existent lock for Mountpoint Pod UID %s", mpPodUID) return } diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 9d546204..10464a22 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -47,10 +47,11 @@ type PodMounter struct { bindMountSyscall bindMountSyscall kubernetesVersion string credProvider *credentialprovider.Provider + nodeID string } // NewPodMounter creates a new [PodMounter] with given Kubernetes client. -func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, mount mount.Interface, mountSyscall mountSyscall, kubernetesVersion string) (*PodMounter, error) { +func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, mount mount.Interface, mountSyscall mountSyscall, kubernetesVersion, nodeID string) (*PodMounter, error) { return &PodMounter{ podWatcher: podWatcher, s3paCache: s3paCache, @@ -59,6 +60,7 @@ func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvi kubeletPath: util.KubeletPath(), mountSyscall: mountSyscall, kubernetesVersion: kubernetesVersion, + nodeID: nodeID, }, nil } @@ -106,19 +108,7 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin releaseMPPodLock(mpPodUID) }() - podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) - if err != nil { - klog.Errorf("Failed to create credentials directory for %q: %v", target, err) - return fmt.Errorf("Failed to create credentials directory for %q: %w", target, err) - } - - credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) - - _, _, err = pm.credProvider.Provide(ctx, credentialCtx) - if err != nil { - klog.Errorf("Failed to provide credentials for %s: %v\n%s", target, err, "TODO: pm.helpMessageForGettingMountpointLogs(pod)") - return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", target, err, "TODO: pm.helpMessageForGettingMountpointLogs(pod)") - } + pm.provideCredentials(ctx, podPath, credentialCtx) return nil } @@ -159,100 +149,102 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin return fmt.Errorf("Could not check if %q is already a mount point: %w", source, err) } - podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) - if err != nil { - klog.Errorf("Failed to create credentials directory for %q: %v", source, err) - return fmt.Errorf("Failed to create credentials directory for %q: %w", source, err) - } - - credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) - // Note that this part happens before `isMountPoint` check, as we want to update credentials even though // there is an existing mount point at `target`. - credEnv, authenticationSource, err := pm.credProvider.Provide(ctx, credentialCtx) + credEnv, authenticationSource, err := pm.provideCredentials(ctx, podPath, credentialCtx) if err != nil { klog.Errorf("Failed to provide credentials for %s: %v\n%s", source, err, pm.helpMessageForGettingMountpointLogs(pod)) return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", source, err, pm.helpMessageForGettingMountpointLogs(pod)) } if !isSourceMountPoint { - env := envprovider.Default() - env.Merge(credEnv) - - // Move `--aws-max-attempts` to env if provided - if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { - env.Set(envprovider.EnvMaxAttempts, maxAttempts) + err = pm.mountS3AtSource(ctx, source, pod, podPath, bucketName, credEnv, authenticationSource, args) + if err != nil { + return fmt.Errorf("Failed to mount at source %s: %v", source, err) } + } - args.Set(mountpoint.ArgUserAgentPrefix, UserAgent(authenticationSource, pm.kubernetesVersion)) + err = pm.bindMountSyscallWithDefault(source, target) + if err != nil { + klog.Errorf("Failed to bind mount %s to target %s: %v", source, target, err) + return fmt.Errorf("Failed to bind mount %s to target %s: %w", source, target, err) + } - podMountSockPath := mppod.PathOnHost(podPath, mppod.KnownPathMountSock) - podMountErrorPath := mppod.PathOnHost(podPath, mppod.KnownPathMountError) + return nil +} - klog.V(4).Infof("Mounting %s for %s", source, pod.Name) +func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, pod *corev1.Pod, podPath string, + bucketName string, credEnv envprovider.Environment, authenticationSource credentialprovider.AuthenticationSource, + args mountpoint.Args) error { + env := envprovider.Default() + env.Merge(credEnv) - fuseDeviceFD, err := pm.mountSyscallWithDefault(source, args) - if err != nil { - klog.Errorf("Failed to mount %s: %v", source, err) - return fmt.Errorf("Failed to mount %s: %w", source, err) - } + // Move `--aws-max-attempts` to env if provided + if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { + env.Set(envprovider.EnvMaxAttempts, maxAttempts) + } - // Remove the read-only argument from the list as mount-s3 does not support it when using FUSE - // file descriptor (we already pass MS_RDONLY flag during mount syscall in `pod_mounter_linux.go`) - if args.Has(mountpoint.ArgReadOnly) { - args.Remove(mountpoint.ArgReadOnly) - } + args.Set(mountpoint.ArgUserAgentPrefix, UserAgent(authenticationSource, pm.kubernetesVersion)) - // This will set to false in the success condition. This is set to `true` by default to - // ensure we don't leave `source` mounted if Mountpoint is not started to serve requests for it. - unmount := true - defer func() { - if unmount { - if err := pm.unmountTarget(source); err != nil { - klog.V(4).ErrorS(err, "Failed to unmount mounted source %s\n", source) - } else { - klog.V(4).Infof("Source %s unmounted successfully\n", source) - } - } - }() + podMountSockPath := mppod.PathOnHost(podPath, mppod.KnownPathMountSock) + podMountErrorPath := mppod.PathOnHost(podPath, mppod.KnownPathMountError) - // This function can either fail or successfully send mount options to Mountpoint Pod - in which - // Mountpoint Pod will get its own fd referencing the same underlying file description. - // In both case we need to close the fd in this process. - defer pm.closeFUSEDevFD(fuseDeviceFD) + klog.V(4).Infof("Mounting %s for %s", source, pod.Name) - // Remove old mount error file if exists - _ = os.Remove(podMountErrorPath) + fuseDeviceFD, err := pm.mountSyscallWithDefault(source, args) + if err != nil { + klog.Errorf("Failed to mount %s: %v", source, err) + return fmt.Errorf("Failed to mount %s: %w", source, err) + } - klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", pod.Name, podMountSockPath) + // Remove the read-only argument from the list as mount-s3 does not support it when using FUSE + // file descriptor (we already pass MS_RDONLY flag during mount syscall in `pod_mounter_linux.go`) + if args.Has(mountpoint.ArgReadOnly) { + args.Remove(mountpoint.ArgReadOnly) + } - err = mountoptions.Send(ctx, podMountSockPath, mountoptions.Options{ - Fd: fuseDeviceFD, - BucketName: bucketName, - Args: args.SortedList(), - Env: env.List(), - }) - if err != nil { - klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + // This will set to false in the success condition. This is set to `true` by default to + // ensure we don't leave `source` mounted if Mountpoint is not started to serve requests for it. + unmount := true + defer func() { + if unmount { + if err := pm.unmountTarget(source); err != nil { + klog.V(4).ErrorS(err, "Failed to unmount mounted source %s\n", source) + } else { + klog.V(4).Infof("Source %s unmounted successfully\n", source) + } } + }() - err = pm.waitForMount(ctx, source, pod.Name, podMountErrorPath) - if err != nil { - klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) - } + // This function can either fail or successfully send mount options to Mountpoint Pod - in which + // Mountpoint Pod will get its own fd referencing the same underlying file description. + // In both case we need to close the fd in this process. + defer pm.closeFUSEDevFD(fuseDeviceFD) + + // Remove old mount error file if exists + _ = os.Remove(podMountErrorPath) - // Mountpoint successfully started, so don't unmount the filesystem - unmount = false + klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", pod.Name, podMountSockPath) + + err = mountoptions.Send(ctx, podMountSockPath, mountoptions.Options{ + Fd: fuseDeviceFD, + BucketName: bucketName, + Args: args.SortedList(), + Env: env.List(), + }) + if err != nil { + klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) } - err = pm.bindMountSyscallWithDefault(source, target) + err = pm.waitForMount(ctx, source, pod.Name, podMountErrorPath) if err != nil { - klog.Errorf("Failed to bind mount %s to target %s: %v", source, target, err) - return fmt.Errorf("Failed to bind mount %s to target %s: %w", source, target, err) + klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) } + // Mountpoint successfully started, so don't unmount the filesystem + unmount = false return nil } @@ -374,6 +366,18 @@ func (pm *PodMounter) verifyOrSetupMountTarget(target string) error { return err } +func (pm *PodMounter) provideCredentials(ctx context.Context, podPath string, credentialCtx credentialprovider.ProvideContext) (envprovider.Environment, credentialprovider.AuthenticationSource, error) { + podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) + if err != nil { + klog.Errorf("Failed to create credentials directory: %v", err) + return nil, "", fmt.Errorf("Failed to create credentials directory: %w", err) + } + + credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) + + return pm.credProvider.Provide(ctx, credentialCtx) +} + // ensureCredentialsDirExists ensures credentials dir for `podPath` is exists. // It returns credentials dir and any error. func (pm *PodMounter) ensureCredentialsDirExists(podPath string) (string, error) { @@ -433,9 +437,11 @@ func (pm *PodMounter) helpMessageForGettingMountpointLogs(pod *corev1.Pod) strin return fmt.Sprintf("You can see Mountpoint logs by running: `kubectl logs -n %s %s`. If the Mountpoint Pod already restarted, you can also pass `--previous` to get logs from the previous run.", pod.Namespace, pod.Name) } +// getS3PodAttachmentWithRetry retrieves a MountpointS3PodAttachment resource that matches the given volume and credential context. +// It continuously retries the operation until either a matching attachment is found or the context is canceled. func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeName string, credentialCtx credentialprovider.ProvideContext, fsGroup, pvMountOptions string) (*crdv1beta.MountpointS3PodAttachment, string, error) { fieldFilters := client.MatchingFields{ - crdv1beta.FieldNodeName: os.Getenv("CSI_NODE_NAME"), // TODO + crdv1beta.FieldNodeName: pm.nodeID, crdv1beta.FieldPersistentVolumeName: volumeName, crdv1beta.FieldVolumeID: credentialCtx.VolumeID, crdv1beta.FieldMountOptions: pvMountOptions, @@ -445,6 +451,9 @@ func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeNam if credentialCtx.AuthenticationSource == credentialprovider.AuthenticationSourcePod { fieldFilters[crdv1beta.FieldWorkloadNamespace] = credentialCtx.PodNamespace fieldFilters[crdv1beta.FieldWorkloadServiceAccountName] = credentialCtx.ServiceAccountName + // Note that we intentionally do not include `FieldWorkloadServiceAccountIAMRoleARN` to list filters because + // CSI Driver Node does not know which role ARN to use (if any). + // Role ARN is determined by reconciler and passed to node via MountpointS3PodAttachment. } for { diff --git a/pkg/driver/node/mounter/pod_mounter_test.go b/pkg/driver/node/mounter/pod_mounter_test.go index 6f35c0e7..76840d32 100644 --- a/pkg/driver/node/mounter/pod_mounter_test.go +++ b/pkg/driver/node/mounter/pod_mounter_test.go @@ -131,7 +131,7 @@ func setup(t *testing.T) *testCtx { err = podWatcher.Start(stopCh) assert.NoError(t, err) - podMounter, err := mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mount, mountSyscall, testK8sVersion) + podMounter, err := mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mount, mountSyscall, testK8sVersion, nodeName) assert.NoError(t, err) testCtx.podMounter = podMounter From 161ea2d61fe2632977364155752de3239dd5a517 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 19:33:43 +0100 Subject: [PATCH 07/24] Add mppod_lock unit test --- pkg/driver/node/mounter/mppod_lock_test.go | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 pkg/driver/node/mounter/mppod_lock_test.go diff --git a/pkg/driver/node/mounter/mppod_lock_test.go b/pkg/driver/node/mounter/mppod_lock_test.go new file mode 100644 index 00000000..d4361019 --- /dev/null +++ b/pkg/driver/node/mounter/mppod_lock_test.go @@ -0,0 +1,59 @@ +package mounter + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestGetMPPodLock(t *testing.T) { + // Clear the map before testing + mpPodLocks = make(map[string]*MPPodLock) + + t.Run("New lock creation", func(t *testing.T) { + podUID := "pod1" + lock := getMPPodLock(podUID) + + assert.Equals(t, 1, lock.refCount) + assert.Equals(t, 1, len(mpPodLocks)) + }) + + t.Run("Existing lock retrieval", func(t *testing.T) { + podUID := "pod2" + firstLock := getMPPodLock(podUID) + secondLock := getMPPodLock(podUID) + + if firstLock != secondLock { + t.Fatal("Expected to get the same lock instance") + } + assert.Equals(t, 2, firstLock.refCount) + }) +} + +func TestReleaseMPPodLock(t *testing.T) { + // Clear the map before testing + mpPodLocks = make(map[string]*MPPodLock) + + t.Run("Release existing lock", func(t *testing.T) { + podUID := "pod3" + getMPPodLock(podUID) + getMPPodLock(podUID) + + releaseMPPodLock(podUID) + + lock, exists := mpPodLocks[podUID] + assert.Equals(t, true, exists) + assert.Equals(t, 1, lock.refCount) + + releaseMPPodLock(podUID) + + _, exists = mpPodLocks[podUID] + assert.Equals(t, false, exists) + }) + + t.Run("Release non-existent lock", func(t *testing.T) { + podUID := "non-existent-pod" + releaseMPPodLock(podUID) + // This test passes if no panic occurs + }) +} From 3fd56c0af83455bd07702696090bb030fcbdd8a2 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Mon, 28 Apr 2025 19:39:12 +0100 Subject: [PATCH 08/24] Add expectations unit test --- .../csicontroller/expectations_test.go | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 cmd/aws-s3-csi-controller/csicontroller/expectations_test.go diff --git a/cmd/aws-s3-csi-controller/csicontroller/expectations_test.go b/cmd/aws-s3-csi-controller/csicontroller/expectations_test.go new file mode 100644 index 00000000..73f6b4f7 --- /dev/null +++ b/cmd/aws-s3-csi-controller/csicontroller/expectations_test.go @@ -0,0 +1,105 @@ +package csicontroller + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func TestDeriveExpectationKeyFromFilters(t *testing.T) { + tests := []struct { + name string + fieldFilters client.MatchingFields + want string + }{ + { + name: "empty filters", + fieldFilters: client.MatchingFields{}, + want: "", + }, + { + name: "single filter", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + want: "key1=value1;", + }, + { + name: "multiple filters", + fieldFilters: client.MatchingFields{ + "key2": "value2", + "key1": "value1", + "key3": "value3", + }, + want: "key1=value1;key2=value2;key3=value3;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := deriveExpectationKeyFromFilters(tt.fieldFilters) + assert.Equals(t, tt.want, got) + }) + } +} + +func TestExpectations(t *testing.T) { + tests := []struct { + name string + fieldFilters client.MatchingFields + operations func(*expectations) + wantPending bool + }{ + { + name: "set and check pending", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + operations: func(e *expectations) { + e.setPending(client.MatchingFields{"key1": "value1"}) + }, + wantPending: true, + }, + { + name: "set and clear pending", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + operations: func(e *expectations) { + e.setPending(client.MatchingFields{"key1": "value1"}) + e.clear(client.MatchingFields{"key1": "value1"}) + }, + wantPending: false, + }, + { + name: "check non-existent pending", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + operations: func(e *expectations) {}, + wantPending: false, + }, + { + name: "multiple operations", + fieldFilters: client.MatchingFields{ + "key1": "value1", + "key2": "value2", + }, + operations: func(e *expectations) { + e.setPending(client.MatchingFields{"key1": "value1", "key2": "value2"}) + e.clear(client.MatchingFields{"key1": "value1"}) // Different key, shouldn't affect the test + }, + wantPending: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := newExpectations() + tt.operations(e) + got := e.isPending(tt.fieldFilters) + assert.Equals(t, tt.wantPending, got) + }) + } +} From faa33e65409868b85a6aa3d8bcea54b9348c5a96 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Tue, 29 Apr 2025 11:05:51 +0100 Subject: [PATCH 09/24] Add PodMounter unit tests --- pkg/driver/driver.go | 3 +- pkg/driver/node/mounter/fake_cache.go | 7 +- pkg/driver/node/mounter/mounter.go | 11 +- pkg/driver/node/mounter/pod_mounter.go | 61 +++--- pkg/driver/node/mounter/pod_mounter_test.go | 199 ++++++++++++++++++-- 5 files changed, 230 insertions(+), 51 deletions(-) diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 9947acf6..78a9d329 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -153,7 +153,8 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error unmounter.CleanupDanglingMounts() - mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mountUtil, nil, kubernetesVersion, nodeID) + mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mountUtil, nil, nil, nil, + kubernetesVersion, nodeID, mounter.SourceMountDir) if err != nil { klog.Fatalln(err) } diff --git a/pkg/driver/node/mounter/fake_cache.go b/pkg/driver/node/mounter/fake_cache.go index 2346be02..d37a2c74 100644 --- a/pkg/driver/node/mounter/fake_cache.go +++ b/pkg/driver/node/mounter/fake_cache.go @@ -3,18 +3,23 @@ package mounter import ( "context" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "k8s.io/apimachinery/pkg/runtime/schema" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" ) -type FakeCache struct{} +type FakeCache struct { + TestItems []crdv1beta.MountpointS3PodAttachment +} func (f *FakeCache) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { return nil } func (f *FakeCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + s3paList := list.(*crdv1beta.MountpointS3PodAttachmentList) + s3paList.Items = f.TestItems return nil } diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index 7258f3a2..06942b98 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -31,8 +31,6 @@ type Mounter interface { IsMountPoint(target string) (bool, error) } -// Internal S3 CSI Driver directory for source mount points -const SourceMountDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/mnt/" const MountS3PathEnv = "MOUNT_S3_PATH" const defaultMountS3Path = "/usr/bin/mount-s3" @@ -71,11 +69,12 @@ func isMountPoint(mounter mount.Interface, target string) (bool, error) { } // findSourceMountPoint locates the source S3 mount point for a given target path by comparing -// device IDs and inodes with all S3 mount points at driver source directory `SourceMountDir`. +// device IDs and inodes with all S3 mount points at driver source directory `sourceMountDir`. // // Parameters: // - mounter: Interface providing mounting operations and mount point listing capabilities // - target: The target path whose source mount point needs to be found +// - sourceMountDir: directory where to find source mount points // // Returns: // - string: The path of the source mount point if found @@ -83,9 +82,9 @@ func isMountPoint(mounter mount.Interface, target string) (bool, error) { // // The function works by: // 1. Getting the device ID and inode of the target path -// 2. Listing all mount points in the system that has "mountpoint-s3" as device name and prefix `SourceMountDir` +// 2. Listing all mount points in the system that has "mountpoint-s3" as device name and prefix `sourceMountDir` // 3. Finding a mount point that matches both the device ID and inode of the target -func findSourceMountPoint(mounter mount.Interface, target string) (string, error) { +func findSourceMountPoint(mounter mount.Interface, target, sourceMountDir string) (string, error) { if mounter == nil { return "", fmt.Errorf("mounter interface cannot be nil") } @@ -109,7 +108,7 @@ func findSourceMountPoint(mounter mount.Interface, target string) (string, error } for _, mountPoint := range mountPoints { - if mountPoint.Device != mountpointDeviceName || !strings.HasPrefix(mountPoint.Path, SourceMountDir) { + if mountPoint.Device != mountpointDeviceName || !strings.HasPrefix(mountPoint.Path, sourceMountDir) { continue } diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 10464a22..98b6480f 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -28,6 +28,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) +// Internal S3 CSI Driver directory for source mount points +const SourceMountDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/mnt/" + // targetDirPerm is the permission to use while creating target directory if its not exists. const targetDirPerm = fs.FileMode(0755) @@ -36,31 +39,39 @@ const targetDirPerm = fs.FileMode(0755) // This is mainly exposed for testing, in production platform-native function (`mountSyscallDefault`) will be used. type mountSyscall func(target string, args mountpoint.Args) (fd int, err error) type bindMountSyscall func(source, target string) (err error) +type sourceMountPointFinder func(mounter mount.Interface, target, sourceMountDir string) (string, error) // A PodMounter is a [Mounter] that mounts Mountpoint on pre-created Kubernetes Pod running in the same node. type PodMounter struct { - podWatcher *watcher.Watcher - s3paCache cache.Cache - mount mount.Interface - kubeletPath string - mountSyscall mountSyscall - bindMountSyscall bindMountSyscall - kubernetesVersion string - credProvider *credentialprovider.Provider - nodeID string + podWatcher *watcher.Watcher + s3paCache cache.Cache + mount mount.Interface + kubeletPath string + sourceMountDir string + mountSyscall mountSyscall + bindMountSyscall bindMountSyscall + sourceMountPointFinder sourceMountPointFinder + kubernetesVersion string + credProvider *credentialprovider.Provider + nodeID string } // NewPodMounter creates a new [PodMounter] with given Kubernetes client. -func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, mount mount.Interface, mountSyscall mountSyscall, kubernetesVersion, nodeID string) (*PodMounter, error) { +func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, mount mount.Interface, + mountSyscall mountSyscall, bindMountSyscall bindMountSyscall, sourceMountPointFinder sourceMountPointFinder, kubernetesVersion, nodeID, + sourceMountDir string) (*PodMounter, error) { return &PodMounter{ - podWatcher: podWatcher, - s3paCache: s3paCache, - credProvider: credProvider, - mount: mount, - kubeletPath: util.KubeletPath(), - mountSyscall: mountSyscall, - kubernetesVersion: kubernetesVersion, - nodeID: nodeID, + podWatcher: podWatcher, + s3paCache: s3paCache, + credProvider: credProvider, + mount: mount, + kubeletPath: util.KubeletPath(), + sourceMountDir: sourceMountDir, + mountSyscall: mountSyscall, + bindMountSyscall: bindMountSyscall, + sourceMountPointFinder: sourceMountPointFinder, + kubernetesVersion: kubernetesVersion, + nodeID: nodeID, }, nil } @@ -93,7 +104,7 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin if isTargetMountPoint { klog.V(4).Infof("Target path %q is already mounted. Only refreshing credentials.", target) - source, err := pm.findSourceMountPoint(target) + source, err := pm.findSourceMountPointWithDefault(target) if err != nil { klog.Errorf("Failed to find source mount point for %q: %v", target, err) return fmt.Errorf("Failed to find source mount point for %q: %w", target, err) @@ -138,7 +149,7 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin releaseMPPodLock(mpPodUID) }() - source := filepath.Join(SourceMountDir, mpPodUID) + source := filepath.Join(pm.sourceMountDir, mpPodUID) err = pm.verifyOrSetupMountTarget(source) if err != nil { return fmt.Errorf("Failed to verify source path can be used as a mount point %q: %w", source, err) @@ -264,9 +275,13 @@ func (pm *PodMounter) IsMountPoint(target string) (bool, error) { return isMountPoint(pm.mount, target) } -// findSourceMountPoint calls `findSourceMountPoint` on `target`. -func (pm *PodMounter) findSourceMountPoint(target string) (string, error) { - return findSourceMountPoint(pm.mount, target) +// findSourceMountPointWithDefault calls `findSourceMountPoint` on `target`. +func (pm *PodMounter) findSourceMountPointWithDefault(target string) (string, error) { + if pm.sourceMountPointFinder != nil { + return pm.sourceMountPointFinder(pm.mount, target, pm.sourceMountDir) + } + + return findSourceMountPoint(pm.mount, target, pm.sourceMountDir) } // waitForMountpointPod waits until Mountpoint Pod for given `podName` is in `Running` state. diff --git a/pkg/driver/node/mounter/pod_mounter_test.go b/pkg/driver/node/mounter/pod_mounter_test.go index 76840d32..4af06c56 100644 --- a/pkg/driver/node/mounter/pod_mounter_test.go +++ b/pkg/driver/node/mounter/pod_mounter_test.go @@ -19,6 +19,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "k8s.io/mount-utils" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" @@ -40,13 +41,15 @@ type testCtx struct { podMounter *mounter.PodMounter - client *fake.Clientset - mount *mount.FakeMounter - s3paCache *mounter.FakeCache - mountSyscall func(target string, args mountpoint.Args) (fd int, err error) + client *fake.Clientset + mount *mount.FakeMounter + s3paCache *mounter.FakeCache + mountSyscall func(target string, args mountpoint.Args) (fd int, err error) + mountBindSyscall func(source, target string) (err error) bucketName string kubeletPath string + sourcePath string targetPath string podUID string volumeID string @@ -54,6 +57,8 @@ type testCtx struct { nodeName string fsGroup string pvMountOptions string + mpPodName string + mpPodUID string } func setup(t *testing.T) *testCtx { @@ -67,8 +72,12 @@ func setup(t *testing.T) *testCtx { // to overcome `bind: invalid argument`. t.Chdir(kubeletPath) + sourceMountDir := t.TempDir() + bucketName := "test-bucket" podUID := uuid.New().String() + mpPodName := "test-mppod" + mpPodUID := uuid.New().String() volumeID := "s3-csi-driver-volume" pvName := "s3-csi-driver-pv" nodeName := "test-node" @@ -89,15 +98,18 @@ func setup(t *testing.T) *testCtx { parentDir, err := filepath.EvalSymlinks(filepath.Dir(targetPath)) assert.NoError(t, err) targetPath = filepath.Join(parentDir, filepath.Base(targetPath)) + parentDir, err = filepath.EvalSymlinks(filepath.Dir(sourceMountDir)) + assert.NoError(t, err) + sourceMountDir = filepath.Join(parentDir, filepath.Base(sourceMountDir)) client := fake.NewClientset() - mount := mount.NewFakeMounter(nil) + fakeMounter := mount.NewFakeMounter(nil) testCtx := &testCtx{ t: t, ctx: ctx, client: client, - mount: mount, + mount: fakeMounter, bucketName: bucketName, kubeletPath: kubeletPath, targetPath: targetPath, @@ -108,17 +120,61 @@ func setup(t *testing.T) *testCtx { fsGroup: fsGroup, s3paCache: s3paCache, pvMountOptions: pvMountOptions, + mpPodName: mpPodName, + mpPodUID: mpPodUID, + sourcePath: filepath.Join(sourceMountDir, mpPodUID), + } + + testCrd := crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: testCtx.nodeName, + PersistentVolumeName: testCtx.pvName, + VolumeID: testCtx.volumeID, + WorkloadFSGroup: testCtx.fsGroup, + MountOptions: testCtx.pvMountOptions, + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + testCtx.mpPodName: {testCtx.podUID}, + }, + }, } + testCtx.s3paCache.TestItems = []crdv1beta.MountpointS3PodAttachment{testCrd} mountSyscall := func(target string, args mountpoint.Args) (fd int, err error) { if testCtx.mountSyscall != nil { return testCtx.mountSyscall(target, args) } - mount.Mount("mountpoint-s3", target, "fuse", nil) + fakeMounter.Mount("mountpoint-s3", target, "fuse", nil) return int(mountertest.OpenDevNull(t).Fd()), nil } + mountBindSyscall := func(source, target string) (err error) { + if testCtx.mountBindSyscall != nil { + return testCtx.mountBindSyscall(source, target) + } + + fakeMounter.Mount(source, target, "fuse", []string{"bind"}) + return nil + } + + findSourceMountPoint := func(mounter mount.Interface, target, sourceMountDir string) (string, error) { + fakeMounter := mounter.(*mount.FakeMounter) + mountPoints, err := fakeMounter.List() + if err != nil { + return "", fmt.Errorf("failed to list mount points: %w", err) + } + + for _, mp := range mountPoints { + if mp.Device == "mountpoint-s3" && + strings.HasPrefix(mp.Path, sourceMountDir) && + mp.Path != target { + return mp.Path, nil + } + } + + return "", fmt.Errorf("no source mount point found for target %q", target) + } + credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { return dummyIMDSRegion, nil }) @@ -131,7 +187,8 @@ func setup(t *testing.T) *testCtx { err = podWatcher.Start(stopCh) assert.NoError(t, err) - podMounter, err := mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mount, mountSyscall, testK8sVersion, nodeName) + podMounter, err := mounter.NewPodMounter(podWatcher, s3paCache, credProvider, fakeMounter, mountSyscall, + mountBindSyscall, findSourceMountPoint, testK8sVersion, nodeName, sourceMountDir) assert.NoError(t, err) testCtx.podMounter = podMounter @@ -244,11 +301,11 @@ func TestPodMounter(t *testing.T) { assert.Equals(t, true, credDirInfo.IsDir()) assert.Equals(t, credentialprovider.CredentialDirPerm, credDirInfo.Mode().Perm()) }) - t.Run("Does not duplicate mounts if target is already mounted", func(t *testing.T) { testCtx := setup(t) var mountCount atomic.Int32 + var bindMountCount atomic.Int32 testCtx.mountSyscall = func(target string, args mountpoint.Args) (fd int, err error) { mountCount.Add(1) @@ -256,6 +313,12 @@ func TestPodMounter(t *testing.T) { return int(mountertest.OpenDevNull(t).Fd()), nil } + testCtx.mountBindSyscall = func(source, target string) (err error) { + bindMountCount.Add(1) + testCtx.mount.Mount(source, target, "fuse", []string{"bind"}) + return nil + } + go func() { mpPod := createMountpointPod(testCtx) mpPod.run() @@ -263,17 +326,99 @@ func TestPodMounter(t *testing.T) { }() for range 5 { - err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ - VolumeID: testCtx.volumeID, - PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) + err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, + credentialprovider.ProvideContext{ + VolumeID: testCtx.volumeID, + PodID: testCtx.podUID, + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) } assert.Equals(t, int32(1), mountCount.Load()) + assert.Equals(t, int32(1), bindMountCount.Load()) }) - t.Run("Unmounts target if Mountpoint Pod does not receive mount options", func(t *testing.T) { + t.Run("Re-uses the same source mount for different targets if they share same Mountpoint Pod", func(t *testing.T) { + // First Pod + testCtx := setup(t) + + ok, _ := testCtx.podMounter.IsMountPoint(testCtx.targetPath) + assert.Equals(t, false, ok) + + var mountCount atomic.Int32 + var bindMountCount atomic.Int32 + + testCtx.mountSyscall = func(target string, args mountpoint.Args) (fd int, err error) { + mountCount.Add(1) + testCtx.mount.Mount("mountpoint-s3", target, "fuse", nil) + return int(mountertest.OpenDevNull(t).Fd()), nil + } + testCtx.mountBindSyscall = func(source, target string) (err error) { + bindMountCount.Add(1) + testCtx.mount.Mount(source, target, "fuse", []string{"bind"}) + return nil + } + + go func() { + mpPod := createMountpointPod(testCtx) + mpPod.run() + mpPod.receiveMountOptions(testCtx.ctx) + }() + + err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ + VolumeID: testCtx.volumeID, + PodID: testCtx.podUID, + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) + assert.NoError(t, err) + + ok, err = testCtx.podMounter.IsMountPoint(testCtx.sourcePath) + assert.NoError(t, err) + assert.Equals(t, true, ok) + ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) + assert.NoError(t, err) + assert.Equals(t, true, ok) + + // Second Pod + testCtx.podUID = uuid.New().String() + targetPath2 := filepath.Join( + testCtx.kubeletPath, + fmt.Sprintf("pods/%s/volumes/kubernetes.io~csi/%s/mount", testCtx.podUID, testCtx.pvName), + ) + err = os.MkdirAll(filepath.Dir(targetPath2), 0750) + assert.NoError(t, err) + parentDir, err := filepath.EvalSymlinks(filepath.Dir(targetPath2)) + assert.NoError(t, err) + targetPath2 = filepath.Join(parentDir, filepath.Base(targetPath2)) + testCtx.targetPath = targetPath2 + testCrd2 := crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: testCtx.nodeName, + PersistentVolumeName: testCtx.pvName, + VolumeID: testCtx.volumeID, + WorkloadFSGroup: testCtx.fsGroup, + MountOptions: testCtx.pvMountOptions, + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + testCtx.mpPodName: {testCtx.podUID}, + }, + }, + } + testCtx.s3paCache.TestItems = []crdv1beta.MountpointS3PodAttachment{testCrd2} + + err = testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ + VolumeID: testCtx.volumeID, + PodID: testCtx.podUID, + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) + assert.NoError(t, err) + + ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) + assert.NoError(t, err) + assert.Equals(t, true, ok) + + assert.Equals(t, int32(1), mountCount.Load()) + assert.Equals(t, int32(2), bindMountCount.Load()) + }) + + t.Run("Unmounts source if Mountpoint Pod does not receive mount options", func(t *testing.T) { testCtx := setup(t) go func() { @@ -294,14 +439,19 @@ func TestPodMounter(t *testing.T) { t.Errorf("mount shouldn't succeeded if Mountpoint does not receive the mount options") } - ok, err := testCtx.mount.IsMountPoint(testCtx.targetPath) + ok, err := testCtx.mount.IsMountPoint(testCtx.sourcePath) + assert.NoError(t, err) + if ok { + t.Errorf("it should unmount the source path if Mountpoint does not receive the mount options") + } + ok, err = testCtx.mount.IsMountPoint(testCtx.targetPath) assert.NoError(t, err) if ok { - t.Errorf("it should unmount the target path if Mountpoint does not receive the mount options") + t.Errorf("it should not bind mount the target path if Mountpoint does not receive the mount options") } }) - t.Run("Unmounts target if Mountpoint Pod fails to start", func(t *testing.T) { + t.Run("Unmounts source if Mountpoint Pod fails to start", func(t *testing.T) { testCtx := setup(t) testCtx.mountSyscall = func(target string, args mountpoint.Args) (fd int, err error) { @@ -328,10 +478,15 @@ func TestPodMounter(t *testing.T) { t.Errorf("mount shouldn't succeeded if Mountpoint fails to start") } - ok, err := testCtx.mount.IsMountPoint(testCtx.targetPath) + ok, err := testCtx.mount.IsMountPoint(testCtx.sourcePath) assert.NoError(t, err) if ok { - t.Errorf("it should unmount the target path if Mountpoint fails to start") + t.Errorf("it should unmount the source path if Mountpoint fails to start") + } + ok, err = testCtx.mount.IsMountPoint(testCtx.targetPath) + assert.NoError(t, err) + if ok { + t.Errorf("it should not bind mount the target path if Mountpoint fails to start") } }) @@ -394,6 +549,9 @@ func TestPodMounter(t *testing.T) { }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) + ok, err = testCtx.podMounter.IsMountPoint(testCtx.sourcePath) + assert.NoError(t, err) + assert.Equals(t, true, ok) ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) assert.NoError(t, err) assert.Equals(t, true, ok) @@ -442,7 +600,8 @@ func createMountpointPod(testCtx *testCtx) *mountpointPod { pod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(uuid.New().String()), + UID: types.UID(testCtx.mpPodUID), + Name: testCtx.mpPodName, }, } pod, err := testCtx.client.CoreV1().Pods(mountpointPodNamespace).Create(context.TODO(), pod, metav1.CreateOptions{}) From 5ae555fa5ee7060a639d2d3fbc19608e62517949 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:46:01 +0100 Subject: [PATCH 10/24] Add PodUnmounter unit tests --- pkg/driver/driver.go | 2 +- pkg/driver/node/mounter/pod_unmounter.go | 43 +-- pkg/driver/node/mounter/pod_unmounter_test.go | 282 ++++++++++++++++++ 3 files changed, 307 insertions(+), 20 deletions(-) create mode 100644 pkg/driver/node/mounter/pod_unmounter_test.go diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 78a9d329..39ed663a 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -141,7 +141,7 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error } }() - unmounter := mounter.NewPodUnmounter(nodeID, mountUtil, podWatcher, s3paCache, credProvider) + unmounter := mounter.NewPodUnmounter(nodeID, mountUtil, podWatcher, s3paCache, credProvider, mounter.SourceMountDir) s3podAttachmentInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ UpdateFunc: unmounter.HandleS3PodAttachmentUpdate, diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go index 6ca86762..bdaec293 100644 --- a/pkg/driver/node/mounter/pod_unmounter.go +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -18,11 +18,13 @@ import ( ) type PodUnmounter struct { - nodeID string - mountUtil mount.Interface - podWatcher *watcher.Watcher - s3paCache cache.Cache - credProvider *credentialprovider.Provider + nodeID string + mountUtil mount.Interface + kubeletPath string + sourceMountDir string + podWatcher *watcher.Watcher + s3paCache cache.Cache + credProvider *credentialprovider.Provider } func NewPodUnmounter( @@ -31,13 +33,16 @@ func NewPodUnmounter( podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, + sourceMountDir string, ) *PodUnmounter { return &PodUnmounter{ - nodeID: nodeID, - mountUtil: mountUtil, - podWatcher: podWatcher, - s3paCache: s3paCache, - credProvider: credProvider, + nodeID: nodeID, + mountUtil: mountUtil, + kubeletPath: util.KubeletPath(), + sourceMountDir: sourceMountDir, + podWatcher: podWatcher, + s3paCache: s3paCache, + credProvider: credProvider, } } @@ -55,16 +60,16 @@ func (u *PodUnmounter) HandleS3PodAttachmentUpdate(old, new any) { } func (u *PodUnmounter) unmountSourceForPod(s3pa *crdv1beta.MountpointS3PodAttachment, mpPodName string) { - klog.Infof("Found Mountpoint pod with zero workload pods, unmounting it - %s", mpPodName) + klog.Infof("Found Mountpoint Pod with zero workload pods, unmounting it - %s", mpPodName) mpPod, err := u.podWatcher.Get(mpPodName) if err != nil { - klog.Infof("failed to find mpPodName during update event") + klog.Infof("failed to find Mountpoint Pod %s during update event", mpPodName) return } mpPodUID := string(mpPod.UID) - podPath := filepath.Join(util.KubeletPath(), "pods", mpPodUID) - source := filepath.Join(SourceMountDir, mpPodUID) + podPath := filepath.Join(u.kubeletPath, "pods", mpPodUID) + source := filepath.Join(u.sourceMountDir, mpPodUID) if err := u.writeExitFile(podPath, mpPod); err != nil { return @@ -107,7 +112,7 @@ func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1beta.MountpointS3PodAttachm err := u.credProvider.Cleanup(credentialprovider.CleanupContext{ VolumeID: s3pa.Spec.VolumeID, PodID: mpPodUID, - WritePath: filepath.Join(util.KubeletPath(), "pods", mpPodUID), + WritePath: filepath.Join(u.kubeletPath, "pods", mpPodUID), }) if err != nil { klog.Errorf("Failed to clean up credentials for %s: %v", source, err) @@ -117,9 +122,9 @@ func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1beta.MountpointS3PodAttachm } func (u *PodUnmounter) CleanupDanglingMounts() { - entries, err := os.ReadDir(SourceMountDir) + entries, err := os.ReadDir(u.sourceMountDir) if err != nil { - klog.Errorf("Failed to read source mount directory (`%s`): %v", SourceMountDir, err) + klog.Errorf("Failed to read source mount directory (`%s`): %v", u.sourceMountDir, err) return } @@ -129,7 +134,7 @@ func (u *PodUnmounter) CleanupDanglingMounts() { } mpPodUID := file.Name() - source := filepath.Join(SourceMountDir, mpPodUID) + source := filepath.Join(u.sourceMountDir, mpPodUID) // Try to find corresponding pod mpPod, err := u.findPodByUID(mpPodUID) if err != nil { @@ -149,7 +154,7 @@ func (u *PodUnmounter) CleanupDanglingMounts() { if !hasWorkloads { klog.Infof("Found dangling mount for Mountpoint Pod %s (UID: %s), cleaning up", mpPod.Name, mpPodUID) - podPath := filepath.Join(util.KubeletPath(), "pods", mpPodUID) + podPath := filepath.Join(u.kubeletPath, "pods", mpPodUID) if err := u.writeExitFile(podPath, mpPod); err != nil { return } diff --git a/pkg/driver/node/mounter/pod_unmounter_test.go b/pkg/driver/node/mounter/pod_unmounter_test.go new file mode 100644 index 00000000..974fdea1 --- /dev/null +++ b/pkg/driver/node/mounter/pod_unmounter_test.go @@ -0,0 +1,282 @@ +package mounter_test + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" + + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/mount-utils" +) + +const ( + nodeName = "test-node" +) + +func setupPodWatcher(t *testing.T, pods ...*corev1.Pod) (*watcher.Watcher, *fake.Clientset) { + client := fake.NewClientset() + podWatcher := watcher.New(client, mountpointPodNamespace, nodeName, 10*time.Second) + stopCh := make(chan struct{}) + t.Cleanup(func() { + close(stopCh) + }) + + for _, pod := range pods { + if pod != nil { + _, err := client.CoreV1().Pods(mountpointPodNamespace).Create(context.Background(), pod, metav1.CreateOptions{}) + assert.NoError(t, err) + } + } + + err := podWatcher.Start(stopCh) + assert.NoError(t, err) + + return podWatcher, client +} + +func countUnmountCalls(mounter *mount.FakeMounter) int { + unmountCalls := 0 + for _, action := range mounter.GetLog() { + if action.Action == mount.FakeActionUnmount { + unmountCalls++ + } + } + return unmountCalls +} + +func TestHandleS3PodAttachmentUpdate(t *testing.T) { + tests := []struct { + name string + nodeID string + s3pa *crdv1beta.MountpointS3PodAttachment + pod *corev1.Pod + unmountError error + expectUnmount bool + }{ + { + name: "different node", + nodeID: "node1", + s3pa: &crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: "node2", + }, + }, + expectUnmount: false, + }, + { + name: "same node with empty workload", + nodeID: nodeName, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + }, + s3pa: &crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: nodeName, + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + "pod1": {}, + }, + }, + }, + expectUnmount: true, + }, + { + name: "same node with empty workload and unmount error", + nodeID: nodeName, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + }, + s3pa: &crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: nodeName, + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + "pod1": {}, + }, + }, + }, + unmountError: errors.New("unmount error"), + expectUnmount: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + kubeletPath := t.TempDir() + t.Setenv("KUBELET_PATH", kubeletPath) + t.Chdir(kubeletPath) + + sourceMountDir := t.TempDir() + + podWatcher, client := setupPodWatcher(t, tt.pod) + + if tt.pod != nil { + podPath := filepath.Join(kubeletPath, "pods", string(tt.pod.UID)) + commDir := mppod.PathOnHost(podPath) + err := os.MkdirAll(commDir, 0750) + assert.NoError(t, err) + + err = os.MkdirAll(filepath.Join(sourceMountDir, string(tt.pod.UID)), 0750) + assert.NoError(t, err) + } + + fakeMounter := mount.NewFakeMounter(nil) + if tt.unmountError != nil { + fakeMounter.UnmountFunc = func(path string) error { + return tt.unmountError + } + } + + credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { + return dummyIMDSRegion, nil + }) + s3paCache := &mounter.FakeCache{} + + unmounter := mounter.NewPodUnmounter(tt.nodeID, fakeMounter, podWatcher, s3paCache, credProvider, sourceMountDir) + unmounter.HandleS3PodAttachmentUpdate(nil, tt.s3pa) + + unmountCalls := countUnmountCalls(fakeMounter) + expectedUnmounts := 0 + if tt.expectUnmount { + expectedUnmounts = 1 + } + assert.Equals(t, expectedUnmounts, unmountCalls) + }) + } +} + +func TestCleanupDanglingMounts(t *testing.T) { + tests := []struct { + name string + pods []*corev1.Pod + s3paItems []crdv1beta.MountpointS3PodAttachment + unmountError error + expectedCalls int + }{ + { + name: "no dangling mounts", + pods: []*corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + }, + }, + s3paItems: []crdv1beta.MountpointS3PodAttachment{ + { + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + "pod1": {"workload1"}, + }, + }, + }, + }, + expectedCalls: 0, + }, + { + name: "with dangling mount", + pods: []*corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + }, + }, + s3paItems: []crdv1beta.MountpointS3PodAttachment{ + { + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + "pod1": {}, + }, + }, + }, + }, + expectedCalls: 1, + }, + { + name: "with dangling mount and unmount error", + pods: []*corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + }, + }, + s3paItems: []crdv1beta.MountpointS3PodAttachment{ + { + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + "pod1": {}, + }, + }, + }, + }, + unmountError: errors.New("unmount error"), + expectedCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podWatcher, client := setupPodWatcher(t, tt.pods...) + kubeletPath := t.TempDir() + t.Setenv("KUBELET_PATH", kubeletPath) + t.Chdir(kubeletPath) + sourceMountDir := t.TempDir() + + for _, pod := range tt.pods { + podPath := filepath.Join(kubeletPath, "pods", string(pod.UID)) + commDir := mppod.PathOnHost(podPath) + err := os.MkdirAll(commDir, 0750) + assert.NoError(t, err) + + err = os.MkdirAll(filepath.Join(sourceMountDir, string(pod.UID)), 0750) + assert.NoError(t, err) + } + + fakeMounter := mount.NewFakeMounter(nil) + if tt.unmountError != nil { + fakeMounter.UnmountFunc = func(path string) error { + return tt.unmountError + } + } + + s3paCache := &mounter.FakeCache{ + TestItems: tt.s3paItems, + } + + credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { + return dummyIMDSRegion, nil + }) + + unmounter := mounter.NewPodUnmounter(nodeName, fakeMounter, podWatcher, s3paCache, credProvider, sourceMountDir) + unmounter.CleanupDanglingMounts() + + unmountCalls := countUnmountCalls(fakeMounter) + assert.Equals(t, tt.expectedCalls, unmountCalls) + }) + } +} From 8fde1c4a4faa0d43d7ff00d2a6f6d08f0bcb6637 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Tue, 29 Apr 2025 17:53:37 +0100 Subject: [PATCH 11/24] Improve PodSharing e2e tests --- .../e2e-kubernetes/testsuites/pod_sharing.go | 116 ++++++++++++++---- 1 file changed, 90 insertions(+), 26 deletions(-) diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go index 25e09a52..73e74394 100644 --- a/tests/e2e-kubernetes/testsuites/pod_sharing.go +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "path/filepath" + "strconv" "strings" "time" @@ -11,6 +12,7 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" @@ -23,7 +25,12 @@ import ( "k8s.io/utils/ptr" ) -var s3paGVR = schema.GroupVersionResource{Group: "s3.csi.aws.com", Version: "v1", Resource: "mountpoints3podattachments"} +var s3paGVR = schema.GroupVersionResource{Group: "s3.csi.aws.com", Version: "v1beta", Resource: "mountpoints3podattachments"} + +const mountpointNamespace = "mount-s3" + +const defaultTimeout = 10 * time.Second +const defaultInterval = 1 * time.Second type s3CSIPodSharingTestSuite struct { tsInfo storageframework.TestSuiteInfo @@ -32,7 +39,7 @@ type s3CSIPodSharingTestSuite struct { func InitS3CSIPodSharingTestSuite() storageframework.TestSuite { return &s3CSIPodSharingTestSuite{ tsInfo: storageframework.TestSuiteInfo{ - Name: "multivolume", + Name: "podsharing", TestPatterns: []storageframework.TestPattern{ storageframework.DefaultFsPreprovisionedPV, }, @@ -56,7 +63,7 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive l local ) - f := framework.NewFrameworkWithCustomTimeouts(NamespacePrefix+"multivolume", storageframework.GetDriverTimeouts(driver)) + f := framework.NewFrameworkWithCustomTimeouts(NamespacePrefix+"podsharing", storageframework.GetDriverTimeouts(driver)) f.NamespacePodSecurityLevel = admissionapi.LevelBaseline cleanup := func(ctx context.Context) { @@ -76,25 +83,35 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) l.resources = append(l.resources, resource) + var s3paNames []string + var mountpointPodNames []string var pods []*v1.Pod - node := l.config.ClientNodeSelection - // Create each pod with pvc + var targetNode string + var nodeSelector map[string]string for i := 0; i < 2; i++ { index := i + 1 - ginkgo.By(fmt.Sprintf("Creating pod%d with a volume on %+v", index, node)) - pod, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, nil, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + + if i > 0 && targetNode != "" { + nodeSelector = map[string]string{"kubernetes.io/hostname": targetNode} + } + + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume", index)) + pod, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, nodeSelector, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") framework.ExpectNoError(err) - // The pod must get deleted before this function returns because the caller may try to - // delete volumes as part of the tests. Keeping the pod running would block that. - // If the test times out, then the namespace deletion will take care of it. - defer func() { - framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) - }() + + if i == 0 { + targetNode = pod.Spec.NodeName + } pods = append(pods, pod) - e2epod.SetAffinity(&node, pod.Spec.NodeName) } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() - verifyPodsShareMountpointPod(ctx, f, pods, resource.Pv) + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, resource.Pv) checkCrossReadWrite(f, pods[0], pods[1]) }) @@ -102,6 +119,8 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) l.resources = append(l.resources, resource) + var s3paNames []string + var mountpointPodNames []string var pods []*v1.Pod var targetNode string for i := 0; i < 2; i++ { @@ -129,16 +148,18 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive if i == 0 { targetNode = pod.Spec.NodeName } - - defer func() { - framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) - }() pods = append(pods, pod) } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() - verifyPodsHaveDifferentMountpointPods(ctx, f, pods, resource.Pv, func(pod *v1.Pod) map[string]string { + s3paNames, mountpointPodNames = verifyPodsHaveDifferentMountpointPods(ctx, f, pods, func(pod *v1.Pod) map[string]string { expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) - expectedFields["WorkloadFSGroup"] = fmt.Sprintf("%d", pod.Spec.SecurityContext.FSGroup) + expectedFields["WorkloadFSGroup"] = strconv.FormatInt(*pod.Spec.SecurityContext.FSGroup, 10) return expectedFields }) checkCrossReadWrite(f, pods[0], pods[1]) @@ -147,7 +168,9 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive // TODO: Add more test cases } -func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume) { +func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume) ([]string, []string) { + var s3paNames []string + var mountpointPodNames []string var s3paList *crdv1beta.MountpointS3PodAttachmentList framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) @@ -160,8 +183,10 @@ func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, p } for _, s3pa := range s3paList.Items { if matchesSpec(s3pa.Spec, defaultExpectedFields(pods[0].Spec.NodeName, pv)) { + s3paNames = append(s3paNames, s3pa.Name) allUIDs := make(map[string]bool) - for _, uids := range s3paList.Items[0].Spec.MountpointS3PodToWorkloadPodUIDs { + for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + mountpointPodNames = append(mountpointPodNames, mpPodName) for _, uid := range uids { allUIDs[uid] = true } @@ -178,11 +203,14 @@ func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, p } return false, err - })).WithTimeout(10 * time.Second).WithPolling(1 * time.Second).Should(gomega.BeTrue()) + })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + return s3paNames, mountpointPodNames } -func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume, expectedFieldsFunc func(pod *v1.Pod) map[string]string) { +func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Framework, pods []*v1.Pod, expectedFieldsFunc func(pod *v1.Pod) map[string]string) ([]string, []string) { + var s3paNames []string + var mountpointPodNames []string var s3paList *crdv1beta.MountpointS3PodAttachmentList framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) @@ -198,6 +226,7 @@ func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Fra for _, s3pa := range s3paList.Items { for _, pod := range pods { if matchesSpec(s3pa.Spec, expectedFieldsFunc(pod)) { + s3paNames = append(s3paNames, s3pa.Name) matchCount++ break } @@ -205,7 +234,7 @@ func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Fra } return matchCount == len(pods), nil - })).WithTimeout(10 * time.Second).WithPolling(1 * time.Second).Should(gomega.BeTrue()) + })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) podToMountpointPod := make(map[string]string) for _, s3pa := range s3paList.Items { @@ -227,9 +256,44 @@ func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Fra framework.Gomega().Expect(alreadySeen).To(gomega.BeFalse()) seenMountpointPods[mpPodName] = true + mountpointPodNames = append(mountpointPodNames, mpPodName) } framework.Gomega().Expect(len(seenMountpointPods)).To(gomega.Equal(len(pods))) + + return s3paNames, mountpointPodNames +} + +func verifyMountpointResourcesCleanup(ctx context.Context, f *framework.Framework, s3paNames []string, mountpointPodNames []string) { + // Verify specific MountpointS3PodAttachments are deleted + framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + for _, s3paName := range s3paNames { + _, err := f.DynamicClient.Resource(s3paGVR).Get(ctx, s3paName, metav1.GetOptions{}) + if err == nil { + // S3PodAttachment still exists + return false, nil + } + if !apierrors.IsNotFound(err) { + return false, err + } + } + return true, nil + })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + + // Verify specific Mountpoint Pods are deleted + framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + for _, mpPodName := range mountpointPodNames { + _, err := f.ClientSet.CoreV1().Pods(mountpointNamespace).Get(ctx, mpPodName, metav1.GetOptions{}) + if err == nil { + // Pod still exists + return false, nil + } + if !apierrors.IsNotFound(err) { + return false, err + } + } + return true, nil + })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) } // Convert UnstructuredList to MountpointS3PodAttachmentList From 87f4edb6e0869487f5ca85ff51ecc7b26a6647d1 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Tue, 29 Apr 2025 20:15:35 +0100 Subject: [PATCH 12/24] Add more e2e tests --- .../e2e-kubernetes/testsuites/pod_sharing.go | 278 ++++++++++++++++-- 1 file changed, 260 insertions(+), 18 deletions(-) diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go index 73e74394..c00a3137 100644 --- a/tests/e2e-kubernetes/testsuites/pod_sharing.go +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -79,7 +79,7 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive ginkgo.DeferCleanup(cleanup) }) - ginkgo.It("should concurrently access the single volume from pods on the same node using the same Mountpoint Pod", func(ctx context.Context) { + ginkgo.It("should share Mountpoint Pod (authenticationSource=driver)", func(ctx context.Context) { resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) l.resources = append(l.resources, resource) @@ -111,11 +111,56 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) }() - s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, resource.Pv) + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, defaultExpectedFields(targetNode, resource.Pv)) checkCrossReadWrite(f, pods[0], pods[1]) }) - ginkgo.It("should concurrently access the single volume from pods on the same node using different Mountpoint Pods if fsGroup is different", func(ctx context.Context) { + ginkgo.It("should share Mountpoint Pod if pods have the same fsGroup", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + for i := 0; i < 2; i++ { + index := i + 1 + podConfig := &e2epod.Config{ + NS: f.Namespace.Name, + PVCs: []*v1.PersistentVolumeClaim{resource.Pvc}, + SecurityLevel: admissionapi.LevelBaseline, + FsGroup: ptr.To(int64(1000)), + } + + if i > 0 && targetNode != "" { + podConfig.NodeSelection = e2epod.NodeSelection{ + Name: targetNode, + } + } + + ginkgo.By(fmt.Sprintf("Creating pod%d", index)) + pod, err := e2epod.CreateSecPod(ctx, f.ClientSet, podConfig, 10*time.Second) + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + expectedFields := defaultExpectedFields(targetNode, resource.Pv) + expectedFields["WorkloadFSGroup"] = "1000" + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, expectedFields) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should not share Mountpoint Pod if pods have different fsGroup", func(ctx context.Context) { resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) l.resources = append(l.resources, resource) @@ -165,10 +210,152 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive checkCrossReadWrite(f, pods[0], pods[1]) }) - // TODO: Add more test cases + ginkgo.It("should not share Mountpoint Pod if mountOptions are different", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + + // First Pod + ginkgo.By("Creating pod1 with a volume") + pod1, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, nil, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + framework.ExpectNoError(err) + targetNode = pod1.Spec.NodeName + + resource.Pv, err = f.ClientSet.CoreV1().PersistentVolumes().Get(ctx, resource.Pv.Name, metav1.GetOptions{}) + framework.ExpectNoError(err) + firstMountOptions := strings.Join(resource.Pv.Spec.MountOptions, ",") + resource.Pv.Spec.MountOptions = []string{"--allow-delete"} + resource.Pv, err = f.ClientSet.CoreV1().PersistentVolumes().Update(ctx, resource.Pv, metav1.UpdateOptions{}) + framework.ExpectNoError(err) + + // Second Pod + pods = append(pods, pod1) + ginkgo.By("Creating pod2 with a volume") + pod2, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, map[string]string{"kubernetes.io/hostname": targetNode}, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + framework.ExpectNoError(err) + pods = append(pods, pod2) + + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + s3paNames, mountpointPodNames = verifyPodsHaveDifferentMountpointPods(ctx, f, pods, func(pod *v1.Pod) map[string]string { + expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) + if pod.Name == pod1.Name { + expectedFields["MountOptions"] = firstMountOptions + } else { + expectedFields["MountOptions"] = "--allow-delete" + } + return expectedFields + }) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should share Mountpoint Pod if pod namespaces and service accounts are the same (authenticationSource=pod)", func(ctx context.Context) { + idConfig, err := setupPodLevelIdentity(ctx, f) + framework.ExpectNoError(err) + defer idConfig.Cleanup(ctx) + resource := createVolumeResourceWithMountOptions(contextWithAuthenticationSource(ctx, "pod"), l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + var nodeSelector map[string]string + for i := 0; i < 2; i++ { + index := i + 1 + + if i > 0 && targetNode != "" { + nodeSelector = map[string]string{"kubernetes.io/hostname": targetNode} + } + + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume", index)) + pod := e2epod.MakePod(f.Namespace.Name, nodeSelector, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + pod.Spec.ServiceAccountName = idConfig.ServiceAccount.Name + pod, err := createPod(ctx, f.ClientSet, f.Namespace.Name, pod) + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + expectedFields := defaultExpectedFields(targetNode, resource.Pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = f.Namespace.Name + expectedFields["WorkloadServiceAccountName"] = idConfig.ServiceAccount.Name + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, expectedFields) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should not share Mountpoint Pod if pod service accounts are the different (authenticationSource=pod)", func(ctx context.Context) { + idConfig1, err := setupPodLevelIdentity(ctx, f) + framework.ExpectNoError(err) + defer idConfig1.Cleanup(ctx) + idConfig2, err := setupPodLevelIdentity(ctx, f) + framework.ExpectNoError(err) + defer idConfig2.Cleanup(ctx) + saNames := []string{idConfig1.ServiceAccount.Name, idConfig2.ServiceAccount.Name} + resource := createVolumeResourceWithMountOptions(contextWithAuthenticationSource(ctx, "pod"), l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + var nodeSelector map[string]string + for i := 0; i < 2; i++ { + index := i + 1 + + if i > 0 && targetNode != "" { + nodeSelector = map[string]string{"kubernetes.io/hostname": targetNode} + } + + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume", index)) + pod := e2epod.MakePod(f.Namespace.Name, nodeSelector, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + pod.Spec.ServiceAccountName = saNames[i] + pod, err := createPod(ctx, f.ClientSet, f.Namespace.Name, pod) + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + s3paNames, mountpointPodNames = verifyPodsHaveDifferentMountpointPods(ctx, f, pods, func(pod *v1.Pod) map[string]string { + expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = f.Namespace.Name + expectedFields["WorkloadServiceAccountName"] = pod.Spec.ServiceAccountName + return expectedFields + }) + checkCrossReadWrite(f, pods[0], pods[1]) + }) } -func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, pv *v1.PersistentVolume) ([]string, []string) { +func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, expectedFields map[string]string) ([]string, []string) { var s3paNames []string var mountpointPodNames []string var s3paList *crdv1beta.MountpointS3PodAttachmentList @@ -182,7 +369,7 @@ func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, p return false, err } for _, s3pa := range s3paList.Items { - if matchesSpec(s3pa.Spec, defaultExpectedFields(pods[0].Spec.NodeName, pv)) { + if matchesSpec(s3pa.Spec, expectedFields) { s3paNames = append(s3paNames, s3pa.Name) allUIDs := make(map[string]bool) for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { @@ -264,36 +451,38 @@ func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Fra return s3paNames, mountpointPodNames } +// TODO: This does not fail for some reason after timeout func verifyMountpointResourcesCleanup(ctx context.Context, f *framework.Framework, s3paNames []string, mountpointPodNames []string) { - // Verify specific MountpointS3PodAttachments are deleted - framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + framework.Logf("Verifying MountpointS3PodAttachments are deleted: %v", s3paNames) + framework.Gomega().Eventually(ctx, func() bool { for _, s3paName := range s3paNames { _, err := f.DynamicClient.Resource(s3paGVR).Get(ctx, s3paName, metav1.GetOptions{}) if err == nil { // S3PodAttachment still exists - return false, nil + return false } if !apierrors.IsNotFound(err) { - return false, err + return false } } - return true, nil - })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + return true + }).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) - // Verify specific Mountpoint Pods are deleted - framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + framework.Logf("Verifying Mountpoint Pods are deleted: %v", mountpointPodNames) + framework.Gomega().Eventually(ctx, func() bool { for _, mpPodName := range mountpointPodNames { _, err := f.ClientSet.CoreV1().Pods(mountpointNamespace).Get(ctx, mpPodName, metav1.GetOptions{}) if err == nil { // Pod still exists - return false, nil + return false } if !apierrors.IsNotFound(err) { - return false, err + framework.Logf("Error checking pod %s: %v", mpPodName, err) + return false } } - return true, nil - })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + return true + }).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) } // Convert UnstructuredList to MountpointS3PodAttachmentList @@ -366,3 +555,56 @@ func checkPodWriteAndOtherPodRead(f *framework.Framework, writerPod, readerPod * checkWriteToPath(f, writerPod, filePath, size, seed) checkReadFromPath(f, readerPod, filePath, size, seed) } + +type PodLevelIdentityConfig struct { + OIDCProvider string + ServiceAccount *v1.ServiceAccount + IAMRole string + Cleanup func(context.Context) error +} + +// setupPodLevelIdentity creates necessary resources for pod-level identity tests +func setupPodLevelIdentity(ctx context.Context, f *framework.Framework) (*PodLevelIdentityConfig, error) { + config := &PodLevelIdentityConfig{} + var cleanupFuncs []func(context.Context) error + + // Get OIDC Provider + config.OIDCProvider = oidcProviderForCluster(ctx, f) + if config.OIDCProvider == "" { + return nil, fmt.Errorf("OIDC provider is not configured") + } + + // Create Service Account + sa, removeSA := createServiceAccount(ctx, f) + config.ServiceAccount = sa + cleanupFuncs = append(cleanupFuncs, removeSA) + + // Create IAM Role with full access policy + role, removeRole := createRole(ctx, f, + assumeRoleWithWebIdentityPolicyDocument(ctx, config.OIDCProvider, sa), + iamPolicyS3FullAccess) + config.IAMRole = *role.Arn + cleanupFuncs = append(cleanupFuncs, removeRole) + + // Annotate Service Account with Role ARN + sa, restoreServiceAccountRole := overrideServiceAccountRole(ctx, f, sa, config.IAMRole) + config.ServiceAccount = sa + cleanupFuncs = append(cleanupFuncs, restoreServiceAccountRole) + + // Wait for role to be assumable + waitUntilRoleIsAssumableWithWebIdentity(ctx, f, sa) + + // Combine cleanup functions + config.Cleanup = func(ctx context.Context) error { + var errs []error + // Execute cleanup functions in reverse order + for i := len(cleanupFuncs) - 1; i >= 0; i-- { + if err := cleanupFuncs[i](ctx); err != nil { + errs = append(errs, err) + } + } + return errors.NewAggregate(errs) + } + + return config, nil +} From 75892bdc0c5dce4bcade394cc4a643fc7192a85a Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:13:44 +0100 Subject: [PATCH 13/24] Add attachment time to CRD --- .../mountpoints3podattachments-crd.yaml | 24 +++++++-- .../csicontroller/reconciler.go | 36 +++++++------ .../v1beta/mountpoints3podattachment_types.go | 13 ++++- pkg/api/v1beta/zz_generated.deepcopy.go | 30 ++++++++--- pkg/driver/node/mounter/pod_mounter.go | 6 +-- pkg/driver/node/mounter/pod_mounter_test.go | 8 +-- pkg/driver/node/mounter/pod_unmounter.go | 4 +- pkg/driver/node/mounter/pod_unmounter_test.go | 14 ++--- tests/controller/controller_test.go | 54 +++++++++---------- tests/crd/mountpoints3podattachments-crd.yaml | 24 +++++++-- .../e2e-kubernetes/testsuites/pod_sharing.go | 12 ++--- 11 files changed, 143 insertions(+), 82 deletions(-) diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml index 31780b01..14c7abac 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml @@ -52,13 +52,27 @@ spec: mountOptions: description: Comma separated mount options taken from volume. type: string - mountpointS3PodToWorkloadPodUIDs: + mountpointS3PodAttachments: additionalProperties: items: - type: string + description: WorkloadAttachment represents the attachment details + of a workload pod to a Mountpoint S3 pod. + properties: + attachmentTime: + description: AttachmentTime represents when the workload pod + was attached to the Mountpoint S3 pod + format: date-time + type: string + workloadPodUID: + description: WorkloadPodUID is the unique identifier of the + attached workload pod + type: string + required: + - attachmentTime + - workloadPodUID + type: object type: array - description: Maps each Mountpoint S3 pod name to the list of workload - pod UIDs it is attached to. + description: Maps each Mountpoint S3 pod name to its workload attachments type: object nodeName: description: Name of the node. @@ -88,7 +102,7 @@ spec: required: - authenticationSource - mountOptions - - mountpointS3PodToWorkloadPodUIDs + - mountpointS3PodAttachments - nodeName - persistentVolumeName - volumeID diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index 7faaae45..f95dc369 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" "strings" + "time" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -332,8 +333,11 @@ func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *cr func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { log.Info("Adding workload UID to MountpointS3PodAttachment") - for key := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[key] = append(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[key], workloadUID) + for key := range s3pa.Spec.MountpointS3PodAttachments { + s3pa.Spec.MountpointS3PodAttachments[key] = append(s3pa.Spec.MountpointS3PodAttachments[key], crdv1beta.WorkloadAttachment{ + WorkloadPodUID: workloadUID, + AttachmentTime: metav1.NewTime(time.Now().UTC()), + }) break } @@ -354,18 +358,18 @@ func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crd // It will delete MountpointS3PodAttachment if map becomes empty. func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { // Remove workload UID from mountpoint pods - for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - filteredUIDs := []string{} + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + filteredUIDs := []crdv1beta.WorkloadAttachment{} found := false - for _, uid := range uids { - if uid == workloadUID { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == workloadUID { found = true continue } - filteredUIDs = append(filteredUIDs, uid) + filteredUIDs = append(filteredUIDs, attachment) } if found { - s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[mpPodName] = filteredUIDs + s3pa.Spec.MountpointS3PodAttachments[mpPodName] = filteredUIDs err := r.Update(ctx, s3pa) if err != nil { if apierrors.IsConflict(err) { @@ -381,11 +385,11 @@ func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa } // Remove Mountpoint pods with zero workloads - for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for mpPodName, uids := range s3pa.Spec.MountpointS3PodAttachments { if len(uids) == 0 { log.Info("Mountpoint pod has zero workload UIDs. Will remove it from MountpointS3PodAttachment", "mountpointPodName", mpPodName) - delete(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs, mpPodName) + delete(s3pa.Spec.MountpointS3PodAttachments, mpPodName) err := r.Update(ctx, s3pa) if err != nil { if apierrors.IsConflict(err) { @@ -400,7 +404,7 @@ func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa } // Delete MountpointS3PodAttachment if map is empty - if len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs) == 0 { + if len(s3pa.Spec.MountpointS3PodAttachments) == 0 { log.Info("MountpointS3PodAttachment has zero Mountpoint Pods. Will delete it") err := r.Delete(ctx, s3pa) if err != nil { @@ -466,8 +470,8 @@ func (r *Reconciler) createS3PodAttachmentWithMPPod( MountOptions: strings.Join(pv.Spec.MountOptions, ","), WorkloadFSGroup: r.getFSGroup(workloadPod), AuthenticationSource: authSource, - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ - mpPod.Name: {string(workloadPod.UID)}, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + mpPod.Name: {{WorkloadPodUID: string(workloadPod.UID), AttachmentTime: metav1.NewTime(time.Now().UTC())}}, }, }, } @@ -623,9 +627,9 @@ func isPodActive(p *corev1.Pod) bool { } func s3paContainsWorkload(s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string) bool { - for _, workloads := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - for _, workload := range workloads { - if workload == workloadUID { + for _, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == workloadUID { return true } } diff --git a/pkg/api/v1beta/mountpoints3podattachment_types.go b/pkg/api/v1beta/mountpoints3podattachment_types.go index 83ed93c1..6e4a4e46 100644 --- a/pkg/api/v1beta/mountpoints3podattachment_types.go +++ b/pkg/api/v1beta/mountpoints3podattachment_types.go @@ -48,8 +48,17 @@ type MountpointS3PodAttachmentSpec struct { // EKS IAM Role ARN from workload pod's service account annotation (IRSA). Exists only if `authenticationSource: pod` and service account has `eks.amazonaws.com/role-arn` annotation. WorkloadServiceAccountIAMRoleARN string `json:"workloadServiceAccountIAMRoleARN,omitempty"` - // Maps each Mountpoint S3 pod name to the list of workload pod UIDs it is attached to. - MountpointS3PodToWorkloadPodUIDs map[string][]string `json:"mountpointS3PodToWorkloadPodUIDs"` + // Maps each Mountpoint S3 pod name to its workload attachments + MountpointS3PodAttachments map[string][]WorkloadAttachment `json:"mountpointS3PodAttachments"` +} + +// WorkloadAttachment represents the attachment details of a workload pod to a Mountpoint S3 pod. +type WorkloadAttachment struct { + // WorkloadPodUID is the unique identifier of the attached workload pod + WorkloadPodUID string `json:"workloadPodUID"` + + // AttachmentTime represents when the workload pod was attached to the Mountpoint S3 pod + AttachmentTime metav1.Time `json:"attachmentTime"` } // +kubebuilder:object:root=true diff --git a/pkg/api/v1beta/zz_generated.deepcopy.go b/pkg/api/v1beta/zz_generated.deepcopy.go index 93c5d922..2e79df80 100644 --- a/pkg/api/v1beta/zz_generated.deepcopy.go +++ b/pkg/api/v1beta/zz_generated.deepcopy.go @@ -69,18 +69,20 @@ func (in *MountpointS3PodAttachmentList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MountpointS3PodAttachmentSpec) DeepCopyInto(out *MountpointS3PodAttachmentSpec) { *out = *in - if in.MountpointS3PodToWorkloadPodUIDs != nil { - in, out := &in.MountpointS3PodToWorkloadPodUIDs, &out.MountpointS3PodToWorkloadPodUIDs - *out = make(map[string][]string, len(*in)) + if in.MountpointS3PodAttachments != nil { + in, out := &in.MountpointS3PodAttachments, &out.MountpointS3PodAttachments + *out = make(map[string][]WorkloadAttachment, len(*in)) for key, val := range *in { - var outVal []string + var outVal []WorkloadAttachment if val == nil { (*out)[key] = nil } else { inVal := (*in)[key] in, out := &inVal, &outVal - *out = make([]string, len(*in)) - copy(*out, *in) + *out = make([]WorkloadAttachment, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } } (*out)[key] = outVal } @@ -96,3 +98,19 @@ func (in *MountpointS3PodAttachmentSpec) DeepCopy() *MountpointS3PodAttachmentSp in.DeepCopyInto(out) return out } + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkloadAttachment) DeepCopyInto(out *WorkloadAttachment) { + *out = *in + in.AttachmentTime.DeepCopyInto(&out.AttachmentTime) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkloadAttachment. +func (in *WorkloadAttachment) DeepCopy() *WorkloadAttachment { + if in == nil { + return nil + } + out := new(WorkloadAttachment) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 98b6480f..859059f2 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -485,9 +485,9 @@ func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeNam return nil, "", err } for _, s3pa := range s3paList.Items { - for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - for _, uid := range uids { - if uid == credentialCtx.PodID { + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == credentialCtx.PodID { return &s3pa, mpPodName, nil } } diff --git a/pkg/driver/node/mounter/pod_mounter_test.go b/pkg/driver/node/mounter/pod_mounter_test.go index 4af06c56..43e9b4d5 100644 --- a/pkg/driver/node/mounter/pod_mounter_test.go +++ b/pkg/driver/node/mounter/pod_mounter_test.go @@ -132,8 +132,8 @@ func setup(t *testing.T) *testCtx { VolumeID: testCtx.volumeID, WorkloadFSGroup: testCtx.fsGroup, MountOptions: testCtx.pvMountOptions, - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ - testCtx.mpPodName: {testCtx.podUID}, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + testCtx.mpPodName: []crdv1beta.WorkloadAttachment{{WorkloadPodUID: testCtx.podUID}}, }, }, } @@ -397,8 +397,8 @@ func TestPodMounter(t *testing.T) { VolumeID: testCtx.volumeID, WorkloadFSGroup: testCtx.fsGroup, MountOptions: testCtx.pvMountOptions, - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ - testCtx.mpPodName: {testCtx.podUID}, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + testCtx.mpPodName: []crdv1beta.WorkloadAttachment{{WorkloadPodUID: testCtx.podUID}}, }, }, } diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go index bdaec293..a53666a8 100644 --- a/pkg/driver/node/mounter/pod_unmounter.go +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -52,7 +52,7 @@ func (u *PodUnmounter) HandleS3PodAttachmentUpdate(old, new any) { return } - for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for mpPodName, uids := range s3pa.Spec.MountpointS3PodAttachments { if len(uids) == 0 { u.unmountSourceForPod(s3pa, mpPodName) } @@ -192,7 +192,7 @@ func (u *PodUnmounter) checkForWorkloads(mpPod *corev1.Pod) (bool, error) { // Find attachment for this pod and check if it has workloads for _, s3pa := range s3paList.Items { - for mpPodName, workloadUIDs := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for mpPodName, workloadUIDs := range s3pa.Spec.MountpointS3PodAttachments { if mpPodName == mpPod.Name { return len(workloadUIDs) > 0, nil } diff --git a/pkg/driver/node/mounter/pod_unmounter_test.go b/pkg/driver/node/mounter/pod_unmounter_test.go index 974fdea1..8e56cffd 100644 --- a/pkg/driver/node/mounter/pod_unmounter_test.go +++ b/pkg/driver/node/mounter/pod_unmounter_test.go @@ -87,7 +87,7 @@ func TestHandleS3PodAttachmentUpdate(t *testing.T) { s3pa: &crdv1beta.MountpointS3PodAttachment{ Spec: crdv1beta.MountpointS3PodAttachmentSpec{ NodeName: nodeName, - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ "pod1": {}, }, }, @@ -107,7 +107,7 @@ func TestHandleS3PodAttachmentUpdate(t *testing.T) { s3pa: &crdv1beta.MountpointS3PodAttachment{ Spec: crdv1beta.MountpointS3PodAttachmentSpec{ NodeName: nodeName, - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ "pod1": {}, }, }, @@ -184,8 +184,10 @@ func TestCleanupDanglingMounts(t *testing.T) { s3paItems: []crdv1beta.MountpointS3PodAttachment{ { Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ - "pod1": {"workload1"}, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + "pod1": []crdv1beta.WorkloadAttachment{crdv1beta.WorkloadAttachment{ + WorkloadPodUID: "workload1", + }}, }, }, }, @@ -206,7 +208,7 @@ func TestCleanupDanglingMounts(t *testing.T) { s3paItems: []crdv1beta.MountpointS3PodAttachment{ { Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ "pod1": {}, }, }, @@ -228,7 +230,7 @@ func TestCleanupDanglingMounts(t *testing.T) { s3paItems: []crdv1beta.MountpointS3PodAttachment{ { Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - MountpointS3PodToWorkloadPodUIDs: map[string][]string{ + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ "pod1": {}, }, }, diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 2b28f89b..3a008d5b 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -370,8 +370,8 @@ var _ = Describe("Mountpoint Controller", func() { s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) @@ -397,9 +397,9 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadFSGroup"] = "2222" s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa3.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) @@ -441,8 +441,8 @@ var _ = Describe("Mountpoint Controller", func() { s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) @@ -478,8 +478,8 @@ var _ = Describe("Mountpoint Controller", func() { s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) @@ -524,9 +524,9 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role-2" s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa3.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) @@ -578,9 +578,9 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadServiceAccountName"] = sa2.Name s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) - Expect(len(s3pa3.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) @@ -644,7 +644,7 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadNamespace"] = defaultNamespace expectedFields["WorkloadServiceAccountName"] = sa1.Name s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa1.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) pod2 := createPod(withPVC(pvc2), withServiceAccount(sa1.Name), withNamespace(ns.Name)) @@ -653,7 +653,7 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadNamespace"] = ns.Name s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") - Expect(len(s3pa2.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") @@ -1066,11 +1066,11 @@ func expectNoS3PodAttachmentWithFields(expectedFields map[string]string) { }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) } -// expectNoPodUIDInS3PodAttachment validates that pod UID does not exist in MountpointS3PodToWorkloadPodUIDs map +// expectNoPodUIDInS3PodAttachment validates that pod UID does not exist in MountpointS3PodAttachments map func expectNoPodUIDInS3PodAttachment(s3pa *crdv1beta.MountpointS3PodAttachment, podUID string) { - for _, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - for _, uid := range uids { - if uid == podUID { + for _, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == podUID { Expect(false).To(BeTrue(), "Found pod UID %s in S3PodAttachment when none was expected: %#v", podUID, s3pa) } } @@ -1085,7 +1085,7 @@ func waitAndVerifyS3PodAttachmentAndMountpointPod( pod *testPod, ) (*crdv1beta.MountpointS3PodAttachment, *testPod) { s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(node, vol.pv), "") - Expect(len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) return s3pa, mpPod } @@ -1099,7 +1099,7 @@ func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion( minVersion string, ) (*crdv1beta.MountpointS3PodAttachment, *testPod) { s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv), minVersion) - Expect(len(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs)).To(Equal(1)) + Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) return s3pa, mpPod } @@ -1110,9 +1110,9 @@ func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1beta.MountpointS3Pod var mpPodName string podUID := string(pod.UID) - for k, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - for _, uid := range uids { - if uid == podUID { + for k, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == podUID { mpPodName = k break } @@ -1123,7 +1123,7 @@ func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1beta.MountpointS3Pod } Expect(mpPodName).NotTo(BeEmpty(), "No Mountpoint Pod found for pod UID %s in MountpointS3PodAttachment: %#v", podUID, s3pa) - Expect(s3pa.Spec.MountpointS3PodToWorkloadPodUIDs[mpPodName]).To(ContainElement(podUID)) + Expect(s3pa.Spec.MountpointS3PodAttachments[mpPodName]).To(ContainElement(podUID)) mountpointPod := waitForMountpointPodWithName(mpPodName) verifyMountpointPodFor(pod, vol, mountpointPod) diff --git a/tests/crd/mountpoints3podattachments-crd.yaml b/tests/crd/mountpoints3podattachments-crd.yaml index 91fcde59..7da946f3 100644 --- a/tests/crd/mountpoints3podattachments-crd.yaml +++ b/tests/crd/mountpoints3podattachments-crd.yaml @@ -51,13 +51,27 @@ spec: mountOptions: description: Comma separated mount options taken from volume. type: string - mountpointS3PodToWorkloadPodUIDs: + mountpointS3PodAttachments: additionalProperties: items: - type: string + description: WorkloadAttachment represents the attachment details + of a workload pod to a Mountpoint S3 pod. + properties: + attachmentTime: + description: AttachmentTime represents when the workload pod + was attached to the Mountpoint S3 pod + format: date-time + type: string + workloadPodUID: + description: WorkloadPodUID is the unique identifier of the + attached workload pod + type: string + required: + - attachmentTime + - workloadPodUID + type: object type: array - description: Maps each Mountpoint S3 pod name to the list of workload - pod UIDs it is attached to. + description: Maps each Mountpoint S3 pod name to its workload attachments type: object nodeName: description: Name of the node. @@ -87,7 +101,7 @@ spec: required: - authenticationSource - mountOptions - - mountpointS3PodToWorkloadPodUIDs + - mountpointS3PodAttachments - nodeName - persistentVolumeName - volumeID diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go index c00a3137..09df3fbe 100644 --- a/tests/e2e-kubernetes/testsuites/pod_sharing.go +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -372,10 +372,10 @@ func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, p if matchesSpec(s3pa.Spec, expectedFields) { s3paNames = append(s3paNames, s3pa.Name) allUIDs := make(map[string]bool) - for mpPodName, uids := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { mountpointPodNames = append(mountpointPodNames, mpPodName) - for _, uid := range uids { - allUIDs[uid] = true + for _, attachment := range attachments { + allUIDs[attachment.WorkloadPodUID] = true } } for _, pod := range pods { @@ -425,9 +425,9 @@ func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Fra podToMountpointPod := make(map[string]string) for _, s3pa := range s3paList.Items { - for mpPodName, workloadPodUIDs := range s3pa.Spec.MountpointS3PodToWorkloadPodUIDs { - for _, uid := range workloadPodUIDs { - podToMountpointPod[uid] = mpPodName + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + podToMountpointPod[attachment.WorkloadPodUID] = mpPodName } } } From d192ded2c81c171c3acfb14afecbb780804874f4 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Wed, 30 Apr 2025 17:01:29 +0100 Subject: [PATCH 14/24] Add StaleAttachmentCleaner --- .../csicontroller/stale_attachment_cleaner.go | 159 ++++++++++++++++++ cmd/aws-s3-csi-controller/main.go | 12 +- 2 files changed, 168 insertions(+), 3 deletions(-) create mode 100644 cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go diff --git a/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go new file mode 100644 index 00000000..8f3017f0 --- /dev/null +++ b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go @@ -0,0 +1,159 @@ +package csicontroller + +import ( + "context" + "sync" + "time" + + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + logf "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + cleanupInterval = 10 * time.Second + staleAttachmentThreshold = 10 * time.Second +) + +// StaleAttachmentCleaner handles periodic cleanup of stale workload attachments in case reconciler missed pod deletion event. +type StaleAttachmentCleaner struct { + reconciler *Reconciler + mutex sync.Mutex + stopCh chan struct{} +} + +// NewStaleAttachmentCleaner creates a new StaleAttachmentCleaner +func NewStaleAttachmentCleaner(reconciler *Reconciler) *StaleAttachmentCleaner { + return &StaleAttachmentCleaner{ + reconciler: reconciler, + stopCh: make(chan struct{}), + } +} + +// Start begins the periodic cleanup process +func (cm *StaleAttachmentCleaner) Start(ctx context.Context) error { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-cm.stopCh: + return nil + case <-ticker.C: + if err := cm.runCleanup(ctx); err != nil { + log := logf.FromContext(ctx) + log.Error(err, "Failed to run cleanup") + } + } + } +} + +// runCleanup performs cleanup operation +func (cm *StaleAttachmentCleaner) runCleanup(ctx context.Context) error { + // Ensure only one cleanup runs at a time + if !cm.mutex.TryLock() { + return nil + } + defer cm.mutex.Unlock() + + log := logf.FromContext(ctx) + + // Get all pods in the cluster + podList := &corev1.PodList{} + if err := cm.reconciler.List(ctx, podList); err != nil { + return err + } + + // Create a map of existing pod UIDs for quick lookup + existingPods := make(map[string]struct{}) + for _, pod := range podList.Items { + existingPods[string(pod.UID)] = struct{}{} + } + + // Get all MountpointS3PodAttachments + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} + if err := cm.reconciler.List(ctx, s3paList); err != nil { + return err + } + + // Check each S3PodAttachment for stale workload references + for _, s3pa := range s3paList.Items { + if err := cm.cleanupStaleWorkloads(ctx, &s3pa, existingPods); err != nil { + log.Error(err, "Error cleaning up S3PodAttachment", "s3pa", s3pa.Name) + continue + } + } + + return nil +} + +// cleanupStaleWorkloads removes stale workload references from a single S3PodAttachment. +// A workload reference is considered stale if the referenced Pod no longer exists in the cluster +// and the attachment is older than staleAttachmentThreshold (this is to avoid race condition with reconciler). +// If a Mountpoint Pod has zero attachments after cleanup, both the Pod and its entry in S3PodAttachment are deleted. +// If S3PodAttachment has no remaining Mountpoint Pods, the entire S3PodAttachment is deleted. +func (cm *StaleAttachmentCleaner) cleanupStaleWorkloads(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, existingPods map[string]struct{}) error { + log := logf.FromContext(ctx).WithValues("s3pa", s3pa.Name) + modified := false + + now := time.Now().UTC() + + // Check each mountpoint pod's attachments + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + var validAttachments []crdv1beta.WorkloadAttachment + + for _, attachment := range attachments { + // Check if pod exists and attachment is not too new + _, exists := existingPods[attachment.WorkloadPodUID] + isStale := now.Sub(attachment.AttachmentTime.Time) > staleAttachmentThreshold + + if exists || !isStale { + validAttachments = append(validAttachments, attachment) + } else { + modified = true + log.Info("Removing stale workload reference", + "workloadUID", attachment.WorkloadPodUID, + "mountpointPod", mpPodName, + "attachmentAge", now.Sub(attachment.AttachmentTime.Time)) + } + } + + if len(validAttachments) == 0 { + // Delete the Mountpoint Pod + mpPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: mpPodName, + Namespace: cm.reconciler.mountpointPodConfig.Namespace, + }, + } + if err := cm.reconciler.Delete(ctx, mpPod); err != nil { + if !apierrors.IsNotFound(err) { + log.Error(err, "Failed to delete Mountpoint Pod", "mountpointPod", mpPodName) + return err + } + // If pod is not found, that's fine - continue with removing it from s3pa + log.Info("Mountpoint Pod does not exist", "mountpointPod", mpPodName) + } else { + log.Info("Deleted Mountpoint Pod with no attachments", "mountpointPod", mpPodName) + } + + delete(s3pa.Spec.MountpointS3PodAttachments, mpPodName) + } else { + s3pa.Spec.MountpointS3PodAttachments[mpPodName] = validAttachments + } + } + + // Update the S3PodAttachment if modified + if modified { + if len(s3pa.Spec.MountpointS3PodAttachments) == 0 { + cm.reconciler.Delete(ctx, s3pa) + } + return cm.reconciler.Update(ctx, s3pa) + } + + return nil +} diff --git a/cmd/aws-s3-csi-controller/main.go b/cmd/aws-s3-csi-controller/main.go index eb0c5c3c..6b270ac1 100644 --- a/cmd/aws-s3-csi-controller/main.go +++ b/cmd/aws-s3-csi-controller/main.go @@ -65,7 +65,7 @@ func main() { IndexMountpointS3PodAttachmentFields(log, mgr) - err = csicontroller.NewReconciler(mgr.GetClient(), mppod.Config{ + reconciler := csicontroller.NewReconciler(mgr.GetClient(), mppod.Config{ Namespace: *mountpointNamespace, MountpointVersion: *mountpointVersion, PriorityClassName: *mountpointPriorityClassName, @@ -76,12 +76,18 @@ func main() { }, CSIDriverVersion: version.GetVersion().DriverVersion, ClusterVariant: cluster.DetectVariant(conf, log), - }).SetupWithManager(mgr) - if err != nil { + }) + + if err := reconciler.SetupWithManager(mgr); err != nil { log.Error(err, "Failed to create controller") os.Exit(1) } + if err := mgr.Add(csicontroller.NewStaleAttachmentCleaner(reconciler)); err != nil { + log.Error(err, "Failed to add stale attachment cleaner to manager") + os.Exit(1) + } + if err := mgr.Start(signals.SetupSignalHandler()); err != nil { log.Error(err, "Failed to start manager") os.Exit(1) From 84b88cfbcc19c2607615c2df403ff33384e3346a Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Wed, 30 Apr 2025 17:38:52 +0100 Subject: [PATCH 15/24] Fix controller test after changing CRD --- tests/controller/controller_test.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 3a008d5b..13149a54 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + . "github.com/onsi/gomega/gstruct" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" @@ -1123,7 +1124,12 @@ func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1beta.MountpointS3Pod } Expect(mpPodName).NotTo(BeEmpty(), "No Mountpoint Pod found for pod UID %s in MountpointS3PodAttachment: %#v", podUID, s3pa) - Expect(s3pa.Spec.MountpointS3PodAttachments[mpPodName]).To(ContainElement(podUID)) + Expect(s3pa.Spec.MountpointS3PodAttachments[mpPodName]).To(ContainElement( + MatchFields(IgnoreExtras, Fields{ + "WorkloadPodUID": Equal(podUID), + "AttachmentTime": Not(BeZero()), + }), + )) mountpointPod := waitForMountpointPodWithName(mpPodName) verifyMountpointPodFor(pod, vol, mountpointPod) From a5887d0fb97001bdb2ffaf47b6af340e9635bec2 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Wed, 30 Apr 2025 21:29:36 +0100 Subject: [PATCH 16/24] Fix CI tests --- tests/controller/controller_test.go | 72 ++++++++++++------- tests/e2e-kubernetes/e2e_test.go | 2 + tests/e2e-kubernetes/scripts/run.sh | 9 ++- tests/e2e-kubernetes/testdriver.go | 1 + .../e2e-kubernetes/testsuites/pod_sharing.go | 6 ++ 5 files changed, 62 insertions(+), 28 deletions(-) diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 13149a54..01867c31 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -364,17 +364,16 @@ var _ = Describe("Mountpoint Controller", func() { pod1 := createPod(withPVC(vol.pvc), withFSGroup(1111)) pod2 := createPod(withPVC(vol.pvc), withFSGroup(1111)) pod1.schedule(testNode) - pod2.schedule(testNode) expectedFields := defaultExpectedFields(testNode, vol.pv) expectedFields["WorkloadFSGroup"] = "1111" - s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") - s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(testNode, vol, pod1, expectedFields) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) - Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) - mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) - mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + pod2.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod1, s3pa.ResourceVersion, expectedFields) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod2, s3pa.ResourceVersion, expectedFields) Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") @@ -433,19 +432,18 @@ var _ = Describe("Mountpoint Controller", func() { pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) pod1.schedule(testNode) - pod2.schedule(testNode) expectedFields := defaultExpectedFields(testNode, vol.pv) expectedFields["AuthenticationSource"] = "pod" expectedFields["WorkloadServiceAccountName"] = sa.Name expectedFields["WorkloadNamespace"] = defaultNamespace - s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") - s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(testNode, vol, pod1, expectedFields) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) - Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) - mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) - mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + pod2.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod1, s3pa.ResourceVersion, expectedFields) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod2, s3pa.ResourceVersion, expectedFields) Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") @@ -469,20 +467,19 @@ var _ = Describe("Mountpoint Controller", func() { pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) pod1.schedule(testNode) - pod2.schedule(testNode) expectedFields := defaultExpectedFields(testNode, vol.pv) expectedFields["AuthenticationSource"] = "pod" expectedFields["WorkloadServiceAccountName"] = sa.Name expectedFields["WorkloadNamespace"] = defaultNamespace expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role" - s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") - s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(testNode, vol, pod1, expectedFields) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) - Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) - Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) - mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) - mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + pod2.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod1, s3pa.ResourceVersion, expectedFields) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod2, s3pa.ResourceVersion, expectedFields) Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") @@ -1078,6 +1075,20 @@ func expectNoPodUIDInS3PodAttachment(s3pa *crdv1beta.MountpointS3PodAttachment, } } +// waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields waits and verifies that MountpointS3PodAttachment and Mountpoint Pod +// are created for given `node`, `vol`, `pod` and `expectedFields` +func waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields( + node string, + vol *testVolume, + pod *testPod, + expectedFields map[string]string, +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { + s3pa := waitForS3PodAttachmentWithFields(expectedFields, "") + Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) + return s3pa, mpPod +} + // waitAndVerifyS3PodAttachmentAndMountpointPod waits and verifies that MountpointS3PodAttachment and Mountpoint Pod // are created for given `node`, `vol` and `pod` func waitAndVerifyS3PodAttachmentAndMountpointPod( @@ -1085,7 +1096,19 @@ func waitAndVerifyS3PodAttachmentAndMountpointPod( vol *testVolume, pod *testPod, ) (*crdv1beta.MountpointS3PodAttachment, *testPod) { - s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(node, vol.pv), "") + return waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(node, vol, pod, defaultExpectedFields(node, vol.pv)) +} + +// waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField waits and verifies that MountpointS3PodAttachment with `minVersion` and Mountpoint Pod +// are created for given `node`, `vol`, `pod` and `expectedFields` +func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField( + testNode string, + vol *testVolume, + pod *testPod, + minVersion string, + expectedFields map[string]string, +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { + s3pa := waitForS3PodAttachmentWithFields(expectedFields, minVersion) Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) return s3pa, mpPod @@ -1099,10 +1122,7 @@ func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion( pod *testPod, minVersion string, ) (*crdv1beta.MountpointS3PodAttachment, *testPod) { - s3pa := waitForS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv), minVersion) - Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) - mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) - return s3pa, mpPod + return waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod, minVersion, defaultExpectedFields(testNode, vol.pv)) } // waitAndVerifyMountpointPodFromPodAttachment waits and verifies Mountpoint Pod scheduled for given `s3pa`, `pod` and `vol.` diff --git a/tests/e2e-kubernetes/e2e_test.go b/tests/e2e-kubernetes/e2e_test.go index 9ee0f99e..d584fc6d 100644 --- a/tests/e2e-kubernetes/e2e_test.go +++ b/tests/e2e-kubernetes/e2e_test.go @@ -26,11 +26,13 @@ func init() { flag.StringVar(&BucketPrefix, "bucket-prefix", "local", "prefix for temporary buckets") flag.BoolVar(&Performance, "performance", false, "run performance tests") flag.BoolVar(&IMDSAvailable, "imds-available", false, "indicates whether instance metadata service is available") + flag.BoolVar(&IsPodMounter, "pod-mounter", false, "indicates whether CSI Driver is installed with Pod Mounter or not") flag.Parse() s3client.DefaultRegion = BucketRegion custom_testsuites.DefaultRegion = BucketRegion custom_testsuites.IMDSAvailable = IMDSAvailable + custom_testsuites.IsPodMounter = IsPodMounter } func TestE2E(t *testing.T) { diff --git a/tests/e2e-kubernetes/scripts/run.sh b/tests/e2e-kubernetes/scripts/run.sh index fe025e53..e454c9a0 100755 --- a/tests/e2e-kubernetes/scripts/run.sh +++ b/tests/e2e-kubernetes/scripts/run.sh @@ -86,6 +86,11 @@ fi CI_ROLE_ARN=${CI_ROLE_ARN:-""} MOUNTER_KIND=${MOUNTER_KIND:-systemd} +if [ "$MOUNTER_KIND" = "pod" ]; then + USE_POD_MOUNTER=true +else + USE_POD_MOUNTER=false +fi mkdir -p ${TEST_DIR} mkdir -p ${BIN_DIR} @@ -218,14 +223,14 @@ elif [[ "${ACTION}" == "install_driver" ]]; then elif [[ "${ACTION}" == "run_tests" ]]; then set +e pushd tests/e2e-kubernetes - KUBECONFIG=${KUBECONFIG} ginkgo -p -vv -timeout 60m -- --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --imds-available=true + KUBECONFIG=${KUBECONFIG} ginkgo -p -vv -timeout 60m -- --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --imds-available=true --pod-mounter=${USE_POD_MOUNTER} EXIT_CODE=$? print_cluster_info exit $EXIT_CODE elif [[ "${ACTION}" == "run_perf" ]]; then set +e pushd tests/e2e-kubernetes - KUBECONFIG=${KUBECONFIG} go test -ginkgo.vv --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --performance=true --imds-available=true + KUBECONFIG=${KUBECONFIG} go test -ginkgo.vv --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --performance=true --imds-available=true --pod-mounter=${USE_POD_MOUNTER} EXIT_CODE=$? print_cluster_info popd diff --git a/tests/e2e-kubernetes/testdriver.go b/tests/e2e-kubernetes/testdriver.go index b0ee62d0..c4ef4456 100644 --- a/tests/e2e-kubernetes/testdriver.go +++ b/tests/e2e-kubernetes/testdriver.go @@ -19,6 +19,7 @@ var ( BucketPrefix string Performance bool IMDSAvailable bool + IsPodMounter bool ) type s3Driver struct { diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go index 09df3fbe..a5132c17 100644 --- a/tests/e2e-kubernetes/testsuites/pod_sharing.go +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -32,6 +32,8 @@ const mountpointNamespace = "mount-s3" const defaultTimeout = 10 * time.Second const defaultInterval = 1 * time.Second +var IsPodMounter bool + type s3CSIPodSharingTestSuite struct { tsInfo storageframework.TestSuiteInfo } @@ -74,6 +76,10 @@ func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDrive framework.ExpectNoError(errors.NewAggregate(errs), "while cleanup resource") } ginkgo.BeforeEach(func(ctx context.Context) { + if !IsPodMounter { + ginkgo.Skip("Pod Mounter is not enabled, skipping pod sharing tests") + } + l = local{} l.config = driver.PrepareTest(ctx, f) ginkgo.DeferCleanup(cleanup) From b6bd5ce05b4fee75eb87abf95736bda37b3d5410 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Wed, 30 Apr 2025 21:57:47 +0100 Subject: [PATCH 17/24] Fix MountOptions controller test --- tests/controller/controller_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 01867c31..af59e25e 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/google/uuid" . "github.com/onsi/ginkgo/v2" @@ -342,8 +343,11 @@ var _ = Describe("Mountpoint Controller", func() { s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) - pv.Spec.MountOptions = []string{"--allow-delete"} + // Adding some sleep time before updating PV because reconciler requeues pod1 event to clear expectation + // and it can cause transient test failure if we update PV MountOptions too quickly + time.Sleep(5 * time.Second) + pv.Spec.MountOptions = []string{"--allow-delete"} Expect(k8sClient.Update(ctx, pv)).To(Succeed()) pod2 := createPod(withPVC(vol.pvc)) From 183e517d7fb7212ab149ba05aca22482168beca8 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Thu, 1 May 2025 13:59:41 +0100 Subject: [PATCH 18/24] Add needs-unmount annotation logic --- .../serviceaccount-csi-controller.yaml | 2 +- .../csicontroller/reconciler.go | 39 +++++- .../csicontroller/stale_attachment_cleaner.go | 29 +--- pkg/driver/driver.go | 12 +- pkg/driver/node/mounter/pod_unmounter.go | 124 ++++++++---------- pkg/driver/node/mounter/pod_unmounter_test.go | 106 +++++++-------- pkg/podmounter/mppod/creator.go | 6 + pkg/podmounter/mppod/creator_test.go | 10 ++ pkg/podmounter/mppod/watcher/watcher.go | 5 + 9 files changed, 173 insertions(+), 160 deletions(-) diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml index 529ee41e..8d972b36 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml @@ -23,7 +23,7 @@ metadata: rules: - apiGroups: [""] resources: ["pods"] - verbs: ["get", "create", "watch", "delete", "list"] + verbs: ["get", "create", "watch", "delete", "list", "update"] --- kind: RoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index f95dc369..b6d23d66 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -387,10 +387,17 @@ func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa // Remove Mountpoint pods with zero workloads for mpPodName, uids := range s3pa.Spec.MountpointS3PodAttachments { if len(uids) == 0 { + log.Info("Mountpoint pod has zero workload UIDs. Adding "+mppod.AnnotationNeedsUnmount+" annotation", + "mountpointPodName", mpPodName) + err := r.addNeedsUnmountAnnotation(ctx, mpPodName, log) + if err != nil { + return Requeue, err + } + log.Info("Mountpoint pod has zero workload UIDs. Will remove it from MountpointS3PodAttachment", "mountpointPodName", mpPodName) delete(s3pa.Spec.MountpointS3PodAttachments, mpPodName) - err := r.Update(ctx, s3pa) + err = r.Update(ctx, s3pa) if err != nil { if apierrors.IsConflict(err) { log.Info("Failed to remove Mountpoint pod from MountpointS3PodAttachment due to resource conflict, requeueing", @@ -601,6 +608,36 @@ func (r *Reconciler) findIRSAServiceAccountRole(ctx context.Context, pod *corev1 return sa.Annotations[AnnotationServiceAccountRole], nil } +// addNeedsUnmountAnnotation add "s3.csi.aws.com/needs-unmount" to Mountpoint Pod. +// This will trigger CSI Driver Node to cleanly unmount and Mountpoint Pod will become 'Succeeded'. +func (r *Reconciler) addNeedsUnmountAnnotation(ctx context.Context, mpPodName string, log logr.Logger) error { + // Get the pod + mpPod := &corev1.Pod{} + err := r.Get(ctx, types.NamespacedName{Namespace: r.mountpointPodConfig.Namespace, Name: mpPodName}, mpPod) + if err != nil { + if apierrors.IsNotFound(err) { + log.Info("Failed to find Mountpoint Pod - ignoring") + return nil + } + log.Error(err, "Failed to get Pod") + return err + } + + if mpPod.Annotations == nil { + mpPod.Annotations = make(map[string]string) + } + mpPod.Annotations[mppod.AnnotationNeedsUnmount] = "true" + + // Update the pod + err = r.Update(ctx, mpPod) + if err != nil { + log.Error(err, "Failed to update Mountpoint Pod") + return err + } + + return nil +} + // isMountpointPod returns whether given `pod` is a Mountpoint Pod. // It currently checks namespace of `pod`. func (r *Reconciler) isMountpointPod(pod *corev1.Pod) bool { diff --git a/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go index 8f3017f0..24ffc935 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go +++ b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go @@ -7,8 +7,6 @@ import ( crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" corev1 "k8s.io/api/core/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" logf "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -21,14 +19,12 @@ const ( type StaleAttachmentCleaner struct { reconciler *Reconciler mutex sync.Mutex - stopCh chan struct{} } // NewStaleAttachmentCleaner creates a new StaleAttachmentCleaner func NewStaleAttachmentCleaner(reconciler *Reconciler) *StaleAttachmentCleaner { return &StaleAttachmentCleaner{ reconciler: reconciler, - stopCh: make(chan struct{}), } } @@ -41,8 +37,6 @@ func (cm *StaleAttachmentCleaner) Start(ctx context.Context) error { select { case <-ctx.Done(): return nil - case <-cm.stopCh: - return nil case <-ticker.C: if err := cm.runCleanup(ctx); err != nil { log := logf.FromContext(ctx) @@ -94,7 +88,7 @@ func (cm *StaleAttachmentCleaner) runCleanup(ctx context.Context) error { // cleanupStaleWorkloads removes stale workload references from a single S3PodAttachment. // A workload reference is considered stale if the referenced Pod no longer exists in the cluster // and the attachment is older than staleAttachmentThreshold (this is to avoid race condition with reconciler). -// If a Mountpoint Pod has zero attachments after cleanup, both the Pod and its entry in S3PodAttachment are deleted. +// If a Mountpoint Pod has zero attachments after cleanup, "s3.csi.aws.com/needs-unmount" annotation is added and its entry in S3PodAttachment is deleted. // If S3PodAttachment has no remaining Mountpoint Pods, the entire S3PodAttachment is deleted. func (cm *StaleAttachmentCleaner) cleanupStaleWorkloads(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, existingPods map[string]struct{}) error { log := logf.FromContext(ctx).WithValues("s3pa", s3pa.Name) @@ -123,24 +117,9 @@ func (cm *StaleAttachmentCleaner) cleanupStaleWorkloads(ctx context.Context, s3p } if len(validAttachments) == 0 { - // Delete the Mountpoint Pod - mpPod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: mpPodName, - Namespace: cm.reconciler.mountpointPodConfig.Namespace, - }, - } - if err := cm.reconciler.Delete(ctx, mpPod); err != nil { - if !apierrors.IsNotFound(err) { - log.Error(err, "Failed to delete Mountpoint Pod", "mountpointPod", mpPodName) - return err - } - // If pod is not found, that's fine - continue with removing it from s3pa - log.Info("Mountpoint Pod does not exist", "mountpointPod", mpPodName) - } else { - log.Info("Deleted Mountpoint Pod with no attachments", "mountpointPod", mpPodName) + if err := cm.reconciler.addNeedsUnmountAnnotation(ctx, mpPodName, log); err != nil { + return err } - delete(s3pa.Spec.MountpointS3PodAttachments, mpPodName) } else { s3pa.Spec.MountpointS3PodAttachments[mpPodName] = validAttachments @@ -150,7 +129,7 @@ func (cm *StaleAttachmentCleaner) cleanupStaleWorkloads(ctx context.Context, s3p // Update the S3PodAttachment if modified if modified { if len(s3pa.Spec.MountpointS3PodAttachments) == 0 { - cm.reconciler.Delete(ctx, s3pa) + return cm.reconciler.Delete(ctx, s3pa) } return cm.reconciler.Update(ctx, s3pa) } diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 39ed663a..718efbc2 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -141,17 +141,15 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error } }() - unmounter := mounter.NewPodUnmounter(nodeID, mountUtil, podWatcher, s3paCache, credProvider, mounter.SourceMountDir) - - s3podAttachmentInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - UpdateFunc: unmounter.HandleS3PodAttachmentUpdate, - }) - if !cache.WaitForCacheSync(stopCh, s3podAttachmentInformer.HasSynced) { klog.Fatalf("Failed to sync informer cache within the timeout: %v\n", err) } - unmounter.CleanupDanglingMounts() + unmounter := mounter.NewPodUnmounter(nodeID, mountUtil, podWatcher, credProvider, mounter.SourceMountDir) + + podWatcher.AddEventHandler(cache.ResourceEventHandlerFuncs{UpdateFunc: unmounter.HandleMountpointPodUpdate}) + + go unmounter.StartPeriodicCleanup(stopCh) mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mountUtil, nil, nil, nil, kubernetesVersion, nodeID, mounter.SourceMountDir) diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go index a53666a8..2b3ec924 100644 --- a/pkg/driver/node/mounter/pod_unmounter.go +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -1,12 +1,12 @@ package mounter import ( - "context" "fmt" "os" "path/filepath" + "sync" + "time" - crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" @@ -14,24 +14,24 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/klog/v2" "k8s.io/mount-utils" - "sigs.k8s.io/controller-runtime/pkg/cache" ) +const cleanupInterval = 10 * time.Second + type PodUnmounter struct { nodeID string mountUtil mount.Interface kubeletPath string sourceMountDir string podWatcher *watcher.Watcher - s3paCache cache.Cache credProvider *credentialprovider.Provider + mutex sync.Mutex } func NewPodUnmounter( nodeID string, mountUtil mount.Interface, podWatcher *watcher.Watcher, - s3paCache cache.Cache, credProvider *credentialprovider.Provider, sourceMountDir string, ) *PodUnmounter { @@ -41,51 +41,51 @@ func NewPodUnmounter( kubeletPath: util.KubeletPath(), sourceMountDir: sourceMountDir, podWatcher: podWatcher, - s3paCache: s3paCache, credProvider: credProvider, } } -func (u *PodUnmounter) HandleS3PodAttachmentUpdate(old, new any) { - s3pa := new.(*crdv1beta.MountpointS3PodAttachment) - if s3pa.Spec.NodeName != u.nodeID { +func (u *PodUnmounter) HandleMountpointPodUpdate(old, new any) { + mpPod := new.(*corev1.Pod) + if mpPod.Spec.NodeName != u.nodeID { return } - for mpPodName, uids := range s3pa.Spec.MountpointS3PodAttachments { - if len(uids) == 0 { - u.unmountSourceForPod(s3pa, mpPodName) - } + if value, exists := mpPod.Annotations[mppod.AnnotationNeedsUnmount]; exists && value == "true" { + u.unmountSourceForPod(mpPod) } } -func (u *PodUnmounter) unmountSourceForPod(s3pa *crdv1beta.MountpointS3PodAttachment, mpPodName string) { - klog.Infof("Found Mountpoint Pod with zero workload pods, unmounting it - %s", mpPodName) - mpPod, err := u.podWatcher.Get(mpPodName) - if err != nil { - klog.Infof("failed to find Mountpoint Pod %s during update event", mpPodName) - return - } - +func (u *PodUnmounter) unmountSourceForPod(mpPod *corev1.Pod) { mpPodUID := string(mpPod.UID) + mpPodLock := getMPPodLock(mpPodUID) + mpPodLock.mutex.Lock() + defer func() { + mpPodLock.mutex.Unlock() + releaseMPPodLock(mpPodUID) + }() + + klog.Infof("Found Mountpoint Pod %s (UID: %s) with %s annotation, unmounting it", mpPod.Name, mpPodUID, mppod.AnnotationNeedsUnmount) + podPath := filepath.Join(u.kubeletPath, "pods", mpPodUID) source := filepath.Join(u.sourceMountDir, mpPodUID) + volumeId := mpPod.Labels[mppod.LabelVolumeId] - if err := u.writeExitFile(podPath, mpPod); err != nil { + if err := u.writeExitFile(podPath); err != nil { return } if err := u.unmountAndCleanup(source); err != nil { return } - klog.Infof("Successfully unmounted Mountpoint Pod - %s", mpPodName) + klog.Infof("Successfully unmounted Mountpoint Pod - %s", mpPod.Name) - if err := u.cleanupCredentials(s3pa, mpPodUID, podPath, source, mpPod); err != nil { + if err := u.cleanupCredentials(volumeId, mpPodUID, podPath, source, mpPod); err != nil { return } } -func (u *PodUnmounter) writeExitFile(podPath string, mpPod *corev1.Pod) error { +func (u *PodUnmounter) writeExitFile(podPath string) error { podMountExitPath := mppod.PathOnHost(podPath, mppod.KnownPathMountExit) _, err := os.OpenFile(podMountExitPath, os.O_RDONLY|os.O_CREATE, credentialprovider.CredentialFilePerm) if err != nil { @@ -108,9 +108,9 @@ func (u *PodUnmounter) unmountAndCleanup(source string) error { return nil } -func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1beta.MountpointS3PodAttachment, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { +func (u *PodUnmounter) cleanupCredentials(volumeId, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { err := u.credProvider.Cleanup(credentialprovider.CleanupContext{ - VolumeID: s3pa.Spec.VolumeID, + VolumeID: volumeId, PodID: mpPodUID, WritePath: filepath.Join(u.kubeletPath, "pods", mpPodUID), }) @@ -121,11 +121,33 @@ func (u *PodUnmounter) cleanupCredentials(s3pa *crdv1beta.MountpointS3PodAttachm return nil } -func (u *PodUnmounter) CleanupDanglingMounts() { +func (u *PodUnmounter) StartPeriodicCleanup(stopCh <-chan struct{}) error { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-stopCh: + return nil + case <-ticker.C: + if err := u.CleanupDanglingMounts(); err != nil { + klog.Errorf("Failed to run clean up of dangling mounts: %v", err) + } + } + } +} + +func (u *PodUnmounter) CleanupDanglingMounts() error { + // Ensure only one cleanup runs at a time + if !u.mutex.TryLock() { + return nil + } + defer u.mutex.Unlock() + entries, err := os.ReadDir(u.sourceMountDir) if err != nil { klog.Errorf("Failed to read source mount directory (`%s`): %v", u.sourceMountDir, err) - return + return err } for _, file := range entries { @@ -145,30 +167,16 @@ func (u *PodUnmounter) CleanupDanglingMounts() { continue } - // Check if pod has an S3PodAttachment - hasWorkloads, err := u.checkForWorkloads(mpPod) - if err != nil { - klog.Errorf("Failed to check workloads for Mountpoint Pod %s: %v", mpPod.Name, err) - continue - } - - if !hasWorkloads { - klog.Infof("Found dangling mount for Mountpoint Pod %s (UID: %s), cleaning up", mpPod.Name, mpPodUID) - podPath := filepath.Join(u.kubeletPath, "pods", mpPodUID) - if err := u.writeExitFile(podPath, mpPod); err != nil { - return - } - - if err := u.unmountAndCleanup(source); err != nil { - klog.Errorf("Failed to cleanup dangling mount for Mountpoint Pod %s: %v", mpPod.Name, err) - continue - } - - // TODO: Skip credential clean up as we do not know volumeID OR delete all files in credential folder? + // Unmount if Mountpoint Pod is marked for unmounting + if value, exists := mpPod.Annotations[mppod.AnnotationNeedsUnmount]; exists && value == "true" { + u.unmountSourceForPod(mpPod) } } + + return nil } +// findPodByUID finds Mountpoint Pod by UID in podWatcher's cache func (u *PodUnmounter) findPodByUID(mpPodUID string) (*corev1.Pod, error) { pods, err := u.podWatcher.List() if err != nil { @@ -182,21 +190,3 @@ func (u *PodUnmounter) findPodByUID(mpPodUID string) (*corev1.Pod, error) { } return nil, fmt.Errorf("Mountpoint Pod not found for UID %s", mpPodUID) } - -func (u *PodUnmounter) checkForWorkloads(mpPod *corev1.Pod) (bool, error) { - s3paList := &crdv1beta.MountpointS3PodAttachmentList{} - err := u.s3paCache.List(context.Background(), s3paList) - if err != nil { - return false, err - } - - // Find attachment for this pod and check if it has workloads - for _, s3pa := range s3paList.Items { - for mpPodName, workloadUIDs := range s3pa.Spec.MountpointS3PodAttachments { - if mpPodName == mpPod.Name { - return len(workloadUIDs) > 0, nil - } - } - } - return false, nil -} diff --git a/pkg/driver/node/mounter/pod_unmounter_test.go b/pkg/driver/node/mounter/pod_unmounter_test.go index 8e56cffd..4cbcadec 100644 --- a/pkg/driver/node/mounter/pod_unmounter_test.go +++ b/pkg/driver/node/mounter/pod_unmounter_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" @@ -59,7 +58,6 @@ func TestHandleS3PodAttachmentUpdate(t *testing.T) { tests := []struct { name string nodeID string - s3pa *crdv1beta.MountpointS3PodAttachment pod *corev1.Pod unmountError error expectUnmount bool @@ -67,35 +65,44 @@ func TestHandleS3PodAttachmentUpdate(t *testing.T) { { name: "different node", nodeID: "node1", - s3pa: &crdv1beta.MountpointS3PodAttachment{ - Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - NodeName: "node2", + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + }, + Spec: corev1.PodSpec{ + NodeName: "different-node", }, }, expectUnmount: false, }, { - name: "same node with empty workload", + name: "same node with unmount annotation", nodeID: nodeName, pod: &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: "pod1", Namespace: mountpointPodNamespace, UID: "uid1", + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + Labels: map[string]string{ + mppod.LabelVolumeId: "vol1", + }, }, - }, - s3pa: &crdv1beta.MountpointS3PodAttachment{ - Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + Spec: corev1.PodSpec{ NodeName: nodeName, - MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ - "pod1": {}, - }, }, }, expectUnmount: true, }, { - name: "same node with empty workload and unmount error", + name: "same node without unmount annotation", nodeID: nodeName, pod: &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ @@ -103,17 +110,11 @@ func TestHandleS3PodAttachmentUpdate(t *testing.T) { Namespace: mountpointPodNamespace, UID: "uid1", }, - }, - s3pa: &crdv1beta.MountpointS3PodAttachment{ - Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + Spec: corev1.PodSpec{ NodeName: nodeName, - MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ - "pod1": {}, - }, }, }, - unmountError: errors.New("unmount error"), - expectUnmount: true, + expectUnmount: false, }, } @@ -147,10 +148,9 @@ func TestHandleS3PodAttachmentUpdate(t *testing.T) { credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { return dummyIMDSRegion, nil }) - s3paCache := &mounter.FakeCache{} - unmounter := mounter.NewPodUnmounter(tt.nodeID, fakeMounter, podWatcher, s3paCache, credProvider, sourceMountDir) - unmounter.HandleS3PodAttachmentUpdate(nil, tt.s3pa) + unmounter := mounter.NewPodUnmounter(tt.nodeID, fakeMounter, podWatcher, credProvider, sourceMountDir) + unmounter.HandleMountpointPodUpdate(nil, tt.pod) unmountCalls := countUnmountCalls(fakeMounter) expectedUnmounts := 0 @@ -166,7 +166,6 @@ func TestCleanupDanglingMounts(t *testing.T) { tests := []struct { name string pods []*corev1.Pod - s3paItems []crdv1beta.MountpointS3PodAttachment unmountError error expectedCalls int }{ @@ -179,38 +178,30 @@ func TestCleanupDanglingMounts(t *testing.T) { Namespace: mountpointPodNamespace, UID: "uid1", }, - }, - }, - s3paItems: []crdv1beta.MountpointS3PodAttachment{ - { - Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ - "pod1": []crdv1beta.WorkloadAttachment{crdv1beta.WorkloadAttachment{ - WorkloadPodUID: "workload1", - }}, - }, + Spec: corev1.PodSpec{ + NodeName: nodeName, }, }, }, expectedCalls: 0, }, { - name: "with dangling mount", + name: "pod marked for unmount", pods: []*corev1.Pod{ { ObjectMeta: metav1.ObjectMeta{ Name: "pod1", Namespace: mountpointPodNamespace, UID: "uid1", - }, - }, - }, - s3paItems: []crdv1beta.MountpointS3PodAttachment{ - { - Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ - "pod1": {}, + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", }, + Labels: map[string]string{ + mppod.LabelVolumeId: "vol1", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, }, }, }, @@ -224,16 +215,16 @@ func TestCleanupDanglingMounts(t *testing.T) { Name: "pod1", Namespace: mountpointPodNamespace, UID: "uid1", - }, - }, - }, - s3paItems: []crdv1beta.MountpointS3PodAttachment{ - { - Spec: crdv1beta.MountpointS3PodAttachmentSpec{ - MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ - "pod1": {}, + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + Labels: map[string]string{ + mppod.LabelVolumeId: "vol1", }, }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, }, }, unmountError: errors.New("unmount error"), @@ -243,7 +234,6 @@ func TestCleanupDanglingMounts(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podWatcher, client := setupPodWatcher(t, tt.pods...) kubeletPath := t.TempDir() t.Setenv("KUBELET_PATH", kubeletPath) t.Chdir(kubeletPath) @@ -266,16 +256,14 @@ func TestCleanupDanglingMounts(t *testing.T) { } } - s3paCache := &mounter.FakeCache{ - TestItems: tt.s3paItems, - } - + podWatcher, client := setupPodWatcher(t, tt.pods...) credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { return dummyIMDSRegion, nil }) - unmounter := mounter.NewPodUnmounter(nodeName, fakeMounter, podWatcher, s3paCache, credProvider, sourceMountDir) - unmounter.CleanupDanglingMounts() + unmounter := mounter.NewPodUnmounter(nodeName, fakeMounter, podWatcher, credProvider, sourceMountDir) + err := unmounter.CleanupDanglingMounts() + assert.NoError(t, err) unmountCalls := countUnmountCalls(fakeMounter) assert.Equals(t, tt.expectedCalls, unmountCalls) diff --git a/pkg/podmounter/mppod/creator.go b/pkg/podmounter/mppod/creator.go index cc3bef11..d24a5ca8 100644 --- a/pkg/podmounter/mppod/creator.go +++ b/pkg/podmounter/mppod/creator.go @@ -15,9 +15,14 @@ import ( const ( LabelMountpointVersion = "s3.csi.aws.com/mountpoint-version" LabelVolumeName = "s3.csi.aws.com/volume-name" + LabelVolumeId = "s3.csi.aws.com/volume-id" LabelCSIDriverVersion = "s3.csi.aws.com/mounted-by-csi-driver-version" ) +const ( + AnnotationNeedsUnmount = "s3.csi.aws.com/needs-unmount" +) + // A ContainerConfig represents configuration for containers in the spawned Mountpoint Pods. type ContainerConfig struct { Command string @@ -54,6 +59,7 @@ func (c *Creator) Create(node string, pv *corev1.PersistentVolume) *corev1.Pod { Labels: map[string]string{ LabelMountpointVersion: c.config.MountpointVersion, LabelVolumeName: pv.Name, + LabelVolumeId: pv.Spec.CSI.VolumeHandle, LabelCSIDriverVersion: c.config.CSIDriverVersion, }, }, diff --git a/pkg/podmounter/mppod/creator_test.go b/pkg/podmounter/mppod/creator_test.go index 13fb31ad..aa8e032b 100644 --- a/pkg/podmounter/mppod/creator_test.go +++ b/pkg/podmounter/mppod/creator_test.go @@ -22,6 +22,7 @@ const ( testNode = "test-node" testPodUID = "test-pod-uid" testVolName = "test-vol" + testVolID = "test-vol-id" csiDriverVersion = "1.12.0" ) @@ -51,6 +52,7 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu mppod.LabelMountpointVersion: mountpointVersion, mppod.LabelCSIDriverVersion: csiDriverVersion, mppod.LabelVolumeName: testVolName, + mppod.LabelVolumeId: testVolID, }, mpPod.Labels) assert.Equals(t, priorityClassName, mpPod.Spec.PriorityClassName) @@ -107,6 +109,13 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, + Spec: corev1.PersistentVolumeSpec{ + PersistentVolumeSource: corev1.PersistentVolumeSource{ + CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, + }, + }, + }, }) verifyDefaultValues(mpPod) @@ -120,6 +129,7 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu Spec: corev1.PersistentVolumeSpec{ PersistentVolumeSource: corev1.PersistentVolumeSource{ CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, VolumeAttributes: map[string]string{ "mountpointPodServiceAccountName": "mount-s3-sa", }, diff --git a/pkg/podmounter/mppod/watcher/watcher.go b/pkg/podmounter/mppod/watcher/watcher.go index 5fffc9c2..2fc3ccf6 100644 --- a/pkg/podmounter/mppod/watcher/watcher.go +++ b/pkg/podmounter/mppod/watcher/watcher.go @@ -71,6 +71,11 @@ func (w *Watcher) List() ([]*corev1.Pod, error) { return w.lister.List(labels.Everything()) } +// AddEventHandler adds pod event handler. +func (w *Watcher) AddEventHandler(handler cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + return w.informer.AddEventHandler(handler) +} + // Wait blocks until the specified Mountpoint Pod is found and ready, or until the context is cancelled. func (w *Watcher) Wait(ctx context.Context, name string) (*corev1.Pod, error) { // Set a watcher for Pod create & update events From d04d8d1a5763457eade224a8ee2cb52d9997b799 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Thu, 1 May 2025 14:33:11 +0100 Subject: [PATCH 19/24] Conditionally support selectable fields --- pkg/cluster/cluster.go | 15 +++++++ pkg/cluster/cluster_test.go | 72 +++++++++++++++++++++++++++++++++ pkg/driver/driver.go | 81 +++++++++++++++++++++++-------------- 3 files changed, 138 insertions(+), 30 deletions(-) diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 92e60122..b399ff2c 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -2,6 +2,7 @@ package cluster import ( "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/util/version" "k8s.io/client-go/discovery" "k8s.io/client-go/rest" "k8s.io/utils/ptr" @@ -53,3 +54,17 @@ func (c Variant) MountpointPodUserID() *int64 { return defaultMountpointUID } + +// Helper function to check availability of selectableFields on CustomResourceDefinitions feature in K8s cluster +func IsSelectableFieldsSupported(serverVersion string) (bool, error) { + currentVersion, err := version.ParseGeneric(serverVersion) + if err != nil { + return false, err + } + + // Selectable fields are supported from 1.32 + // https://kubernetes.io/docs/tasks/extend-kubernetes/custom-resources/custom-resource-definitions/#crd-selectable-fields + selectableFieldsVersion := version.MustParseGeneric("v1.32.0") + + return !currentVersion.LessThan(selectableFieldsVersion), nil +} diff --git a/pkg/cluster/cluster_test.go b/pkg/cluster/cluster_test.go index 700bae61..174609d4 100644 --- a/pkg/cluster/cluster_test.go +++ b/pkg/cluster/cluster_test.go @@ -33,3 +33,75 @@ func TestMountpointPodUserID(t *testing.T) { }) } } + +func TestIsSelectableFieldsSupported(t *testing.T) { + tests := []struct { + name string + serverVersion string + want bool + wantErr bool + }{ + { + name: "version greater than minimum supported version", + serverVersion: "v1.33.0", + want: true, + wantErr: false, + }, + { + name: "version equal to minimum supported version", + serverVersion: "v1.32.0", + want: true, + wantErr: false, + }, + { + name: "version less than minimum supported version", + serverVersion: "v1.31.0", + want: false, + wantErr: false, + }, + { + name: "version with patch number", + serverVersion: "v1.32.2", + want: true, + wantErr: false, + }, + { + name: "version with release candidate", + serverVersion: "v1.32.0-rc.1", + want: true, + wantErr: false, + }, + { + name: "invalid version format", + serverVersion: "invalid.version", + want: false, + wantErr: true, + }, + { + name: "empty version string", + serverVersion: "", + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cluster.IsSelectableFieldsSupported(tt.serverVersion) + + if tt.wantErr { + // If we expect an error, we don't check the boolean result + if err == nil { + t.Error("expected error but got none") + } + return + } + + // If we don't expect an error, verify there isn't one + assert.NoError(t, err) + + // Compare the actual result with expected + assert.Equals(t, tt.want, got) + }) + } +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 718efbc2..998cbc77 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -24,6 +24,7 @@ import ( "time" crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" @@ -114,36 +115,7 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error klog.Fatalf("Failed to start Pod watcher: %v\n", err) } - s3paCache, err := ctrlcache.New(config, ctrlcache.Options{ - Scheme: scheme, - SyncPeriod: &podWatcherResyncPeriod, - ReaderFailOnMissingInformer: true, - ByObject: map[client.Object]ctrlcache.ByObject{ - &crdv1beta.MountpointS3PodAttachment{}: { - Field: fields.OneTermEqualSelector("spec.nodeName", nodeID), - }, - }, - }) - if err != nil { - klog.Fatalf("Failed to create cache: %v\n", err) - } - - indexMountpointS3PodAttachmentFields(s3paCache) - - s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1beta.MountpointS3PodAttachment{}) - if err != nil { - klog.Fatalf("Failed to create informer for MountpointS3PodAttachment: %v\n", err) - } - - go func() { - if err := s3paCache.Start(signals.SetupSignalHandler()); err != nil { - klog.Fatalf("Failed to start cache: %v\n", err) - } - }() - - if !cache.WaitForCacheSync(stopCh, s3podAttachmentInformer.HasSynced) { - klog.Fatalf("Failed to sync informer cache within the timeout: %v\n", err) - } + s3paCache := setupS3PodAttachmentCache(config, stopCh, nodeID, kubernetesVersion) unmounter := mounter.NewPodUnmounter(nodeID, mountUtil, podWatcher, credProvider, mounter.SourceMountDir) @@ -241,6 +213,55 @@ func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) { return version.String(), nil } +// setupS3PodAttachmentCache sets up cache for MountpointS3PodAttachment custom resource +func setupS3PodAttachmentCache(config *rest.Config, stopCh <-chan struct{}, nodeID, kubernetesVersion string) ctrlcache.Cache { + options := ctrlcache.Options{ + Scheme: scheme, + SyncPeriod: &podWatcherResyncPeriod, + ReaderFailOnMissingInformer: true, + } + isSelectFieldsSupported, err := cluster.IsSelectableFieldsSupported(kubernetesVersion) + if err != nil { + klog.Fatalf("Failed to check support for selectable fields in the cluster %v\n", err) + } + if isSelectFieldsSupported { + options.ByObject = map[client.Object]ctrlcache.ByObject{ + &crdv1beta.MountpointS3PodAttachment{}: { + Field: fields.OneTermEqualSelector("spec.nodeName", nodeID), + }, + } + } else { + // TODO: We can potentially use label filter hash of nodeId for old clusters instead of field selector + options.ByObject = map[client.Object]ctrlcache.ByObject{ + &crdv1beta.MountpointS3PodAttachment{}: {}, + } + } + + s3paCache, err := ctrlcache.New(config, options) + if err != nil { + klog.Fatalf("Failed to create cache: %v\n", err) + } + + indexMountpointS3PodAttachmentFields(s3paCache) + + s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1beta.MountpointS3PodAttachment{}) + if err != nil { + klog.Fatalf("Failed to create informer for MountpointS3PodAttachment: %v\n", err) + } + + go func() { + if err := s3paCache.Start(signals.SetupSignalHandler()); err != nil { + klog.Fatalf("Failed to start cache: %v\n", err) + } + }() + + if !cache.WaitForCacheSync(stopCh, s3podAttachmentInformer.HasSynced) { + klog.Fatalf("Failed to sync informer cache within the timeout: %v\n", err) + } + + return s3paCache +} + // TODO: This is duplicated multiple times func indexMountpointS3PodAttachmentFields(s3paCache ctrlcache.Cache) { indexField(s3paCache, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) From 845b508575948294baad7fed22a461f72f7e2306 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Thu, 1 May 2025 15:27:12 +0100 Subject: [PATCH 20/24] Add sleep in controller for IRSA role change test --- tests/controller/controller_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index af59e25e..f3266014 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -512,6 +512,10 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadServiceAccountIAMRoleARN"] = "" s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + // Adding some sleep time before updating SA because reconciler requeues pod1 event to clear expectation + // and it can cause transient test failure if we update SA annotation too quickly + time.Sleep(5 * time.Second) + sa.Annotations = map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role-1"} Expect(k8sClient.Update(ctx, sa)).To(Succeed()) pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) @@ -519,6 +523,8 @@ var _ = Describe("Mountpoint Controller", func() { expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role-1" s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + time.Sleep(5 * time.Second) + sa.Annotations = map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role-2"} Expect(k8sClient.Update(ctx, sa)).To(Succeed()) pod3 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) From 26dc0d87dc02e10948c9c1a744d0be8d9d13e183 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Thu, 1 May 2025 17:00:59 +0100 Subject: [PATCH 21/24] Further refactoring, add more doc comments --- .../csicontroller/reconciler.go | 1 + cmd/aws-s3-csi-controller/main.go | 34 ++-------- .../mountpoints3podattachment_indexer.go | 67 +++++++++++++++++++ pkg/driver/driver.go | 26 +------ pkg/driver/node/mounter/pod_mounter.go | 39 ++++++++--- pkg/driver/node/mounter/pod_unmounter.go | 37 ++++++++-- tests/controller/suite_test.go | 29 +------- 7 files changed, 139 insertions(+), 94 deletions(-) create mode 100644 pkg/api/v1beta/mountpoints3podattachment_indexer.go diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index b6d23d66..a87906bb 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -663,6 +663,7 @@ func isPodActive(p *corev1.Pod) bool { p.DeletionTimestamp == nil } +// s3paContainsWorkload checks whether MountpointS3PodAttachment has `workloadUID` in it. func s3paContainsWorkload(s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string) bool { for _, attachments := range s3pa.Spec.MountpointS3PodAttachments { for _, attachment := range attachments { diff --git a/cmd/aws-s3-csi-controller/main.go b/cmd/aws-s3-csi-controller/main.go index 6b270ac1..52d0f177 100644 --- a/cmd/aws-s3-csi-controller/main.go +++ b/cmd/aws-s3-csi-controller/main.go @@ -7,16 +7,13 @@ package main import ( - "context" "flag" - "fmt" "os" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/config" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -28,7 +25,6 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" - "github.com/go-logr/logr" ) var mountpointNamespace = flag.String("mountpoint-namespace", os.Getenv("MOUNTPOINT_NAMESPACE"), "Namespace to spawn Mountpoint Pods in.") @@ -63,7 +59,10 @@ func main() { os.Exit(1) } - IndexMountpointS3PodAttachmentFields(log, mgr) + if err := crdv1beta.SetupManagerIndices(mgr); err != nil { + log.Error(err, "Failed to setup field indexers") + os.Exit(1) + } reconciler := csicontroller.NewReconciler(mgr.GetClient(), mppod.Config{ Namespace: *mountpointNamespace, @@ -93,28 +92,3 @@ func main() { os.Exit(1) } } - -// IndexMountpointS3PodAttachmentFields adds internal index on fields for our custom resource. -// This is needed for `List()` method to work with field filters. -func IndexMountpointS3PodAttachmentFields(log logr.Logger, mgr manager.Manager) { - indexField(log, mgr, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) - indexField(log, mgr, crdv1beta.FieldPersistentVolumeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) - indexField(log, mgr, crdv1beta.FieldVolumeID, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) - indexField(log, mgr, crdv1beta.FieldMountOptions, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) - indexField(log, mgr, crdv1beta.FieldAuthenticationSource, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) - indexField(log, mgr, crdv1beta.FieldWorkloadFSGroup, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) - indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) - indexField(log, mgr, crdv1beta.FieldWorkloadNamespace, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) - indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) -} - -// indexField adds index on a field. -func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1beta.MountpointS3PodAttachment) string) { - err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1beta.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { - return []string{extractor(obj.(*crdv1beta.MountpointS3PodAttachment))} - }) - if err != nil { - log.Error(err, fmt.Sprintf("Failed to create a %s field indexer", field)) - os.Exit(1) - } -} diff --git a/pkg/api/v1beta/mountpoints3podattachment_indexer.go b/pkg/api/v1beta/mountpoints3podattachment_indexer.go new file mode 100644 index 00000000..12dc19a3 --- /dev/null +++ b/pkg/api/v1beta/mountpoints3podattachment_indexer.go @@ -0,0 +1,67 @@ +package v1beta + +import ( + "context" + "fmt" + + ctrlcache "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" +) + +// getIndexFields returns the set of field extractors +func getIndexFields() map[string]func(*MountpointS3PodAttachment) string { + return map[string]func(*MountpointS3PodAttachment) string{ + FieldNodeName: func(cr *MountpointS3PodAttachment) string { return cr.Spec.NodeName }, + FieldPersistentVolumeName: func(cr *MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }, + FieldVolumeID: func(cr *MountpointS3PodAttachment) string { return cr.Spec.VolumeID }, + FieldMountOptions: func(cr *MountpointS3PodAttachment) string { return cr.Spec.MountOptions }, + FieldAuthenticationSource: func(cr *MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }, + FieldWorkloadFSGroup: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }, + FieldWorkloadServiceAccountName: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }, + FieldWorkloadNamespace: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }, + FieldWorkloadServiceAccountIAMRoleARN: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }, + } +} + +// SetupManagerIndices sets up indices for a manager +func SetupManagerIndices(mgr manager.Manager) error { + for field, extractor := range getIndexFields() { + if err := setupManagerIndex(mgr, field, extractor); err != nil { + return fmt.Errorf("failed to setup index for field %s: %w", field, err) + } + } + return nil +} + +// SetupCacheIndices sets up indices for a cache +func SetupCacheIndices(cache ctrlcache.Cache) error { + for field, extractor := range getIndexFields() { + if err := setupCacheIndex(cache, field, extractor); err != nil { + return fmt.Errorf("failed to setup index for field %s: %w", field, err) + } + } + return nil +} + +func setupManagerIndex(mgr manager.Manager, field string, extractor func(*MountpointS3PodAttachment) string) error { + return mgr.GetFieldIndexer().IndexField( + context.Background(), + &MountpointS3PodAttachment{}, + field, + func(obj client.Object) []string { + return []string{extractor(obj.(*MountpointS3PodAttachment))} + }, + ) +} + +func setupCacheIndex(cache ctrlcache.Cache, field string, extractor func(*MountpointS3PodAttachment) string) error { + return cache.IndexField( + context.Background(), + &MountpointS3PodAttachment{}, + field, + func(obj client.Object) []string { + return []string{extractor(obj.(*MountpointS3PodAttachment))} + }, + ) +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 998cbc77..f5884d3c 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -242,7 +242,9 @@ func setupS3PodAttachmentCache(config *rest.Config, stopCh <-chan struct{}, node klog.Fatalf("Failed to create cache: %v\n", err) } - indexMountpointS3PodAttachmentFields(s3paCache) + if err := crdv1beta.SetupCacheIndices(s3paCache); err != nil { + klog.Fatalf("Failed to setup field indexers: %v", err) + } s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1beta.MountpointS3PodAttachment{}) if err != nil { @@ -261,25 +263,3 @@ func setupS3PodAttachmentCache(config *rest.Config, stopCh <-chan struct{}, node return s3paCache } - -// TODO: This is duplicated multiple times -func indexMountpointS3PodAttachmentFields(s3paCache ctrlcache.Cache) { - indexField(s3paCache, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) - indexField(s3paCache, crdv1beta.FieldPersistentVolumeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) - indexField(s3paCache, crdv1beta.FieldVolumeID, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) - indexField(s3paCache, crdv1beta.FieldMountOptions, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) - indexField(s3paCache, crdv1beta.FieldAuthenticationSource, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) - indexField(s3paCache, crdv1beta.FieldWorkloadFSGroup, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) - indexField(s3paCache, crdv1beta.FieldWorkloadServiceAccountName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) - indexField(s3paCache, crdv1beta.FieldWorkloadNamespace, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) - indexField(s3paCache, crdv1beta.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) -} - -func indexField(cache ctrlcache.Cache, field string, extractor func(*crdv1beta.MountpointS3PodAttachment) string) { - err := cache.IndexField(context.Background(), &crdv1beta.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { - return []string{extractor(obj.(*crdv1beta.MountpointS3PodAttachment))} - }) - if err != nil { - klog.Fatalf("Failed to create a %s field indexer: %v", field, err) - } -} diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index e0b07571..c603cf3f 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -184,7 +184,29 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin return nil } -func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, pod *corev1.Pod, podPath string, +// mountS3AtSource mounts an S3 bucket at the specified source path using the Mountpoint Pod. +// +// Parameters: +// - ctx: Context for cancellation and timeout control +// - source: The path where the S3 bucket should be mounted +// - mpPod: Mountpoint Pod that will serve this mount point +// - podPath: Base path for Pod-specific files +// - bucketName: Name of the S3 bucket to mount +// - credEnv: Environment variables related to AWS credentials +// - authenticationSource: Authentication source from PV volume attribute +// - args: Mountpoint arguments +// +// Returns: +// - error: nil if successful, otherwise an error describing what went wrong +// +// The function performs the following steps: +// 1. Prepares environment and mount arguments +// 2. Performs the initial mount syscall to obtain FUSE file descriptor +// 3. Sends mount options to the Mountpoint Pod +// 4. Waits for the mount to be ready +// +// If any step fails, it ensures cleanup by unmounting the source path. +func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, mpPod *corev1.Pod, podPath string, bucketName string, credEnv envprovider.Environment, authenticationSource credentialprovider.AuthenticationSource, args mountpoint.Args) error { env := envprovider.Default() @@ -200,7 +222,7 @@ func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, pod *c podMountSockPath := mppod.PathOnHost(podPath, mppod.KnownPathMountSock) podMountErrorPath := mppod.PathOnHost(podPath, mppod.KnownPathMountError) - klog.V(4).Infof("Mounting %s for %s", source, pod.Name) + klog.V(4).Infof("Mounting %s for %s", source, mpPod.Name) fuseDeviceFD, err := pm.mountSyscallWithDefault(source, args) if err != nil { @@ -235,7 +257,7 @@ func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, pod *c // Remove old mount error file if exists _ = os.Remove(podMountErrorPath) - klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", pod.Name, podMountSockPath) + klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", mpPod.Name, podMountSockPath) err = mountoptions.Send(ctx, podMountSockPath, mountoptions.Options{ Fd: fuseDeviceFD, @@ -244,14 +266,14 @@ func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, pod *c Env: env.List(), }) if err != nil { - klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) + return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) } - err = pm.waitForMount(ctx, source, pod.Name, podMountErrorPath) + err = pm.waitForMount(ctx, source, mpPod.Name, podMountErrorPath) if err != nil { - klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", pod.Name, source, err, pm.helpMessageForGettingMountpointLogs(pod)) + klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) + return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) } // Mountpoint successfully started, so don't unmount the filesystem @@ -381,6 +403,7 @@ func (pm *PodMounter) verifyOrSetupMountTarget(target string) error { return err } +// provideCredentials provides credentials func (pm *PodMounter) provideCredentials(ctx context.Context, podPath string, credentialCtx credentialprovider.ProvideContext) (envprovider.Environment, credentialprovider.AuthenticationSource, error) { podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) if err != nil { diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go index 2b3ec924..916047f8 100644 --- a/pkg/driver/node/mounter/pod_unmounter.go +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -1,7 +1,6 @@ package mounter import ( - "fmt" "os" "path/filepath" "sync" @@ -18,6 +17,7 @@ import ( const cleanupInterval = 10 * time.Second +// PodUnmounter handles unmounting of Mountpoint Pods and cleanup of associated resources type PodUnmounter struct { nodeID string mountUtil mount.Interface @@ -28,6 +28,7 @@ type PodUnmounter struct { mutex sync.Mutex } +// NewPodUnmounter creates a new PodUnmounter instance with the given parameters func NewPodUnmounter( nodeID string, mountUtil mount.Interface, @@ -45,6 +46,8 @@ func NewPodUnmounter( } } +// HandleMountpointPodUpdate is a Pod Update handler that triggers unmounting +// if the Mountpoint Pod is marked for unmounting via annotations func (u *PodUnmounter) HandleMountpointPodUpdate(old, new any) { mpPod := new.(*corev1.Pod) if mpPod.Spec.NodeName != u.nodeID { @@ -56,6 +59,9 @@ func (u *PodUnmounter) HandleMountpointPodUpdate(old, new any) { } } +// unmountSourceForPod performs the unmounting process for a specific Mountpoint Pod +// including cleanup of associated resources +// mpPod: The Mountpoint Pod to unmount func (u *PodUnmounter) unmountSourceForPod(mpPod *corev1.Pod) { mpPodUID := string(mpPod.UID) mpPodLock := getMPPodLock(mpPodUID) @@ -85,6 +91,9 @@ func (u *PodUnmounter) unmountSourceForPod(mpPod *corev1.Pod) { } } +// writeExitFile creates an exit file in the pod's directory to signal Mountpoint Pod termination +// podPath: Path to the pod's directory +// Returns error if file creation fails func (u *PodUnmounter) writeExitFile(podPath string) error { podMountExitPath := mppod.PathOnHost(podPath, mppod.KnownPathMountExit) _, err := os.OpenFile(podMountExitPath, os.O_RDONLY|os.O_CREATE, credentialprovider.CredentialFilePerm) @@ -95,6 +104,9 @@ func (u *PodUnmounter) writeExitFile(podPath string) error { return nil } +// unmountAndCleanup unmounts the source directory and removes it +// source: Path to the source directory to unmount +// Returns error if unmounting or cleanup fails func (u *PodUnmounter) unmountAndCleanup(source string) error { if err := u.mountUtil.Unmount(source); err != nil { klog.Errorf("Failed to unmount source %q: %v", source, err) @@ -108,6 +120,7 @@ func (u *PodUnmounter) unmountAndCleanup(source string) error { return nil } +// cleanupCredentials removes credentials associated with the Mountpoint Pod func (u *PodUnmounter) cleanupCredentials(volumeId, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { err := u.credProvider.Cleanup(credentialprovider.CleanupContext{ VolumeID: volumeId, @@ -121,14 +134,17 @@ func (u *PodUnmounter) cleanupCredentials(volumeId, mpPodUID, podPath, source st return nil } -func (u *PodUnmounter) StartPeriodicCleanup(stopCh <-chan struct{}) error { +// StartPeriodicCleanup begins periodic cleanup of dangling mounts +// This is needed in case when `HandleMountpointPodUpdate()` missed an update event to trigger cleanup. +// stopCh: Channel to signal stopping of the cleanup routine +func (u *PodUnmounter) StartPeriodicCleanup(stopCh <-chan struct{}) { ticker := time.NewTicker(cleanupInterval) defer ticker.Stop() for { select { case <-stopCh: - return nil + return case <-ticker.C: if err := u.CleanupDanglingMounts(); err != nil { klog.Errorf("Failed to run clean up of dangling mounts: %v", err) @@ -137,6 +153,8 @@ func (u *PodUnmounter) StartPeriodicCleanup(stopCh <-chan struct{}) error { } } +// CleanupDanglingMounts scans the source mount directory for potential dangling mounts +// and cleans them up. It also unmounts any Mountpoint Pods marked for unmounting. func (u *PodUnmounter) CleanupDanglingMounts() error { // Ensure only one cleanup runs at a time if !u.mutex.TryLock() { @@ -160,14 +178,19 @@ func (u *PodUnmounter) CleanupDanglingMounts() error { // Try to find corresponding pod mpPod, err := u.findPodByUID(mpPodUID) if err != nil { - klog.V(4).Infof("Mountpoint Pod not found for UID %s, will only unmount and delete folder: %v", mpPodUID, err) + klog.Errorf("Failed to check Mountpoint Pod (UID: %s) existence: %v", mpPodUID, err) + return err + } + + if mpPod == nil { + klog.V(4).Infof("Mountpoint Pod not found for UID %s, will unmount and delete folder", mpPodUID) if err := u.unmountAndCleanup(source); err != nil { - klog.Errorf("Failed to cleanup dangling mount for Mountpoint Pod %s: %v", mpPod.Name, err) + klog.Errorf("Failed to cleanup dangling mount for UID %s: %v", mpPodUID, err) } continue } - // Unmount if Mountpoint Pod is marked for unmounting + // Unmount only if Mountpoint Pod is marked for unmounting if value, exists := mpPod.Annotations[mppod.AnnotationNeedsUnmount]; exists && value == "true" { u.unmountSourceForPod(mpPod) } @@ -188,5 +211,5 @@ func (u *PodUnmounter) findPodByUID(mpPodUID string) (*corev1.Pod, error) { return pod, nil } } - return nil, fmt.Errorf("Mountpoint Pod not found for UID %s", mpPodUID) + return nil, nil } diff --git a/tests/controller/suite_test.go b/tests/controller/suite_test.go index 160a936f..e2947692 100644 --- a/tests/controller/suite_test.go +++ b/tests/controller/suite_test.go @@ -3,12 +3,10 @@ package controller_test import ( "context" "fmt" - "os" "testing" "time" crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" - "github.com/go-logr/logr" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -21,7 +19,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - "sigs.k8s.io/controller-runtime/pkg/manager" "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" @@ -90,7 +87,9 @@ var _ = BeforeSuite(func() { k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{Scheme: scheme.Scheme}) Expect(err).ToNot(HaveOccurred()) - IndexMountpointS3PodAttachmentFields(logf.Log.WithName("controller-test"), k8sManager) + if err := crdv1beta.SetupManagerIndices(k8sManager); err != nil { + Expect(err).NotTo(HaveOccurred()) + } err = csicontroller.NewReconciler(k8sManager.GetClient(), mppod.Config{ Namespace: mountpointNamespace, @@ -155,25 +154,3 @@ func createMountpointPriorityClass() { Expect(k8sClient.Create(ctx, priorityClass)).To(Succeed()) waitForObject(priorityClass) } - -func IndexMountpointS3PodAttachmentFields(log logr.Logger, mgr manager.Manager) { - indexField(log, mgr, crdv1beta.FieldNodeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.NodeName }) - indexField(log, mgr, crdv1beta.FieldPersistentVolumeName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }) - indexField(log, mgr, crdv1beta.FieldVolumeID, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.VolumeID }) - indexField(log, mgr, crdv1beta.FieldMountOptions, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.MountOptions }) - indexField(log, mgr, crdv1beta.FieldAuthenticationSource, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }) - indexField(log, mgr, crdv1beta.FieldWorkloadFSGroup, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }) - indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountName, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }) - indexField(log, mgr, crdv1beta.FieldWorkloadNamespace, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }) - indexField(log, mgr, crdv1beta.FieldWorkloadServiceAccountIAMRoleARN, func(cr *crdv1beta.MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }) -} - -func indexField(log logr.Logger, mgr manager.Manager, field string, extractor func(*crdv1beta.MountpointS3PodAttachment) string) { - err := mgr.GetFieldIndexer().IndexField(context.Background(), &crdv1beta.MountpointS3PodAttachment{}, field, func(obj client.Object) []string { - return []string{extractor(obj.(*crdv1beta.MountpointS3PodAttachment))} - }) - if err != nil { - log.Error(err, fmt.Sprintf("Failed to create a %s field indexer", field)) - os.Exit(1) - } -} From 258e0c826c4834cd17fc59666cded5457f6235a0 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Thu, 1 May 2025 17:30:45 +0100 Subject: [PATCH 22/24] go mod tidy --- tests/e2e-kubernetes/go.mod | 2 ++ tests/e2e-kubernetes/go.sum | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/tests/e2e-kubernetes/go.mod b/tests/e2e-kubernetes/go.mod index 8b0e6c93..79b7d043 100644 --- a/tests/e2e-kubernetes/go.mod +++ b/tests/e2e-kubernetes/go.mod @@ -49,6 +49,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.5.0 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/evanphx/json-patch/v5 v5.9.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -121,6 +122,7 @@ require ( golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/genproto v0.0.0-20231127180814-3a041ad873d4 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect diff --git a/tests/e2e-kubernetes/go.sum b/tests/e2e-kubernetes/go.sum index 58d8a51d..291ba502 100644 --- a/tests/e2e-kubernetes/go.sum +++ b/tests/e2e-kubernetes/go.sum @@ -80,6 +80,10 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v4.12.0+incompatible h1:4onqiflcdA9EOZ4RxV643DvftH5pOlLGNtQ5lPWQu84= +github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= +github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -358,6 +362,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= +gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= From 9dfd81689536526e048466e13c5719b8a3de05a5 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Fri, 2 May 2025 08:24:44 +0100 Subject: [PATCH 23/24] Node: Handle shutdown signal --- cmd/aws-s3-csi-driver/main.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cmd/aws-s3-csi-driver/main.go b/cmd/aws-s3-csi-driver/main.go index 4bcd9e0a..ce720e2f 100644 --- a/cmd/aws-s3-csi-driver/main.go +++ b/cmd/aws-s3-csi-driver/main.go @@ -21,6 +21,8 @@ import ( "flag" "fmt" "os" + "os/signal" + "syscall" "github.com/awslabs/aws-s3-csi-driver/pkg/driver" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" @@ -70,6 +72,15 @@ func main() { klog.Fatalf("failed to create driver: %s", err) } + // Handle shutdown signals + stopCh := make(chan os.Signal, 1) + signal.Notify(stopCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-stopCh + klog.Infof("Received signal %s, shutting down...", sig) + drv.Stop() + }() + if err := drv.Run(); err != nil { klog.Fatalln(err) } From eeb1699f3bdd8681bad7d67b1a6a489ad05d32d0 Mon Sep 17 00:00:00 2001 From: Yerzhan Mazhkenov <20302932+yerzhan7@users.noreply.github.com> Date: Tue, 6 May 2025 09:19:30 +0100 Subject: [PATCH 24/24] Address small comments --- .../csicontroller/expectations.go | 4 ++-- .../csicontroller/reconciler.go | 9 ++++---- .../csicontroller/stale_attachment_cleaner.go | 8 ------- pkg/driver/node/mounter/mppod_lock.go | 22 +++++++++++++++++++ pkg/driver/node/mounter/pod_mounter.go | 16 ++++---------- pkg/driver/node/mounter/pod_unmounter.go | 10 +++------ pkg/driver/node/node.go | 2 +- 7 files changed, 37 insertions(+), 34 deletions(-) diff --git a/cmd/aws-s3-csi-controller/csicontroller/expectations.go b/cmd/aws-s3-csi-controller/csicontroller/expectations.go index a444613a..a450c85a 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/expectations.go +++ b/cmd/aws-s3-csi-controller/csicontroller/expectations.go @@ -63,9 +63,9 @@ func deriveExpectationKeyFromFilters(fieldFilters client.MatchingFields) string var sb strings.Builder for _, k := range keys { sb.WriteString(k) - sb.WriteString("=") + sb.WriteRune('=') sb.WriteString(fieldFilters[k]) - sb.WriteString(";") + sb.WriteRune(';') } return sb.String() } diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index ed3e6e42..261c1453 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -18,6 +18,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" "github.com/go-logr/logr" @@ -253,7 +254,7 @@ func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.Persi crdv1beta.FieldAuthenticationSource: authSource, } - if authSource == "pod" { + if authSource == credentialprovider.AuthenticationSourcePod { fieldFilters[crdv1beta.FieldWorkloadNamespace] = workloadPod.Namespace fieldFilters[crdv1beta.FieldWorkloadServiceAccountName] = getServiceAccountName(workloadPod) fieldFilters[crdv1beta.FieldWorkloadServiceAccountIAMRoleARN] = roleArn @@ -267,8 +268,8 @@ func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.Persi func (r *Reconciler) getAuthSource(pv *corev1.PersistentVolume) string { volumeAttributes := mppod.ExtractVolumeAttributes(pv) authSource := volumeAttributes[volumecontext.AuthenticationSource] - if authSource == "" { - return "driver" + if authSource == credentialprovider.AuthenticationSourceUnspecified { + return credentialprovider.AuthenticationSourceDriver } return authSource } @@ -329,7 +330,7 @@ func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *cr } // addWorkloadToS3PodAttachment adds workload UID to the first Mountpoint Pod in the map -// TODO: We will later add extra logic for selecting/creating MPPod if existing MP Pods are using old CSI Driver version or have some "no-new-attachments" annotation +// TODO: We will later add extra logic for selecting/creating MPPod if existing MP Pods are using old CSI Driver version or have some "no-new-attachments" or "needs-unmount" annotation func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { log.Info("Adding workload UID to MountpointS3PodAttachment") diff --git a/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go index 24ffc935..e7f8216b 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go +++ b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go @@ -2,7 +2,6 @@ package csicontroller import ( "context" - "sync" "time" crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" @@ -18,7 +17,6 @@ const ( // StaleAttachmentCleaner handles periodic cleanup of stale workload attachments in case reconciler missed pod deletion event. type StaleAttachmentCleaner struct { reconciler *Reconciler - mutex sync.Mutex } // NewStaleAttachmentCleaner creates a new StaleAttachmentCleaner @@ -48,12 +46,6 @@ func (cm *StaleAttachmentCleaner) Start(ctx context.Context) error { // runCleanup performs cleanup operation func (cm *StaleAttachmentCleaner) runCleanup(ctx context.Context) error { - // Ensure only one cleanup runs at a time - if !cm.mutex.TryLock() { - return nil - } - defer cm.mutex.Unlock() - log := logf.FromContext(ctx) // Get all pods in the cluster diff --git a/pkg/driver/node/mounter/mppod_lock.go b/pkg/driver/node/mounter/mppod_lock.go index 5bb6e5b9..6222221f 100644 --- a/pkg/driver/node/mounter/mppod_lock.go +++ b/pkg/driver/node/mounter/mppod_lock.go @@ -21,6 +21,28 @@ var ( mpPodLocksMutex sync.Mutex ) +// lockMountpointPod acquires a lock for the specified pod UID and returns an unlock function. +// The returned function must be called to release the lock and cleanup resources. +// +// Parameters: +// - uid: The unique identifier of the Mountpoint Pod to lock +// +// Returns: +// - func(): A function that when called will unlock the pod and release associated resources +// +// Usage: +// +// unlock := lockMountpointPod(podUID) +// defer unlock() +func lockMountpointPod(uid string) func() { + mpPodLock := getMPPodLock(uid) + mpPodLock.mutex.Lock() + return func() { + mpPodLock.mutex.Unlock() + releaseMPPodLock(uid) + } +} + // getMPPodLock retrieves or creates a lock for the specified pod UID. // It increments the reference count for existing locks. // The caller is responsible for calling releaseMPPodLock when the lock is no longer needed. diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 100d33cd..0883c803 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -109,12 +109,8 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin mpPodUID := filepath.Base(source) podPath := pm.podPath(mpPodUID) - mpPodLock := getMPPodLock(mpPodUID) - mpPodLock.mutex.Lock() - defer func() { - mpPodLock.mutex.Unlock() - releaseMPPodLock(mpPodUID) - }() + unlockMountpointPod := lockMountpointPod(mpPodUID) + defer unlockMountpointPod() pm.provideCredentials(ctx, podPath, credentialCtx) @@ -139,12 +135,8 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin return fmt.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %w", target, err) } mpPodUID := string(pod.UID) - mpPodLock := getMPPodLock(mpPodUID) - mpPodLock.mutex.Lock() - defer func() { - mpPodLock.mutex.Unlock() - releaseMPPodLock(mpPodUID) - }() + unlockMountpointPod := lockMountpointPod(mpPodUID) + defer unlockMountpointPod() source := filepath.Join(pm.sourceMountDir, mpPodUID) diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go index d1b06da5..eecb3c1f 100644 --- a/pkg/driver/node/mounter/pod_unmounter.go +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -15,7 +15,7 @@ import ( "k8s.io/klog/v2" ) -const cleanupInterval = 10 * time.Second +const cleanupInterval = 2 * time.Minute // PodUnmounter handles unmounting of Mountpoint Pods and cleanup of associated resources type PodUnmounter struct { @@ -64,12 +64,8 @@ func (u *PodUnmounter) HandleMountpointPodUpdate(old, new any) { // mpPod: The Mountpoint Pod to unmount func (u *PodUnmounter) unmountSourceForPod(mpPod *corev1.Pod) { mpPodUID := string(mpPod.UID) - mpPodLock := getMPPodLock(mpPodUID) - mpPodLock.mutex.Lock() - defer func() { - mpPodLock.mutex.Unlock() - releaseMPPodLock(mpPodUID) - }() + unlockMountpointPod := lockMountpointPod(mpPodUID) + defer unlockMountpointPod() klog.Infof("Found Mountpoint Pod %s (UID: %s) with %s annotation, unmounting it", mpPod.Name, mpPodUID, mppod.AnnotationNeedsUnmount) diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index 946d082f..b55df50e 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -265,7 +265,7 @@ func credentialProvideContextFromPublishRequest(req *csi.NodePublishVolumeReques } authSource := credentialprovider.AuthenticationSourceDriver - if volumeCtx[volumecontext.AuthenticationSource] != "" { + if volumeCtx[volumecontext.AuthenticationSource] != credentialprovider.AuthenticationSourceUnspecified { authSource = volumeCtx[volumecontext.AuthenticationSource] }