@@ -30,11 +30,16 @@ func NewK8sWithGpuOperator(client client.Client) *K8sWithGpuOperator {
3030 }
3131}
3232
33- // Discover discovers GPU capacity by iterating over nodes and checking GFD labels.
34- // It queries nodes for each GPU vendor (NVIDIA, AMD, Intel) separately since
35- // Kubernetes LabelSelectors don't support OR logic across different label keys.
36- func (d * K8sWithGpuOperator ) Discover (ctx context.Context ) (map [string ]map [string ]AcceleratorModelInfo , error ) {
37- inv := make (map [string ]map [string ]AcceleratorModelInfo )
33+ // listGPUNodes queries GPU-bearing nodes across all supported vendors
34+ // (NVIDIA, AMD, Intel) and returns a canonical per-node view keyed by node name.
35+ // It queries per vendor because Kubernetes LabelSelectors don't support OR logic
36+ // across different label keys. Multi-vendor nodes (nodes with labels from more
37+ // than one vendor) are merged into a single NodeInfo entry.
38+ //
39+ // This is the single internal node-listing primitive; public methods Discover,
40+ // discoverNodeGPUTypes, and DiscoverNodes project from its result.
41+ func (d * K8sWithGpuOperator ) listGPUNodes (ctx context.Context ) (map [string ]NodeInfo , error ) {
42+ nodes := make (map [string ]NodeInfo )
3843
3944 // Parse WVA_NODE_SELECTOR once for reuse across vendor queries
4045 var userRequirements []labels.Requirement
@@ -46,12 +51,12 @@ func (d *K8sWithGpuOperator) Discover(ctx context.Context) (map[string]map[strin
4651 userRequirements , _ = userSelector .Requirements ()
4752 }
4853
49- // Query nodes for each GPU vendor separately
50- // K8s LabelSelectors don't support OR logic across different keys (e.g. nvidia OR amd)
54+ // Query nodes for each GPU vendor separately.
55+ // K8s LabelSelectors don't support OR logic across different keys (e.g. nvidia OR amd).
5156 for _ , vendor := range vendors {
5257 prodKey := vendor + "/gpu.product"
58+ memKey := vendor + "/gpu.memory"
5359
54- // Create vendor-specific selector
5560 req , err := labels .NewRequirement (prodKey , selection .Exists , nil )
5661 if err != nil {
5762 return nil , fmt .Errorf ("failed to create label requirement for %s: %w" , vendor , err )
@@ -68,34 +73,69 @@ func (d *K8sWithGpuOperator) Discover(ctx context.Context) (map[string]map[strin
6873 return nil , fmt .Errorf ("failed to list nodes for vendor %s: %w" , vendor , err )
6974 }
7075
71- // Process nodes for this vendor
7276 for _ , node := range nodeList .Items {
73- nodeName := node .Name
74- memKey := vendor + "/gpu.memory"
75-
7677 model , ok := node .Labels [prodKey ]
7778 if ! ok {
7879 continue
7980 }
8081
81- mem := node .Labels [memKey ]
8282 count := 0
8383 if cap , ok := node .Status .Allocatable [corev1 .ResourceName (vendor + "/gpu" )]; ok {
8484 count = int (cap .Value ())
8585 }
8686
87- if inv [nodeName ] == nil {
88- inv [nodeName ] = make (map [string ]AcceleratorModelInfo )
87+ ni , exists := nodes [node .Name ]
88+ if ! exists {
89+ ni = NodeInfo {
90+ Name : node .Name ,
91+ Labels : copyStringMap (node .Labels ),
92+ Accelerators : make (map [string ]AcceleratorModelInfo ),
93+ }
8994 }
90-
91- inv [nodeName ][model ] = AcceleratorModelInfo {
95+ ni .Accelerators [model ] = AcceleratorModelInfo {
9296 Count : count ,
93- Memory : mem ,
97+ Memory : node . Labels [ memKey ] ,
9498 }
99+ nodes [node .Name ] = ni
95100 }
96101 }
97102
98- return inv , nil
103+ return nodes , nil
104+ }
105+
106+ // copyStringMap returns a shallow copy of m, or an empty map if m is nil.
107+ // Used to ensure the labels map returned in NodeInfo is independent of the
108+ // underlying corev1.Node object.
109+ func copyStringMap (m map [string ]string ) map [string ]string {
110+ out := make (map [string ]string , len (m ))
111+ for k , v := range m {
112+ out [k ] = v
113+ }
114+ return out
115+ }
116+
117+ // Discover discovers GPU capacity by iterating over nodes and checking GFD labels.
118+ // It queries nodes for each GPU vendor (NVIDIA, AMD, Intel) separately since
119+ // Kubernetes LabelSelectors don't support OR logic across different label keys.
120+ //
121+ // This is a projection of listGPUNodes into the CapacityDiscovery shape
122+ // (per-node accelerator map without labels).
123+ func (d * K8sWithGpuOperator ) Discover (ctx context.Context ) (map [string ]map [string ]AcceleratorModelInfo , error ) {
124+ nodes , err := d .listGPUNodes (ctx )
125+ if err != nil {
126+ return nil , err
127+ }
128+ out := make (map [string ]map [string ]AcceleratorModelInfo , len (nodes ))
129+ for name , n := range nodes {
130+ out [name ] = n .Accelerators
131+ }
132+ return out , nil
133+ }
134+
135+ // DiscoverNodes returns per-node info (labels + accelerators) for all GPU-bearing
136+ // nodes. Used by label-aware features such as the namespace-scoped limiter.
137+ func (d * K8sWithGpuOperator ) DiscoverNodes (ctx context.Context ) (map [string ]NodeInfo , error ) {
138+ return d .listGPUNodes (ctx )
99139}
100140
101141// DiscoverUsage calculates current GPU usage by summing GPU requests from running pods.
@@ -143,48 +183,38 @@ func (d *K8sWithGpuOperator) DiscoverUsage(ctx context.Context) (map[string]int,
143183}
144184
145185// discoverNodeGPUTypes returns a map of node name to GPU type (model name).
146- // It queries nodes for each GPU vendor separately to support multi-vendor clusters.
186+ // For multi-vendor nodes (nodes labeled for more than one GPU vendor), the model
187+ // from the LAST matching vendor in `vendors` wins (order: nvidia.com, amd.com,
188+ // intel.com → intel wins if present, else amd, else nvidia).
189+ //
190+ // This preserves the pre-refactor behavior: the original implementation iterated
191+ // vendors in order and assigned `nodeGPUType[node] = model` on each match,
192+ // causing later assignments to overwrite earlier ones. Changing this would
193+ // silently affect usage attribution for multi-vendor nodes, so the refactor
194+ // keeps the exact same tie-break semantics.
195+ //
196+ // This is a projection of listGPUNodes into a single-model-per-node shape.
147197func (d * K8sWithGpuOperator ) discoverNodeGPUTypes (ctx context.Context ) (map [string ]string , error ) {
148- nodeGPUType := make (map [string ]string )
149-
150- // Parse WVA_NODE_SELECTOR once for reuse across vendor queries
151- var userRequirements []labels.Requirement
152- if selectorStr := os .Getenv ("WVA_NODE_SELECTOR" ); selectorStr != "" {
153- userSelector , err := labels .Parse (selectorStr )
154- if err != nil {
155- return nil , fmt .Errorf ("invalid WVA_NODE_SELECTOR: %w" , err )
156- }
157- userRequirements , _ = userSelector .Requirements ()
198+ nodes , err := d .listGPUNodes (ctx )
199+ if err != nil {
200+ return nil , err
158201 }
159202
160- // Query nodes for each GPU vendor separately
161- for _ , vendor := range vendors {
162- prodKey := vendor + "/gpu.product"
163-
164- req , err := labels .NewRequirement (prodKey , selection .Exists , nil )
165- if err != nil {
166- return nil , fmt .Errorf ("failed to create label requirement for %s: %w" , vendor , err )
167- }
168- selector := labels .NewSelector ().Add (* req )
169-
170- // Add user requirements for sharding
171- for _ , userReq := range userRequirements {
172- selector = selector .Add (userReq )
173- }
174-
175- var nodeList corev1.NodeList
176- if err := d .Client .List (ctx , & nodeList , & client.ListOptions {LabelSelector : selector }); err != nil {
177- return nil , fmt .Errorf ("failed to list nodes for vendor %s: %w" , vendor , err )
178- }
179-
180- for _ , node := range nodeList .Items {
181- if model , ok := node .Labels [prodKey ]; ok {
182- nodeGPUType [node .Name ] = model
203+ out := make (map [string ]string , len (nodes ))
204+ for name , n := range nodes {
205+ // Iterate vendors in REVERSE order and break on first match so the
206+ // last vendor in `vendors` wins (intel > amd > nvidia). Relies on the
207+ // listGPUNodes invariant that n.Accelerators[model] exists whenever
208+ // n.Labels[vendor+"/gpu.product"] == model.
209+ for i := len (vendors ) - 1 ; i >= 0 ; i -- {
210+ prodKey := vendors [i ] + "/gpu.product"
211+ if model , ok := n .Labels [prodKey ]; ok {
212+ out [name ] = model
213+ break
183214 }
184215 }
185216 }
186-
187- return nodeGPUType , nil
217+ return out , nil
188218}
189219
190220// getPodGPURequests returns the total GPU requests for a pod across all containers.
0 commit comments