77
88import asyncio
99import base64
10- import firecrest
1110import hostlist
1211import httpx
1312import inspect
2221from async_generator import async_generator , yield_
2322from enum import Enum
2423from firecrest .FirecrestException import PollingIterException
24+ from firecrest import ClientCredentialsAuth
25+ from firecrest .FirecrestException import UnexpectedStatusException
26+ from firecrest .v2 ._async .Client import AsyncFirecrest as Firecrest
2527from jinja2 import Template
2628from jupyterhub .spawner import Spawner
2729from time import sleep
@@ -146,6 +148,11 @@ class FirecRESTSpawnerBase(Spawner):
146148 "button availability"
147149 )
148150
151+ workdir = Unicode (
152+ "/home" ,
153+ help = "Directory where the job will be submitted from"
154+ ).tag (config = True )
155+
149156 # override default since batch systems typically need longer
150157 start_timeout = Integer (
151158 300 ,
@@ -306,7 +313,7 @@ async def get_firecrest_client(self):
306313 "log back in to refresh the credentials." )
307314 raise err
308315
309- client = firecrest . AsyncFirecrest (
316+ client = Firecrest (
310317 firecrest_url = self .firecrest_url , authorization = auth
311318 )
312319
@@ -330,13 +337,13 @@ async def get_firecrest_client_service_account(self):
330337 client_secret = os .environ ["SA_CLIENT_SECRET" ]
331338 token_url = os .environ ["SA_AUTH_TOKEN_URL" ]
332339
333- auth = firecrest . ClientCredentialsAuth (
340+ auth = ClientCredentialsAuth (
334341 client_id ,
335342 client_secret ,
336343 token_url
337344 )
338345
339- client = firecrest . AsyncFirecrest (
346+ client = Firecrest (
340347 firecrest_url = self .firecrest_url , authorization = auth
341348 )
342349
@@ -365,8 +372,11 @@ async def firecrest_poll(self):
365372 # to be an empty list
366373 poll_result = []
367374 while poll_result == []:
368- poll_result = await client .poll (self .host , [self .job_id ])
369- await asyncio .sleep (1 )
375+ try :
376+ poll_result = await client .job_info (self .host , self .job_id )
377+ except UnexpectedStatusException as e :
378+ self .log .info (f"Polling job status fail: { e } " )
379+ await asyncio .sleep (1 )
370380
371381 return poll_result
372382
@@ -407,7 +417,7 @@ async def submit_batch_script(self):
407417 else :
408418 client = await self .get_firecrest_client ()
409419
410- groups = await client .groups (self .host )
420+ groups = await client .userinfo (self .host )
411421 account_from_form = self .user_options .get ("account" )
412422 if not account_from_form or account_from_form == ["" ]:
413423 subvars ["account" ] = groups ["group" ]["name" ]
@@ -419,10 +429,13 @@ async def submit_batch_script(self):
419429 try :
420430 self .log .info ("firecREST: Submitting job" )
421431 self .job = await client .submit (
422- self .host , script_str = script , env_vars = job_env
432+ self .host ,
433+ script_str = script ,
434+ env_vars = job_env ,
435+ working_dir = "/" .join ((self .workdir , self .user .name ))
423436 )
424437 self .log .debug (f"[client.submit] { self .job } " )
425- self .job_id = f"{ self .job ['jobid ' ]} "
438+ self .job_id = f"{ self .job ['jobId ' ]} "
426439 self .log .info (f"Job { self .job_id } submitted" )
427440 # In case the connection to the firecrest server timesout
428441 # catch httpx.ConnectTimeout since httpx.ConnectTimeout
@@ -454,8 +467,8 @@ async def query_job_status(self):
454467 try :
455468 poll_result = await self .firecrest_poll ()
456469 self .log .debug (f"[client.poll] [query_job_status] { poll_result } " )
457- state = poll_result [0 ]["state" ]
458- nodelist = hostlist .expand_hostlist (poll_result [0 ]["nodelist " ])
470+ state = poll_result [0 ]["status" ][ " state" ]
471+ nodelist = hostlist .expand_hostlist (poll_result [0 ]["nodes " ])
459472 # when PENDING nodelist is []
460473 host = nodelist [0 ] if len (nodelist ) > 0 else ""
461474 # `job_status` must keep the format used in the original
@@ -487,7 +500,7 @@ async def cancel_batch_job(self) -> None:
487500 client = await self .get_firecrest_client ()
488501
489502 self .log .info ("firecREST: Canceling job" )
490- cancel_result = await client .cancel (self .host , self .job_id )
503+ cancel_result = await client .cancel_job (self .host , self .job_id )
491504 self .log .debug (f"[client.cancel] { cancel_result } " )
492505
493506 def load_state (self , state ) -> None :
@@ -676,28 +689,29 @@ async def progress(self) -> AsyncGenerator[dict[str, str], None]:
676689 while True :
677690 if self .state_ispending ():
678691 try :
679- poll_result = await client .poll_active (self .host , [ self .job ["jobid" ] ])
692+ poll_result = await client .job_info (self .host , self .job ["jobId" ])
680693 if poll_result [0 ]["state" ] != "RUNNING" :
681- reason = poll_result [0 ]["nodelist " ]
682- message = f"Job { self .job ['jobid ' ]} is pending in queue { reason } "
694+ reason = poll_result [0 ]["nodes " ]
695+ message = f"Job { self .job ['jobId ' ]} is pending in queue { reason } "
683696 else :
684- message = f"Job { self .job ['jobid ' ]} is being allocated"
697+ message = f"Job { self .job ['jobId ' ]} is being allocated"
685698
686699 except :
687- message = f"Job { self .job ['jobid ' ]} is pending in queue "
700+ message = f"Job { self .job ['jobId ' ]} is pending in queue "
688701
689702 await yield_ (
690703 {
691704 "message" : message ,
692705 }
693706 )
694707 elif self .state_isrunning ():
708+ poll_result = await client .job_metadata (self .host , self .job ["jobId" ])
695709 await yield_ (
696710 {
697711 "message" : "Cluster job running... waiting to connect. "
698712 "If the server fails to start in a few moments, "
699713 "check the log file for possible reasons: "
700- f"{ self . job [ 'job_file_out ' ]} " ,
714+ f"{ poll_result [ 0 ][ 'standardOutput ' ]} " ,
701715 }
702716 )
703717 return
@@ -788,7 +802,7 @@ async def state_gethost(self) -> str:
788802
789803 # this function is called only when the job has been allocated,
790804 # then ``nodelist`` won't be ``[]``
791- host = hostlist .expand_hostlist (poll_result [0 ]["nodelist " ])[0 ]
805+ host = hostlist .expand_hostlist (poll_result [0 ]["nodes " ])[0 ]
792806 return self .node_name_template .format (host )
793807
794808
0 commit comments