Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/gpu-mutating-webhook/admission_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"log"
"net/http"

admissionv1 "k8s.io/api/admission/v1"

Check failure on line 27 in cmd/gpu-mutating-webhook/admission_controller.go

View workflow job for this annotation

GitHub Actions / golang / Build

cannot find module providing package k8s.io/api/admission/v1: import lookup disabled by -mod=vendor

Check failure on line 27 in cmd/gpu-mutating-webhook/admission_controller.go

View workflow job for this annotation

GitHub Actions / code-scanning / Analyze Go code with CodeQL

cannot find module providing package k8s.io/api/admission/v1: import lookup disabled by -mod=vendor

Check failure on line 27 in cmd/gpu-mutating-webhook/admission_controller.go

View workflow job for this annotation

GitHub Actions / golang / Unit test

cannot find module providing package k8s.io/api/admission/v1: import lookup disabled by -mod=vendor
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/serializer"
Expand All @@ -46,6 +46,7 @@

type admitFunc func(*admissionv1.AdmissionRequest) ([]patchOperation, error)

// Swati: skip nvidia-dra-driver-gpu ns as well
func isKubeNamespace(ns string) bool {
return (ns == metav1.NamespacePublic || ns == metav1.NamespaceSystem)
}
Expand Down
213 changes: 123 additions & 90 deletions cmd/gpu-mutating-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,126 +26,155 @@ import (
admissionv1 "k8s.io/api/admission/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
)

const (
tlsDir = `/etc/webhook/tls`
tlsCertFile = `tls.crt`
tlsKeyFile = `tls.key`
tlsDir = `/etc/webhook/tls`
tlsCertFile = `tls.crt`
tlsKeyFile = `tls.key`
gpuResourceName = "nvidia.com/gpu"
gpuClaimName = "nvidia-gpu-resourceclaim"
gpuTemplateName = "nvidia-gpu-resourceclaim-template"
)

var (
podResource = metav1.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"}
gpuClaimName = "nvidia-gpu-resourceclaim"
gpuTemplateName = "nvidia-gpu-resourceclaim-template"
podResource = metav1.GroupVersionResource{Version: "v1", Resource: "pods"}
)

