44import logging
55from dataclasses import dataclass
66from datetime import datetime
7- from typing import Dict , List
7+ from typing import Dict , List , Optional
88
99from hostlist import expand_hostlist
1010from simple_parsing import field
@@ -25,7 +25,7 @@ class AcquireSlurmConfig:
2525 required = False ,
2626 help = (
2727 "Cluster config file date (format YYYY-MM-DD). "
28- "Used for file versioning. Should represents a day when config file has been updated "
28+ "Used for file versioning. Should represent a day when config file has been updated "
2929 "(e.g. for new GPU billings, node GPUs, etc.). "
3030 "If not specified, uses current day and downloads config file from cluster."
3131 ),
@@ -43,7 +43,7 @@ def execute(self) -> int:
4343 self .cluster_name , parser .day , slurm_conf .gpu_to_billing
4444 )
4545 _node_gpu_mapping_collection ().save_node_gpu_mapping (
46- self .cluster_name , parser .day , slurm_conf .node_to_gpu
46+ self .cluster_name , parser .day , slurm_conf .node_to_gpus
4747 )
4848 return 0
4949
@@ -98,17 +98,18 @@ def _cache_key(self):
9898 def load (self , file ) -> SlurmConfig :
9999 """
100100 Parse cached slurm conf file and return a SlurmConfig object
101- containing node_to_gpu and gpu_to_billing.
101+ containing node_to_gpus and gpu_to_billing.
102102 """
103103 partitions : List [Partition ] = []
104- node_to_gpu = {}
104+ node_to_gpus : Dict [ str , List [ str ]] = {}
105105
106- # Parse lines: extract partitions and node_to_gpu
106+ # Parse lines: extract partitions and node_to_gpus
107107 for line_number , line in enumerate (file ):
108108 line = line .strip ()
109109 if line .startswith ("PartitionName=" ):
110110 partitions .append (
111111 Partition (
112+ cluster_name = self .cluster .name ,
112113 line_number = line_number + 1 ,
113114 line = line ,
114115 info = dict (
@@ -120,43 +121,40 @@ def load(self, file) -> SlurmConfig:
120121 nodes_config = dict (
121122 option .split ("=" , maxsplit = 1 ) for option in line .split ()
122123 )
123- gpu_type = nodes_config .get ("Gres" )
124- if gpu_type :
125- node_to_gpu .update (
124+ gres = nodes_config .get ("Gres" )
125+ if gres :
126+ # A node may have many GPUs, e.g. MIG GPUs
127+ # Example on narval (2023-11-28):
128+ # NodeName=ng20304 ... Gres=gpu:a100_1g.5gb:8,gpu:a100_2g.10gb:4,gpu:a100_3g.20gb:4
129+ gpu_types = gres .split ("," )
130+ node_to_gpus .update (
126131 {
127- node_name : gpu_type
132+ node_name : gpu_types
128133 for node_name in expand_hostlist (nodes_config ["NodeName" ])
129134 }
130135 )
131136
132137 # Parse partitions: extract gpu_to_billing
133- gpu_to_billing = self ._parse_gpu_to_billing (partitions , node_to_gpu )
138+ gpu_to_billing = self ._parse_gpu_to_billing (partitions , node_to_gpus )
134139
135140 # Return parsed data
136- return SlurmConfig (node_to_gpu = node_to_gpu , gpu_to_billing = gpu_to_billing )
141+ return SlurmConfig (node_to_gpus = node_to_gpus , gpu_to_billing = gpu_to_billing )
137142
138143 def _parse_gpu_to_billing (
139- self , partitions : List [Partition ], node_to_gpu : Dict [str , str ]
144+ self , partitions : List [Partition ], node_to_gpus : Dict [str , List [ str ] ]
140145 ) -> Dict [str , float ]:
141-
142146 # Mapping of GPU to partition billing.
143- # ALlow to check that inferred billing for a GPU is the same across partitions.
147+ # Allow to check that inferred billing for a GPU is the same across partitions.
144148 # If not, an error will be raised with additional info about involved partitions.
145149 gpu_to_partition_billing : Dict [str , PartitionGPUBilling ] = {}
146150
147- # Collection for all GPUs found in partition nodes.
148- # We will later iterate on this collection to resolve any GPU without billing.
149- all_partition_node_gpus = set ()
150-
151151 for partition in partitions :
152- # Get all GPUs in partition nodes and all partition GPU billings.
153- local_gres , local_gpu_to_billing = (
154- partition .get_gpus_and_partition_billings (node_to_gpu )
152+ # Get billing from this partition
153+ parsed_partition = partition .parse (node_to_gpus )
154+ local_gpu_to_billing = parsed_partition .get_harmonized_gpu_to_billing (
155+ self .cluster
155156 )
156157
157- # Merge local GPUs into global partition node GPUs.
158- all_partition_node_gpus .update (local_gres )
159-
160158 # Merge local GPU billings into global GPU billings
161159 for gpu_type , value in local_gpu_to_billing .items ():
162160 new_billing = PartitionGPUBilling (
@@ -170,59 +168,47 @@ def _parse_gpu_to_billing(
170168 raise InconsistentGPUBillingError (
171169 gpu_type , gpu_to_partition_billing [gpu_type ], new_billing
172170 )
173-
174- # Generate GPU->billing mapping
175- gpu_to_billing = {
176- gpu_type : billing .value
177- for gpu_type , billing in gpu_to_partition_billing .items ()
178- }
179-
180- # Resolve GPUs without billing
181- for gpu_desc in all_partition_node_gpus :
182- if gpu_desc not in gpu_to_billing :
183- if gpu_desc .startswith ("gpu:" ) and gpu_desc .count (":" ) == 2 :
184- # GPU resource with format `gpu:<real-gpu-type>:<gpu-count>`
185- _ , gpu_type , gpu_count = gpu_desc .split (":" )
186- if gpu_type in gpu_to_billing :
187- billing = gpu_to_billing [gpu_type ] * int (gpu_count )
188- gpu_to_billing [gpu_desc ] = billing
189- logger .info (f"Inferred billing for { gpu_desc } := { billing } " )
190- else :
191- logger .warning (
192- f"Cannot find GPU billing for GPU type { gpu_type } in GPU resource { gpu_desc } "
193- )
194- else :
195- logger .warning (f"Cannot infer billing for GPU: { gpu_desc } " )
196-
197- # We can finally return GPU->billing mapping.
198- return gpu_to_billing
171+ return {gpu : billing .value for gpu , billing in gpu_to_partition_billing .items ()}
199172
200173
201174@dataclass
202175class SlurmConfig :
203176 """Parsed data from slurm config file"""
204177
205- node_to_gpu : Dict [str , str ]
178+ node_to_gpus : Dict [str , List [ str ] ]
206179 gpu_to_billing : Dict [str , float ]
207180
208181
209182@dataclass
210183class Partition :
211184 """Partition entry in slurm config file"""
212185
186+ cluster_name : str
213187 line_number : int
214188 line : str
215189 info : Dict [str , str ]
216190
217- def get_gpus_and_partition_billings (self , node_to_gpu : Dict [str , str ]):
218- """
219- Parse and return:
220- - partition node GPUs
221- - partition GPU billings inferred from field `TRESBillingWeights`
222- """
191+ @property
192+ def nodes (self ) -> str :
193+ """Return hostlist of partition nodes"""
194+ return self .info ["Nodes" ]
195+
196+ def message (self , msg : str ) -> str :
197+ """For logging: prepend given message with cluster name and partition name"""
198+ return f"[{ self .cluster_name } ][{ self .info ['PartitionName' ]} ] { msg } "
199+
200+ def parse (self , node_to_gpus : Dict [str , List [str ]]) -> ParsedPartition :
201+ """Parse partition's gpu => nodes, gpu => billing, and default billing"""
223202
224- # Get partition node GPUs
225- local_gres = self ._get_node_gpus (node_to_gpu )
203+ # Map each partition GPU to belonging nodes
204+ gpu_to_nodes = {}
205+ for node_name in expand_hostlist (self .nodes ):
206+ for gpu_type in node_to_gpus .get (node_name , ()):
207+ # Parse `gpu:<real gpu name>:<count>` if necessary
208+ if gpu_type .startswith ("gpu:" ) and gpu_type .count (":" ) == 2 :
209+ _ , real_gpu_type , _ = gpu_type .split (":" )
210+ gpu_type = real_gpu_type
211+ gpu_to_nodes .setdefault (gpu_type , []).append (node_name )
226212
227213 # Get GPU weights from TRESBillingWeights
228214 tres_billing_weights = dict (
@@ -236,48 +222,107 @@ def get_gpus_and_partition_billings(self, node_to_gpu: Dict[str, str]):
236222 if key .startswith ("GRES/gpu" )
237223 }
238224
239- # Parse local GPU billings
240- local_gpu_to_billing = {}
225+ # Parse partition GPU billings
226+ default_billing = None
227+ gpu_to_billing = {}
241228 for key , value in gpu_weights .items ():
242229 value = float (value )
243230 if key == "GRES/gpu" :
244231 if len (gpu_weights ) == 1 :
245232 # We only have `GRES/gpu=<value>`
246- # Let's map value to each GPU found in partition nodes
247- local_gpu_to_billing .update (
248- {gpu_name : value for gpu_name in local_gres }
249- )
233+ # Save it as default billing for all partition GPUs
234+ default_billing = value
250235 else :
251236 # We have both `GRES/gpu=<value>` and at least one `GRES/gpu:a_gpu=a_value`.
252237 # Ambiguous case, cannot map `GRES/gpu=<value>` to a specific GPU.
253238 logger .debug (
254- f"[line { self .line_number } ] "
255- f"Ignored ambiguous GPU billing (cannot match to a specific GPU): `{ key } ={ value } ` "
256- f"| partition: { self .info ['PartitionName' ]} "
257- # f"| nodes: {partition.info['Nodes']}, "
258- f"| nodes GPUs: { ', ' .join (local_gres )} "
259- f"| TRESBillingWeights: { self .info ['TRESBillingWeights' ]} "
239+ self .message (
240+ f"Ignored ambiguous GPU billing (cannot match to a specific GPU): `{ key } ={ value } ` "
241+ f"| nodes GPUs: { ', ' .join (sorted (gpu_to_nodes .keys ()))} "
242+ f"| TRESBillingWeights: { self .info ['TRESBillingWeights' ]} "
243+ )
260244 )
261245 else :
262246 # We have `GRES/gpu:a_gpu=a_value`.
263247 # We can parse.
264- _ , gpu_name = key .split (":" , maxsplit = 1 )
265- local_gpu_to_billing [gpu_name ] = value
266-
267- return local_gres , local_gpu_to_billing
268-
269- def _get_node_gpus (self , node_to_gpu : Dict [str , str ]) -> List [str ]:
270- """Return all GPUs found in partition nodes"""
271- return sorted (
272- {
273- gres
274- for node_name in expand_hostlist (self .info ["Nodes" ])
275- for gres in node_to_gpu .get (node_name , "" ).split ("," )
276- if gres
277- }
248+ _ , gpu_type = key .split (":" , maxsplit = 1 )
249+ gpu_to_billing [gpu_type ] = value
250+
251+ return ParsedPartition (
252+ partition = self ,
253+ gpu_to_nodes = gpu_to_nodes ,
254+ gpu_to_billing = gpu_to_billing ,
255+ default_billing = default_billing ,
278256 )
279257
280258
259+ @dataclass
260+ class ParsedPartition :
261+ partition : Partition
262+ gpu_to_nodes : Dict [str , List [str ]]
263+ gpu_to_billing : Dict [str , float ]
264+ default_billing : Optional [float ]
265+
266+ def get_harmonized_gpu_to_billing (self , cluster : ClusterConfig ) -> Dict [str , float ]:
267+ """
268+ Convert GPU names read from slurm conf file into harmonized GPU names.
269+
270+ Return harmonized GPU => billing mapping.
271+ """
272+
273+ gpu_to_billing = self .gpu_to_billing .copy ()
274+ gpus_not_billed = [
275+ gpu for gpu in self .gpu_to_nodes if gpu not in gpu_to_billing
276+ ]
277+ # If default billing is available,
278+ # set it as billing for all GPUs not yet billed in this partition.
279+ if self .default_billing is not None :
280+ for gpu_type in gpus_not_billed :
281+ gpu_to_billing [gpu_type ] = self .default_billing
282+
283+ # Build harmonized GPU => billing mapping.
284+ harmonized_gpu_to_billing = {}
285+ for gpu , billing in gpu_to_billing .items ():
286+ if gpu in self .gpu_to_nodes :
287+ harmonized_gpu_names = {
288+ cluster .harmonize_gpu (node_name , gpu )
289+ for node_name in self .gpu_to_nodes [gpu ]
290+ }
291+ harmonized_gpu_names .discard (None )
292+ if not harmonized_gpu_names :
293+ logger .warning (
294+ self .partition .message (
295+ f"Cannot harmonize: { gpu } (keep this name as-is) : { self .partition .nodes } "
296+ )
297+ )
298+ harmonized_gpu_to_billing [gpu ] = billing
299+ else :
300+ if len (harmonized_gpu_names ) != 1 :
301+ # We may find many harmonized names for a same GPU name.
302+ # Example on graham (2024-04-03), partition gpubase_bynode_b1:
303+ # v100 => {'V100-SXM2-32GB', 'V100-PCIe-16GB'}
304+ # Let's just associate billing to all harmonized names
305+ logger .debug (
306+ self .partition .message (
307+ f"harmonize to multiple names: { gpu } => { harmonized_gpu_names } : { self .partition .nodes } "
308+ )
309+ )
310+ for name in sorted (harmonized_gpu_names ):
311+ assert name not in harmonized_gpu_to_billing , (
312+ name ,
313+ billing ,
314+ harmonized_gpu_to_billing ,
315+ )
316+ harmonized_gpu_to_billing [name ] = billing
317+ else :
318+ logger .warning (
319+ self .partition .message (
320+ f"GPU not in partition nodes: { gpu } (billing: { billing } )"
321+ )
322+ )
323+ return harmonized_gpu_to_billing
324+
325+
281326@dataclass
282327class PartitionGPUBilling :
283328 """Represents a GPU billing found in a specific partition entry."""
0 commit comments