diff --git a/Makefile b/Makefile index 4dc21ba2..5ad49d28 100644 --- a/Makefile +++ b/Makefile @@ -190,6 +190,23 @@ generate_licenses: download_go_deps clean: rm -rf bin/ && docker system prune +# Generate files for Custom Resources (`zz_generated.deepcopy.go` and CustomResourceDefinition YAML file). +TMP_POD_ATTACHMENT_CRD_FILE ?= "./hack/s3.csi.aws.com_mountpoints3podattachments.yaml" +# Helm CRD file needs extra `{{- if .Values.experimental.podMounter -}}` wrapper, while we don't need it for tests. +HELM_POD_ATTACHMENT_CRD_FILE ?= "./charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml" +TEST_POD_ATTACHMENT_CRD_FILE ?= "./tests/crd/mountpoints3podattachments-crd.yaml" +.PHONY: generate +generate: + controller-gen object:headerFile="hack/boilerplate.go.txt" paths="./pkg/api/..." + controller-gen crd paths="./pkg/api/..." output:crd:dir=./hack/ + echo '# Auto-generated file via `make generate`. Do not edit.' > $(HELM_POD_ATTACHMENT_CRD_FILE) + echo '{{- if .Values.experimental.podMounter -}}' >> $(HELM_POD_ATTACHMENT_CRD_FILE) + cat $(TMP_POD_ATTACHMENT_CRD_FILE) >> $(HELM_POD_ATTACHMENT_CRD_FILE) + echo '{{- end -}}' >> $(HELM_POD_ATTACHMENT_CRD_FILE) + echo '# Auto-generated file via `make generate`. Do not edit.' > $(TEST_POD_ATTACHMENT_CRD_FILE) + cat $(TMP_POD_ATTACHMENT_CRD_FILE) >> $(TEST_POD_ATTACHMENT_CRD_FILE) + rm $(TMP_POD_ATTACHMENT_CRD_FILE) + ## Binaries used in tests. TESTBIN ?= $(shell pwd)/tests/bin diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml new file mode 100644 index 00000000..14c7abac --- /dev/null +++ b/charts/aws-mountpoint-s3-csi-driver/templates/mountpoints3podattachments-crd.yaml @@ -0,0 +1,118 @@ +# Auto-generated file via `make generate`. Do not edit. +{{- if .Values.experimental.podMounter -}} +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.17.3 + name: mountpoints3podattachments.s3.csi.aws.com +spec: + group: s3.csi.aws.com + names: + kind: MountpointS3PodAttachment + listKind: MountpointS3PodAttachmentList + plural: mountpoints3podattachments + shortNames: + - s3pa + singular: mountpoints3podattachment + scope: Cluster + versions: + - name: v1beta + schema: + openAPIV3Schema: + description: MountpointS3PodAttachment is the Schema for the mountpoints3podattachments + API. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: MountpointS3PodAttachmentSpec defines the desired state of + MountpointS3PodAttachment. + properties: + authenticationSource: + description: Authentication source taken from volume attribute field + `authenticationSource`. + type: string + mountOptions: + description: Comma separated mount options taken from volume. + type: string + mountpointS3PodAttachments: + additionalProperties: + items: + description: WorkloadAttachment represents the attachment details + of a workload pod to a Mountpoint S3 pod. + properties: + attachmentTime: + description: AttachmentTime represents when the workload pod + was attached to the Mountpoint S3 pod + format: date-time + type: string + workloadPodUID: + description: WorkloadPodUID is the unique identifier of the + attached workload pod + type: string + required: + - attachmentTime + - workloadPodUID + type: object + type: array + description: Maps each Mountpoint S3 pod name to its workload attachments + type: object + nodeName: + description: Name of the node. + type: string + persistentVolumeName: + description: Name of the Persistent Volume. + type: string + volumeID: + description: Volume ID. + type: string + workloadFSGroup: + description: Workload pod's `fsGroup` from pod security context + type: string + workloadNamespace: + description: 'Workload pod''s namespace. Exists only if `authenticationSource: + pod`.' + type: string + workloadServiceAccountIAMRoleARN: + description: 'EKS IAM Role ARN from workload pod''s service account + annotation (IRSA). Exists only if `authenticationSource: pod` and + service account has `eks.amazonaws.com/role-arn` annotation.' + type: string + workloadServiceAccountName: + description: 'Workload pod''s service account name. Exists only if + `authenticationSource: pod`.' + type: string + required: + - authenticationSource + - mountOptions + - mountpointS3PodAttachments + - nodeName + - persistentVolumeName + - volumeID + - workloadFSGroup + type: object + type: object + selectableFields: + - jsonPath: .spec.nodeName + served: true + storage: true + subresources: + status: {} +{{- end -}} diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml index c5469692..8d972b36 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-controller.yaml @@ -23,7 +23,7 @@ metadata: rules: - apiGroups: [""] resources: ["pods"] - verbs: ["get", "create", "watch", "delete", "list"] + verbs: ["get", "create", "watch", "delete", "list", "update"] --- kind: RoleBinding apiVersion: rbac.authorization.k8s.io/v1 @@ -49,8 +49,11 @@ metadata: {{- include "aws-mountpoint-s3-csi-driver.labels" . | nindent 4 }} rules: - apiGroups: [""] - resources: ["pods", "persistentvolumeclaims", "persistentvolumes"] + resources: ["pods", "persistentvolumeclaims", "persistentvolumes", "serviceaccounts"] verbs: ["get", "watch", "list"] + - apiGroups: ["s3.csi.aws.com"] + resources: ["mountpoints3podattachments"] + verbs: ["create", "delete", "update", "get", "watch", "list"] --- kind: ClusterRoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml index 9e6baa8e..5c68a202 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/serviceaccount-csi-node.yaml @@ -23,8 +23,11 @@ metadata: app.kubernetes.io/name: aws-mountpoint-s3-csi-driver rules: - apiGroups: [""] - resources: ["serviceaccounts"] + resources: ["serviceaccounts"] # TODO: Remove once we stop supporting systemd mounts because in PodMounter we get IRSA Role ARN from MountpointS3PodAttachment verbs: ["get"] + - apiGroups: ["s3.csi.aws.com"] + resources: ["mountpoints3podattachments"] + verbs: ["get", "list", "watch"] --- kind: ClusterRoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/cmd/aws-s3-csi-controller/csicontroller/expectations.go b/cmd/aws-s3-csi-controller/csicontroller/expectations.go new file mode 100644 index 00000000..a450c85a --- /dev/null +++ b/cmd/aws-s3-csi-controller/csicontroller/expectations.go @@ -0,0 +1,71 @@ +package csicontroller + +import ( + "sort" + "strings" + "sync" + + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// expectations is a structure that manages pending expectations for Kubernetes resources. +// It uses field filters as keys to track resources that are expected to be created +// helping with eventual consistency, reducing unnecessary processing and API server load. +type expectations struct { + pending sync.Map +} + +// newExpectations creates and returns a new Expectations instance. +func newExpectations() *expectations { + return &expectations{} +} + +// setPending marks a resource as pending based on the given field filters. +// This is typically used when a create operation is initiated. +func (e *expectations) setPending(fieldFilters client.MatchingFields) { + key := deriveExpectationKeyFromFilters(fieldFilters) + e.pending.Store(key, struct{}{}) +} + +// isPending checks if a resource is marked as pending based on the given field filters. +// Returns true if the resource is pending, false otherwise. +func (e *expectations) isPending(fieldFilters client.MatchingFields) bool { + key := deriveExpectationKeyFromFilters(fieldFilters) + _, ok := e.pending.Load(key) + return ok +} + +// clear removes the pending mark for a resource based on the given field filters. +// This is typically called when an expected operation has been confirmed as completed. +func (e *expectations) clear(fieldFilters client.MatchingFields) { + key := deriveExpectationKeyFromFilters(fieldFilters) + e.pending.Delete(key) +} + +// deriveExpectationKeyFromFilters generates a deterministic string key from a map of field filters. +// It creates a consistent string representation of the filters by: +// 1. Sorting the filter keys alphabetically +// 2. Concatenating each key-value pair in the format "key=value;" +// +// For example, given filters {"foo": "bar", "baz": "qux"}, it will produce "baz=qux;foo=bar;" +// +// Parameters: +// - fieldFilters: A map of field names to their values used for filtering Kubernetes resources +// +// Returns: +// - A string that uniquely represents the combination of field filters +func deriveExpectationKeyFromFilters(fieldFilters client.MatchingFields) string { + keys := make([]string, 0, len(fieldFilters)) + for k := range fieldFilters { + keys = append(keys, k) + } + sort.Strings(keys) + var sb strings.Builder + for _, k := range keys { + sb.WriteString(k) + sb.WriteRune('=') + sb.WriteString(fieldFilters[k]) + sb.WriteRune(';') + } + return sb.String() +} diff --git a/cmd/aws-s3-csi-controller/csicontroller/expectations_test.go b/cmd/aws-s3-csi-controller/csicontroller/expectations_test.go new file mode 100644 index 00000000..73f6b4f7 --- /dev/null +++ b/cmd/aws-s3-csi-controller/csicontroller/expectations_test.go @@ -0,0 +1,105 @@ +package csicontroller + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func TestDeriveExpectationKeyFromFilters(t *testing.T) { + tests := []struct { + name string + fieldFilters client.MatchingFields + want string + }{ + { + name: "empty filters", + fieldFilters: client.MatchingFields{}, + want: "", + }, + { + name: "single filter", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + want: "key1=value1;", + }, + { + name: "multiple filters", + fieldFilters: client.MatchingFields{ + "key2": "value2", + "key1": "value1", + "key3": "value3", + }, + want: "key1=value1;key2=value2;key3=value3;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := deriveExpectationKeyFromFilters(tt.fieldFilters) + assert.Equals(t, tt.want, got) + }) + } +} + +func TestExpectations(t *testing.T) { + tests := []struct { + name string + fieldFilters client.MatchingFields + operations func(*expectations) + wantPending bool + }{ + { + name: "set and check pending", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + operations: func(e *expectations) { + e.setPending(client.MatchingFields{"key1": "value1"}) + }, + wantPending: true, + }, + { + name: "set and clear pending", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + operations: func(e *expectations) { + e.setPending(client.MatchingFields{"key1": "value1"}) + e.clear(client.MatchingFields{"key1": "value1"}) + }, + wantPending: false, + }, + { + name: "check non-existent pending", + fieldFilters: client.MatchingFields{ + "key1": "value1", + }, + operations: func(e *expectations) {}, + wantPending: false, + }, + { + name: "multiple operations", + fieldFilters: client.MatchingFields{ + "key1": "value1", + "key2": "value2", + }, + operations: func(e *expectations) { + e.setPending(client.MatchingFields{"key1": "value1", "key2": "value2"}) + e.clear(client.MatchingFields{"key1": "value1"}) // Different key, shouldn't affect the test + }, + wantPending: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := newExpectations() + tt.operations(e) + got := e.isPending(tt.fieldFilters) + assert.Equals(t, tt.wantPending, got) + }) + } +} diff --git a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go index 158ed320..261c1453 100644 --- a/cmd/aws-s3-csi-controller/csicontroller/reconciler.go +++ b/cmd/aws-s3-csi-controller/csicontroller/reconciler.go @@ -4,26 +4,46 @@ import ( "context" "errors" "fmt" + "strconv" + "strings" + "time" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/go-logr/logr" ) const debugLevel = 4 const mountpointCSIDriverName = "s3.csi.aws.com" +const defaultServiceAccount = "default" + +const ( + AnnotationServiceAccountRole = "eks.amazonaws.com/role-arn" + LabelCSIDriverVersion = "s3.csi.aws.com/created-by-csi-driver-version" +) + +const ( + Requeue = true + DontRequeue = false +) // A Reconciler reconciles Mountpoint Pods by watching other workload Pods thats using S3 CSI Driver. type Reconciler struct { mountpointPodConfig mppod.Config mountpointPodCreator *mppod.Creator + s3paExpectations *expectations client.Client } @@ -31,7 +51,7 @@ type Reconciler struct { // NewReconciler returns a new reconciler created from `client` and `podConfig`. func NewReconciler(client client.Client, podConfig mppod.Config) *Reconciler { creator := mppod.NewCreator(podConfig) - return &Reconciler{Client: client, mountpointPodConfig: podConfig, mountpointPodCreator: creator} + return &Reconciler{Client: client, mountpointPodConfig: podConfig, mountpointPodCreator: creator, s3paExpectations: newExpectations()} } // SetupWithManager configures reconciler to run with given `mgr`. @@ -139,7 +159,8 @@ func (r *Reconciler) reconcileWorkloadPod(ctx context.Context, pod *corev1.Pod) log.V(debugLevel).Info("Found bound PV for PVC", "pvc", pvc.Name, "volumeName", pv.Name) - err = r.spawnOrDeleteMountpointPodIfNeeded(ctx, pod, pvc, pv, csiSpec) + needsRequeue, err := r.spawnOrDeleteMountpointPodIfNeeded(ctx, pod, pvc, pv, csiSpec) + requeue = requeue || needsRequeue if err != nil { errs = append(errs, err) continue @@ -163,56 +184,323 @@ func (r *Reconciler) spawnOrDeleteMountpointPodIfNeeded( pvc *corev1.PersistentVolumeClaim, pv *corev1.PersistentVolume, csiSpec *corev1.CSIPersistentVolumeSource, -) error { - mpPodName := mppod.MountpointPodNameFor(string(workloadPod.UID), pvc.Spec.VolumeName) +) (bool, error) { + workloadUID := string(workloadPod.UID) + roleArn, err := r.findIRSAServiceAccountRole(ctx, workloadPod) + if err != nil { + return Requeue, err + } + fieldFilters := r.buildFieldFilters(workloadPod, pv, roleArn) + s3pa, err := r.getExistingS3PodAttachment(ctx, fieldFilters) + if err != nil { + return Requeue, err + } + log := r.setupLogger(ctx, workloadPod, pvc, workloadUID, fieldFilters, s3pa) - log := logf.FromContext(ctx).WithValues( + if !isPodActive(workloadPod) { + return r.handleInactivePod(ctx, s3pa, workloadUID, log) + } + + if s3pa != nil { + return r.handleExistingS3PodAttachment(ctx, s3pa, workloadUID, fieldFilters, log) + } else { + return r.handleNewS3PodAttachment(ctx, workloadPod, pv, roleArn, fieldFilters, log) + } +} + +// setupLogger creates and configures logger that includes pod namespace/name, PVC name, and workload UID fields. +// If an S3PodAttachment is provided, its name is added. All fieldFilters are appended as additional key-value pairs. +func (r *Reconciler) setupLogger( + ctx context.Context, + workloadPod *corev1.Pod, + pvc *corev1.PersistentVolumeClaim, + workloadUID string, + fieldFilters client.MatchingFields, + s3pa *crdv1beta.MountpointS3PodAttachment, +) logr.Logger { + logger := logf.FromContext(ctx).WithValues( "workloadPod", types.NamespacedName{Namespace: workloadPod.Namespace, Name: workloadPod.Name}, - "mountpointPod", mpPodName, - "pvc", pvc.Name, "volumeName", pv.Name) + "pvc", pvc.Name, + "workloadUID", workloadUID, + ) - mpPod := &corev1.Pod{} - err := r.Get(ctx, types.NamespacedName{Namespace: r.mountpointPodConfig.Namespace, Name: mpPodName}, mpPod) - if err != nil && !apierrors.IsNotFound(err) { - log.Error(err, "Failed to get Mountpoint Pod") - return err + if s3pa != nil { + logger = logger.WithValues("s3pa", s3pa.Name) } - isMountpointPodExists := err == nil + var keyValues []interface{} + for k, v := range fieldFilters { + keyValues = append(keyValues, k, v) + } - // `workloadPod` is not active, its either terminated (i.e., `phase == Succeeded or phase == Failed`) or - // its scheduled for termination (i.e., `DeletionTimestamp != nil`) - if !isPodActive(workloadPod) { - // if its scheduled for termination and its still in `Pending` phase, - // delete if there is an existing Mountpoint Pod as otherwise this - // Mountpoint Pod might take some time to terminate on its own. - if isMountpointPodExists && workloadPod.Status.Phase == corev1.PodPending { - log.Info("Deleting scheduled Mountpoint Pod") - err := r.deleteMountpointPod(ctx, mpPod) + if len(keyValues) > 0 { + logger = logger.WithValues(keyValues...) + } + + return logger +} + +// buildFieldFilters build appropriate matching field filters for List operation on MountpointS3PodAttachments +func (r *Reconciler) buildFieldFilters(workloadPod *corev1.Pod, pv *corev1.PersistentVolume, roleArn string) client.MatchingFields { + authSource := r.getAuthSource(pv) + fsGroup := r.getFSGroup(workloadPod) + + fieldFilters := client.MatchingFields{ + crdv1beta.FieldNodeName: workloadPod.Spec.NodeName, + crdv1beta.FieldPersistentVolumeName: pv.Name, + crdv1beta.FieldVolumeID: pv.Spec.CSI.VolumeHandle, + crdv1beta.FieldMountOptions: strings.Join(pv.Spec.MountOptions, ","), + crdv1beta.FieldWorkloadFSGroup: fsGroup, + crdv1beta.FieldAuthenticationSource: authSource, + } + + if authSource == credentialprovider.AuthenticationSourcePod { + fieldFilters[crdv1beta.FieldWorkloadNamespace] = workloadPod.Namespace + fieldFilters[crdv1beta.FieldWorkloadServiceAccountName] = getServiceAccountName(workloadPod) + fieldFilters[crdv1beta.FieldWorkloadServiceAccountIAMRoleARN] = roleArn + } + + return fieldFilters +} + +// getAuthSource returns authentication source from given PV. +// Defaults to `driver` if `authenticationSource` is not found in volume attributes. +func (r *Reconciler) getAuthSource(pv *corev1.PersistentVolume) string { + volumeAttributes := mppod.ExtractVolumeAttributes(pv) + authSource := volumeAttributes[volumecontext.AuthenticationSource] + if authSource == credentialprovider.AuthenticationSourceUnspecified { + return credentialprovider.AuthenticationSourceDriver + } + return authSource +} + +// getFSGroup returns the FSGroup value from the pod's security context as a string. +// If FSGroup is not set, it returns an empty string. +func (r *Reconciler) getFSGroup(workloadPod *corev1.Pod) string { + if workloadPod.Spec.SecurityContext.FSGroup != nil { + return strconv.FormatInt(*workloadPod.Spec.SecurityContext.FSGroup, 10) + } + return "" +} + +// getExistingS3PodAttachment retrieves a MountpointS3PodAttachment resource that matches the provided field filters. +// It returns: +// - The matching MountpointS3PodAttachment if exactly one is found +// - nil if no matching resource is found +// - An error if multiple matching resources are found or if the list operation fails +func (r *Reconciler) getExistingS3PodAttachment(ctx context.Context, fieldFilters client.MatchingFields) (*crdv1beta.MountpointS3PodAttachment, error) { + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} + if err := r.List(ctx, s3paList, fieldFilters); err != nil { + return nil, fmt.Errorf("failed to list MountpointS3PodAttachments: %w", err) + } + + switch len(s3paList.Items) { + case 0: + return nil, nil + case 1: + return &s3paList.Items[0], nil + default: + return nil, fmt.Errorf("found %d MountpointS3PodAttachments when expecting 0 or 1", len(s3paList.Items)) + } +} + +// handleInactivePod handles inactive workload pod. +func (r *Reconciler) handleInactivePod(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { + if s3pa == nil { + log.Info("Workload pod is not active. Did not find any MountpointS3PodAttachments.") + return DontRequeue, nil + } + + return r.removeWorkloadFromS3PodAttachment(ctx, s3pa, workloadUID, log) +} + +// handleExistingS3PodAttachment handles existing S3 Pod Attachment. +func (r *Reconciler) handleExistingS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, fieldFilters client.MatchingFields, log logr.Logger) (bool, error) { + if r.s3paExpectations.isPending(fieldFilters) { + log.Info("MountpointS3PodAttachment creation is pending, removing from pending") + r.s3paExpectations.clear(fieldFilters) + } + + if s3paContainsWorkload(s3pa, workloadUID) { + log.Info("MountpointS3PodAttachment already has this workload UID") + return DontRequeue, nil + } + + return r.addWorkloadToS3PodAttachment(ctx, s3pa, workloadUID, log) +} + +// addWorkloadToS3PodAttachment adds workload UID to the first Mountpoint Pod in the map +// TODO: We will later add extra logic for selecting/creating MPPod if existing MP Pods are using old CSI Driver version or have some "no-new-attachments" or "needs-unmount" annotation +func (r *Reconciler) addWorkloadToS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { + log.Info("Adding workload UID to MountpointS3PodAttachment") + + for key := range s3pa.Spec.MountpointS3PodAttachments { + s3pa.Spec.MountpointS3PodAttachments[key] = append(s3pa.Spec.MountpointS3PodAttachments[key], crdv1beta.WorkloadAttachment{ + WorkloadPodUID: workloadUID, + AttachmentTime: metav1.NewTime(time.Now().UTC()), + }) + break + } + + err := r.Update(ctx, s3pa) + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to update MountpointS3PodAttachment - resource conflict - requeue") + return Requeue, nil + } + log.Error(err, "Failed to update MountpointS3PodAttachment") + return Requeue, err + } + + return DontRequeue, nil +} + +// removeWorkloadFromS3PodAttachment removes workload UID from MountpointS3PodAttachment map. +// It will delete MountpointS3PodAttachment if map becomes empty. +func (r *Reconciler) removeWorkloadFromS3PodAttachment(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string, log logr.Logger) (bool, error) { + // Remove workload UID from mountpoint pods + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + filteredUIDs := []crdv1beta.WorkloadAttachment{} + found := false + for _, attachment := range attachments { + if attachment.WorkloadPodUID == workloadUID { + found = true + continue + } + filteredUIDs = append(filteredUIDs, attachment) + } + if found { + s3pa.Spec.MountpointS3PodAttachments[mpPodName] = filteredUIDs + err := r.Update(ctx, s3pa) + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to remove workload pod UID from existing MountpointS3PodAttachment due to resource conflict, requeueing") + return Requeue, nil + } + log.Error(err, "Failed to update MountpointS3PodAttachment") + return Requeue, err + } + log.Info("Successfully removed workload pod UID from MountpointS3PodAttachment") + break + } + } + + // Remove Mountpoint pods with zero workloads + for mpPodName, uids := range s3pa.Spec.MountpointS3PodAttachments { + if len(uids) == 0 { + log.Info("Mountpoint pod has zero workload UIDs. Adding "+mppod.AnnotationNeedsUnmount+" annotation", + "mountpointPodName", mpPodName) + err := r.addNeedsUnmountAnnotation(ctx, mpPodName, log) if err != nil { - log.Error(err, "Failed to delete scheduled Mountpoint Pod") - return err + return Requeue, err } - log.Info("Scheduled Mountpoint Pod deleted") - return err + log.Info("Mountpoint pod has zero workload UIDs. Will remove it from MountpointS3PodAttachment", + "mountpointPodName", mpPodName) + delete(s3pa.Spec.MountpointS3PodAttachments, mpPodName) + err = r.Update(ctx, s3pa) + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to remove Mountpoint pod from MountpointS3PodAttachment due to resource conflict, requeueing", + "mountpointPodName", mpPodName) + return Requeue, nil + } + log.Error(err, "Failed to update MountpointS3PodAttachment") + return Requeue, err + } } + } - // No need to do anything - either there was no Mountpoint Pod for `pod` or it was in `Running` state, - // so a clean unmount operation will be performed and Mountpoint Pod will cleany exit (and get deleted by `reconcileMountpointPod`). - return nil + // Delete MountpointS3PodAttachment if map is empty + if len(s3pa.Spec.MountpointS3PodAttachments) == 0 { + log.Info("MountpointS3PodAttachment has zero Mountpoint Pods. Will delete it") + err := r.Delete(ctx, s3pa) + if err != nil { + if apierrors.IsConflict(err) { + log.Info("Failed to delete MountpointS3PodAttachment due to resource conflict, requeueing") + return Requeue, nil + } + log.Error(err, "Failed to delete MountpointS3PodAttachment") + return Requeue, err + } } - if isMountpointPodExists { - log.V(debugLevel).Info("Mountpoint Pod already exists - ignoring") - return nil + return DontRequeue, nil +} + +// handleNewS3PodAttachment handles new S3 pod attachment in case none were found. +func (r *Reconciler) handleNewS3PodAttachment( + ctx context.Context, + workloadPod *corev1.Pod, + pv *corev1.PersistentVolume, + roleArn string, + fieldFilters client.MatchingFields, + log logr.Logger, +) (bool, error) { + if r.s3paExpectations.isPending(fieldFilters) { + log.Info("MountpointS3PodAttachment creation is pending, requeueing") + return Requeue, nil } - if err := r.spawnMountpointPod(ctx, workloadPod, pvc, pv, csiSpec, mpPodName); err != nil { + if err := r.createS3PodAttachmentWithMPPod(ctx, workloadPod, pv, roleArn, log); err != nil { + return Requeue, err + } + + r.s3paExpectations.setPending(fieldFilters) + return Requeue, nil +} + +// createS3PodAttachmentWithMPPod creates new MountpointS3PodAttachment resource and Mountpoint Pod for given workload and PV. +func (r *Reconciler) createS3PodAttachmentWithMPPod( + ctx context.Context, + workloadPod *corev1.Pod, + pv *corev1.PersistentVolume, + roleArn string, + log logr.Logger, +) error { + authSource := r.getAuthSource(pv) + mpPod, err := r.spawnMountpointPod(ctx, workloadPod, pv, log) + if err != nil { log.Error(err, "Failed to spawn Mountpoint Pod") return err } + s3pa := &crdv1beta.MountpointS3PodAttachment{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "s3pa-", + Labels: map[string]string{ + LabelCSIDriverVersion: r.mountpointPodConfig.CSIDriverVersion, + }, + }, + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: workloadPod.Spec.NodeName, + PersistentVolumeName: pv.Name, + VolumeID: pv.Spec.CSI.VolumeHandle, + MountOptions: strings.Join(pv.Spec.MountOptions, ","), + WorkloadFSGroup: r.getFSGroup(workloadPod), + AuthenticationSource: authSource, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + mpPod.Name: {{WorkloadPodUID: string(workloadPod.UID), AttachmentTime: metav1.NewTime(time.Now().UTC())}}, + }, + }, + } + if authSource == "pod" { + s3pa.Spec.WorkloadNamespace = workloadPod.Namespace + s3pa.Spec.WorkloadServiceAccountName = getServiceAccountName(workloadPod) + s3pa.Spec.WorkloadServiceAccountIAMRoleARN = roleArn + } + + err = r.Create(ctx, s3pa) + if err != nil { + log.Error(err, "Failed to create MountpointS3PodAttachment") + if deleteErr := r.Delete(ctx, mpPod); deleteErr != nil { + log.Error(deleteErr, "Failed to cleanup Mountpoint Pod after MountpointS3PodAttachment creation failure", "mountpointPodName", mpPod.Name) + } else { + log.Info("Successfully cleaned up Mountpoint Pod after S3PodAttachment creation failure", "mountpointPodName", mpPod.Name) + } + return err + } + log.Info("MountpointS3PodAttachment is created", "s3pa", s3pa.Name) return nil } @@ -222,38 +510,24 @@ func (r *Reconciler) spawnOrDeleteMountpointPodIfNeeded( func (r *Reconciler) spawnMountpointPod( ctx context.Context, workloadPod *corev1.Pod, - pvc *corev1.PersistentVolumeClaim, pv *corev1.PersistentVolume, - _ *corev1.CSIPersistentVolumeSource, - name string, -) error { - log := logf.FromContext(ctx).WithValues( - "workloadPod", types.NamespacedName{Namespace: workloadPod.Namespace, Name: workloadPod.Name}, - "mountpointPod", name, - "pvc", pvc.Name, "volumeName", pv.Name) - + log logr.Logger, +) (*corev1.Pod, error) { log.Info("Spawning Mountpoint Pod") - mpPod, err := r.mountpointPodCreator.Create(workloadPod, pv) + mpPod, err := r.mountpointPodCreator.Create(workloadPod.Spec.NodeName, pv) if err != nil { log.Error(err, "Failed to create Mountpoint Pod Spec") - return err - } - - if mpPod.Name != name { - err := fmt.Errorf("Mountpoint Pod name mismatch %s vs %s", mpPod.Name, name) - log.Error(err, "Name mismatch on Mountpoint Pod") - return err + return nil, err } err = r.Create(ctx, mpPod) if err != nil { - log.Error(err, "Failed to create Mountpoint Pod") - return err + return nil, err } - log.Info("Mountpoint Pod spawned", "mountpointPodUID", mpPod.UID) - return nil + log.Info("Mountpoint Pod spawned", "mountpointPodName", mpPod.Name) + return mpPod, nil } // deleteMountpointPod deletes given `mountpointPod`. @@ -319,6 +593,56 @@ func (r *Reconciler) getBoundPVForPodClaim( return pvc, pv, nil } +// findIRSAServiceAccountRole retrieves the IAM role ARN associated with a pod's service account +// through IRSA (IAM Roles for Service Accounts) annotation ("eks.amazonaws.com/role-arn"). +// +// Parameters: +// - ctx: Context for the request +// - pod: The Kubernetes pod whose service account role should be retrieved +// +// Returns: +// - string: The IAM role ARN from the service account's annotation, empty string if not found +// - error: Error if the service account cannot be retrieved +func (r *Reconciler) findIRSAServiceAccountRole(ctx context.Context, pod *corev1.Pod) (string, error) { + sa := &corev1.ServiceAccount{} + err := r.Get(ctx, types.NamespacedName{Namespace: pod.Namespace, Name: getServiceAccountName(pod)}, sa) + if err != nil { + return "", fmt.Errorf("Failed to find workload pod's service account %s", getServiceAccountName(pod)) + } + + return sa.Annotations[AnnotationServiceAccountRole], nil +} + +// addNeedsUnmountAnnotation add "s3.csi.aws.com/needs-unmount" to Mountpoint Pod. +// This will trigger CSI Driver Node to cleanly unmount and Mountpoint Pod will become 'Succeeded'. +func (r *Reconciler) addNeedsUnmountAnnotation(ctx context.Context, mpPodName string, log logr.Logger) error { + // Get the pod + mpPod := &corev1.Pod{} + err := r.Get(ctx, types.NamespacedName{Namespace: r.mountpointPodConfig.Namespace, Name: mpPodName}, mpPod) + if err != nil { + if apierrors.IsNotFound(err) { + log.Info("Failed to find Mountpoint Pod - ignoring") + return nil + } + log.Error(err, "Failed to get Pod") + return err + } + + if mpPod.Annotations == nil { + mpPod.Annotations = make(map[string]string) + } + mpPod.Annotations[mppod.AnnotationNeedsUnmount] = "true" + + // Update the pod + err = r.Update(ctx, mpPod) + if err != nil { + log.Error(err, "Failed to update Mountpoint Pod") + return err + } + + return nil +} + // isMountpointPod returns whether given `pod` is a Mountpoint Pod. // It currently checks namespace of `pod`. func (r *Reconciler) isMountpointPod(pod *corev1.Pod) bool { @@ -343,3 +667,23 @@ func isPodActive(p *corev1.Pod) bool { corev1.PodFailed != p.Status.Phase && p.DeletionTimestamp == nil } + +// s3paContainsWorkload checks whether MountpointS3PodAttachment has `workloadUID` in it. +func s3paContainsWorkload(s3pa *crdv1beta.MountpointS3PodAttachment, workloadUID string) bool { + for _, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == workloadUID { + return true + } + } + } + return false +} + +// getServiceAccountName returns the pod's service account name or "default" if not specified +func getServiceAccountName(pod *corev1.Pod) string { + if pod.Spec.ServiceAccountName != "" { + return pod.Spec.ServiceAccountName + } + return defaultServiceAccount +} diff --git a/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go new file mode 100644 index 00000000..e7f8216b --- /dev/null +++ b/cmd/aws-s3-csi-controller/csicontroller/stale_attachment_cleaner.go @@ -0,0 +1,130 @@ +package csicontroller + +import ( + "context" + "time" + + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + corev1 "k8s.io/api/core/v1" + logf "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + cleanupInterval = 10 * time.Second + staleAttachmentThreshold = 10 * time.Second +) + +// StaleAttachmentCleaner handles periodic cleanup of stale workload attachments in case reconciler missed pod deletion event. +type StaleAttachmentCleaner struct { + reconciler *Reconciler +} + +// NewStaleAttachmentCleaner creates a new StaleAttachmentCleaner +func NewStaleAttachmentCleaner(reconciler *Reconciler) *StaleAttachmentCleaner { + return &StaleAttachmentCleaner{ + reconciler: reconciler, + } +} + +// Start begins the periodic cleanup process +func (cm *StaleAttachmentCleaner) Start(ctx context.Context) error { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + if err := cm.runCleanup(ctx); err != nil { + log := logf.FromContext(ctx) + log.Error(err, "Failed to run cleanup") + } + } + } +} + +// runCleanup performs cleanup operation +func (cm *StaleAttachmentCleaner) runCleanup(ctx context.Context) error { + log := logf.FromContext(ctx) + + // Get all pods in the cluster + podList := &corev1.PodList{} + if err := cm.reconciler.List(ctx, podList); err != nil { + return err + } + + // Create a map of existing pod UIDs for quick lookup + existingPods := make(map[string]struct{}) + for _, pod := range podList.Items { + existingPods[string(pod.UID)] = struct{}{} + } + + // Get all MountpointS3PodAttachments + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} + if err := cm.reconciler.List(ctx, s3paList); err != nil { + return err + } + + // Check each S3PodAttachment for stale workload references + for _, s3pa := range s3paList.Items { + if err := cm.cleanupStaleWorkloads(ctx, &s3pa, existingPods); err != nil { + log.Error(err, "Error cleaning up S3PodAttachment", "s3pa", s3pa.Name) + continue + } + } + + return nil +} + +// cleanupStaleWorkloads removes stale workload references from a single S3PodAttachment. +// A workload reference is considered stale if the referenced Pod no longer exists in the cluster +// and the attachment is older than staleAttachmentThreshold (this is to avoid race condition with reconciler). +// If a Mountpoint Pod has zero attachments after cleanup, "s3.csi.aws.com/needs-unmount" annotation is added and its entry in S3PodAttachment is deleted. +// If S3PodAttachment has no remaining Mountpoint Pods, the entire S3PodAttachment is deleted. +func (cm *StaleAttachmentCleaner) cleanupStaleWorkloads(ctx context.Context, s3pa *crdv1beta.MountpointS3PodAttachment, existingPods map[string]struct{}) error { + log := logf.FromContext(ctx).WithValues("s3pa", s3pa.Name) + modified := false + + now := time.Now().UTC() + + // Check each mountpoint pod's attachments + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + var validAttachments []crdv1beta.WorkloadAttachment + + for _, attachment := range attachments { + // Check if pod exists and attachment is not too new + _, exists := existingPods[attachment.WorkloadPodUID] + isStale := now.Sub(attachment.AttachmentTime.Time) > staleAttachmentThreshold + + if exists || !isStale { + validAttachments = append(validAttachments, attachment) + } else { + modified = true + log.Info("Removing stale workload reference", + "workloadUID", attachment.WorkloadPodUID, + "mountpointPod", mpPodName, + "attachmentAge", now.Sub(attachment.AttachmentTime.Time)) + } + } + + if len(validAttachments) == 0 { + if err := cm.reconciler.addNeedsUnmountAnnotation(ctx, mpPodName, log); err != nil { + return err + } + delete(s3pa.Spec.MountpointS3PodAttachments, mpPodName) + } else { + s3pa.Spec.MountpointS3PodAttachments[mpPodName] = validAttachments + } + } + + // Update the S3PodAttachment if modified + if modified { + if len(s3pa.Spec.MountpointS3PodAttachments) == 0 { + return cm.reconciler.Delete(ctx, s3pa) + } + return cm.reconciler.Update(ctx, s3pa) + } + + return nil +} diff --git a/cmd/aws-s3-csi-controller/main.go b/cmd/aws-s3-csi-controller/main.go index 31c46f00..52d0f177 100644 --- a/cmd/aws-s3-csi-controller/main.go +++ b/cmd/aws-s3-csi-controller/main.go @@ -11,6 +11,9 @@ import ( "os" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/config" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -18,6 +21,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager/signals" "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" @@ -30,21 +34,37 @@ var mountpointImage = flag.String("mountpoint-image", os.Getenv("MOUNTPOINT_IMAG var mountpointImagePullPolicy = flag.String("mountpoint-image-pull-policy", os.Getenv("MOUNTPOINT_IMAGE_PULL_POLICY"), "Pull policy of Mountpoint images.") var mountpointContainerCommand = flag.String("mountpoint-container-command", "/bin/aws-s3-csi-mounter", "Entrypoint command of the Mountpoint Pods.") +var ( + scheme = runtime.NewScheme() +) + +func init() { + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(crdv1beta.AddToScheme(scheme)) +} + func main() { flag.Parse() logf.SetLogger(zap.New()) log := logf.Log.WithName(csicontroller.Name) - client := config.GetConfigOrDie() + conf := config.GetConfigOrDie() - mgr, err := manager.New(client, manager.Options{}) + mgr, err := manager.New(conf, manager.Options{ + Scheme: scheme, + }) if err != nil { log.Error(err, "Failed to create a new manager") os.Exit(1) } - err = csicontroller.NewReconciler(mgr.GetClient(), mppod.Config{ + if err := crdv1beta.SetupManagerIndices(mgr); err != nil { + log.Error(err, "Failed to setup field indexers") + os.Exit(1) + } + + reconciler := csicontroller.NewReconciler(mgr.GetClient(), mppod.Config{ Namespace: *mountpointNamespace, MountpointVersion: *mountpointVersion, PriorityClassName: *mountpointPriorityClassName, @@ -54,13 +74,19 @@ func main() { ImagePullPolicy: corev1.PullPolicy(*mountpointImagePullPolicy), }, CSIDriverVersion: version.GetVersion().DriverVersion, - ClusterVariant: cluster.DetectVariant(client, log), - }).SetupWithManager(mgr) - if err != nil { + ClusterVariant: cluster.DetectVariant(conf, log), + }) + + if err := reconciler.SetupWithManager(mgr); err != nil { log.Error(err, "Failed to create controller") os.Exit(1) } + if err := mgr.Add(csicontroller.NewStaleAttachmentCleaner(reconciler)); err != nil { + log.Error(err, "Failed to add stale attachment cleaner to manager") + os.Exit(1) + } + if err := mgr.Start(signals.SetupSignalHandler()); err != nil { log.Error(err, "Failed to start manager") os.Exit(1) diff --git a/cmd/aws-s3-csi-driver/main.go b/cmd/aws-s3-csi-driver/main.go index 4bcd9e0a..ce720e2f 100644 --- a/cmd/aws-s3-csi-driver/main.go +++ b/cmd/aws-s3-csi-driver/main.go @@ -21,6 +21,8 @@ import ( "flag" "fmt" "os" + "os/signal" + "syscall" "github.com/awslabs/aws-s3-csi-driver/pkg/driver" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" @@ -70,6 +72,15 @@ func main() { klog.Fatalf("failed to create driver: %s", err) } + // Handle shutdown signals + stopCh := make(chan os.Signal, 1) + signal.Notify(stopCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-stopCh + klog.Infof("Received signal %s, shutting down...", sig) + drv.Stop() + }() + if err := drv.Run(); err != nil { klog.Fatalln(err) } diff --git a/hack/boilerplate.go.txt b/hack/boilerplate.go.txt new file mode 100644 index 00000000..e69de29b diff --git a/pkg/api/v1beta/groupversion_info.go b/pkg/api/v1beta/groupversion_info.go new file mode 100644 index 00000000..1efe5dfc --- /dev/null +++ b/pkg/api/v1beta/groupversion_info.go @@ -0,0 +1,20 @@ +// Package v1beta contains API Schema definitions for the s3.csi.aws.com v1beta API group. +// +kubebuilder:object:generate=true +// +groupName=s3.csi.aws.com +package v1beta + +import ( + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/scheme" +) + +var ( + // GroupVersion is group version used to register these objects. + GroupVersion = schema.GroupVersion{Group: "s3.csi.aws.com", Version: "v1beta"} + + // SchemeBuilder is used to add go types to the GroupVersionKind scheme. + SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion} + + // AddToScheme adds the types in this group-version to the given scheme. + AddToScheme = SchemeBuilder.AddToScheme +) diff --git a/pkg/api/v1beta/mountpoints3podattachment_indexer.go b/pkg/api/v1beta/mountpoints3podattachment_indexer.go new file mode 100644 index 00000000..12dc19a3 --- /dev/null +++ b/pkg/api/v1beta/mountpoints3podattachment_indexer.go @@ -0,0 +1,67 @@ +package v1beta + +import ( + "context" + "fmt" + + ctrlcache "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" +) + +// getIndexFields returns the set of field extractors +func getIndexFields() map[string]func(*MountpointS3PodAttachment) string { + return map[string]func(*MountpointS3PodAttachment) string{ + FieldNodeName: func(cr *MountpointS3PodAttachment) string { return cr.Spec.NodeName }, + FieldPersistentVolumeName: func(cr *MountpointS3PodAttachment) string { return cr.Spec.PersistentVolumeName }, + FieldVolumeID: func(cr *MountpointS3PodAttachment) string { return cr.Spec.VolumeID }, + FieldMountOptions: func(cr *MountpointS3PodAttachment) string { return cr.Spec.MountOptions }, + FieldAuthenticationSource: func(cr *MountpointS3PodAttachment) string { return cr.Spec.AuthenticationSource }, + FieldWorkloadFSGroup: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadFSGroup }, + FieldWorkloadServiceAccountName: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountName }, + FieldWorkloadNamespace: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadNamespace }, + FieldWorkloadServiceAccountIAMRoleARN: func(cr *MountpointS3PodAttachment) string { return cr.Spec.WorkloadServiceAccountIAMRoleARN }, + } +} + +// SetupManagerIndices sets up indices for a manager +func SetupManagerIndices(mgr manager.Manager) error { + for field, extractor := range getIndexFields() { + if err := setupManagerIndex(mgr, field, extractor); err != nil { + return fmt.Errorf("failed to setup index for field %s: %w", field, err) + } + } + return nil +} + +// SetupCacheIndices sets up indices for a cache +func SetupCacheIndices(cache ctrlcache.Cache) error { + for field, extractor := range getIndexFields() { + if err := setupCacheIndex(cache, field, extractor); err != nil { + return fmt.Errorf("failed to setup index for field %s: %w", field, err) + } + } + return nil +} + +func setupManagerIndex(mgr manager.Manager, field string, extractor func(*MountpointS3PodAttachment) string) error { + return mgr.GetFieldIndexer().IndexField( + context.Background(), + &MountpointS3PodAttachment{}, + field, + func(obj client.Object) []string { + return []string{extractor(obj.(*MountpointS3PodAttachment))} + }, + ) +} + +func setupCacheIndex(cache ctrlcache.Cache, field string, extractor func(*MountpointS3PodAttachment) string) error { + return cache.IndexField( + context.Background(), + &MountpointS3PodAttachment{}, + field, + func(obj client.Object) []string { + return []string{extractor(obj.(*MountpointS3PodAttachment))} + }, + ) +} diff --git a/pkg/api/v1beta/mountpoints3podattachment_types.go b/pkg/api/v1beta/mountpoints3podattachment_types.go new file mode 100644 index 00000000..6e4a4e46 --- /dev/null +++ b/pkg/api/v1beta/mountpoints3podattachment_types.go @@ -0,0 +1,88 @@ +package v1beta + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// The following fields are used as matching criteria to determine if a mountpoint s3 pod can be shared by having the same MountpointS3PodAttachment resource: +const ( + FieldNodeName = "spec.nodeName" + FieldPersistentVolumeName = "spec.persistentVolumeName" + FieldVolumeID = "spec.volumeID" + FieldMountOptions = "spec.mountOptions" + FieldAuthenticationSource = "spec.authenticationSource" + FieldWorkloadFSGroup = "spec.workloadFSGroup" + FieldWorkloadServiceAccountName = "spec.workloadServiceAccountName" + FieldWorkloadNamespace = "spec.workloadNamespace" + FieldWorkloadServiceAccountIAMRoleARN = "spec.workloadServiceAccountIAMRoleARN" +) + +// MountpointS3PodAttachmentSpec defines the desired state of MountpointS3PodAttachment. +type MountpointS3PodAttachmentSpec struct { + // Important: Run "make generate" to regenerate code after modifying this file + + // Name of the node. + NodeName string `json:"nodeName"` + + // Name of the Persistent Volume. + PersistentVolumeName string `json:"persistentVolumeName"` + + // Volume ID. + VolumeID string `json:"volumeID"` + + // Comma separated mount options taken from volume. + MountOptions string `json:"mountOptions"` + + // Authentication source taken from volume attribute field `authenticationSource`. + AuthenticationSource string `json:"authenticationSource"` + + // Workload pod's `fsGroup` from pod security context + WorkloadFSGroup string `json:"workloadFSGroup"` + + // Workload pod's service account name. Exists only if `authenticationSource: pod`. + WorkloadServiceAccountName string `json:"workloadServiceAccountName,omitempty"` + + // Workload pod's namespace. Exists only if `authenticationSource: pod`. + WorkloadNamespace string `json:"workloadNamespace,omitempty"` + + // EKS IAM Role ARN from workload pod's service account annotation (IRSA). Exists only if `authenticationSource: pod` and service account has `eks.amazonaws.com/role-arn` annotation. + WorkloadServiceAccountIAMRoleARN string `json:"workloadServiceAccountIAMRoleARN,omitempty"` + + // Maps each Mountpoint S3 pod name to its workload attachments + MountpointS3PodAttachments map[string][]WorkloadAttachment `json:"mountpointS3PodAttachments"` +} + +// WorkloadAttachment represents the attachment details of a workload pod to a Mountpoint S3 pod. +type WorkloadAttachment struct { + // WorkloadPodUID is the unique identifier of the attached workload pod + WorkloadPodUID string `json:"workloadPodUID"` + + // AttachmentTime represents when the workload pod was attached to the Mountpoint S3 pod + AttachmentTime metav1.Time `json:"attachmentTime"` +} + +// +kubebuilder:object:root=true +// +kubebuilder:subresource:status +// +kubebuilder:resource:scope=Cluster,shortName=s3pa +// +kubebuilder:selectablefield:JSONPath=`.spec.nodeName` + +// MountpointS3PodAttachment is the Schema for the mountpoints3podattachments API. +type MountpointS3PodAttachment struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec MountpointS3PodAttachmentSpec `json:"spec,omitempty"` +} + +// +kubebuilder:object:root=true + +// MountpointS3PodAttachmentList contains a list of MountpointS3PodAttachment. +type MountpointS3PodAttachmentList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []MountpointS3PodAttachment `json:"items"` +} + +func init() { + SchemeBuilder.Register(&MountpointS3PodAttachment{}, &MountpointS3PodAttachmentList{}) +} diff --git a/pkg/api/v1beta/zz_generated.deepcopy.go b/pkg/api/v1beta/zz_generated.deepcopy.go new file mode 100644 index 00000000..2e79df80 --- /dev/null +++ b/pkg/api/v1beta/zz_generated.deepcopy.go @@ -0,0 +1,116 @@ +//go:build !ignore_autogenerated + +// Code generated by controller-gen. DO NOT EDIT. + +package v1beta + +import ( + runtime "k8s.io/apimachinery/pkg/runtime" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MountpointS3PodAttachment) DeepCopyInto(out *MountpointS3PodAttachment) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + in.Spec.DeepCopyInto(&out.Spec) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MountpointS3PodAttachment. +func (in *MountpointS3PodAttachment) DeepCopy() *MountpointS3PodAttachment { + if in == nil { + return nil + } + out := new(MountpointS3PodAttachment) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MountpointS3PodAttachment) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MountpointS3PodAttachmentList) DeepCopyInto(out *MountpointS3PodAttachmentList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]MountpointS3PodAttachment, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MountpointS3PodAttachmentList. +func (in *MountpointS3PodAttachmentList) DeepCopy() *MountpointS3PodAttachmentList { + if in == nil { + return nil + } + out := new(MountpointS3PodAttachmentList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MountpointS3PodAttachmentList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MountpointS3PodAttachmentSpec) DeepCopyInto(out *MountpointS3PodAttachmentSpec) { + *out = *in + if in.MountpointS3PodAttachments != nil { + in, out := &in.MountpointS3PodAttachments, &out.MountpointS3PodAttachments + *out = make(map[string][]WorkloadAttachment, len(*in)) + for key, val := range *in { + var outVal []WorkloadAttachment + if val == nil { + (*out)[key] = nil + } else { + inVal := (*in)[key] + in, out := &inVal, &outVal + *out = make([]WorkloadAttachment, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + (*out)[key] = outVal + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MountpointS3PodAttachmentSpec. +func (in *MountpointS3PodAttachmentSpec) DeepCopy() *MountpointS3PodAttachmentSpec { + if in == nil { + return nil + } + out := new(MountpointS3PodAttachmentSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkloadAttachment) DeepCopyInto(out *WorkloadAttachment) { + *out = *in + in.AttachmentTime.DeepCopyInto(&out.AttachmentTime) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkloadAttachment. +func (in *WorkloadAttachment) DeepCopy() *WorkloadAttachment { + if in == nil { + return nil + } + out := new(WorkloadAttachment) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 92e60122..b399ff2c 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -2,6 +2,7 @@ package cluster import ( "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/util/version" "k8s.io/client-go/discovery" "k8s.io/client-go/rest" "k8s.io/utils/ptr" @@ -53,3 +54,17 @@ func (c Variant) MountpointPodUserID() *int64 { return defaultMountpointUID } + +// Helper function to check availability of selectableFields on CustomResourceDefinitions feature in K8s cluster +func IsSelectableFieldsSupported(serverVersion string) (bool, error) { + currentVersion, err := version.ParseGeneric(serverVersion) + if err != nil { + return false, err + } + + // Selectable fields are supported from 1.32 + // https://kubernetes.io/docs/tasks/extend-kubernetes/custom-resources/custom-resource-definitions/#crd-selectable-fields + selectableFieldsVersion := version.MustParseGeneric("v1.32.0") + + return !currentVersion.LessThan(selectableFieldsVersion), nil +} diff --git a/pkg/cluster/cluster_test.go b/pkg/cluster/cluster_test.go index 700bae61..174609d4 100644 --- a/pkg/cluster/cluster_test.go +++ b/pkg/cluster/cluster_test.go @@ -33,3 +33,75 @@ func TestMountpointPodUserID(t *testing.T) { }) } } + +func TestIsSelectableFieldsSupported(t *testing.T) { + tests := []struct { + name string + serverVersion string + want bool + wantErr bool + }{ + { + name: "version greater than minimum supported version", + serverVersion: "v1.33.0", + want: true, + wantErr: false, + }, + { + name: "version equal to minimum supported version", + serverVersion: "v1.32.0", + want: true, + wantErr: false, + }, + { + name: "version less than minimum supported version", + serverVersion: "v1.31.0", + want: false, + wantErr: false, + }, + { + name: "version with patch number", + serverVersion: "v1.32.2", + want: true, + wantErr: false, + }, + { + name: "version with release candidate", + serverVersion: "v1.32.0-rc.1", + want: true, + wantErr: false, + }, + { + name: "invalid version format", + serverVersion: "invalid.version", + want: false, + wantErr: true, + }, + { + name: "empty version string", + serverVersion: "", + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cluster.IsSelectableFieldsSupported(tt.serverVersion) + + if tt.wantErr { + // If we expect an error, we don't check the boolean result + if err == nil { + t.Error("expected error but got none") + } + return + } + + // If we don't expect an error, verify there isn't one + assert.NoError(t, err) + + // Compare the actual result with expected + assert.Equals(t, tt.want, got) + }) + } +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index be7f26b8..c2403152 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -23,19 +23,29 @@ import ( "os" "time" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" + "github.com/awslabs/aws-s3-csi-driver/pkg/util" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/kubernetes" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" + ctrlcache "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager/signals" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" mpmounter "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint/mounter" - "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" ) const ( @@ -44,11 +54,18 @@ const ( grpcServerMaxReceiveMessageSize = 1024 * 1024 * 2 // 2MB unixSocketPerm = os.FileMode(0700) // only owner can write and read. +) +var ( + mountpointPodNamespace = os.Getenv("MOUNTPOINT_NAMESPACE") podWatcherResyncPeriod = time.Minute + scheme = runtime.NewScheme() ) -var mountpointPodNamespace = os.Getenv("MOUNTPOINT_NAMESPACE") +func init() { + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(crdv1beta.AddToScheme(scheme)) +} type Driver struct { Endpoint string @@ -93,13 +110,22 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error var mounterImpl mounter.Mounter mpMounter := mpmounter.New() if util.UsePodMounter() { - podWatcher := watcher.New(clientset, mountpointPodNamespace, podWatcherResyncPeriod) + podWatcher := watcher.New(clientset, mountpointPodNamespace, nodeID, podWatcherResyncPeriod) err = podWatcher.Start(stopCh) if err != nil { klog.Fatalf("Failed to start Pod watcher: %v\n", err) } - mounterImpl, err = mounter.NewPodMounter(podWatcher, credProvider, mpMounter, nil, kubernetesVersion) + s3paCache := setupS3PodAttachmentCache(config, stopCh, nodeID, kubernetesVersion) + + unmounter := mounter.NewPodUnmounter(nodeID, mpMounter, podWatcher, credProvider, mounter.SourceMountDir) + + podWatcher.AddEventHandler(cache.ResourceEventHandlerFuncs{UpdateFunc: unmounter.HandleMountpointPodUpdate}) + + go unmounter.StartPeriodicCleanup(stopCh) + + mounterImpl, err = mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mpMounter, nil, nil, nil, + kubernetesVersion, nodeID, mounter.SourceMountDir) if err != nil { klog.Fatalln(err) } @@ -187,3 +213,54 @@ func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) { return version.String(), nil } + +// setupS3PodAttachmentCache sets up cache for MountpointS3PodAttachment custom resource +func setupS3PodAttachmentCache(config *rest.Config, stopCh <-chan struct{}, nodeID, kubernetesVersion string) ctrlcache.Cache { + options := ctrlcache.Options{ + Scheme: scheme, + SyncPeriod: &podWatcherResyncPeriod, + ReaderFailOnMissingInformer: true, + } + isSelectFieldsSupported, err := cluster.IsSelectableFieldsSupported(kubernetesVersion) + if err != nil { + klog.Fatalf("Failed to check support for selectable fields in the cluster %v\n", err) + } + if isSelectFieldsSupported { + options.ByObject = map[client.Object]ctrlcache.ByObject{ + &crdv1beta.MountpointS3PodAttachment{}: { + Field: fields.OneTermEqualSelector("spec.nodeName", nodeID), + }, + } + } else { + // TODO: We can potentially use label filter hash of nodeId for old clusters instead of field selector + options.ByObject = map[client.Object]ctrlcache.ByObject{ + &crdv1beta.MountpointS3PodAttachment{}: {}, + } + } + + s3paCache, err := ctrlcache.New(config, options) + if err != nil { + klog.Fatalf("Failed to create cache: %v\n", err) + } + + if err := crdv1beta.SetupCacheIndices(s3paCache); err != nil { + klog.Fatalf("Failed to setup field indexers: %v", err) + } + + s3podAttachmentInformer, err := s3paCache.GetInformer(context.Background(), &crdv1beta.MountpointS3PodAttachment{}) + if err != nil { + klog.Fatalf("Failed to create informer for MountpointS3PodAttachment: %v\n", err) + } + + go func() { + if err := s3paCache.Start(signals.SetupSignalHandler()); err != nil { + klog.Fatalf("Failed to start cache: %v\n", err) + } + }() + + if !cache.WaitForCacheSync(stopCh, s3podAttachmentInformer.HasSynced) { + klog.Fatalf("Failed to sync informer cache within the timeout: %v\n", err) + } + + return s3paCache +} diff --git a/pkg/driver/node/credentialprovider/provider.go b/pkg/driver/node/credentialprovider/provider.go index 06070f84..9fa9797b 100644 --- a/pkg/driver/node/credentialprovider/provider.go +++ b/pkg/driver/node/credentialprovider/provider.go @@ -62,10 +62,11 @@ type ProvideContext struct { VolumeID string // The following values are provided from CSI volume context. - AuthenticationSource AuthenticationSource - PodNamespace string - ServiceAccountTokens string - ServiceAccountName string + AuthenticationSource AuthenticationSource + PodNamespace string + ServiceAccountTokens string + ServiceAccountName string + ServiceAccountEKSRoleARN string // StsRegion is the `stsRegion` parameter passed via volume attribute. StsRegion string // BucketRegion is the `--region` parameter passed via mount options. @@ -78,6 +79,11 @@ func (ctx *ProvideContext) SetWriteAndEnvPath(writePath, envPath string) { ctx.EnvPath = envPath } +// SetServiceAccountEKSRoleARN sets `ServiceAccountEKSRoleARN` for `ctx`. +func (ctx *ProvideContext) SetServiceAccountEKSRoleARN(roleArn string) { + ctx.ServiceAccountEKSRoleARN = roleArn +} + // A CleanupContext contains parameters needed to clean up credentials after volume unmount. type CleanupContext struct { // WritePath is basepath where credentials previously written into. diff --git a/pkg/driver/node/credentialprovider/provider_pod.go b/pkg/driver/node/credentialprovider/provider_pod.go index d011a942..832fc8b8 100644 --- a/pkg/driver/node/credentialprovider/provider_pod.go +++ b/pkg/driver/node/credentialprovider/provider_pod.go @@ -113,6 +113,11 @@ func (c *Provider) cleanupFromPod(cleanupCtx CleanupContext) error { // findPodServiceAccountRole tries to provide associated AWS IAM role for service account specified in the volume context. func (c *Provider) findPodServiceAccountRole(ctx context.Context, provideCtx ProvideContext) (string, error) { + // In PodMounter we get IAM Role ARN from MountpointS3PodAttachment custom resource + if provideCtx.ServiceAccountEKSRoleARN != "" { + return provideCtx.ServiceAccountEKSRoleARN, nil + } + podNamespace := provideCtx.PodNamespace podServiceAccount := provideCtx.ServiceAccountName if podNamespace == "" || podServiceAccount == "" { diff --git a/pkg/driver/node/mounter/fake_cache.go b/pkg/driver/node/mounter/fake_cache.go new file mode 100644 index 00000000..d37a2c74 --- /dev/null +++ b/pkg/driver/node/mounter/fake_cache.go @@ -0,0 +1,48 @@ +package mounter + +import ( + "context" + + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type FakeCache struct { + TestItems []crdv1beta.MountpointS3PodAttachment +} + +func (f *FakeCache) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + return nil +} + +func (f *FakeCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + s3paList := list.(*crdv1beta.MountpointS3PodAttachmentList) + s3paList.Items = f.TestItems + return nil +} + +func (f *FakeCache) GetInformer(ctx context.Context, obj client.Object, opts ...cache.InformerGetOption) (cache.Informer, error) { + return nil, nil +} + +func (f *FakeCache) GetInformerForKind(ctx context.Context, gvk schema.GroupVersionKind, opts ...cache.InformerGetOption) (cache.Informer, error) { + return nil, nil +} + +func (f *FakeCache) RemoveInformer(ctx context.Context, obj client.Object) error { + return nil +} + +func (f *FakeCache) IndexField(ctx context.Context, obj client.Object, field string, extractValue client.IndexerFunc) error { + return nil +} + +func (f *FakeCache) Start(ctx context.Context) error { + return nil +} + +func (f *FakeCache) WaitForCacheSync(ctx context.Context) bool { + return true +} diff --git a/pkg/driver/node/mounter/fake_mounter.go b/pkg/driver/node/mounter/fake_mounter.go index 8cfaa603..10f07dd7 100644 --- a/pkg/driver/node/mounter/fake_mounter.go +++ b/pkg/driver/node/mounter/fake_mounter.go @@ -10,7 +10,7 @@ import ( type FakeMounter struct{} func (m *FakeMounter) Mount(ctx context.Context, bucketName string, target string, - credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { + credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error { return nil } diff --git a/pkg/driver/node/mounter/mocks/mock_mount.go b/pkg/driver/node/mounter/mocks/mock_mount.go index 27fc6fbd..c8539a9f 100644 --- a/pkg/driver/node/mounter/mocks/mock_mount.go +++ b/pkg/driver/node/mounter/mocks/mock_mount.go @@ -106,17 +106,17 @@ func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call } // Mount mocks base method. -func (m *MockMounter) Mount(ctx context.Context, bucketName, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (m *MockMounter) Mount(ctx context.Context, bucketName, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", ctx, bucketName, target, credentialCtx, args) + ret := m.ctrl.Call(m, "Mount", ctx, bucketName, target, credentialCtx, args, fsGroup, pvMountOptions) ret0, _ := ret[0].(error) return ret0 } // Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(ctx, bucketName, target, credentialCtx, args interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Mount(ctx, bucketName, target, credentialCtx, args, fsGroup, pvMountOptions interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), ctx, bucketName, target, credentialCtx, args) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), ctx, bucketName, target, credentialCtx, args, fsGroup, pvMountOptions) } // Unmount mocks base method. diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index 9df722e0..fa4fe365 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -17,7 +17,7 @@ type ServiceRunner interface { // Mounter is an interface for mount operations type Mounter interface { - Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error + Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error Unmount(ctx context.Context, target string, credentialCtx credentialprovider.CleanupContext) error IsMountPoint(target string) (bool, error) } diff --git a/pkg/driver/node/mounter/mppod_lock.go b/pkg/driver/node/mounter/mppod_lock.go new file mode 100644 index 00000000..6222221f --- /dev/null +++ b/pkg/driver/node/mounter/mppod_lock.go @@ -0,0 +1,82 @@ +package mounter + +import ( + "sync" + + "k8s.io/klog/v2" +) + +// MPPodLock represents a reference-counted mutex lock for Mountpoint Pod. +// It ensures synchronized access to pod-specific resources. +type MPPodLock struct { + mutex sync.Mutex + refCount int +} + +var ( + // mpPodLocks maps pod UIDs to their corresponding locks. + mpPodLocks = make(map[string]*MPPodLock) + + // mpPodLocksMutex guards access to the mpPodLocks map. + mpPodLocksMutex sync.Mutex +) + +// lockMountpointPod acquires a lock for the specified pod UID and returns an unlock function. +// The returned function must be called to release the lock and cleanup resources. +// +// Parameters: +// - uid: The unique identifier of the Mountpoint Pod to lock +// +// Returns: +// - func(): A function that when called will unlock the pod and release associated resources +// +// Usage: +// +// unlock := lockMountpointPod(podUID) +// defer unlock() +func lockMountpointPod(uid string) func() { + mpPodLock := getMPPodLock(uid) + mpPodLock.mutex.Lock() + return func() { + mpPodLock.mutex.Unlock() + releaseMPPodLock(uid) + } +} + +// getMPPodLock retrieves or creates a lock for the specified pod UID. +// It increments the reference count for existing locks. +// The caller is responsible for calling releaseMPPodLock when the lock is no longer needed. +func getMPPodLock(mpPodUID string) *MPPodLock { + mpPodLocksMutex.Lock() + defer mpPodLocksMutex.Unlock() + + lock, exists := mpPodLocks[mpPodUID] + if !exists { + lock = &MPPodLock{refCount: 1} + mpPodLocks[mpPodUID] = lock + } else { + lock.refCount++ + } + return lock +} + +// releaseMPPodLock decrements the reference count for a pod's lock. +// When the reference count reaches zero, the lock is removed from the map. +// If the lock doesn't exist, the function returns silently. +func releaseMPPodLock(mpPodUID string) { + mpPodLocksMutex.Lock() + defer mpPodLocksMutex.Unlock() + + lock, exists := mpPodLocks[mpPodUID] + if !exists { + // Should never happen + klog.Errorf("Attempted to release non-existent lock for Mountpoint Pod UID %s", mpPodUID) + return + } + + lock.refCount-- + + if lock.refCount <= 0 { + delete(mpPodLocks, mpPodUID) + } +} diff --git a/pkg/driver/node/mounter/mppod_lock_test.go b/pkg/driver/node/mounter/mppod_lock_test.go new file mode 100644 index 00000000..d4361019 --- /dev/null +++ b/pkg/driver/node/mounter/mppod_lock_test.go @@ -0,0 +1,59 @@ +package mounter + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestGetMPPodLock(t *testing.T) { + // Clear the map before testing + mpPodLocks = make(map[string]*MPPodLock) + + t.Run("New lock creation", func(t *testing.T) { + podUID := "pod1" + lock := getMPPodLock(podUID) + + assert.Equals(t, 1, lock.refCount) + assert.Equals(t, 1, len(mpPodLocks)) + }) + + t.Run("Existing lock retrieval", func(t *testing.T) { + podUID := "pod2" + firstLock := getMPPodLock(podUID) + secondLock := getMPPodLock(podUID) + + if firstLock != secondLock { + t.Fatal("Expected to get the same lock instance") + } + assert.Equals(t, 2, firstLock.refCount) + }) +} + +func TestReleaseMPPodLock(t *testing.T) { + // Clear the map before testing + mpPodLocks = make(map[string]*MPPodLock) + + t.Run("Release existing lock", func(t *testing.T) { + podUID := "pod3" + getMPPodLock(podUID) + getMPPodLock(podUID) + + releaseMPPodLock(podUID) + + lock, exists := mpPodLocks[podUID] + assert.Equals(t, true, exists) + assert.Equals(t, 1, lock.refCount) + + releaseMPPodLock(podUID) + + _, exists = mpPodLocks[podUID] + assert.Equals(t, false, exists) + }) + + t.Run("Release non-existent lock", func(t *testing.T) { + podUID := "non-existent-pod" + releaseMPPodLock(podUID) + // This test passes if no panic occurs + }) +} diff --git a/pkg/driver/node/mounter/pod_mounter.go b/pkg/driver/node/mounter/pod_mounter.go index 678e3971..0883c803 100644 --- a/pkg/driver/node/mounter/pod_mounter.go +++ b/pkg/driver/node/mounter/pod_mounter.go @@ -13,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" @@ -22,10 +23,12 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" "github.com/awslabs/aws-s3-csi-driver/pkg/util" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" ) -const mountpointPodReadinessTimeout = 10 * time.Second -const mountpointPodReadinessCheckInterval = 100 * time.Millisecond +// Internal S3 CSI Driver directory for source mount points +const SourceMountDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/mnt/" // targetDirPerm is the permission to use while creating target directory if its not exists. const targetDirPerm = fs.FileMode(0755) @@ -34,26 +37,40 @@ const targetDirPerm = fs.FileMode(0755) // It returns mounted FUSE file descriptor as a result. // This is mainly exposed for testing, in production platform-native function (`mpmounter.Mount`) will be used. type mountSyscall func(target string, args mountpoint.Args) (fd int, err error) +type bindMountSyscall func(source, target string) (err error) +type sourceMountPointFinder func(target, sourceMountDir string) (string, error) // A PodMounter is a [Mounter] that mounts Mountpoint on pre-created Kubernetes Pod running in the same node. type PodMounter struct { - podWatcher *watcher.Watcher - mount *mpmounter.Mounter - kubeletPath string - mountSyscall mountSyscall - kubernetesVersion string - credProvider *credentialprovider.Provider + podWatcher *watcher.Watcher + s3paCache cache.Cache + mount *mpmounter.Mounter + kubeletPath string + sourceMountDir string + mountSyscall mountSyscall + bindMountSyscall bindMountSyscall + sourceMountPointFinder sourceMountPointFinder + kubernetesVersion string + credProvider *credentialprovider.Provider + nodeID string } // NewPodMounter creates a new [PodMounter] with given Kubernetes client. -func NewPodMounter(podWatcher *watcher.Watcher, credProvider *credentialprovider.Provider, mount *mpmounter.Mounter, mountSyscall mountSyscall, kubernetesVersion string) (*PodMounter, error) { +func NewPodMounter(podWatcher *watcher.Watcher, s3paCache cache.Cache, credProvider *credentialprovider.Provider, mount *mpmounter.Mounter, + mountSyscall mountSyscall, bindMountSyscall bindMountSyscall, sourceMountPointFinder sourceMountPointFinder, kubernetesVersion, nodeID, + sourceMountDir string) (*PodMounter, error) { return &PodMounter{ - podWatcher: podWatcher, - credProvider: credProvider, - mount: mount, - kubeletPath: util.KubeletPath(), - mountSyscall: mountSyscall, - kubernetesVersion: kubernetesVersion, + podWatcher: podWatcher, + s3paCache: s3paCache, + credProvider: credProvider, + mount: mount, + kubeletPath: util.KubeletPath(), + sourceMountDir: sourceMountDir, + mountSyscall: mountSyscall, + bindMountSyscall: bindMountSyscall, + sourceMountPointFinder: sourceMountPointFinder, + kubernetesVersion: kubernetesVersion, + nodeID: nodeID, }, nil } @@ -68,15 +85,13 @@ func NewPodMounter(podWatcher *watcher.Watcher, credProvider *credentialprovider // 6. Wait until Mountpoint successfully mounts at `target` // // If Mountpoint is already mounted at `target`, it will return early at step 2 to ensure credentials are up-to-date. -func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup string, pvMountOptions string) error { volumeName, err := pm.volumeNameFromTargetPath(target) if err != nil { return fmt.Errorf("Failed to extract volume name from %q: %w", target, err) } - podID := credentialCtx.PodID - - isMountPoint, err := pm.IsMountPoint(target) + isTargetMountPoint, err := pm.IsMountPoint(target) if err != nil { err = pm.verifyOrSetupMountTarget(target, err) if err != nil { @@ -84,35 +99,104 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin } } + if isTargetMountPoint { + klog.V(4).Infof("Target path %q is already mounted. Only refreshing credentials.", target) + source, err := pm.findSourceMountPointWithDefault(target) + if err != nil { + klog.Errorf("Failed to find source mount point for %q: %v", target, err) + return fmt.Errorf("Failed to find source mount point for %q: %w", target, err) + } + mpPodUID := filepath.Base(source) + podPath := pm.podPath(mpPodUID) + + unlockMountpointPod := lockMountpointPod(mpPodUID) + defer unlockMountpointPod() + + pm.provideCredentials(ctx, podPath, credentialCtx) + + return nil + } + + s3PodAttachment, mpPodName, err := pm.getS3PodAttachmentWithRetry(ctx, volumeName, credentialCtx, fsGroup, pvMountOptions) + if err != nil { + klog.Errorf("Failed to find corresponding MountpointS3PodAttachment custom resource %q: %v", target, err) + return fmt.Errorf("Failed to find corresponding MountpointS3PodAttachment custom resource %q: %w", target, err) + } + + if s3PodAttachment.Spec.WorkloadServiceAccountIAMRoleARN != "" { + credentialCtx.SetServiceAccountEKSRoleARN(s3PodAttachment.Spec.WorkloadServiceAccountIAMRoleARN) + } + // TODO: If `target` is a `systemd`-mounted Mountpoint, this would return an error, // but we should still update the credentials for it by calling `credProvider.Provide`. - pod, podPath, err := pm.waitForMountpointPod(ctx, podID, volumeName) + pod, podPath, err := pm.waitForMountpointPod(ctx, mpPodName) if err != nil { klog.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %v", target, err) return fmt.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %w", target, err) } + mpPodUID := string(pod.UID) + unlockMountpointPod := lockMountpointPod(mpPodUID) + defer unlockMountpointPod() - podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) + source := filepath.Join(pm.sourceMountDir, mpPodUID) + + isSourceMountPoint, err := pm.IsMountPoint(source) if err != nil { - klog.Errorf("Failed to create credentials directory for %q: %v", target, err) - return fmt.Errorf("Failed to create credentials directory for %q: %w", target, err) + err = pm.verifyOrSetupMountTarget(source, err) + if err != nil { + return fmt.Errorf("Failed to verify source path can be used as a mount point %q: %w", source, err) + } } - credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) - // Note that this part happens before `isMountPoint` check, as we want to update credentials even though // there is an existing mount point at `target`. - credEnv, authenticationSource, err := pm.credProvider.Provide(ctx, credentialCtx) + credEnv, authenticationSource, err := pm.provideCredentials(ctx, podPath, credentialCtx) if err != nil { - klog.Errorf("Failed to provide credentials for %s: %v\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) + klog.Errorf("Failed to provide credentials for %s: %v\n%s", source, err, pm.helpMessageForGettingMountpointLogs(pod)) + return fmt.Errorf("Failed to provide credentials for %q: %w\n%s", source, err, pm.helpMessageForGettingMountpointLogs(pod)) } - if isMountPoint { - klog.V(4).Infof("Target path %q is already mounted", target) - return nil + if !isSourceMountPoint { + err = pm.mountS3AtSource(ctx, source, pod, podPath, bucketName, credEnv, authenticationSource, args) + if err != nil { + return fmt.Errorf("Failed to mount at source %s: %v", source, err) + } + } + + err = pm.bindMountSyscallWithDefault(source, target) + if err != nil { + klog.Errorf("Failed to bind mount %s to target %s: %v", source, target, err) + return fmt.Errorf("Failed to bind mount %s to target %s: %w", source, target, err) } + return nil +} + +// mountS3AtSource mounts an S3 bucket at the specified source path using the Mountpoint Pod. +// +// Parameters: +// - ctx: Context for cancellation and timeout control +// - source: The path where the S3 bucket should be mounted +// - mpPod: Mountpoint Pod that will serve this mount point +// - podPath: Base path for Pod-specific files +// - bucketName: Name of the S3 bucket to mount +// - credEnv: Environment variables related to AWS credentials +// - authenticationSource: Authentication source from PV volume attribute +// - args: Mountpoint arguments +// +// Returns: +// - error: nil if successful, otherwise an error describing what went wrong +// +// The function performs the following steps: +// 1. Prepares environment and mount arguments +// 2. Performs the initial mount syscall to obtain FUSE file descriptor +// 3. Sends mount options to the Mountpoint Pod +// 4. Waits for the mount to be ready +// +// If any step fails, it ensures cleanup by unmounting the source path. +func (pm *PodMounter) mountS3AtSource(ctx context.Context, source string, mpPod *corev1.Pod, podPath string, + bucketName string, credEnv envprovider.Environment, authenticationSource credentialprovider.AuthenticationSource, + args mountpoint.Args) error { env := envprovider.Default() env.Merge(credEnv) @@ -126,12 +210,12 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin podMountSockPath := mppod.PathOnHost(podPath, mppod.KnownPathMountSock) podMountErrorPath := mppod.PathOnHost(podPath, mppod.KnownPathMountError) - klog.V(4).Infof("Mounting %s for %s", target, pod.Name) + klog.V(4).Infof("Mounting %s for %s", source, mpPod.Name) - fuseDeviceFD, err := pm.mountSyscallWithDefault(target, args) + fuseDeviceFD, err := pm.mountSyscallWithDefault(source, args) if err != nil { - klog.Errorf("Failed to mount %s: %v", target, err) - return fmt.Errorf("Failed to mount %s: %w", target, err) + klog.Errorf("Failed to mount %s: %v", source, err) + return fmt.Errorf("Failed to mount %s: %w", source, err) } // Remove the read-only argument from the list as mount-s3 does not support it when using FUSE @@ -141,14 +225,14 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin } // This will set to false in the success condition. This is set to `true` by default to - // ensure we don't leave `target` mounted if Mountpoint is not started to serve requests for it. + // ensure we don't leave `source` mounted if Mountpoint is not started to serve requests for it. unmount := true defer func() { if unmount { - if err := pm.unmountTarget(target); err != nil { - klog.V(4).ErrorS(err, "Failed to unmount mounted target %s\n", target) + if err := pm.unmountTarget(source); err != nil { + klog.V(4).ErrorS(err, "Failed to unmount mounted source %s\n", source) } else { - klog.V(4).Infof("Target %s unmounted successfully\n", target) + klog.V(4).Infof("Source %s unmounted successfully\n", source) } } }() @@ -161,7 +245,7 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin // Remove old mount error file if exists _ = os.Remove(podMountErrorPath) - klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", pod.Name, podMountSockPath) + klog.V(4).Infof("Sending mount options to Mountpoint Pod %s on %s", mpPod.Name, podMountSockPath) err = mountoptions.Send(ctx, podMountSockPath, mountoptions.Options{ Fd: fuseDeviceFD, @@ -170,14 +254,14 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin Env: env.List(), }) if err != nil { - klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) + klog.Errorf("Failed to send mount option to Mountpoint Pod %s for %s: %v\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) + return fmt.Errorf("Failed to send mount options to Mountpoint Pod %s for %s: %w\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) } - err = pm.waitForMount(ctx, target, pod.Name, podMountErrorPath) + err = pm.waitForMount(ctx, source, mpPod.Name, podMountErrorPath) if err != nil { - klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", pod.Name, target, err, pm.helpMessageForGettingMountpointLogs(pod)) + klog.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %v\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) + return fmt.Errorf("Failed to wait for Mountpoint Pod %s to be ready for %s: %w\n%s", mpPod.Name, source, err, pm.helpMessageForGettingMountpointLogs(mpPod)) } // Mountpoint successfully started, so don't unmount the filesystem @@ -185,45 +269,13 @@ func (pm *PodMounter) Mount(ctx context.Context, bucketName string, target strin return nil } -// Unmount unmounts the mount point at `target` and cleans all credentials. -func (pm *PodMounter) Unmount(ctx context.Context, target string, credentialCtx credentialprovider.CleanupContext) error { - volumeName, err := pm.volumeNameFromTargetPath(target) - if err != nil { - return fmt.Errorf("Failed to extract volume name from %q: %w", target, err) - } - - podID := credentialCtx.PodID - - // TODO: If `target` is a `systemd`-mounted Mountpoint, this would return an error, - // but we should still unmount it and clean the credentials. - pod, podPath, err := pm.waitForMountpointPod(ctx, podID, volumeName) - if err != nil { - klog.Errorf("Failed to wait for Mountpoint Pod to be ready for %q: %v", target, err) - return fmt.Errorf("Failed to wait for Mountpoint Pod for %q: %w", target, err) - } - - credentialCtx.WritePath = pm.credentialsDir(podPath) - - // Write `mount.exit` file to indicate Mountpoint Pod to cleanly exit. - podMountExitPath := mppod.PathOnHost(podPath, mppod.KnownPathMountExit) - _, err = os.OpenFile(podMountExitPath, os.O_RDONLY|os.O_CREATE, credentialprovider.CredentialFilePerm) - if err != nil { - klog.Errorf("Failed to send a exit message to Mountpoint Pod for %q: %s\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to send a exit message to Mountpoint Pod for %q: %w\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - } - - err = pm.unmountTarget(target) +// Unmount unmounts only the bind mount point at `target`. +func (pm *PodMounter) Unmount(_ context.Context, target string, _ credentialprovider.CleanupContext) error { + err := pm.unmountTarget(target) if err != nil { klog.Errorf("Failed to unmount %q: %v", target, err) return fmt.Errorf("Failed to unmount %q: %w", target, err) } - - err = pm.credProvider.Cleanup(credentialCtx) - if err != nil { - klog.Errorf("Failed to clean up credentials for %s: %v\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - return fmt.Errorf("Failed to clean up credentials for %q: %w\n%s", target, err, pm.helpMessageForGettingMountpointLogs(pod)) - } - return nil } @@ -232,11 +284,18 @@ func (pm *PodMounter) IsMountPoint(target string) (bool, error) { return pm.mount.CheckMountPoint(target) } -// waitForMountpointPod waints until Mountpoint Pod for given `podID` and `volumeName` is in `Running` state. -// It returns found Mountpoint Pod and it's base directory. -func (pm *PodMounter) waitForMountpointPod(ctx context.Context, podID, volumeName string) (*corev1.Pod, string, error) { - podName := mppod.MountpointPodNameFor(podID, volumeName) +// findSourceMountPointWithDefault calls `FindSourceMountPoint` on `target`. +func (pm *PodMounter) findSourceMountPointWithDefault(target string) (string, error) { + if pm.sourceMountPointFinder != nil { + return pm.sourceMountPointFinder(target, pm.sourceMountDir) + } + + return pm.mount.FindSourceMountPoint(target, pm.sourceMountDir) +} +// waitForMountpointPod waits until Mountpoint Pod for given `podName` is in `Running` state. +// It returns found Mountpoint Pod and it's base directory. +func (pm *PodMounter) waitForMountpointPod(ctx context.Context, podName string) (*corev1.Pod, string, error) { pod, err := pm.podWatcher.Wait(ctx, podName) if err != nil { return nil, "", err @@ -244,7 +303,7 @@ func (pm *PodMounter) waitForMountpointPod(ctx context.Context, podID, volumeNam klog.V(4).Infof("Mountpoint Pod %s/%s is running with id %s", pod.Namespace, podName, pod.UID) - return pod, pm.podPath(pod), nil + return pod, pm.podPath(string(pod.UID)), nil } // waitForMount waits until Mountpoint is successfully mounted at `target`. @@ -327,6 +386,19 @@ func (pm *PodMounter) verifyOrSetupMountTarget(target string, err error) error { return err } +// provideCredentials provides credentials +func (pm *PodMounter) provideCredentials(ctx context.Context, podPath string, credentialCtx credentialprovider.ProvideContext) (envprovider.Environment, credentialprovider.AuthenticationSource, error) { + podCredentialsPath, err := pm.ensureCredentialsDirExists(podPath) + if err != nil { + klog.Errorf("Failed to create credentials directory: %v", err) + return nil, "", fmt.Errorf("Failed to create credentials directory: %w", err) + } + + credentialCtx.SetWriteAndEnvPath(podCredentialsPath, mppod.PathInsideMountpointPod(mppod.KnownPathCredentials)) + + return pm.credProvider.Provide(ctx, credentialCtx) +} + // ensureCredentialsDirExists ensures credentials dir for `podPath` is exists. // It returns credentials dir and any error. func (pm *PodMounter) ensureCredentialsDirExists(podPath string) (string, error) { @@ -346,8 +418,8 @@ func (pm *PodMounter) credentialsDir(podPath string) string { } // podPath returns `pod`'s basepath inside kubelet's path. -func (pm *PodMounter) podPath(pod *corev1.Pod) string { - return filepath.Join(pm.kubeletPath, "pods", string(pod.UID)) +func (pm *PodMounter) podPath(podUID string) string { + return filepath.Join(pm.kubeletPath, "pods", podUID) } // mountSyscallWithDefault delegates to `mountSyscall` if set, or fallbacks to platform-native `mpmounter.Mount`. @@ -363,6 +435,15 @@ func (pm *PodMounter) mountSyscallWithDefault(target string, args mountpoint.Arg return pm.mount.Mount(target, opts) } +// bindMountWithDefault delegates to `bindMountSyscall` if set, or fallbacks to platform-native `mpmounter.BindMount`. +func (pm *PodMounter) bindMountSyscallWithDefault(source, target string) error { + if pm.bindMountSyscall != nil { + return pm.bindMountSyscall(source, target) + } + + return pm.mount.BindMount(source, target) +} + // unmountTarget calls `unmount` syscall on `target`. func (pm *PodMounter) unmountTarget(target string) error { return pm.mount.Unmount(target) @@ -380,3 +461,53 @@ func (pm *PodMounter) volumeNameFromTargetPath(target string) (string, error) { func (pm *PodMounter) helpMessageForGettingMountpointLogs(pod *corev1.Pod) string { return fmt.Sprintf("You can see Mountpoint logs by running: `kubectl logs -n %s %s`. If the Mountpoint Pod already restarted, you can also pass `--previous` to get logs from the previous run.", pod.Namespace, pod.Name) } + +// getS3PodAttachmentWithRetry retrieves a MountpointS3PodAttachment resource that matches the given volume and credential context. +// It continuously retries the operation until either a matching attachment is found or the context is canceled. +func (pm *PodMounter) getS3PodAttachmentWithRetry(ctx context.Context, volumeName string, credentialCtx credentialprovider.ProvideContext, fsGroup, pvMountOptions string) (*crdv1beta.MountpointS3PodAttachment, string, error) { + fieldFilters := client.MatchingFields{ + crdv1beta.FieldNodeName: pm.nodeID, + crdv1beta.FieldPersistentVolumeName: volumeName, + crdv1beta.FieldVolumeID: credentialCtx.VolumeID, + crdv1beta.FieldMountOptions: pvMountOptions, + crdv1beta.FieldWorkloadFSGroup: fsGroup, + crdv1beta.FieldAuthenticationSource: credentialCtx.AuthenticationSource, + } + if credentialCtx.AuthenticationSource == credentialprovider.AuthenticationSourcePod { + fieldFilters[crdv1beta.FieldWorkloadNamespace] = credentialCtx.PodNamespace + fieldFilters[crdv1beta.FieldWorkloadServiceAccountName] = credentialCtx.ServiceAccountName + // Note that we intentionally do not include `FieldWorkloadServiceAccountIAMRoleARN` to list filters because + // CSI Driver Node does not know which role ARN to use (if any). + // Role ARN is determined by reconciler and passed to node via MountpointS3PodAttachment. + } + + for { + select { + case <-ctx.Done(): + return nil, "", ctx.Err() + default: + } + + s3paList := &crdv1beta.MountpointS3PodAttachmentList{} + err := pm.s3paCache.List(ctx, s3paList, fieldFilters) + if err != nil { + klog.Errorf("Failed to list MountpointS3PodAttachments: %v", err) + return nil, "", err + } + for _, s3pa := range s3paList.Items { + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == credentialCtx.PodID { + return &s3pa, mpPodName, nil + } + } + } + } + + select { + case <-ctx.Done(): + return nil, "", ctx.Err() + case <-time.After(100 * time.Millisecond): + } + } +} diff --git a/pkg/driver/node/mounter/pod_mounter_test.go b/pkg/driver/node/mounter/pod_mounter_test.go index c9d515b2..f69b1b8b 100644 --- a/pkg/driver/node/mounter/pod_mounter_test.go +++ b/pkg/driver/node/mounter/pod_mounter_test.go @@ -19,6 +19,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "k8s.io/mount-utils" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" @@ -41,16 +42,24 @@ type testCtx struct { podMounter *mounter.PodMounter - client *fake.Clientset - mount *mount.FakeMounter - mountSyscall func(target string, args mountpoint.Args) (fd int, err error) - - bucketName string - kubeletPath string - targetPath string - podUID string - volumeID string - pvName string + client *fake.Clientset + mount *mount.FakeMounter + s3paCache *mounter.FakeCache + mountSyscall func(target string, args mountpoint.Args) (fd int, err error) + mountBindSyscall func(source, target string) (err error) + + bucketName string + kubeletPath string + sourcePath string + targetPath string + podUID string + volumeID string + pvName string + nodeName string + fsGroup string + pvMountOptions string + mpPodName string + mpPodUID string } func setup(t *testing.T) *testCtx { @@ -64,10 +73,18 @@ func setup(t *testing.T) *testCtx { // to overcome `bind: invalid argument`. t.Chdir(kubeletPath) + sourceMountDir := t.TempDir() + bucketName := "test-bucket" podUID := uuid.New().String() + mpPodName := "test-mppod" + mpPodUID := uuid.New().String() volumeID := "s3-csi-driver-volume" pvName := "s3-csi-driver-pv" + nodeName := "test-node" + fsGroup := "1000" + pvMountOptions := "--fake-mountoption" + s3paCache := &mounter.FakeCache{} targetPath := filepath.Join( kubeletPath, fmt.Sprintf("pods/%s/volumes/kubernetes.io~csi/%s/mount", podUID, pvName), @@ -82,37 +99,87 @@ func setup(t *testing.T) *testCtx { parentDir, err := filepath.EvalSymlinks(filepath.Dir(targetPath)) assert.NoError(t, err) targetPath = filepath.Join(parentDir, filepath.Base(targetPath)) + parentDir, err = filepath.EvalSymlinks(filepath.Dir(sourceMountDir)) + assert.NoError(t, err) + sourceMountDir = filepath.Join(parentDir, filepath.Base(sourceMountDir)) client := fake.NewClientset() - mount := mount.NewFakeMounter(nil) + fakeMounter := mount.NewFakeMounter(nil) testCtx := &testCtx{ - t: t, - ctx: ctx, - client: client, - mount: mount, - bucketName: bucketName, - kubeletPath: kubeletPath, - targetPath: targetPath, - podUID: podUID, - volumeID: volumeID, - pvName: pvName, + t: t, + ctx: ctx, + client: client, + mount: fakeMounter, + bucketName: bucketName, + kubeletPath: kubeletPath, + targetPath: targetPath, + podUID: podUID, + volumeID: volumeID, + pvName: pvName, + nodeName: nodeName, + fsGroup: fsGroup, + s3paCache: s3paCache, + pvMountOptions: pvMountOptions, + mpPodName: mpPodName, + mpPodUID: mpPodUID, + sourcePath: filepath.Join(sourceMountDir, mpPodUID), } + testCrd := crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: testCtx.nodeName, + PersistentVolumeName: testCtx.pvName, + VolumeID: testCtx.volumeID, + WorkloadFSGroup: testCtx.fsGroup, + MountOptions: testCtx.pvMountOptions, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + testCtx.mpPodName: []crdv1beta.WorkloadAttachment{{WorkloadPodUID: testCtx.podUID}}, + }, + }, + } + testCtx.s3paCache.TestItems = []crdv1beta.MountpointS3PodAttachment{testCrd} + mountSyscall := func(target string, args mountpoint.Args) (fd int, err error) { if testCtx.mountSyscall != nil { return testCtx.mountSyscall(target, args) } - mount.Mount("mountpoint-s3", target, "fuse", nil) + fakeMounter.Mount("mountpoint-s3", target, "fuse", nil) return int(mountertest.OpenDevNull(t).Fd()), nil } + mountBindSyscall := func(source, target string) (err error) { + if testCtx.mountBindSyscall != nil { + return testCtx.mountBindSyscall(source, target) + } + + fakeMounter.Mount(source, target, "fuse", []string{"bind"}) + return nil + } + + findSourceMountPoint := func(target, sourceMountDir string) (string, error) { + mountPoints, err := fakeMounter.List() + if err != nil { + return "", fmt.Errorf("failed to list mount points: %w", err) + } + + for _, mp := range mountPoints { + if mp.Device == "mountpoint-s3" && + strings.HasPrefix(mp.Path, sourceMountDir) && + mp.Path != target { + return mp.Path, nil + } + } + + return "", fmt.Errorf("no source mount point found for target %q", target) + } + credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { return dummyIMDSRegion, nil }) - podWatcher := watcher.New(client, mountpointPodNamespace, 10*time.Second) + podWatcher := watcher.New(client, mountpointPodNamespace, nodeName, 10*time.Second) stopCh := make(chan struct{}) t.Cleanup(func() { close(stopCh) @@ -120,7 +187,8 @@ func setup(t *testing.T) *testCtx { err = podWatcher.Start(stopCh) assert.NoError(t, err) - podMounter, err := mounter.NewPodMounter(podWatcher, credProvider, mpmounter.NewWithMount(mount), mountSyscall, testK8sVersion) + podMounter, err := mounter.NewPodMounter(podWatcher, s3paCache, credProvider, mpmounter.NewWithMount(fakeMounter), mountSyscall, + mountBindSyscall, findSourceMountPoint, testK8sVersion, nodeName, sourceMountDir) assert.NoError(t, err) testCtx.podMounter = podMounter @@ -155,7 +223,7 @@ func TestPodMounter(t *testing.T) { AuthenticationSource: credentialprovider.AuthenticationSourceDriver, VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, args) + }, args, testCtx.fsGroup, testCtx.pvMountOptions) if err != nil { log.Println("Mount failed", err) } @@ -201,7 +269,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) }) @@ -215,7 +283,7 @@ func TestPodMounter(t *testing.T) { AuthenticationSource: credentialprovider.AuthenticationSourceDriver, VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, args) + }, args, testCtx.fsGroup, testCtx.pvMountOptions) if err != nil { log.Println("Mount failed", err) } @@ -233,11 +301,11 @@ func TestPodMounter(t *testing.T) { assert.Equals(t, true, credDirInfo.IsDir()) assert.Equals(t, credentialprovider.CredentialDirPerm, credDirInfo.Mode().Perm()) }) - t.Run("Does not duplicate mounts if target is already mounted", func(t *testing.T) { testCtx := setup(t) var mountCount atomic.Int32 + var bindMountCount atomic.Int32 testCtx.mountSyscall = func(target string, args mountpoint.Args) (fd int, err error) { mountCount.Add(1) @@ -245,6 +313,12 @@ func TestPodMounter(t *testing.T) { return int(mountertest.OpenDevNull(t).Fd()), nil } + testCtx.mountBindSyscall = func(source, target string) (err error) { + bindMountCount.Add(1) + testCtx.mount.Mount(source, target, "fuse", []string{"bind"}) + return nil + } + go func() { mpPod := createMountpointPod(testCtx) mpPod.run() @@ -252,17 +326,99 @@ func TestPodMounter(t *testing.T) { }() for range 5 { - err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ - VolumeID: testCtx.volumeID, - PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, + credentialprovider.ProvideContext{ + VolumeID: testCtx.volumeID, + PodID: testCtx.podUID, + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) } assert.Equals(t, int32(1), mountCount.Load()) + assert.Equals(t, int32(1), bindMountCount.Load()) + }) + + t.Run("Re-uses the same source mount for different targets if they share same Mountpoint Pod", func(t *testing.T) { + // First Pod + testCtx := setup(t) + + ok, _ := testCtx.podMounter.IsMountPoint(testCtx.targetPath) + assert.Equals(t, false, ok) + + var mountCount atomic.Int32 + var bindMountCount atomic.Int32 + + testCtx.mountSyscall = func(target string, args mountpoint.Args) (fd int, err error) { + mountCount.Add(1) + testCtx.mount.Mount("mountpoint-s3", target, "fuse", nil) + return int(mountertest.OpenDevNull(t).Fd()), nil + } + testCtx.mountBindSyscall = func(source, target string) (err error) { + bindMountCount.Add(1) + testCtx.mount.Mount(source, target, "fuse", []string{"bind"}) + return nil + } + + go func() { + mpPod := createMountpointPod(testCtx) + mpPod.run() + mpPod.receiveMountOptions(testCtx.ctx) + }() + + err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ + VolumeID: testCtx.volumeID, + PodID: testCtx.podUID, + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) + assert.NoError(t, err) + + ok, err = testCtx.podMounter.IsMountPoint(testCtx.sourcePath) + assert.NoError(t, err) + assert.Equals(t, true, ok) + ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) + assert.NoError(t, err) + assert.Equals(t, true, ok) + + // Second Pod + testCtx.podUID = uuid.New().String() + targetPath2 := filepath.Join( + testCtx.kubeletPath, + fmt.Sprintf("pods/%s/volumes/kubernetes.io~csi/%s/mount", testCtx.podUID, testCtx.pvName), + ) + err = os.MkdirAll(filepath.Dir(targetPath2), 0750) + assert.NoError(t, err) + parentDir, err := filepath.EvalSymlinks(filepath.Dir(targetPath2)) + assert.NoError(t, err) + targetPath2 = filepath.Join(parentDir, filepath.Base(targetPath2)) + testCtx.targetPath = targetPath2 + testCrd2 := crdv1beta.MountpointS3PodAttachment{ + Spec: crdv1beta.MountpointS3PodAttachmentSpec{ + NodeName: testCtx.nodeName, + PersistentVolumeName: testCtx.pvName, + VolumeID: testCtx.volumeID, + WorkloadFSGroup: testCtx.fsGroup, + MountOptions: testCtx.pvMountOptions, + MountpointS3PodAttachments: map[string][]crdv1beta.WorkloadAttachment{ + testCtx.mpPodName: []crdv1beta.WorkloadAttachment{{WorkloadPodUID: testCtx.podUID}}, + }, + }, + } + testCtx.s3paCache.TestItems = []crdv1beta.MountpointS3PodAttachment{testCrd2} + + err = testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ + VolumeID: testCtx.volumeID, + PodID: testCtx.podUID, + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) + assert.NoError(t, err) + + ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) + assert.NoError(t, err) + assert.Equals(t, true, ok) + + assert.Equals(t, int32(1), mountCount.Load()) + assert.Equals(t, int32(2), bindMountCount.Load()) }) - t.Run("Unmounts target if Mountpoint Pod does not receive mount options", func(t *testing.T) { + t.Run("Unmounts source if Mountpoint Pod does not receive mount options", func(t *testing.T) { testCtx := setup(t) go func() { @@ -278,19 +434,24 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) if err == nil { t.Errorf("mount shouldn't succeeded if Mountpoint does not receive the mount options") } - ok, err := testCtx.mount.IsMountPoint(testCtx.targetPath) + ok, err := testCtx.mount.IsMountPoint(testCtx.sourcePath) + assert.NoError(t, err) + if ok { + t.Errorf("it should unmount the source path if Mountpoint does not receive the mount options") + } + ok, err = testCtx.mount.IsMountPoint(testCtx.targetPath) assert.NoError(t, err) if ok { - t.Errorf("it should unmount the target path if Mountpoint does not receive the mount options") + t.Errorf("it should not bind mount the target path if Mountpoint does not receive the mount options") } }) - t.Run("Unmounts target if Mountpoint Pod fails to start", func(t *testing.T) { + t.Run("Unmounts source if Mountpoint Pod fails to start", func(t *testing.T) { testCtx := setup(t) testCtx.mountSyscall = func(target string, args mountpoint.Args) (fd int, err error) { @@ -312,15 +473,20 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) if err == nil { t.Errorf("mount shouldn't succeeded if Mountpoint fails to start") } - ok, err := testCtx.mount.IsMountPoint(testCtx.targetPath) + ok, err := testCtx.mount.IsMountPoint(testCtx.sourcePath) assert.NoError(t, err) if ok { - t.Errorf("it should unmount the target path if Mountpoint fails to start") + t.Errorf("it should unmount the source path if Mountpoint fails to start") + } + ok, err = testCtx.mount.IsMountPoint(testCtx.targetPath) + assert.NoError(t, err) + if ok { + t.Errorf("it should not bind mount the target path if Mountpoint fails to start") } }) @@ -347,7 +513,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) if err == nil { t.Errorf("mount shouldn't succeeded if Mountpoint fails to start") } @@ -380,9 +546,12 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) + ok, err = testCtx.podMounter.IsMountPoint(testCtx.sourcePath) + assert.NoError(t, err) + assert.Equals(t, true, ok) ok, err = testCtx.podMounter.IsMountPoint(testCtx.targetPath) assert.NoError(t, err) assert.Equals(t, true, ok) @@ -400,7 +569,7 @@ func TestPodMounter(t *testing.T) { err := testCtx.podMounter.Mount(testCtx.ctx, testCtx.bucketName, testCtx.targetPath, credentialprovider.ProvideContext{ VolumeID: testCtx.volumeID, PodID: testCtx.podUID, - }, mountpoint.ParseArgs(nil)) + }, mountpoint.ParseArgs(nil), testCtx.fsGroup, testCtx.pvMountOptions) assert.NoError(t, err) ok, err := testCtx.podMounter.IsMountPoint(testCtx.targetPath) @@ -431,8 +600,8 @@ func createMountpointPod(testCtx *testCtx) *mountpointPod { pod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(uuid.New().String()), - Name: mppod.MountpointPodNameFor(testCtx.podUID, testCtx.pvName), + UID: types.UID(testCtx.mpPodUID), + Name: testCtx.mpPodName, }, } pod, err := testCtx.client.CoreV1().Pods(mountpointPodNamespace).Create(context.TODO(), pod, metav1.CreateOptions{}) diff --git a/pkg/driver/node/mounter/pod_unmounter.go b/pkg/driver/node/mounter/pod_unmounter.go new file mode 100644 index 00000000..eecb3c1f --- /dev/null +++ b/pkg/driver/node/mounter/pod_unmounter.go @@ -0,0 +1,211 @@ +package mounter + +import ( + "os" + "path/filepath" + "sync" + "time" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + mpmounter "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" + "github.com/awslabs/aws-s3-csi-driver/pkg/util" + corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" +) + +const cleanupInterval = 2 * time.Minute + +// PodUnmounter handles unmounting of Mountpoint Pods and cleanup of associated resources +type PodUnmounter struct { + nodeID string + mount *mpmounter.Mounter + kubeletPath string + sourceMountDir string + podWatcher *watcher.Watcher + credProvider *credentialprovider.Provider + mutex sync.Mutex +} + +// NewPodUnmounter creates a new PodUnmounter instance with the given parameters +func NewPodUnmounter( + nodeID string, + mount *mpmounter.Mounter, + podWatcher *watcher.Watcher, + credProvider *credentialprovider.Provider, + sourceMountDir string, +) *PodUnmounter { + return &PodUnmounter{ + nodeID: nodeID, + mount: mount, + kubeletPath: util.KubeletPath(), + sourceMountDir: sourceMountDir, + podWatcher: podWatcher, + credProvider: credProvider, + } +} + +// HandleMountpointPodUpdate is a Pod Update handler that triggers unmounting +// if the Mountpoint Pod is marked for unmounting via annotations +func (u *PodUnmounter) HandleMountpointPodUpdate(old, new any) { + mpPod := new.(*corev1.Pod) + if mpPod.Spec.NodeName != u.nodeID { + return + } + + if value, exists := mpPod.Annotations[mppod.AnnotationNeedsUnmount]; exists && value == "true" { + u.unmountSourceForPod(mpPod) + } +} + +// unmountSourceForPod performs the unmounting process for a specific Mountpoint Pod +// including cleanup of associated resources +// mpPod: The Mountpoint Pod to unmount +func (u *PodUnmounter) unmountSourceForPod(mpPod *corev1.Pod) { + mpPodUID := string(mpPod.UID) + unlockMountpointPod := lockMountpointPod(mpPodUID) + defer unlockMountpointPod() + + klog.Infof("Found Mountpoint Pod %s (UID: %s) with %s annotation, unmounting it", mpPod.Name, mpPodUID, mppod.AnnotationNeedsUnmount) + + podPath := filepath.Join(u.kubeletPath, "pods", mpPodUID) + source := filepath.Join(u.sourceMountDir, mpPodUID) + volumeId := mpPod.Labels[mppod.LabelVolumeId] + + if err := u.writeExitFile(podPath); err != nil { + return + } + + if err := u.unmountAndCleanup(source); err != nil { + return + } + klog.Infof("Successfully unmounted Mountpoint Pod - %s", mpPod.Name) + + if err := u.cleanupCredentials(volumeId, mpPodUID, podPath, source, mpPod); err != nil { + return + } +} + +// writeExitFile creates an exit file in the pod's directory to signal Mountpoint Pod termination +// podPath: Path to the pod's directory +// Returns error if file creation fails +func (u *PodUnmounter) writeExitFile(podPath string) error { + podMountExitPath := mppod.PathOnHost(podPath, mppod.KnownPathMountExit) + _, err := os.OpenFile(podMountExitPath, os.O_RDONLY|os.O_CREATE, credentialprovider.CredentialFilePerm) + if err != nil { + klog.Errorf("Failed to send a exit message to Mountpoint Pod: %s", err) + return err + } + return nil +} + +// unmountAndCleanup unmounts the source directory and removes it +// source: Path to the source directory to unmount +// Returns error if unmounting or cleanup fails +func (u *PodUnmounter) unmountAndCleanup(source string) error { + if err := u.mount.Unmount(source); err != nil { + klog.Errorf("Failed to unmount source %q: %v", source, err) + return err + } + + if err := os.Remove(source); err != nil { + klog.Errorf("Failed to remove source directory %q: %v", source, err) + return err + } + return nil +} + +// cleanupCredentials removes credentials associated with the Mountpoint Pod +func (u *PodUnmounter) cleanupCredentials(volumeId, mpPodUID, podPath, source string, mpPod *corev1.Pod) error { + err := u.credProvider.Cleanup(credentialprovider.CleanupContext{ + VolumeID: volumeId, + PodID: mpPodUID, + WritePath: filepath.Join(u.kubeletPath, "pods", mpPodUID), + }) + if err != nil { + klog.Errorf("Failed to clean up credentials for %s: %v", source, err) + return err + } + return nil +} + +// StartPeriodicCleanup begins periodic cleanup of dangling mounts +// This is needed in case when `HandleMountpointPodUpdate()` missed an update event to trigger cleanup. +// stopCh: Channel to signal stopping of the cleanup routine +func (u *PodUnmounter) StartPeriodicCleanup(stopCh <-chan struct{}) { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-stopCh: + return + case <-ticker.C: + if err := u.CleanupDanglingMounts(); err != nil { + klog.Errorf("Failed to run clean up of dangling mounts: %v", err) + } + } + } +} + +// CleanupDanglingMounts scans the source mount directory for potential dangling mounts +// and cleans them up. It also unmounts any Mountpoint Pods marked for unmounting. +func (u *PodUnmounter) CleanupDanglingMounts() error { + // Ensure only one cleanup runs at a time + if !u.mutex.TryLock() { + return nil + } + defer u.mutex.Unlock() + + entries, err := os.ReadDir(u.sourceMountDir) + if err != nil { + klog.Errorf("Failed to read source mount directory (`%s`): %v", u.sourceMountDir, err) + return err + } + + for _, file := range entries { + if !file.IsDir() { + continue + } + + mpPodUID := file.Name() + source := filepath.Join(u.sourceMountDir, mpPodUID) + // Try to find corresponding pod + mpPod, err := u.findPodByUID(mpPodUID) + if err != nil { + klog.Errorf("Failed to check Mountpoint Pod (UID: %s) existence: %v", mpPodUID, err) + return err + } + + if mpPod == nil { + klog.V(4).Infof("Mountpoint Pod not found for UID %s, will unmount and delete folder", mpPodUID) + if err := u.unmountAndCleanup(source); err != nil { + klog.Errorf("Failed to cleanup dangling mount for UID %s: %v", mpPodUID, err) + } + continue + } + + // Unmount only if Mountpoint Pod is marked for unmounting + if value, exists := mpPod.Annotations[mppod.AnnotationNeedsUnmount]; exists && value == "true" { + u.unmountSourceForPod(mpPod) + } + } + + return nil +} + +// findPodByUID finds Mountpoint Pod by UID in podWatcher's cache +func (u *PodUnmounter) findPodByUID(mpPodUID string) (*corev1.Pod, error) { + pods, err := u.podWatcher.List() + if err != nil { + return nil, err + } + + for _, pod := range pods { + if string(pod.UID) == mpPodUID { + return pod, nil + } + } + return nil, nil +} diff --git a/pkg/driver/node/mounter/pod_unmounter_test.go b/pkg/driver/node/mounter/pod_unmounter_test.go new file mode 100644 index 00000000..17a60204 --- /dev/null +++ b/pkg/driver/node/mounter/pod_unmounter_test.go @@ -0,0 +1,273 @@ +package mounter_test + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + mpmounter "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" + "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod/watcher" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/mount-utils" +) + +const ( + nodeName = "test-node" +) + +func setupPodWatcher(t *testing.T, pods ...*corev1.Pod) (*watcher.Watcher, *fake.Clientset) { + client := fake.NewClientset() + podWatcher := watcher.New(client, mountpointPodNamespace, nodeName, 10*time.Second) + stopCh := make(chan struct{}) + t.Cleanup(func() { + close(stopCh) + }) + + for _, pod := range pods { + if pod != nil { + _, err := client.CoreV1().Pods(mountpointPodNamespace).Create(context.Background(), pod, metav1.CreateOptions{}) + assert.NoError(t, err) + } + } + + err := podWatcher.Start(stopCh) + assert.NoError(t, err) + + return podWatcher, client +} + +func countUnmountCalls(mounter *mount.FakeMounter) int { + unmountCalls := 0 + for _, action := range mounter.GetLog() { + if action.Action == mount.FakeActionUnmount { + unmountCalls++ + } + } + return unmountCalls +} + +func TestHandleS3PodAttachmentUpdate(t *testing.T) { + tests := []struct { + name string + nodeID string + pod *corev1.Pod + unmountError error + expectUnmount bool + }{ + { + name: "different node", + nodeID: "node1", + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + }, + Spec: corev1.PodSpec{ + NodeName: "different-node", + }, + }, + expectUnmount: false, + }, + { + name: "same node with unmount annotation", + nodeID: nodeName, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + Labels: map[string]string{ + mppod.LabelVolumeId: "vol1", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + }, + expectUnmount: true, + }, + { + name: "same node without unmount annotation", + nodeID: nodeName, + pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + }, + expectUnmount: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + kubeletPath := t.TempDir() + t.Setenv("KUBELET_PATH", kubeletPath) + t.Chdir(kubeletPath) + + sourceMountDir := t.TempDir() + + podWatcher, client := setupPodWatcher(t, tt.pod) + + if tt.pod != nil { + podPath := filepath.Join(kubeletPath, "pods", string(tt.pod.UID)) + commDir := mppod.PathOnHost(podPath) + err := os.MkdirAll(commDir, 0750) + assert.NoError(t, err) + + err = os.MkdirAll(filepath.Join(sourceMountDir, string(tt.pod.UID)), 0750) + assert.NoError(t, err) + } + + fakeMounter := mount.NewFakeMounter(nil) + if tt.unmountError != nil { + fakeMounter.UnmountFunc = func(path string) error { + return tt.unmountError + } + } + + credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { + return dummyIMDSRegion, nil + }) + + unmounter := mounter.NewPodUnmounter(tt.nodeID, mpmounter.NewWithMount(fakeMounter), podWatcher, credProvider, sourceMountDir) + unmounter.HandleMountpointPodUpdate(nil, tt.pod) + + unmountCalls := countUnmountCalls(fakeMounter) + expectedUnmounts := 0 + if tt.expectUnmount { + expectedUnmounts = 1 + } + assert.Equals(t, expectedUnmounts, unmountCalls) + }) + } +} + +func TestCleanupDanglingMounts(t *testing.T) { + tests := []struct { + name string + pods []*corev1.Pod + unmountError error + expectedCalls int + }{ + { + name: "no dangling mounts", + pods: []*corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + }, + }, + expectedCalls: 0, + }, + { + name: "pod marked for unmount", + pods: []*corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + Labels: map[string]string{ + mppod.LabelVolumeId: "vol1", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + }, + }, + expectedCalls: 1, + }, + { + name: "with dangling mount and unmount error", + pods: []*corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: mountpointPodNamespace, + UID: "uid1", + Annotations: map[string]string{ + mppod.AnnotationNeedsUnmount: "true", + }, + Labels: map[string]string{ + mppod.LabelVolumeId: "vol1", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + }, + }, + unmountError: errors.New("unmount error"), + expectedCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + kubeletPath := t.TempDir() + t.Setenv("KUBELET_PATH", kubeletPath) + t.Chdir(kubeletPath) + sourceMountDir := t.TempDir() + + for _, pod := range tt.pods { + podPath := filepath.Join(kubeletPath, "pods", string(pod.UID)) + commDir := mppod.PathOnHost(podPath) + err := os.MkdirAll(commDir, 0750) + assert.NoError(t, err) + + err = os.MkdirAll(filepath.Join(sourceMountDir, string(pod.UID)), 0750) + assert.NoError(t, err) + } + + fakeMounter := mount.NewFakeMounter(nil) + if tt.unmountError != nil { + fakeMounter.UnmountFunc = func(path string) error { + return tt.unmountError + } + } + + podWatcher, client := setupPodWatcher(t, tt.pods...) + credProvider := credentialprovider.New(client.CoreV1(), func() (string, error) { + return dummyIMDSRegion, nil + }) + + unmounter := mounter.NewPodUnmounter(nodeName, mpmounter.NewWithMount(fakeMounter), podWatcher, credProvider, sourceMountDir) + err := unmounter.CleanupDanglingMounts() + assert.NoError(t, err) + + unmountCalls := countUnmountCalls(fakeMounter) + assert.Equals(t, tt.expectedCalls, unmountCalls) + }) + } +} diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index 470533ae..074d6697 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -58,7 +58,7 @@ func (m *SystemdMounter) IsMountPoint(target string) (bool, error) { // // This method will create the target path if it does not exist and if there is an existing corrupt // mount, it will attempt an unmount before attempting the mount. -func (m *SystemdMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (m *SystemdMounter) Mount(ctx context.Context, bucketName string, target string, credentialCtx credentialprovider.ProvideContext, args mountpoint.Args, _, _ string) error { if bucketName == "" { return fmt.Errorf("bucket name is empty") } diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index 771c0047..f55d6b1c 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -154,7 +154,7 @@ func TestS3MounterMount(t *testing.T) { testCase.before(t, env) } err := env.mounter.Mount(env.ctx, testCase.bucketName, testCase.targetPath, - testCase.provideCtx, mountpoint.ParseArgs(testCase.options)) + testCase.provideCtx, mountpoint.ParseArgs(testCase.options), "", "") env.mockCtl.Finish() if err != nil && !testCase.expectedErr { t.Fatal(err) diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index 121eeaad..b55df50e 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -124,8 +124,11 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl args := mountpoint.ParseArgs(mountpointArgs) + fsGroup := "" + pvMountOptions := "" if capMount := volCap.GetMount(); capMount != nil && util.UsePodMounter() { if volumeMountGroup := capMount.GetVolumeMountGroup(); volumeMountGroup != "" { + fsGroup = volumeMountGroup // We need to add the following flags to support fsGroup // If these flags were already set by customer in PV mountOptions then we won't override them args.SetIfAbsent(mountpoint.ArgGid, volumeMountGroup) @@ -133,6 +136,10 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl args.SetIfAbsent(mountpoint.ArgDirMode, filePerm770) args.SetIfAbsent(mountpoint.ArgFileMode, filePerm660) } + + if mountFlags := capMount.GetMountFlags(); mountFlags != nil { + pvMountOptions = strings.Join(mountFlags, ",") + } } if util.UsePodMounter() && !args.Has(mountpoint.ArgAllowOther) { @@ -144,7 +151,7 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl credentialCtx := credentialProvideContextFromPublishRequest(req, args) - if err := ns.Mounter.Mount(ctx, bucket, target, credentialCtx, args); err != nil { + if err := ns.Mounter.Mount(ctx, bucket, target, credentialCtx, args, fsGroup, pvMountOptions); err != nil { os.Remove(target) return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } @@ -257,12 +264,17 @@ func credentialProvideContextFromPublishRequest(req *csi.NodePublishVolumeReques podID, _ = podIDFromTargetPath(req.GetTargetPath()) } + authSource := credentialprovider.AuthenticationSourceDriver + if volumeCtx[volumecontext.AuthenticationSource] != credentialprovider.AuthenticationSourceUnspecified { + authSource = volumeCtx[volumecontext.AuthenticationSource] + } + bucketRegion, _ := args.Value(mountpoint.ArgRegion) return credentialprovider.ProvideContext{ PodID: podID, VolumeID: req.GetVolumeId(), - AuthenticationSource: volumeCtx[volumecontext.AuthenticationSource], + AuthenticationSource: authSource, PodNamespace: volumeCtx[volumecontext.CSIPodNamespace], ServiceAccountTokens: volumeCtx[volumecontext.CSIServiceAccountTokens], ServiceAccountName: volumeCtx[volumecontext.CSIServiceAccountName], diff --git a/pkg/driver/node/node_test.go b/pkg/driver/node/node_test.go index ad95fb38..52bee130 100644 --- a/pkg/driver/node/node_test.go +++ b/pkg/driver/node/node_test.go @@ -3,6 +3,7 @@ package node_test import ( "errors" "io/fs" + "strings" "testing" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -69,9 +70,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Any()) + gomock.Any(), + gomock.Eq(""), + gomock.Eq(""), + ) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -104,9 +109,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"}))) + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"})), + gomock.Eq(""), + gomock.Eq(""), + ) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -142,9 +151,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"}))) + gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"})), + gomock.Eq(""), + gomock.Eq(""), + ) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -180,9 +193,13 @@ func TestNodePublishVolume(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"})), + gomock.Eq(""), + gomock.Eq(""), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -254,9 +271,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"})), + gomock.Eq("123"), + gomock.Eq(""), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -292,9 +313,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--gid=123", "--allow-other", "--dir-mode=770", "--file-mode=660"})), + gomock.Eq("123"), + gomock.Eq("--allow-other"), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -330,9 +355,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--allow-root"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--allow-root"})), + gomock.Eq(""), + gomock.Eq(""), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -368,9 +397,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs([]string{"--allow-other"}))).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--allow-other"})), + gomock.Eq(""), + gomock.Eq("--allow-other"), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -407,9 +440,13 @@ func TestNodePublishVolumeForPodMounter(t *testing.T) { gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq(credentialprovider.ProvideContext{ - VolumeID: volumeId, + VolumeID: volumeId, + AuthenticationSource: credentialprovider.AuthenticationSourceDriver, }), - gomock.Eq(mountpoint.ParseArgs(mountFlags))).Return(nil) + gomock.Eq(mountpoint.ParseArgs(mountFlags)), + gomock.Eq("123"), + gomock.Eq(strings.Join(mountFlags, ",")), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -572,7 +609,7 @@ var _ mounter.Mounter = &dummyMounter{} type dummyMounter struct{} -func (d *dummyMounter) Mount(ctx context.Context, bucketName string, target string, provideCtx credentialprovider.ProvideContext, args mountpoint.Args) error { +func (d *dummyMounter) Mount(ctx context.Context, bucketName string, target string, provideCtx credentialprovider.ProvideContext, args mountpoint.Args, fsGroup, pvMountOptions string) error { return nil } diff --git a/pkg/mountpoint/mounter/mount.go b/pkg/mountpoint/mounter/mount.go index de14657b..c91dcb84 100644 --- a/pkg/mountpoint/mounter/mount.go +++ b/pkg/mountpoint/mounter/mount.go @@ -4,6 +4,9 @@ package mounter import ( "errors" "fmt" + "os" + "strings" + "syscall" "k8s.io/klog/v2" mountutils "k8s.io/mount-utils" @@ -54,6 +57,14 @@ func (m *Mounter) Mount(target Target, opts MountOptions) (int, error) { return mount(target, opts) } +// BindMount performs a bind mount syscall from `source` to `target`. +func (m *Mounter) BindMount(source, target Target) error { + if target == "" { + return ErrMissingTarget + } + return bindMount(source, target) +} + // Unmount unmounts Mountpoint at `target`. // // This requires `CAP_SYS_ADMIN` capability in the target namespace. @@ -105,3 +116,63 @@ func (m *Mounter) CheckMountPoint(target Target) (bool, error) { func (m *Mounter) IsMountPointCorrupted(err error) bool { return mountutils.IsCorruptedMnt(err) } + +// findSourceMountPoint locates the source S3 mount point for a given target path by comparing +// device IDs and inodes with all S3 mount points at driver source directory `sourceMountDir`. +// +// Parameters: +// - target: The target path whose source mount point needs to be found +// - sourceMountDir: directory where to find source mount points +// +// Returns: +// - string: The path of the source mount point if found +// - error: An error if the operation fails +// +// The function works by: +// 1. Getting the device ID and inode of the target path +// 2. Listing all mount points in the system that has "mountpoint-s3" as device name and prefix `sourceMountDir` +// 3. Finding a mount point that matches both the device ID and inode of the target +func (m *Mounter) FindSourceMountPoint(target, sourceMountDir string) (string, error) { + targetFileInfo, err := os.Stat(target) + if err != nil { + return "", fmt.Errorf("failed to stat %q: %w", target, err) + } + + targetSysInfo, ok := targetFileInfo.Sys().(*syscall.Stat_t) + if !ok { + return "", fmt.Errorf("failed to get system info for target %q", target) + } + + targetDevID := targetSysInfo.Dev + targetInodeID := targetSysInfo.Ino + + mountPoints, err := m.mount.List() + if err != nil { + return "", fmt.Errorf("failed to list mount points: %w", err) + } + + for _, mountPoint := range mountPoints { + if mountPoint.Device != fsName || !strings.HasPrefix(mountPoint.Path, sourceMountDir) { + continue + } + + mountPathInfo, err := os.Stat(mountPoint.Path) + if err != nil { + klog.V(4).Infof("Skipping mount point %q: unable to stat %v", mountPoint.Path, err) + continue + } + + mountSysInfo, ok := mountPathInfo.Sys().(*syscall.Stat_t) + if !ok { + klog.V(4).Infof("Skipping mount point %q: unable to get system info", mountPoint.Path) + continue + } + + if targetDevID == mountSysInfo.Dev && targetInodeID == mountSysInfo.Ino { + return mountPoint.Path, nil + } + } + + return "", fmt.Errorf("no source mount point found for path %q (device: %d, inode: %d)", + target, targetDevID, targetInodeID) +} diff --git a/pkg/mountpoint/mounter/mount_darwin.go b/pkg/mountpoint/mounter/mount_darwin.go index e2a6be0b..3e5ea686 100644 --- a/pkg/mountpoint/mounter/mount_darwin.go +++ b/pkg/mountpoint/mounter/mount_darwin.go @@ -9,6 +9,10 @@ func mount(_ Target, _ MountOptions) (int, error) { return 0, errors.New("Only supported on Linux") } +func bindMount(source, target string) error { + return errors.New("Only supported on Linux") +} + func statx(path string) error { // statx is a Linux-specific syscall, let's simulate with os.Stat _, err := os.Stat(path) diff --git a/pkg/mountpoint/mounter/mount_linux.go b/pkg/mountpoint/mounter/mount_linux.go index c42727a2..98470e53 100644 --- a/pkg/mountpoint/mounter/mount_linux.go +++ b/pkg/mountpoint/mounter/mount_linux.go @@ -75,6 +75,14 @@ func mount(target Target, opts MountOptions) (int, error) { return fd, nil } +// bindMount performs a bind mount syscall from `source` to `target`. +func bindMount(source, target string) error { + if err := unix.Mount(source, target, "", unix.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to bind mount from %s to %s: %v", source, target, err) + } + return nil +} + func statx(path string) error { var stat unix.Statx_t if err := unix.Statx(unix.AT_FDCWD, path, unix.AT_STATX_FORCE_SYNC, 0, &stat); err != nil { diff --git a/pkg/podmounter/mppod/creator.go b/pkg/podmounter/mppod/creator.go index a7ed245d..acb50137 100644 --- a/pkg/podmounter/mppod/creator.go +++ b/pkg/podmounter/mppod/creator.go @@ -16,11 +16,15 @@ import ( // Labels populated on spawned Mountpoint Pods. const ( LabelMountpointVersion = "s3.csi.aws.com/mountpoint-version" - LabelPodUID = "s3.csi.aws.com/pod-uid" LabelVolumeName = "s3.csi.aws.com/volume-name" + LabelVolumeId = "s3.csi.aws.com/volume-id" LabelCSIDriverVersion = "s3.csi.aws.com/mounted-by-csi-driver-version" ) +const ( + AnnotationNeedsUnmount = "s3.csi.aws.com/needs-unmount" +) + const CommunicationDirSizeLimit = 10 * 1024 * 1024 // 10MB // A ContainerConfig represents configuration for containers in the spawned Mountpoint Pods. @@ -50,22 +54,16 @@ func NewCreator(config Config) *Creator { return &Creator{config: config} } -// Create returns a new Mountpoint Pod spec to schedule for given `pod` and `pv`. -// -// It automatically assigns Mountpoint Pod to `pod`'s node. -// The name of the Mountpoint Pod is consistently generated from `pod` and `pv` using `MountpointPodNameFor` function. -func (c *Creator) Create(pod *corev1.Pod, pv *corev1.PersistentVolume) (*corev1.Pod, error) { - node := pod.Spec.NodeName - name := MountpointPodNameFor(string(pod.UID), pv.Name) - +// Create returns a new Mountpoint Pod spec to schedule for given `node` and `pv`. +func (c *Creator) Create(node string, pv *corev1.PersistentVolume) (*corev1.Pod, error) { mpPod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: c.config.Namespace, + GenerateName: "mp-", + Namespace: c.config.Namespace, Labels: map[string]string{ LabelMountpointVersion: c.config.MountpointVersion, - LabelPodUID: string(pod.UID), LabelVolumeName: pv.Name, + LabelVolumeId: pv.Spec.CSI.VolumeHandle, LabelCSIDriverVersion: c.config.CSIDriverVersion, }, }, @@ -140,7 +138,7 @@ func (c *Creator) Create(pod *corev1.Pod, pv *corev1.PersistentVolume) (*corev1. }, } - volumeAttributes := extractVolumeAttributes(pv) + volumeAttributes := ExtractVolumeAttributes(pv) if saName := volumeAttributes[volumecontext.MountpointPodServiceAccountName]; saName != "" { mpPod.Spec.ServiceAccountName = saName @@ -201,9 +199,9 @@ func (c *Creator) Create(pod *corev1.Pod, pv *corev1.PersistentVolume) (*corev1. return mpPod, nil } -// extractVolumeAttributes extracts volume attributes from given `pv`. +// ExtractVolumeAttributes extracts volume attributes from given `pv`. // It always returns a non-nil map, and it's safe to use even though `pv` doesn't contain any volume attributes. -func extractVolumeAttributes(pv *corev1.PersistentVolume) map[string]string { +func ExtractVolumeAttributes(pv *corev1.PersistentVolume) map[string]string { csiSpec := pv.Spec.CSI if csiSpec == nil { return map[string]string{} diff --git a/pkg/podmounter/mppod/creator_test.go b/pkg/podmounter/mppod/creator_test.go index 1a5f0ee7..57bc03fa 100644 --- a/pkg/podmounter/mppod/creator_test.go +++ b/pkg/podmounter/mppod/creator_test.go @@ -7,7 +7,6 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" "github.com/awslabs/aws-s3-csi-driver/pkg/cluster" @@ -25,6 +24,7 @@ const ( testNode = "test-node" testPodUID = "test-pod-uid" testVolName = "test-vol" + testVolID = "test-vol-id" csiDriverVersion = "1.12.0" ) @@ -47,14 +47,14 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu creator := mppod.NewCreator(createTestConfig(clusterVariant)) verifyDefaultValues := func(mpPod *corev1.Pod) { - // This is a hash of `testPodUID` + `testVolName` - assert.Equals(t, "mp-8ef7856a0c7f1d5706bd6af93fdc4bc90b33cf2ceb6769b4afd62586", mpPod.Name) + assert.Equals(t, "mp-", mpPod.GenerateName) + assert.Equals(t, "", mpPod.Name) assert.Equals(t, namespace, mpPod.Namespace) assert.Equals(t, map[string]string{ mppod.LabelMountpointVersion: mountpointVersion, - mppod.LabelPodUID: testPodUID, - mppod.LabelVolumeName: testVolName, mppod.LabelCSIDriverVersion: csiDriverVersion, + mppod.LabelVolumeName: testVolName, + mppod.LabelVolumeId: testVolID, }, mpPod.Labels) assert.Equals(t, priorityClassName, mpPod.Spec.PriorityClassName) @@ -111,17 +111,17 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu } t.Run("Empty PV", func(t *testing.T) { - mpPod, err := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod, err := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, + Spec: corev1.PersistentVolumeSpec{ + PersistentVolumeSource: corev1.PersistentVolumeSource{ + CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, + }, + }, + }, }) assert.NoError(t, err) @@ -129,20 +129,14 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu }) t.Run("With ServiceAccountName specified in PV", func(t *testing.T) { - mpPod, err := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod, err := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, Spec: corev1.PersistentVolumeSpec{ PersistentVolumeSource: corev1.PersistentVolumeSource{ CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, VolumeAttributes: map[string]string{ "mountpointPodServiceAccountName": "mount-s3-sa", }, @@ -158,20 +152,14 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu t.Run("With Container Resources specified in PV", func(t *testing.T) { t.Run("With valid requests and limits", func(t *testing.T) { - mpPod, err := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod, err := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, Spec: corev1.PersistentVolumeSpec{ PersistentVolumeSource: corev1.PersistentVolumeSource{ CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, VolumeAttributes: map[string]string{ "mountpointContainerResourcesRequestsCpu": "1", "mountpointContainerResourcesRequestsMemory": "100Mi", @@ -197,20 +185,14 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu }) t.Run("With valid requests only", func(t *testing.T) { - mpPod, err := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod, err := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, Spec: corev1.PersistentVolumeSpec{ PersistentVolumeSource: corev1.PersistentVolumeSource{ CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, VolumeAttributes: map[string]string{ "mountpointContainerResourcesRequestsCpu": "1", "mountpointContainerResourcesRequestsMemory": "100Mi", @@ -232,20 +214,14 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu }) t.Run("With valid limits only", func(t *testing.T) { - mpPod, err := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + mpPod, err := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, Spec: corev1.PersistentVolumeSpec{ PersistentVolumeSource: corev1.PersistentVolumeSource{ CSI: &corev1.CSIPersistentVolumeSource{ + VolumeHandle: testVolID, VolumeAttributes: map[string]string{ "mountpointContainerResourcesLimitsCpu": "2", "mountpointContainerResourcesLimitsMemory": "200Mi", @@ -294,14 +270,7 @@ func createAndVerifyPod(t *testing.T, clusterVariant cluster.Variant, expectedRu }, } { t.Run(name, func(t *testing.T) { - _, err := creator.Create(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - UID: types.UID(testPodUID), - }, - Spec: corev1.PodSpec{ - NodeName: testNode, - }, - }, &corev1.PersistentVolume{ + _, err := creator.Create(testNode, &corev1.PersistentVolume{ ObjectMeta: metav1.ObjectMeta{ Name: testVolName, }, diff --git a/pkg/podmounter/mppod/mppod.go b/pkg/podmounter/mppod/mppod.go deleted file mode 100644 index 05d3b4d2..00000000 --- a/pkg/podmounter/mppod/mppod.go +++ /dev/null @@ -1,17 +0,0 @@ -// Package mppod provides utilities for creating and accessing Mountpoint Pods. -package mppod - -import ( - "crypto/sha256" - "fmt" -) - -// MountpointPodNameFor returns a consistent and unique Pod name for -// Mountpoint Pod for given `podUID` and `volumeName`. -// -// Changing output of this function might cause duplicate Mountpoint Pods to be spawned, -// ideally multiple implementation of this function shouldn't co-exists in the same cluster -// unless there is a clean install of the CSI Driver. -func MountpointPodNameFor(podUID string, volumeName string) string { - return fmt.Sprintf("mp-%x", sha256.Sum224(fmt.Appendf(nil, "%s%s", podUID, volumeName))) -} diff --git a/pkg/podmounter/mppod/mppod_test.go b/pkg/podmounter/mppod/mppod_test.go deleted file mode 100644 index f24a5d7b..00000000 --- a/pkg/podmounter/mppod/mppod_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package mppod_test - -import ( - "testing" - - "github.com/google/uuid" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - - "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" - "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" -) - -func TestGeneratingMountpointPodName(t *testing.T) { - t.Run("Consistency", func(t *testing.T) { - pod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{UID: types.UID(uuid.New().String())}, - } - pvc := &corev1.PersistentVolumeClaim{ - Spec: corev1.PersistentVolumeClaimSpec{VolumeName: "test-vol"}, - } - - // Ensure `MountpointPodNameFor` returns a consistent output for the same Pod and PVC. - assert.Equals(t, - mppod.MountpointPodNameFor(string(pod.UID), pvc.Spec.VolumeName), - mppod.MountpointPodNameFor(string(pod.UID), pvc.Spec.VolumeName)) - }) - - t.Run("Uniqueness", func(t *testing.T) { - pod1 := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{UID: types.UID(uuid.New().String())}, - } - pvc1 := &corev1.PersistentVolumeClaim{ - Spec: corev1.PersistentVolumeClaimSpec{VolumeName: "test-vol-1"}, - } - pod2 := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{UID: types.UID(uuid.New().String())}, - } - pvc2 := &corev1.PersistentVolumeClaim{ - Spec: corev1.PersistentVolumeClaimSpec{VolumeName: "test-vol-2"}, - } - - if mppod.MountpointPodNameFor(string(pod1.UID), pvc1.Spec.VolumeName) == mppod.MountpointPodNameFor(string(pod1.UID), pvc2.Spec.VolumeName) { - t.Error("Different PVCs with same Pod should return a different Mountpoint Pod name") - } - if mppod.MountpointPodNameFor(string(pod1.UID), pvc1.Spec.VolumeName) == mppod.MountpointPodNameFor(string(pod2.UID), pvc1.Spec.VolumeName) { - t.Error("Different Pods with same PVC should return a different Mountpoint Pod name") - } - }) - - t.Run("Snapshot", func(t *testing.T) { - mountpointPodName := mppod.MountpointPodNameFor("a4509011-bd2a-4f37-b1b0-05d715087852", "test-vol") - assert.Equals(t, "mp-55f7d2331f3149f00d62d7af839d4cee895e1c68a2f0d96ffd359f79", mountpointPodName) - }) -} diff --git a/pkg/podmounter/mppod/watcher/watcher.go b/pkg/podmounter/mppod/watcher/watcher.go index c33f5692..2fc3ccf6 100644 --- a/pkg/podmounter/mppod/watcher/watcher.go +++ b/pkg/podmounter/mppod/watcher/watcher.go @@ -10,6 +10,9 @@ import ( corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" listerv1 "k8s.io/client-go/listers/core/v1" @@ -33,8 +36,15 @@ type Watcher struct { } // New creates a new [Watcher] with the given Kubernetes client, Mountpoint Pod namespace, and resync duration. -func New(client kubernetes.Interface, namespace string, defaultResync time.Duration) *Watcher { - factory := informers.NewSharedInformerFactoryWithOptions(client, defaultResync, informers.WithNamespace(namespace)) +func New(client kubernetes.Interface, namespace, nodeName string, defaultResync time.Duration) *Watcher { + factory := informers.NewSharedInformerFactoryWithOptions( + client, + defaultResync, + informers.WithNamespace(namespace), + informers.WithTweakListOptions(func(options *metav1.ListOptions) { + options.FieldSelector = fields.OneTermEqualSelector("spec.nodeName", nodeName).String() + }), + ) informer := factory.Core().V1().Pods().Informer() lister := factory.Core().V1().Pods().Lister().Pods(namespace) return &Watcher{informer, lister} @@ -51,6 +61,21 @@ func (w *Watcher) Start(stopCh <-chan struct{}) error { return nil } +// Get returns pod from watcher's cache. +func (w *Watcher) Get(name string) (*corev1.Pod, error) { + return w.lister.Get(name) +} + +// List returns all pods from watcher's cache. +func (w *Watcher) List() ([]*corev1.Pod, error) { + return w.lister.List(labels.Everything()) +} + +// AddEventHandler adds pod event handler. +func (w *Watcher) AddEventHandler(handler cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + return w.informer.AddEventHandler(handler) +} + // Wait blocks until the specified Mountpoint Pod is found and ready, or until the context is cancelled. func (w *Watcher) Wait(ctx context.Context, name string) (*corev1.Pod, error) { // Set a watcher for Pod create & update events diff --git a/pkg/podmounter/mppod/watcher/watcher_test.go b/pkg/podmounter/mppod/watcher/watcher_test.go index 9055c3f9..716a85be 100644 --- a/pkg/podmounter/mppod/watcher/watcher_test.go +++ b/pkg/podmounter/mppod/watcher/watcher_test.go @@ -108,7 +108,7 @@ func TestGettingPodsConcurrently(t *testing.T) { } func createAndStartWatcher(t *testing.T, client kubernetes.Interface) *watcher.Watcher { - mpPodWatcher := watcher.New(client, testMountpointPodNamespace, 10*time.Second) + mpPodWatcher := watcher.New(client, testMountpointPodNamespace, "test-node", 10*time.Second) stopCh := make(chan struct{}) t.Cleanup(func() { diff --git a/tests/controller/controller_test.go b/tests/controller/controller_test.go index 8fa6dd37..8699e43f 100644 --- a/tests/controller/controller_test.go +++ b/tests/controller/controller_test.go @@ -1,8 +1,15 @@ package controller_test import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/google/uuid" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + . "github.com/onsi/gomega/gstruct" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" @@ -10,12 +17,21 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "github.com/awslabs/aws-s3-csi-driver/cmd/aws-s3-csi-controller/csicontroller" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" ) var _ = Describe("Mountpoint Controller", func() { + var testNode string + + BeforeEach(func() { + testNode = generateRandomNodeName() + }) + Context("Static Provisioning", func() { Context("Scheduled Pod with pre-bound PV and PVC", func() { It("should schedule a Mountpoint Pod", func() { @@ -23,9 +39,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -35,10 +51,10 @@ var _ = Describe("Mountpoint Controller", func() { vol2.bind() pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - waitAndVerifyMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { @@ -46,9 +62,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -58,10 +74,10 @@ var _ = Describe("Mountpoint Controller", func() { vol2.bind() pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) @@ -70,13 +86,13 @@ var _ = Describe("Mountpoint Controller", func() { vol := createVolume() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) vol.bind() - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -84,32 +100,32 @@ var _ = Describe("Mountpoint Controller", func() { vol2 := createVolume() pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol1.bind() - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol2.bind() - waitAndVerifyMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { vol := createVolume(withCSIDriver(ebsCSIDriver)) pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) vol.bind() - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -117,20 +133,20 @@ var _ = Describe("Mountpoint Controller", func() { vol2 := createVolume(withCSIDriver(ebsCSIDriver)) pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol2.bind() - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) vol1.bind() - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) @@ -141,11 +157,11 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -156,13 +172,13 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - waitAndVerifyMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { @@ -171,11 +187,11 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -186,13 +202,13 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) @@ -202,15 +218,15 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) vol.bind() - waitAndVerifyMountpointPodFor(pod, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) }) It("should schedule a Mountpoint Pod per PV", func() { @@ -219,18 +235,18 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) vol2.bind() - expectNoMountpointPodFor(pod, vol1) - waitAndVerifyMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol1.pv)) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol2, pod) vol1.bind() - waitAndVerifyMountpointPodFor(pod, vol1) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) }) It("should not schedule a Mountpoint Pod if the volume is backed by a different CSI driver", func() { @@ -238,15 +254,15 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) vol.bind() - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) It("should only schedule Mountpoint Pods for volumes backed by S3 CSI Driver", func() { @@ -255,61 +271,401 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withPVC(vol1.pvc), withPVC(vol2.pvc)) - expectNoMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol1.pv.Name}) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol2.pv.Name}) vol1.bind() vol2.bind() - pod.schedule("test-node") + pod.schedule(testNode) - waitAndVerifyMountpointPodFor(pod, vol1) - expectNoMountpointPodFor(pod, vol2) + waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol1, pod) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol2.pv)) }) }) Context("Multiple Pods using the same PV and PVC", func() { Context("Same Node", func() { Context("Pre-bound PV and PVC", func() { - It("should schedule a Mountpoint Pod per Workload Pod", func() { + It("should schedule single Mountpoint Pod", func() { vol := createVolume() vol.bind() pod1 := createPod(withPVC(vol.pvc)) pod2 := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) + + pod1.schedule(testNode) - pod1.schedule("test-node") + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + pod2.schedule(testNode) - pod2.schedule("test-node") + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod1, s3pa.ResourceVersion) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod2, s3pa.ResourceVersion) - waitAndVerifyMountpointPodFor(pod2, vol) + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") }) }) Context("Late PV and PVC binding", func() { - It("should schedule a Mountpoint Pod per Workload Pod", func() { + It("should schedule single Mountpoint Pod", func() { vol := createVolume() pod1 := createPod(withPVC(vol.pvc)) pod2 := createPod(withPVC(vol.pvc)) - pod1.schedule("test-node") + pod1.schedule(testNode) + + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + vol.bind() + + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) + + pod2.schedule(testNode) + + waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod1, s3pa.ResourceVersion) + waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion(testNode, vol, pod2, s3pa.ResourceVersion) + }) + }) + Context("MountOptions", func() { + It("should schedule different Mountpoint Pods if mountOptions were modified", func() { + vol := createVolume() vol.bind() + pv := vol.pv + + pod1 := createPod(withPVC(vol.pvc)) + pod1.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod1) + + // Adding some sleep time before updating PV because reconciler requeues pod1 event to clear expectation + // and it can cause transient test failure if we update PV MountOptions too quickly + time.Sleep(5 * time.Second) + + pv.Spec.MountOptions = []string{"--allow-delete"} + Expect(k8sClient.Update(ctx, pv)).To(Succeed()) + + pod2 := createPod(withPVC(vol.pvc)) + pod2.schedule(testNode) + + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod2) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + }) + }) + + Context("FSGroup", func() { + It("should schedule single Mountpoint Pod if workload pods have the same FSGroup", func() { + vol := createVolume() + vol.bind() + + pod1 := createPod(withPVC(vol.pvc), withFSGroup(1111)) + pod2 := createPod(withPVC(vol.pvc), withFSGroup(1111)) + pod1.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["WorkloadFSGroup"] = "1111" + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(testNode, vol, pod1, expectedFields) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + pod2.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod1, s3pa.ResourceVersion, expectedFields) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod2, s3pa.ResourceVersion, expectedFields) + + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") + }) + + It("should schedule different Mountpoint Pods if workload pods have different FSGroup", func() { + vol := createVolume() + vol.bind() + + pod1 := createPod(withPVC(vol.pvc)) // no fsGroup + pod2 := createPod(withPVC(vol.pvc), withFSGroup(1111)) + pod3 := createPod(withPVC(vol.pvc), withFSGroup(2222)) + pod1.schedule(testNode) + pod2.schedule(testNode) + pod3.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadFSGroup"] = "1111" + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadFSGroup"] = "2222" + s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(s3pa1.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + Expect(s3pa2.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod2.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + }) + }) + + Context("authenticationSource=pod", func() { + It("should schedule single Mountpoint Pod if workload pods have the same namespace and service account", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() + + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod1.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadServiceAccountName"] = sa.Name + expectedFields["WorkloadNamespace"] = defaultNamespace + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(testNode, vol, pod1, expectedFields) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) + + pod2.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod1, s3pa.ResourceVersion, expectedFields) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod2, s3pa.ResourceVersion, expectedFields) + + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") + }) + + It("should schedule single Mountpoint Pod if workload pods have the same namespace, service account and IRSA role annotation", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() + + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + Annotations: map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role"}, + }, + } + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod1.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadServiceAccountName"] = sa.Name + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role" + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(testNode, vol, pod1, expectedFields) + expectNoPodUIDInS3PodAttachment(s3pa, string(pod2.UID)) + + pod2.schedule(testNode) + + s3pa1, mpPod1 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod1, s3pa.ResourceVersion, expectedFields) + s3pa2, mpPod2 := waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod2, s3pa.ResourceVersion, expectedFields) + + Expect(s3pa1.Name).To(Equal(s3pa2.Name), "S3PodAttachment should have the same name") + Expect(mpPod1.Name).To(Equal(mpPod2.Name), "Mountpoint Pods should have the same name") + }) + + It("should schedule different Mountpoint Pods if workload pods have different IRSA role annotations for the same service account", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() + + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) // no IRSA annotation + pod1.schedule(testNode) + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountName"] = sa.Name + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "" + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + + // Adding some sleep time before updating SA because reconciler requeues pod1 event to clear expectation + // and it can cause transient test failure if we update SA annotation too quickly + time.Sleep(5 * time.Second) + + sa.Annotations = map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role-1"} + Expect(k8sClient.Update(ctx, sa)).To(Succeed()) + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod2.schedule(testNode) + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role-1" + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + + time.Sleep(5 * time.Second) + + sa.Annotations = map[string]string{csicontroller.AnnotationServiceAccountRole: "test-role-2"} + Expect(k8sClient.Update(ctx, sa)).To(Succeed()) + pod3 := createPod(withPVC(vol.pvc), withServiceAccount(sa.Name)) + pod3.schedule(testNode) + expectedFields["WorkloadServiceAccountIAMRoleARN"] = "test-role-2" + s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(s3pa1.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + Expect(s3pa2.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod2.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + }) + + It("should schedule different Mountpoint Pods if workload pods have different service account names in the same namespace", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() + + sa1 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + sa2 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa1)).To(Succeed()) + Expect(k8sClient.Create(ctx, sa2)).To(Succeed()) + + pod1 := createPod(withPVC(vol.pvc)) // default service account + pod2 := createPod(withPVC(vol.pvc), withServiceAccount(sa1.Name)) + pod3 := createPod(withPVC(vol.pvc), withServiceAccount(sa2.Name)) + pod1.schedule(testNode) + pod2.schedule(testNode) + pod3.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountName"] = "default" + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadServiceAccountName"] = sa1.Name + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + expectedFields["WorkloadServiceAccountName"] = sa2.Name + s3pa3 := waitForS3PodAttachmentWithFields(expectedFields, "") + + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + Expect(len(s3pa3.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + mpPod3 := waitAndVerifyMountpointPodFromPodAttachment(s3pa3, pod3, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(s3pa1.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + Expect(s3pa2.Name).NotTo(Equal(s3pa3.Name), "S3PodAttachment should not have the same name") + + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + Expect(mpPod2.Name).NotTo(Equal(mpPod3.Name), "Mountpoint Pods should not have the same name") + }) - pod2.schedule("test-node") + It("should schedule different Mountpoint Pods if workload pods have same service account names in the different namespace", func() { + vol := createVolume(withVolumeAttributes(map[string]string{ + "authenticationSource": "pod", + })) + vol.bind() - waitAndVerifyMountpointPodFor(pod2, vol) + ns := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "ns-", + }, + } + Expect(k8sClient.Create(ctx, ns)).To(Succeed()) + sa1 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "sa-", + Namespace: defaultNamespace, + }, + } + Expect(k8sClient.Create(ctx, sa1)).To(Succeed()) + sa2 := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: sa1.Name, + Namespace: ns.Name, + }, + } + Expect(k8sClient.Create(ctx, sa2)).To(Succeed()) + _, err := controllerutil.CreateOrUpdate(ctx, k8sClient, sa2, func() error { return nil }) + Expect(err).To(Succeed()) + pvc2 := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "test-pvc", + Namespace: ns.Name, + }, + Spec: corev1.PersistentVolumeClaimSpec{ + StorageClassName: vol.pvc.Spec.StorageClassName, + AccessModes: vol.pvc.Spec.AccessModes, + Resources: vol.pvc.Spec.Resources, + }, + } + Expect(k8sClient.Create(ctx, pvc2)).To(Succeed()) + vol2 := testVolume{pv: vol.pv, pvc: pvc2} + + pod1 := createPod(withPVC(vol.pvc), withServiceAccount(sa1.Name)) + pod1.schedule(testNode) + + expectedFields := defaultExpectedFields(testNode, vol.pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = defaultNamespace + expectedFields["WorkloadServiceAccountName"] = sa1.Name + s3pa1 := waitForS3PodAttachmentWithFields(expectedFields, "") + Expect(len(s3pa1.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod1 := waitAndVerifyMountpointPodFromPodAttachment(s3pa1, pod1, vol) + + pod2 := createPod(withPVC(pvc2), withServiceAccount(sa1.Name), withNamespace(ns.Name)) + vol2.bind() + pod2.schedule(testNode) + + expectedFields["WorkloadNamespace"] = ns.Name + s3pa2 := waitForS3PodAttachmentWithFields(expectedFields, "") + Expect(len(s3pa2.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod2 := waitAndVerifyMountpointPodFromPodAttachment(s3pa2, pod2, vol) + + Expect(s3pa1.Name).NotTo(Equal(s3pa2.Name), "S3PodAttachment should not have the same name") + Expect(mpPod1.Name).NotTo(Equal(mpPod2.Name), "Mountpoint Pods should not have the same name") }) }) }) @@ -323,17 +679,16 @@ var _ = Describe("Mountpoint Controller", func() { pod1 := createPod(withPVC(vol.pvc)) pod2 := createPod(withPVC(vol.pvc)) - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) pod1.schedule("test-node1") - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node1", vol, pod1) + expectNoS3PodAttachmentWithFields(defaultExpectedFields("test-node2", vol.pv)) pod2.schedule("test-node2") - waitAndVerifyMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node2", vol, pod2) }) }) @@ -345,17 +700,16 @@ var _ = Describe("Mountpoint Controller", func() { pod2 := createPod(withPVC(vol.pvc)) pod1.schedule("test-node1") - expectNoMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + expectNoS3PodAttachmentWithFields(map[string]string{"PersistentVolumeName": vol.pv.Name}) vol.bind() - waitAndVerifyMountpointPodFor(pod1, vol) - expectNoMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node1", vol, pod1) + expectNoS3PodAttachmentWithFields(defaultExpectedFields("test-node2", vol.pv)) pod2.schedule("test-node2") - waitAndVerifyMountpointPodFor(pod2, vol) + waitAndVerifyS3PodAttachmentAndMountpointPod("test-node2", vol, pod2) }) }) }) @@ -367,9 +721,9 @@ var _ = Describe("Mountpoint Controller", func() { pod := createPod(withVolume("empty-dir", corev1.VolumeSource{ EmptyDir: &corev1.EmptyDirVolumeSource{}, })) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodForWorkloadPod(pod) + expectNoS3PodAttachmentWithFields(map[string]string{"NodeName": testNode}) }) It("should not schedule a Mountpoint Pod if the Pod only uses a hostPath volume", func() { @@ -379,9 +733,9 @@ var _ = Describe("Mountpoint Controller", func() { Type: ptr.To(corev1.HostPathDirectoryOrCreate), }, })) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodForWorkloadPod(pod) + expectNoS3PodAttachmentWithFields(map[string]string{"NodeName": testNode}) }) It("should not schedule a Mountpoint Pod if the Pod only different volume-types/CSI-drivers", func() { @@ -400,9 +754,9 @@ var _ = Describe("Mountpoint Controller", func() { EmptyDir: &corev1.EmptyDirVolumeSource{}, }), ) - pod.schedule("test-node") + pod.schedule(testNode) - expectNoMountpointPodForWorkloadPod(pod) + expectNoS3PodAttachmentWithFields(map[string]string{"NodeName": testNode}) }) }) @@ -412,10 +766,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + _, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) mountpointPod.succeed() @@ -427,11 +780,10 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) // `pod` got a `mountpointPod` - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + s3pa, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) // `mountpointPod` got terminated mountpointPod.succeed() @@ -439,30 +791,28 @@ var _ = Describe("Mountpoint Controller", func() { // `pod` got terminated pod.terminate() + waitForObjectToDisappear(s3pa) // Since `pod` was in `Pending` state, termination of Pod will still keep that in // `Pending` state but will populate `DeletionTimestamp` to indicate this Pod is terminating. // In this case, there shouldn't be a new Mountpoint Pod spawned for it. - expectNoMountpointPodFor(pod, vol) + expectNoS3PodAttachmentWithFields(defaultExpectedFields(testNode, vol.pv)) }) - It("should delete Mountpoint Pod if the Workload Pod is terminated", func() { + It("should delete S3 Pod Attachment if the Workload Pod is terminated", func() { vol := createVolume() vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) // `pod` got a `mountpointPod` - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + s3pa, _ := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) // `pod` got terminated pod.terminate() waitForObjectToDisappear(pod.Pod) - - // `mountpointPod` scheduled for `pod` should also get terminated - waitForObjectToDisappear(mountpointPod.Pod) + waitForObjectToDisappear(s3pa) }) }) @@ -481,10 +831,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + _, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) Expect(mountpointPod.Spec.ServiceAccountName).To(Equal(sa.Name)) @@ -502,10 +851,9 @@ var _ = Describe("Mountpoint Controller", func() { vol.bind() pod := createPod(withPVC(vol.pvc)) - pod.schedule("test-node") + pod.schedule(testNode) - mountpointPod := waitForMountpointPodFor(pod, vol) - verifyMountpointPodFor(pod, vol, mountpointPod) + _, mountpointPod := waitAndVerifyS3PodAttachmentAndMountpointPod(testNode, vol, pod) mpContainer := mountpointPod.Spec.Containers[0] Expect(mpContainer.Resources.Requests).To(Equal(corev1.ResourceList{ @@ -579,6 +927,30 @@ func withVolume(name string, vol corev1.VolumeSource) podModifier { } } +// withFSGroup returns a `podModifier` that sets fsGroup in the Pod's security context. +func withFSGroup(fsGroup int64) podModifier { + return func(pod *corev1.Pod) { + if pod.Spec.SecurityContext == nil { + pod.Spec.SecurityContext = &corev1.PodSecurityContext{} + } + pod.Spec.SecurityContext.FSGroup = &fsGroup + } +} + +// withServiceAccount returns a `podModifier` that sets ServiceAccountName. +func withServiceAccount(saName string) podModifier { + return func(pod *corev1.Pod) { + pod.Spec.ServiceAccountName = saName + } +} + +// withNamespace returns a `podModifier` that sets Namespace. +func withNamespace(namespace string) podModifier { + return func(pod *corev1.Pod) { + pod.Namespace = namespace + } +} + // A podModifier is a function for modifying Pod to be created. type podModifier func(*corev1.Pod) @@ -703,54 +1075,129 @@ func createVolume(modifiers ...volumeModifier) *testVolume { return testVolume } -// waitForMountpointPodFor waits and returns the Mountpoint Pod scheduled for given `pod` and `vol`. -func waitForMountpointPodFor(pod *testPod, vol *testVolume) *testPod { - mountpointPodKey := mountpointPodNameFor(pod, vol) +// waitForMountpointPodWithName waits and returns the Mountpoint Pod scheduled for given `mpPodName` +func waitForMountpointPodWithName(mpPodName string) *testPod { mountpointPod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: mountpointPodKey.Name, - Namespace: mountpointPodKey.Namespace, + Name: mpPodName, + Namespace: mountpointNamespace, }, } waitForObject(mountpointPod) return &testPod{Pod: mountpointPod} } -// expectNoMountpointPodFor verifies that there is no Mountpoint Pod scheduled for given `pod` and `vol`. -func expectNoMountpointPodFor(pod *testPod, vol *testVolume) { - mountpointPodKey := mountpointPodNameFor(pod, vol) - expectNoObject(&corev1.Pod{ObjectMeta: metav1.ObjectMeta{ - Name: mountpointPodKey.Name, - Namespace: mountpointPodKey.Namespace, - }}) +// expectNoS3PodAttachmentWithFields verifies that no MountpointS3PodAttachment matching specified fields exists within a time period +func expectNoS3PodAttachmentWithFields(expectedFields map[string]string) { + Consistently(func(g Gomega) { + list := &crdv1beta.MountpointS3PodAttachmentList{} + g.Expect(k8sClient.List(ctx, list)).To(Succeed()) + + for i := range list.Items { + cr := &list.Items[i] + if matchesSpec(cr.Spec, expectedFields) { + g.Expect(false).To(BeTrue(), "Found matching MountpointS3PodAttachment when none was expected: %#v", cr) + } + } + }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) } -// expectNoMountpointPodForWorkloadPod verifies that there is no Mountpoint Pod scheduled for given `pod`. -// `expectNoMountpointPodFor` is preferable to this method if the `vol` is known as this performs a slower list operation. -func expectNoMountpointPodForWorkloadPod(pod *testPod) { - Consistently(func(g Gomega) { - podList := &corev1.PodList{} - g.Expect(k8sClient.List(ctx, podList, - client.InNamespace(mountpointNamespace), client.MatchingLabels{ - mppod.LabelPodUID: string(pod.UID), - }, - )).To(Succeed()) +// expectNoPodUIDInS3PodAttachment validates that pod UID does not exist in MountpointS3PodAttachments map +func expectNoPodUIDInS3PodAttachment(s3pa *crdv1beta.MountpointS3PodAttachment, podUID string) { + for _, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == podUID { + Expect(false).To(BeTrue(), "Found pod UID %s in S3PodAttachment when none was expected: %#v", podUID, s3pa) + } + } + } +} - g.Expect(podList.Items).To(BeEmpty(), "Expected empty list but got: %#v", podList) - g.Expect(podList.Continue).To(BeEmpty(), "Continue token on list must be empty but got: %s", podList.Continue) - }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) +// waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields waits and verifies that MountpointS3PodAttachment and Mountpoint Pod +// are created for given `node`, `vol`, `pod` and `expectedFields` +func waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields( + node string, + vol *testVolume, + pod *testPod, + expectedFields map[string]string, +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { + s3pa := waitForS3PodAttachmentWithFields(expectedFields, "") + Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) + return s3pa, mpPod +} + +// waitAndVerifyS3PodAttachmentAndMountpointPod waits and verifies that MountpointS3PodAttachment and Mountpoint Pod +// are created for given `node`, `vol` and `pod` +func waitAndVerifyS3PodAttachmentAndMountpointPod( + node string, + vol *testVolume, + pod *testPod, +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { + return waitAndVerifyS3PodAttachmentAndMountpointPodWithExpectedFields(node, vol, pod, defaultExpectedFields(node, vol.pv)) +} + +// waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField waits and verifies that MountpointS3PodAttachment with `minVersion` and Mountpoint Pod +// are created for given `node`, `vol`, `pod` and `expectedFields` +func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField( + testNode string, + vol *testVolume, + pod *testPod, + minVersion string, + expectedFields map[string]string, +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { + s3pa := waitForS3PodAttachmentWithFields(expectedFields, minVersion) + Expect(len(s3pa.Spec.MountpointS3PodAttachments)).To(Equal(1)) + mpPod := waitAndVerifyMountpointPodFromPodAttachment(s3pa, pod, vol) + return s3pa, mpPod +} + +// waitAndVerifyS3PodAttachmentAndMountpointPod waits and verifies that MountpointS3PodAttachment with `minVersion` and Mountpoint Pod +// are created for given `node`, `vol` and `pod` +func waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersion( + testNode string, + vol *testVolume, + pod *testPod, + minVersion string, +) (*crdv1beta.MountpointS3PodAttachment, *testPod) { + return waitAndVerifyS3PodAttachmentAndMountpointPodWithMinVersionAndExpectedField(testNode, vol, pod, minVersion, defaultExpectedFields(testNode, vol.pv)) } -// waitAndVerifyMountpointPodFor waits and verifies Mountpoint Pod scheduled for given `pod` and `vol.` -func waitAndVerifyMountpointPodFor(pod *testPod, vol *testVolume) { - mountpointPod := waitForMountpointPodFor(pod, vol) +// waitAndVerifyMountpointPodFromPodAttachment waits and verifies Mountpoint Pod scheduled for given `s3pa`, `pod` and `vol.` +func waitAndVerifyMountpointPodFromPodAttachment(s3pa *crdv1beta.MountpointS3PodAttachment, pod *testPod, vol *testVolume) *testPod { + // Find the mpPodName where pod.UID exists in the value slice + var mpPodName string + podUID := string(pod.UID) + + for k, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + if attachment.WorkloadPodUID == podUID { + mpPodName = k + break + } + } + if mpPodName != "" { + break + } + } + + Expect(mpPodName).NotTo(BeEmpty(), "No Mountpoint Pod found for pod UID %s in MountpointS3PodAttachment: %#v", podUID, s3pa) + Expect(s3pa.Spec.MountpointS3PodAttachments[mpPodName]).To(ContainElement( + MatchFields(IgnoreExtras, Fields{ + "WorkloadPodUID": Equal(podUID), + "AttachmentTime": Not(BeZero()), + }), + )) + + mountpointPod := waitForMountpointPodWithName(mpPodName) verifyMountpointPodFor(pod, vol, mountpointPod) + + return mountpointPod } // verifyMountpointPodFor verifies given `mountpointPod` for given `pod` and `vol`. func verifyMountpointPodFor(pod *testPod, vol *testVolume, mountpointPod *testPod) { Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelMountpointVersion, mountpointVersion)) - Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelPodUID, string(pod.UID))) Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelVolumeName, vol.pvc.Spec.VolumeName)) Expect(mountpointPod.ObjectMeta.Labels).To(HaveKeyWithValue(mppod.LabelCSIDriverVersion, version.GetVersion().DriverVersion)) @@ -788,6 +1235,82 @@ func waitForObject[Obj client.Object](obj Obj, verifiers ...func(Gomega, Obj)) { }, defaultWaitTimeout, defaultWaitRetryPeriod).Should(Succeed()) } +// waitForS3PodAttachmentWithFields waits until a MountpointS3PodAttachment matching specified node and pv appears in the cluster +func waitForS3PodAttachmentWithFields( + expectedFields map[string]string, + minResourceVersion string, + verifiers ...func(Gomega, *crdv1beta.MountpointS3PodAttachment), +) *crdv1beta.MountpointS3PodAttachment { + var matchedCR *crdv1beta.MountpointS3PodAttachment + + Eventually(func(g Gomega) { + list := &crdv1beta.MountpointS3PodAttachmentList{} + g.Expect(k8sClient.List(ctx, list)).To(Succeed()) + + for i := range list.Items { + cr := &list.Items[i] + if matchesSpec(cr.Spec, expectedFields) { + // Skip if the resource version isn't newer than the minimum + if minResourceVersion != "" { + minVersion, err := strconv.ParseInt(minResourceVersion, 10, 64) + g.Expect(err).NotTo(HaveOccurred()) + + currentVersion, err := strconv.ParseInt(cr.ResourceVersion, 10, 64) + g.Expect(err).NotTo(HaveOccurred()) + + if currentVersion <= minVersion { + continue + } + } + + for _, verifier := range verifiers { + verifier(g, cr) + } + matchedCR = cr + return + } + } + + g.Expect(false).To(BeTrue(), "No matching MountpointS3PodAttachment found") + }, defaultWaitTimeout, defaultWaitRetryPeriod).Should(Succeed()) + + return matchedCR +} + +// matchesSpec checks whether MountpointS3PodAttachmentSpec matches `expected` fields +func matchesSpec(spec crdv1beta.MountpointS3PodAttachmentSpec, expected map[string]string) bool { + specValues := map[string]string{ + "NodeName": spec.NodeName, + "PersistentVolumeName": spec.PersistentVolumeName, + "VolumeID": spec.VolumeID, + "MountOptions": spec.MountOptions, + "AuthenticationSource": spec.AuthenticationSource, + "WorkloadFSGroup": spec.WorkloadFSGroup, + "WorkloadServiceAccountName": spec.WorkloadServiceAccountName, + "WorkloadNamespace": spec.WorkloadNamespace, + "WorkloadServiceAccountIAMRoleARN": spec.WorkloadServiceAccountIAMRoleARN, + } + + for k, v := range expected { + if specValues[k] != v { + return false + } + } + return true +} + +// defaultExpectedFields return default test expected fields for MountpointS3PodAttachmentSpec matching +func defaultExpectedFields(nodeName string, pv *corev1.PersistentVolume) map[string]string { + return map[string]string{ + "NodeName": nodeName, + "PersistentVolumeName": pv.Name, + "VolumeID": pv.Spec.CSI.VolumeHandle, + "MountOptions": strings.Join(pv.Spec.MountOptions, ","), + "AuthenticationSource": "driver", + "WorkloadFSGroup": "", + } +} + // waitForObjectToDisappear waits until `obj` disappears in the control plane. func waitForObjectToDisappear(obj client.Object) { key := types.NamespacedName{Name: obj.GetName(), Namespace: obj.GetNamespace()} @@ -802,20 +1325,7 @@ func waitForObjectToDisappear(obj client.Object) { }, defaultWaitTimeout, defaultWaitRetryPeriod).Should(Succeed()) } -// expectNoObject verifies object with given key does not exists within a time period. -func expectNoObject(obj client.Object) { - key := types.NamespacedName{Name: obj.GetName(), Namespace: obj.GetNamespace()} - Consistently(func(g Gomega) { - err := k8sClient.Get(ctx, key, obj) - g.Expect(err).ToNot(BeNil(), "The object expected not to exists but its found: %#v", obj) - g.Expect(apierrors.IsNotFound(err)).To(BeTrue(), "Expected not found error but fond: %v", err) - }, defaultWaitTimeout/2, defaultWaitTimeout/4).Should(Succeed()) -} - -// mountpointPodNameFor returns namespaced name of Mountpoint Pod for given `pod` and `vol`. -func mountpointPodNameFor(pod *testPod, vol *testVolume) types.NamespacedName { - return types.NamespacedName{ - Name: mppod.MountpointPodNameFor(string(pod.Pod.UID), vol.pvc.Spec.VolumeName), - Namespace: mountpointNamespace, - } +// generateRandomNodeName generates random node name +func generateRandomNodeName() string { + return fmt.Sprintf("test-node-%s", uuid.New().String()[:8]) } diff --git a/tests/controller/suite_test.go b/tests/controller/suite_test.go index 3ae0381c..e2947692 100644 --- a/tests/controller/suite_test.go +++ b/tests/controller/suite_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -65,7 +66,14 @@ var _ = BeforeSuite(func() { ctx, cancel = context.WithCancel(context.TODO()) By("Bootstrapping test environment") - testEnv = &envtest.Environment{} + + crdv1beta.AddToScheme(scheme.Scheme) + testEnv = &envtest.Environment{ + CRDInstallOptions: envtest.CRDInstallOptions{ + Paths: []string{"../crd/mountpoints3podattachments-crd.yaml"}, + }, + ErrorIfCRDPathMissing: true, + } var err error cfg, err = testEnv.Start() @@ -79,6 +87,10 @@ var _ = BeforeSuite(func() { k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{Scheme: scheme.Scheme}) Expect(err).ToNot(HaveOccurred()) + if err := crdv1beta.SetupManagerIndices(k8sManager); err != nil { + Expect(err).NotTo(HaveOccurred()) + } + err = csicontroller.NewReconciler(k8sManager.GetClient(), mppod.Config{ Namespace: mountpointNamespace, MountpointVersion: mountpointVersion, @@ -99,6 +111,7 @@ var _ = BeforeSuite(func() { }() createMountpointNamespace() + createDefaultServiceAccount() createMountpointPriorityClass() }) @@ -117,6 +130,20 @@ func createMountpointNamespace() { waitForObject(namespace) } +// createDefaultServiceAccount creates default service account in the control plane. +func createDefaultServiceAccount() { + sa := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "default", + Namespace: defaultNamespace, + }, + } + + By(fmt.Sprintf("Creating default service account in %q", mountpointNamespace)) + Expect(k8sClient.Create(ctx, sa)).To(Succeed()) + waitForObject(sa) +} + // createMountpointPriorityClass creates priority class for Mountpoint Pods. func createMountpointPriorityClass() { By(fmt.Sprintf("Creating priority class %q for Mountpoint Pods", mountpointPriorityClassName)) diff --git a/tests/crd/mountpoints3podattachments-crd.yaml b/tests/crd/mountpoints3podattachments-crd.yaml new file mode 100644 index 00000000..7da946f3 --- /dev/null +++ b/tests/crd/mountpoints3podattachments-crd.yaml @@ -0,0 +1,116 @@ +# Auto-generated file via `make generate`. Do not edit. +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.17.3 + name: mountpoints3podattachments.s3.csi.aws.com +spec: + group: s3.csi.aws.com + names: + kind: MountpointS3PodAttachment + listKind: MountpointS3PodAttachmentList + plural: mountpoints3podattachments + shortNames: + - s3pa + singular: mountpoints3podattachment + scope: Cluster + versions: + - name: v1beta + schema: + openAPIV3Schema: + description: MountpointS3PodAttachment is the Schema for the mountpoints3podattachments + API. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: MountpointS3PodAttachmentSpec defines the desired state of + MountpointS3PodAttachment. + properties: + authenticationSource: + description: Authentication source taken from volume attribute field + `authenticationSource`. + type: string + mountOptions: + description: Comma separated mount options taken from volume. + type: string + mountpointS3PodAttachments: + additionalProperties: + items: + description: WorkloadAttachment represents the attachment details + of a workload pod to a Mountpoint S3 pod. + properties: + attachmentTime: + description: AttachmentTime represents when the workload pod + was attached to the Mountpoint S3 pod + format: date-time + type: string + workloadPodUID: + description: WorkloadPodUID is the unique identifier of the + attached workload pod + type: string + required: + - attachmentTime + - workloadPodUID + type: object + type: array + description: Maps each Mountpoint S3 pod name to its workload attachments + type: object + nodeName: + description: Name of the node. + type: string + persistentVolumeName: + description: Name of the Persistent Volume. + type: string + volumeID: + description: Volume ID. + type: string + workloadFSGroup: + description: Workload pod's `fsGroup` from pod security context + type: string + workloadNamespace: + description: 'Workload pod''s namespace. Exists only if `authenticationSource: + pod`.' + type: string + workloadServiceAccountIAMRoleARN: + description: 'EKS IAM Role ARN from workload pod''s service account + annotation (IRSA). Exists only if `authenticationSource: pod` and + service account has `eks.amazonaws.com/role-arn` annotation.' + type: string + workloadServiceAccountName: + description: 'Workload pod''s service account name. Exists only if + `authenticationSource: pod`.' + type: string + required: + - authenticationSource + - mountOptions + - mountpointS3PodAttachments + - nodeName + - persistentVolumeName + - volumeID + - workloadFSGroup + type: object + type: object + selectableFields: + - jsonPath: .spec.nodeName + served: true + storage: true + subresources: + status: {} diff --git a/tests/e2e-kubernetes/e2e_test.go b/tests/e2e-kubernetes/e2e_test.go index afe2bbc0..d584fc6d 100644 --- a/tests/e2e-kubernetes/e2e_test.go +++ b/tests/e2e-kubernetes/e2e_test.go @@ -26,11 +26,13 @@ func init() { flag.StringVar(&BucketPrefix, "bucket-prefix", "local", "prefix for temporary buckets") flag.BoolVar(&Performance, "performance", false, "run performance tests") flag.BoolVar(&IMDSAvailable, "imds-available", false, "indicates whether instance metadata service is available") + flag.BoolVar(&IsPodMounter, "pod-mounter", false, "indicates whether CSI Driver is installed with Pod Mounter or not") flag.Parse() s3client.DefaultRegion = BucketRegion custom_testsuites.DefaultRegion = BucketRegion custom_testsuites.IMDSAvailable = IMDSAvailable + custom_testsuites.IsPodMounter = IsPodMounter } func TestE2E(t *testing.T) { @@ -61,6 +63,7 @@ var CSITestSuites = []func() framework.TestSuite{ custom_testsuites.InitS3MountOptionsTestSuite, custom_testsuites.InitS3CSICredentialsTestSuite, custom_testsuites.InitS3CSICacheTestSuite, + custom_testsuites.InitS3CSIPodSharingTestSuite, } // This executes testSuites for csi volumes. diff --git a/tests/e2e-kubernetes/go.mod b/tests/e2e-kubernetes/go.mod index cda5022d..79b7d043 100644 --- a/tests/e2e-kubernetes/go.mod +++ b/tests/e2e-kubernetes/go.mod @@ -49,6 +49,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.5.0 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/evanphx/json-patch/v5 v5.9.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -121,6 +122,7 @@ require ( golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/genproto v0.0.0-20231127180814-3a041ad873d4 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect @@ -145,6 +147,7 @@ require ( k8s.io/kubelet v0.31.3 // indirect k8s.io/mount-utils v0.31.3 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3 // indirect + sigs.k8s.io/controller-runtime v0.19.2 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect sigs.k8s.io/yaml v1.4.0 // indirect diff --git a/tests/e2e-kubernetes/go.sum b/tests/e2e-kubernetes/go.sum index eb82bd88..291ba502 100644 --- a/tests/e2e-kubernetes/go.sum +++ b/tests/e2e-kubernetes/go.sum @@ -80,6 +80,10 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v4.12.0+incompatible h1:4onqiflcdA9EOZ4RxV643DvftH5pOlLGNtQ5lPWQu84= +github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= +github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -358,6 +362,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= +gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -438,6 +444,8 @@ k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1 k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3 h1:2770sDpzrjjsAtVhSeUFseziht227YAWYHLGNM8QPwY= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= +sigs.k8s.io/controller-runtime v0.19.2 h1:3sPrF58XQEPzbE8T81TN6selQIMGbtYwuaJ6eDssDF8= +sigs.k8s.io/controller-runtime v0.19.2/go.mod h1:iRmWllt8IlaLjvTTDLhRBXIEtkCK6hwVBJJsYS9Ajf4= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= diff --git a/tests/e2e-kubernetes/scripts/run.sh b/tests/e2e-kubernetes/scripts/run.sh index fe025e53..e454c9a0 100755 --- a/tests/e2e-kubernetes/scripts/run.sh +++ b/tests/e2e-kubernetes/scripts/run.sh @@ -86,6 +86,11 @@ fi CI_ROLE_ARN=${CI_ROLE_ARN:-""} MOUNTER_KIND=${MOUNTER_KIND:-systemd} +if [ "$MOUNTER_KIND" = "pod" ]; then + USE_POD_MOUNTER=true +else + USE_POD_MOUNTER=false +fi mkdir -p ${TEST_DIR} mkdir -p ${BIN_DIR} @@ -218,14 +223,14 @@ elif [[ "${ACTION}" == "install_driver" ]]; then elif [[ "${ACTION}" == "run_tests" ]]; then set +e pushd tests/e2e-kubernetes - KUBECONFIG=${KUBECONFIG} ginkgo -p -vv -timeout 60m -- --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --imds-available=true + KUBECONFIG=${KUBECONFIG} ginkgo -p -vv -timeout 60m -- --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --imds-available=true --pod-mounter=${USE_POD_MOUNTER} EXIT_CODE=$? print_cluster_info exit $EXIT_CODE elif [[ "${ACTION}" == "run_perf" ]]; then set +e pushd tests/e2e-kubernetes - KUBECONFIG=${KUBECONFIG} go test -ginkgo.vv --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --performance=true --imds-available=true + KUBECONFIG=${KUBECONFIG} go test -ginkgo.vv --bucket-region=${REGION} --commit-id=${TAG} --bucket-prefix=${CLUSTER_NAME} --performance=true --imds-available=true --pod-mounter=${USE_POD_MOUNTER} EXIT_CODE=$? print_cluster_info popd diff --git a/tests/e2e-kubernetes/testdriver.go b/tests/e2e-kubernetes/testdriver.go index b0ee62d0..c4ef4456 100644 --- a/tests/e2e-kubernetes/testdriver.go +++ b/tests/e2e-kubernetes/testdriver.go @@ -19,6 +19,7 @@ var ( BucketPrefix string Performance bool IMDSAvailable bool + IsPodMounter bool ) type s3Driver struct { diff --git a/tests/e2e-kubernetes/testsuites/pod_sharing.go b/tests/e2e-kubernetes/testsuites/pod_sharing.go new file mode 100644 index 00000000..a5132c17 --- /dev/null +++ b/tests/e2e-kubernetes/testsuites/pod_sharing.go @@ -0,0 +1,616 @@ +package custom_testsuites + +import ( + "context" + "fmt" + "path/filepath" + "strconv" + "strings" + "time" + + crdv1beta "github.com/awslabs/aws-s3-csi-driver/pkg/api/v1beta" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/kubernetes/test/e2e/framework" + e2epod "k8s.io/kubernetes/test/e2e/framework/pod" + storageframework "k8s.io/kubernetes/test/e2e/storage/framework" + admissionapi "k8s.io/pod-security-admission/api" + "k8s.io/utils/ptr" +) + +var s3paGVR = schema.GroupVersionResource{Group: "s3.csi.aws.com", Version: "v1beta", Resource: "mountpoints3podattachments"} + +const mountpointNamespace = "mount-s3" + +const defaultTimeout = 10 * time.Second +const defaultInterval = 1 * time.Second + +var IsPodMounter bool + +type s3CSIPodSharingTestSuite struct { + tsInfo storageframework.TestSuiteInfo +} + +func InitS3CSIPodSharingTestSuite() storageframework.TestSuite { + return &s3CSIPodSharingTestSuite{ + tsInfo: storageframework.TestSuiteInfo{ + Name: "podsharing", + TestPatterns: []storageframework.TestPattern{ + storageframework.DefaultFsPreprovisionedPV, + }, + }, + } +} + +func (t *s3CSIPodSharingTestSuite) GetTestSuiteInfo() storageframework.TestSuiteInfo { + return t.tsInfo +} + +func (t *s3CSIPodSharingTestSuite) SkipUnsupportedTests(_ storageframework.TestDriver, _ storageframework.TestPattern) { +} + +func (t *s3CSIPodSharingTestSuite) DefineTests(driver storageframework.TestDriver, pattern storageframework.TestPattern) { + type local struct { + resources []*storageframework.VolumeResource + config *storageframework.PerTestConfig + } + var ( + l local + ) + + f := framework.NewFrameworkWithCustomTimeouts(NamespacePrefix+"podsharing", storageframework.GetDriverTimeouts(driver)) + f.NamespacePodSecurityLevel = admissionapi.LevelBaseline + + cleanup := func(ctx context.Context) { + var errs []error + for _, resource := range l.resources { + errs = append(errs, resource.CleanupResource(ctx)) + } + framework.ExpectNoError(errors.NewAggregate(errs), "while cleanup resource") + } + ginkgo.BeforeEach(func(ctx context.Context) { + if !IsPodMounter { + ginkgo.Skip("Pod Mounter is not enabled, skipping pod sharing tests") + } + + l = local{} + l.config = driver.PrepareTest(ctx, f) + ginkgo.DeferCleanup(cleanup) + }) + + ginkgo.It("should share Mountpoint Pod (authenticationSource=driver)", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + var nodeSelector map[string]string + for i := 0; i < 2; i++ { + index := i + 1 + + if i > 0 && targetNode != "" { + nodeSelector = map[string]string{"kubernetes.io/hostname": targetNode} + } + + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume", index)) + pod, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, nodeSelector, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, defaultExpectedFields(targetNode, resource.Pv)) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should share Mountpoint Pod if pods have the same fsGroup", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + for i := 0; i < 2; i++ { + index := i + 1 + podConfig := &e2epod.Config{ + NS: f.Namespace.Name, + PVCs: []*v1.PersistentVolumeClaim{resource.Pvc}, + SecurityLevel: admissionapi.LevelBaseline, + FsGroup: ptr.To(int64(1000)), + } + + if i > 0 && targetNode != "" { + podConfig.NodeSelection = e2epod.NodeSelection{ + Name: targetNode, + } + } + + ginkgo.By(fmt.Sprintf("Creating pod%d", index)) + pod, err := e2epod.CreateSecPod(ctx, f.ClientSet, podConfig, 10*time.Second) + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + expectedFields := defaultExpectedFields(targetNode, resource.Pv) + expectedFields["WorkloadFSGroup"] = "1000" + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, expectedFields) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should not share Mountpoint Pod if pods have different fsGroup", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + for i := 0; i < 2; i++ { + index := i + 1 + podConfig := &e2epod.Config{ + NS: f.Namespace.Name, + PVCs: []*v1.PersistentVolumeClaim{resource.Pvc}, + SecurityLevel: admissionapi.LevelBaseline, + FsGroup: ptr.To(int64(1000 + i)), + } + + // For the first pod, let it schedule anywhere + // For subsequent pods, force them to the same node as the first pod + if i > 0 && targetNode != "" { + podConfig.NodeSelection = e2epod.NodeSelection{ + Name: targetNode, + } + } + + ginkgo.By(fmt.Sprintf("Creating pod%d", index)) + pod, err := e2epod.CreateSecPod(ctx, f.ClientSet, podConfig, 10*time.Second) + framework.ExpectNoError(err) + + // Store the node name from the first pod + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + s3paNames, mountpointPodNames = verifyPodsHaveDifferentMountpointPods(ctx, f, pods, func(pod *v1.Pod) map[string]string { + expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) + expectedFields["WorkloadFSGroup"] = strconv.FormatInt(*pod.Spec.SecurityContext.FSGroup, 10) + return expectedFields + }) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should not share Mountpoint Pod if mountOptions are different", func(ctx context.Context) { + resource := createVolumeResourceWithMountOptions(ctx, l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + + // First Pod + ginkgo.By("Creating pod1 with a volume") + pod1, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, nil, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + framework.ExpectNoError(err) + targetNode = pod1.Spec.NodeName + + resource.Pv, err = f.ClientSet.CoreV1().PersistentVolumes().Get(ctx, resource.Pv.Name, metav1.GetOptions{}) + framework.ExpectNoError(err) + firstMountOptions := strings.Join(resource.Pv.Spec.MountOptions, ",") + resource.Pv.Spec.MountOptions = []string{"--allow-delete"} + resource.Pv, err = f.ClientSet.CoreV1().PersistentVolumes().Update(ctx, resource.Pv, metav1.UpdateOptions{}) + framework.ExpectNoError(err) + + // Second Pod + pods = append(pods, pod1) + ginkgo.By("Creating pod2 with a volume") + pod2, err := e2epod.CreatePod(ctx, f.ClientSet, f.Namespace.Name, map[string]string{"kubernetes.io/hostname": targetNode}, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + framework.ExpectNoError(err) + pods = append(pods, pod2) + + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + s3paNames, mountpointPodNames = verifyPodsHaveDifferentMountpointPods(ctx, f, pods, func(pod *v1.Pod) map[string]string { + expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) + if pod.Name == pod1.Name { + expectedFields["MountOptions"] = firstMountOptions + } else { + expectedFields["MountOptions"] = "--allow-delete" + } + return expectedFields + }) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should share Mountpoint Pod if pod namespaces and service accounts are the same (authenticationSource=pod)", func(ctx context.Context) { + idConfig, err := setupPodLevelIdentity(ctx, f) + framework.ExpectNoError(err) + defer idConfig.Cleanup(ctx) + resource := createVolumeResourceWithMountOptions(contextWithAuthenticationSource(ctx, "pod"), l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + var nodeSelector map[string]string + for i := 0; i < 2; i++ { + index := i + 1 + + if i > 0 && targetNode != "" { + nodeSelector = map[string]string{"kubernetes.io/hostname": targetNode} + } + + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume", index)) + pod := e2epod.MakePod(f.Namespace.Name, nodeSelector, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + pod.Spec.ServiceAccountName = idConfig.ServiceAccount.Name + pod, err := createPod(ctx, f.ClientSet, f.Namespace.Name, pod) + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + expectedFields := defaultExpectedFields(targetNode, resource.Pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = f.Namespace.Name + expectedFields["WorkloadServiceAccountName"] = idConfig.ServiceAccount.Name + s3paNames, mountpointPodNames = verifyPodsShareMountpointPod(ctx, f, pods, expectedFields) + checkCrossReadWrite(f, pods[0], pods[1]) + }) + + ginkgo.It("should not share Mountpoint Pod if pod service accounts are the different (authenticationSource=pod)", func(ctx context.Context) { + idConfig1, err := setupPodLevelIdentity(ctx, f) + framework.ExpectNoError(err) + defer idConfig1.Cleanup(ctx) + idConfig2, err := setupPodLevelIdentity(ctx, f) + framework.ExpectNoError(err) + defer idConfig2.Cleanup(ctx) + saNames := []string{idConfig1.ServiceAccount.Name, idConfig2.ServiceAccount.Name} + resource := createVolumeResourceWithMountOptions(contextWithAuthenticationSource(ctx, "pod"), l.config, pattern, nil) + l.resources = append(l.resources, resource) + + var s3paNames []string + var mountpointPodNames []string + var pods []*v1.Pod + var targetNode string + var nodeSelector map[string]string + for i := 0; i < 2; i++ { + index := i + 1 + + if i > 0 && targetNode != "" { + nodeSelector = map[string]string{"kubernetes.io/hostname": targetNode} + } + + ginkgo.By(fmt.Sprintf("Creating pod%d with a volume", index)) + pod := e2epod.MakePod(f.Namespace.Name, nodeSelector, []*v1.PersistentVolumeClaim{resource.Pvc}, admissionapi.LevelBaseline, "") + pod.Spec.ServiceAccountName = saNames[i] + pod, err := createPod(ctx, f.ClientSet, f.Namespace.Name, pod) + framework.ExpectNoError(err) + + if i == 0 { + targetNode = pod.Spec.NodeName + } + pods = append(pods, pod) + } + defer func() { + for _, pod := range pods { + framework.ExpectNoError(e2epod.DeletePodWithWait(ctx, f.ClientSet, pod)) + } + verifyMountpointResourcesCleanup(ctx, f, s3paNames, mountpointPodNames) + }() + + s3paNames, mountpointPodNames = verifyPodsHaveDifferentMountpointPods(ctx, f, pods, func(pod *v1.Pod) map[string]string { + expectedFields := defaultExpectedFields(pod.Spec.NodeName, resource.Pv) + expectedFields["AuthenticationSource"] = "pod" + expectedFields["WorkloadNamespace"] = f.Namespace.Name + expectedFields["WorkloadServiceAccountName"] = pod.Spec.ServiceAccountName + return expectedFields + }) + checkCrossReadWrite(f, pods[0], pods[1]) + }) +} + +func verifyPodsShareMountpointPod(ctx context.Context, f *framework.Framework, pods []*v1.Pod, expectedFields map[string]string) ([]string, []string) { + var s3paNames []string + var mountpointPodNames []string + var s3paList *crdv1beta.MountpointS3PodAttachmentList + framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) + if err != nil { + return false, err + } + s3paList, err = convertToCustomResourceList(list) + if err != nil { + return false, err + } + for _, s3pa := range s3paList.Items { + if matchesSpec(s3pa.Spec, expectedFields) { + s3paNames = append(s3paNames, s3pa.Name) + allUIDs := make(map[string]bool) + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + mountpointPodNames = append(mountpointPodNames, mpPodName) + for _, attachment := range attachments { + allUIDs[attachment.WorkloadPodUID] = true + } + } + for _, pod := range pods { + podUID := string(pod.UID) + if _, exists := allUIDs[podUID]; !exists { + return false, fmt.Errorf("pod UID %s not found in MountpointS3PodAttachment", podUID) + } + } + + return true, nil + } + } + + return false, err + })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + + return s3paNames, mountpointPodNames +} + +func verifyPodsHaveDifferentMountpointPods(ctx context.Context, f *framework.Framework, pods []*v1.Pod, expectedFieldsFunc func(pod *v1.Pod) map[string]string) ([]string, []string) { + var s3paNames []string + var mountpointPodNames []string + var s3paList *crdv1beta.MountpointS3PodAttachmentList + framework.Gomega().Eventually(ctx, framework.HandleRetry(func(ctx context.Context) (bool, error) { + list, err := f.DynamicClient.Resource(s3paGVR).List(ctx, metav1.ListOptions{}) + if err != nil { + return false, fmt.Errorf("failed to list S3PodAttachments: %w", err) + } + s3paList, err = convertToCustomResourceList(list) + if err != nil { + return false, fmt.Errorf("failed to convert to custom resource list: %w", err) + } + + matchCount := 0 + for _, s3pa := range s3paList.Items { + for _, pod := range pods { + if matchesSpec(s3pa.Spec, expectedFieldsFunc(pod)) { + s3paNames = append(s3paNames, s3pa.Name) + matchCount++ + break + } + } + } + + return matchCount == len(pods), nil + })).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + + podToMountpointPod := make(map[string]string) + for _, s3pa := range s3paList.Items { + for mpPodName, attachments := range s3pa.Spec.MountpointS3PodAttachments { + for _, attachment := range attachments { + podToMountpointPod[attachment.WorkloadPodUID] = mpPodName + } + } + } + + seenMountpointPods := make(map[string]bool) + for _, pod := range pods { + podUID := string(pod.UID) + mpPodName, exists := podToMountpointPod[podUID] + + framework.Gomega().Expect(exists).To(gomega.BeTrue()) + + _, alreadySeen := seenMountpointPods[mpPodName] + framework.Gomega().Expect(alreadySeen).To(gomega.BeFalse()) + + seenMountpointPods[mpPodName] = true + mountpointPodNames = append(mountpointPodNames, mpPodName) + } + + framework.Gomega().Expect(len(seenMountpointPods)).To(gomega.Equal(len(pods))) + + return s3paNames, mountpointPodNames +} + +// TODO: This does not fail for some reason after timeout +func verifyMountpointResourcesCleanup(ctx context.Context, f *framework.Framework, s3paNames []string, mountpointPodNames []string) { + framework.Logf("Verifying MountpointS3PodAttachments are deleted: %v", s3paNames) + framework.Gomega().Eventually(ctx, func() bool { + for _, s3paName := range s3paNames { + _, err := f.DynamicClient.Resource(s3paGVR).Get(ctx, s3paName, metav1.GetOptions{}) + if err == nil { + // S3PodAttachment still exists + return false + } + if !apierrors.IsNotFound(err) { + return false + } + } + return true + }).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) + + framework.Logf("Verifying Mountpoint Pods are deleted: %v", mountpointPodNames) + framework.Gomega().Eventually(ctx, func() bool { + for _, mpPodName := range mountpointPodNames { + _, err := f.ClientSet.CoreV1().Pods(mountpointNamespace).Get(ctx, mpPodName, metav1.GetOptions{}) + if err == nil { + // Pod still exists + return false + } + if !apierrors.IsNotFound(err) { + framework.Logf("Error checking pod %s: %v", mpPodName, err) + return false + } + } + return true + }).WithTimeout(defaultTimeout).WithPolling(defaultInterval).Should(gomega.BeTrue()) +} + +// Convert UnstructuredList to MountpointS3PodAttachmentList +func convertToCustomResourceList(list *unstructured.UnstructuredList) (*crdv1beta.MountpointS3PodAttachmentList, error) { + crList := &crdv1beta.MountpointS3PodAttachmentList{ + Items: make([]crdv1beta.MountpointS3PodAttachment, 0, len(list.Items)), + } + + for _, item := range list.Items { + cr := &crdv1beta.MountpointS3PodAttachment{} + err := runtime.DefaultUnstructuredConverter.FromUnstructured(item.Object, cr) + if err != nil { + return nil, fmt.Errorf("failed to convert item to MountpointS3PodAttachment: %v", err) + } + crList.Items = append(crList.Items, *cr) + } + + return crList, nil +} + +// matchesSpec checks whether MountpointS3PodAttachmentSpec matches `expected` fields +func matchesSpec(spec crdv1beta.MountpointS3PodAttachmentSpec, expected map[string]string) bool { + specValues := map[string]string{ + "NodeName": spec.NodeName, + "PersistentVolumeName": spec.PersistentVolumeName, + "VolumeID": spec.VolumeID, + "MountOptions": spec.MountOptions, + "AuthenticationSource": spec.AuthenticationSource, + "WorkloadFSGroup": spec.WorkloadFSGroup, + "WorkloadServiceAccountName": spec.WorkloadServiceAccountName, + "WorkloadNamespace": spec.WorkloadNamespace, + "WorkloadServiceAccountIAMRoleARN": spec.WorkloadServiceAccountIAMRoleARN, + } + + for k, v := range expected { + if specValues[k] != v { + return false + } + } + return true +} + +// defaultExpectedFields return default test expected fields for MountpointS3PodAttachmentSpec matching +func defaultExpectedFields(nodeName string, pv *v1.PersistentVolume) map[string]string { + return map[string]string{ + "NodeName": nodeName, + "PersistentVolumeName": pv.Name, + "VolumeID": pv.Spec.CSI.VolumeHandle, + "MountOptions": strings.Join(pv.Spec.MountOptions, ","), + "AuthenticationSource": "driver", + "WorkloadFSGroup": "", + } +} + +func checkCrossReadWrite(f *framework.Framework, pod1, pod2 *v1.Pod) { + toWrite := 1024 // 1KB + path := "/mnt/volume1" + + // Check write from pod1 and read from pod2 + checkPodWriteAndOtherPodRead(f, pod1, pod2, path, "file1.txt", toWrite) + + // Check write from pod2 and read from pod1 + checkPodWriteAndOtherPodRead(f, pod2, pod1, path, "file2.txt", toWrite) +} + +func checkPodWriteAndOtherPodRead(f *framework.Framework, writerPod, readerPod *v1.Pod, basePath, filename string, size int) { + filePath := filepath.Join(basePath, filename) + seed := time.Now().UTC().UnixNano() + + checkWriteToPath(f, writerPod, filePath, size, seed) + checkReadFromPath(f, readerPod, filePath, size, seed) +} + +type PodLevelIdentityConfig struct { + OIDCProvider string + ServiceAccount *v1.ServiceAccount + IAMRole string + Cleanup func(context.Context) error +} + +// setupPodLevelIdentity creates necessary resources for pod-level identity tests +func setupPodLevelIdentity(ctx context.Context, f *framework.Framework) (*PodLevelIdentityConfig, error) { + config := &PodLevelIdentityConfig{} + var cleanupFuncs []func(context.Context) error + + // Get OIDC Provider + config.OIDCProvider = oidcProviderForCluster(ctx, f) + if config.OIDCProvider == "" { + return nil, fmt.Errorf("OIDC provider is not configured") + } + + // Create Service Account + sa, removeSA := createServiceAccount(ctx, f) + config.ServiceAccount = sa + cleanupFuncs = append(cleanupFuncs, removeSA) + + // Create IAM Role with full access policy + role, removeRole := createRole(ctx, f, + assumeRoleWithWebIdentityPolicyDocument(ctx, config.OIDCProvider, sa), + iamPolicyS3FullAccess) + config.IAMRole = *role.Arn + cleanupFuncs = append(cleanupFuncs, removeRole) + + // Annotate Service Account with Role ARN + sa, restoreServiceAccountRole := overrideServiceAccountRole(ctx, f, sa, config.IAMRole) + config.ServiceAccount = sa + cleanupFuncs = append(cleanupFuncs, restoreServiceAccountRole) + + // Wait for role to be assumable + waitUntilRoleIsAssumableWithWebIdentity(ctx, f, sa) + + // Combine cleanup functions + config.Cleanup = func(ctx context.Context) error { + var errs []error + // Execute cleanup functions in reverse order + for i := len(cleanupFuncs) - 1; i >= 0; i-- { + if err := cleanupFuncs[i](ctx); err != nil { + errs = append(errs, err) + } + } + return errors.NewAggregate(errs) + } + + return config, nil +}