func applyGPUMutation(req *admissionv1.AdmissionRequest) ([]patchOperation, error) {
// Only mutate if the incoming resource is a Pod CREATE request.
if req.Resource != podResource {
log.Printf("applyGPUMutation invoked for a non-Pod resource: %v", req.Resource)
return nil, nil
}
if req.Operation != admissionv1.Create {
log.Printf("applyGPUMutation invoked for operation %s, ignoring", req.Operation)
// Only mutate Pod CREATE
// Swati: may be add UPDATE
if req.Resource != podResource || req.Operation != admissionv1.Create {
klog.Infof("skip mutation for %v/%v", req.Resource, req.Operation)
return nil, nil
}

raw := req.Object.Raw
var pod corev1.Pod
if _, _, err := universalDeserializer.Decode(raw, nil, &pod); err != nil {
return nil, fmt.Errorf("could not deserialize pod object: %v", err)
if _, _, err := universalDeserializer.Decode(req.Object.Raw, nil, &pod); err != nil {
klog.Errorf("failed to decode Pod: %v", err)
return nil, fmt.Errorf("could not deserialize pod: %w", err)
}

key := escapeJSONPointer(gpuResourceName)
var patches []patchOperation

// Check if the Pod already has a resource claim
hasGPUClaim := false
for _, rc := range pod.Spec.ResourceClaims {
if rc.Name == gpuClaimName {
hasGPUClaim = true
break
var ctrGPUResourceClaims []string

// Iterate on all containers and check for "nvidia.com/gpu" limits
// using the logic described here for prefering limits over requests
// GPUs are only supposed to be specified in the limits section, meaning
// - can specify GPU limits without specifying requests. limit will be used as request value by default
// - can specify GPU in both limits and requests but they must be equal
// - cannot specify GPU requests without specifying limits
// refer: https://kubernetes.io/docs/tasks/manage-gpus/scheduling-gpus/#using-device-plugins
for ci, ctr := range pod.Spec.Containers {
ctrName := ctr.Name
limitCount, limitOk := ctr.Resources.Limits[gpuResourceName]

// skip if no GPUs in limits
if !limitOk || limitCount.Value() < 1 {
continue
}
}

// Escape "nvidia.com/gpu" for JSON Patch
escapedGPUKey := strings.ReplaceAll(strings.ReplaceAll("nvidia.com/gpu", "~", "~0"), "/", "~1")

for i, c := range pod.Spec.Containers {
foundGPU := false

if _, ok := c.Resources.Requests["nvidia.com/gpu"]; ok {
foundGPU = true
patches = append(patches, patchOperation{
Op: "remove",
Path: fmt.Sprintf("/spec/containers/%d/resources/requests/%s", i, escapedGPUKey),
})
gpuCount := limitCount.Value()

// check any GPUs in requests
// it must be equal to limits
if reqCount, reqOK := ctr.Resources.Requests[gpuResourceName]; reqOK {
if reqCount.Value() != gpuCount {
klog.Warningf("container[%q]: gpu request (%d) != limit (%d), skipping mutation", ctrName, reqCount.Value(), gpuCount)
continue
}
reqPatch := removeResourceRequest(ci, "requests", key)
patches = append(patches, reqPatch)
klog.Infof("removed container[%q].Resources.Requests: %v", ctrName, reqPatch)
}

if _, ok := c.Resources.Limits["nvidia.com/gpu"]; ok {
foundGPU = true
patches = append(patches, patchOperation{
Op: "remove",
Path: fmt.Sprintf("/spec/containers/%d/resources/limits/%s", i, escapedGPUKey),
})
limitPatch := removeResourceRequest(ci, "limits", key)
patches = append(patches, limitPatch)
klog.Infof("removed container[%q].Resources.Limits: %v", ctrName, limitPatch)

// ensure container-claims slice exists
// this is JSON way to first creating the field if it does not exist and append later with "-"
if len(ctr.Resources.Claims) == 0 {
createPatch := createClaimPatch(fmt.Sprintf("/spec/containers/%d/resources/claims", ci))
patches = append(patches, createPatch)
klog.Infof("created container[%q] empty claims array: %v", ctrName, createPatch)
}

if foundGPU {
gpuClaimPresent := false
for _, claimRef := range c.Resources.Claims {
if claimRef.Name == gpuClaimName {
gpuClaimPresent = true
break
}
}
if !gpuClaimPresent {
if c.Resources.Claims == nil {
patches = append(patches, patchOperation{
Op: "add",
Path: fmt.Sprintf("/spec/containers/%d/resources/claims", i),
Value: []map[string]string{
{"name": gpuClaimName},
},
})
} else {
patches = append(patches, patchOperation{
Op: "add",
Path: fmt.Sprintf("/spec/containers/%d/resources/claims/-", i),
Value: map[string]string{"name": gpuClaimName},
})
}
}
// append one claim per GPU
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code test:

$ kubectl logs gpu-mutating-webhook-7f676685c8-8w5sf -n nvidia-dra-driver-gpu
2025/05/02 19:52:34 Handling webhook request ...
I0502 19:52:34.675245       1 main.go:49] skip mutation for { v1 pods}/UPDATE
2025/05/02 19:52:34 Webhook request handled successfully
2025/05/02 19:53:05 Handling webhook request ...
I0502 19:53:05.135191       1 main.go:89] removed container["main"].Resources.Requests: {remove /spec/containers/0/resources/requests/nvidia.com~1gpu <nil>}
I0502 19:53:05.135219       1 main.go:93] removed container["main"].Resources.Limits: {remove /spec/containers/0/resources/limits/nvidia.com~1gpu <nil>}
I0502 19:53:05.135226       1 main.go:100] created container["main"] empty claims array: {add /spec/containers/0/resources/claims []}
I0502 19:53:05.135236       1 main.go:112] added to container["main"].Resources.Claims: {add /spec/containers/0/resources/claims/- map[name:nvidia-gpu-resourceclaim-0]}
I0502 19:53:05.135245       1 main.go:112] added to container["main"].Resources.Claims: {add /spec/containers/0/resources/claims/- map[name:nvidia-gpu-resourceclaim-1]}
I0502 19:53:05.135249       1 main.go:123] created pod["swati-gpu-pod"] empty claims array: {add /spec/resourceClaims []}
I0502 19:53:05.135256       1 main.go:136] added ResourceClaim "nvidia-gpu-resourceclaim-0" (template="nvidia-gpu-resourceclaim-template") to "swati-gpu-pod": {add /spec/resourceClaims/- map[name:nvidia-gpu-resourceclaim-0 resourceClaimTemplateName:nvidia-gpu-resourceclaim-template]}
I0502 19:53:05.135264       1 main.go:136] added ResourceClaim "nvidia-gpu-resourceclaim-1" (template="nvidia-gpu-resourceclaim-template") to "swati-gpu-pod": {add /spec/resourceClaims/- map[name:nvidia-gpu-resourceclaim-1 resourceClaimTemplateName:nvidia-gpu-resourceclaim-template]}
2025/05/02 19:53:05 Webhook request handled successfully

$ kubectl get resourceclaim 
NAME                                             STATE                AGE
swati-gpu-pod-nvidia-gpu-resourceclaim-0-x6dln   allocated,reserved   3m10s
swati-gpu-pod-nvidia-gpu-resourceclaim-1-hkbgk   allocated,reserved   3m10s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we construct unit tests that exercise the same logic?

for i := int64(0); i < gpuCount; i++ {
claimName := fmt.Sprintf("%s-%d", gpuClaimName, i)
ctrGPUResourceClaims = append(ctrGPUResourceClaims, claimName)
appendPatch := appendClaimPatch(
fmt.Sprintf("/spec/containers/%d/resources/claims", ci),
map[string]string{"name": claimName},
)
patches = append(patches, appendPatch)
klog.Infof("added to container[%q].Resources.Claims: %v", ctrName, appendPatch)
}
}

if len(patches) > 0 && !hasGPUClaim {
newClaim := map[string]string{
"name": gpuClaimName,
"resourceClaimTemplateName": gpuTemplateName,
// Add claims pod-level
podName := pod.Name
if len(ctrGPUResourceClaims) > 0 {
// ensure pod-claims slice exists
if len(pod.Spec.ResourceClaims) == 0 {
createPatch := createClaimPatch("/spec/resourceClaims")
patches = append(patches, createPatch)
klog.Infof("created pod[%q] empty claims array: %v", podName, createPatch)
}

if pod.Spec.ResourceClaims == nil {
patches = append(patches, patchOperation{
Op: "add",
Path: "/spec/resourceClaims",
Value: []map[string]string{
newClaim,
// append each container GPU claim at pod-level
for _, name := range ctrGPUResourceClaims {
appendPatch := appendClaimPatch(
"/spec/resourceClaims",
map[string]string{
"name": name,
"resourceClaimTemplateName": gpuTemplateName,
},
})
} else {
patches = append(patches, patchOperation{
Op: "add",
Path: "/spec/resourceClaims/-",
Value: newClaim,
})
)
patches = append(patches, appendPatch)
klog.Infof("added ResourceClaim %q (template=%q) to %q: %v", name, gpuTemplateName, podName, appendPatch)
}
log.Printf("Added ResourceClaim %q referencing template %q to Pod %q",
gpuClaimName, gpuTemplateName, pod.Name)
}

return patches, nil
}

// escapeJSONPointer replace "/" with "~1"
// refer: https://github.com/json-patch/json-patch-tests/issues/42
// needed for "nvidia.com/gpu". otherwise JSON will treat "/" as a path delimiter and treat "gpu" as new field
func escapeJSONPointer(s string) string {
return strings.ReplaceAll(s, "/", "~1")
}

// removeResourceRequest removes either .resources.requests or .resources.limits
func removeResourceRequest(ci int, field, key string) patchOperation {
return patchOperation{
Op: "remove",
Path: fmt.Sprintf("/spec/containers/%d/resources/%s/%s", ci, field, key),
}
}

// createClaimPatch creates an empty slice at the given path
func createClaimPatch(path string) patchOperation {
return patchOperation{
Op: "add",
Path: path,
Value: []map[string]string{},
}
}

// appendClaimPatch appends to the slice at path
// "-" is JSON way to inserting at the end of the array when no index is specified.
// refer: https://datatracker.ietf.org/doc/html/rfc6902
func appendClaimPatch(path string, entry map[string]string) patchOperation {
return patchOperation{
Op: "add",
Path: path + "/-",
Value: entry,
}
}

func main() {
certPath := filepath.Join(tlsDir, tlsCertFile)
keyPath := filepath.Join(tlsDir, tlsKeyFile)
Expand All @@ -157,6 +186,10 @@ func main() {
Addr: ":8443",
Handler: mux,
}
log.Printf("Starting webhook server on %s", server.Addr)
log.Fatal(server.ListenAndServeTLS(certPath, keyPath))

if err := server.ListenAndServeTLS(certPath, keyPath); err != nil {
// Swati: need better error handling here
log.Fatalf("Failed to start server: %v", err)
}
klog.Infof("Started gpu-mutating-webhook server at %s", server.Addr)
}
Loading