Skip to content

Commit 027b223

Browse files
authored
Merge pull request #82 from eth-cscs/fcv2
Add FirecREST API V2 support
2 parents 2967889 + 9198ee4 commit 027b223

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

firecrestspawner/spawner.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import asyncio
99
import base64
10-
import firecrest
1110
import hostlist
1211
import httpx
1312
import inspect
@@ -22,6 +21,9 @@
2221
from async_generator import async_generator, yield_
2322
from enum import Enum
2423
from firecrest.FirecrestException import PollingIterException
24+
from firecrest import ClientCredentialsAuth
25+
from firecrest.FirecrestException import UnexpectedStatusException
26+
from firecrest.v2._async.Client import AsyncFirecrest as Firecrest
2527
from jinja2 import Template
2628
from jupyterhub.spawner import Spawner
2729
from time import sleep
@@ -146,6 +148,11 @@ class FirecRESTSpawnerBase(Spawner):
146148
"button availability"
147149
)
148150

151+
workdir = Unicode(
152+
"/home",
153+
help="Directory where the job will be submitted from"
154+
).tag(config=True)
155+
149156
# override default since batch systems typically need longer
150157
start_timeout = Integer(
151158
300,
@@ -306,7 +313,7 @@ async def get_firecrest_client(self):
306313
"log back in to refresh the credentials.")
307314
raise err
308315

309-
client = firecrest.AsyncFirecrest(
316+
client = Firecrest(
310317
firecrest_url=self.firecrest_url, authorization=auth
311318
)
312319

@@ -330,13 +337,13 @@ async def get_firecrest_client_service_account(self):
330337
client_secret = os.environ["SA_CLIENT_SECRET"]
331338
token_url = os.environ["SA_AUTH_TOKEN_URL"]
332339

333-
auth = firecrest.ClientCredentialsAuth(
340+
auth = ClientCredentialsAuth(
334341
client_id,
335342
client_secret,
336343
token_url
337344
)
338345

339-
client = firecrest.AsyncFirecrest(
346+
client = Firecrest(
340347
firecrest_url=self.firecrest_url, authorization=auth
341348
)
342349

@@ -365,8 +372,11 @@ async def firecrest_poll(self):
365372
# to be an empty list
366373
poll_result = []
367374
while poll_result == []:
368-
poll_result = await client.poll(self.host, [self.job_id])
369-
await asyncio.sleep(1)
375+
try:
376+
poll_result = await client.job_info(self.host, self.job_id)
377+
except UnexpectedStatusException as e:
378+
self.log.info(f"Polling job status fail: {e}")
379+
await asyncio.sleep(1)
370380

371381
return poll_result
372382

@@ -407,7 +417,7 @@ async def submit_batch_script(self):
407417
else:
408418
client = await self.get_firecrest_client()
409419

410-
groups = await client.groups(self.host)
420+
groups = await client.userinfo(self.host)
411421
account_from_form = self.user_options.get("account")
412422
if not account_from_form or account_from_form == [""]:
413423
subvars["account"] = groups["group"]["name"]
@@ -419,10 +429,13 @@ async def submit_batch_script(self):
419429
try:
420430
self.log.info("firecREST: Submitting job")
421431
self.job = await client.submit(
422-
self.host, script_str=script, env_vars=job_env
432+
self.host,
433+
script_str=script,
434+
env_vars=job_env,
435+
working_dir="/".join((self.workdir, self.user.name))
423436
)
424437
self.log.debug(f"[client.submit] {self.job}")
425-
self.job_id = f"{self.job['jobid']}"
438+
self.job_id = f"{self.job['jobId']}"
426439
self.log.info(f"Job {self.job_id} submitted")
427440
# In case the connection to the firecrest server timesout
428441
# catch httpx.ConnectTimeout since httpx.ConnectTimeout
@@ -454,8 +467,8 @@ async def query_job_status(self):
454467
try:
455468
poll_result = await self.firecrest_poll()
456469
self.log.debug(f"[client.poll] [query_job_status] {poll_result}")
457-
state = poll_result[0]["state"]
458-
nodelist = hostlist.expand_hostlist(poll_result[0]["nodelist"])
470+
state = poll_result[0]["status"]["state"]
471+
nodelist = hostlist.expand_hostlist(poll_result[0]["nodes"])
459472
# when PENDING nodelist is []
460473
host = nodelist[0] if len(nodelist) > 0 else ""
461474
# `job_status` must keep the format used in the original
@@ -487,7 +500,7 @@ async def cancel_batch_job(self) -> None:
487500
client = await self.get_firecrest_client()
488501

489502
self.log.info("firecREST: Canceling job")
490-
cancel_result = await client.cancel(self.host, self.job_id)
503+
cancel_result = await client.cancel_job(self.host, self.job_id)
491504
self.log.debug(f"[client.cancel] {cancel_result}")
492505

493506
def load_state(self, state) -> None:
@@ -676,28 +689,29 @@ async def progress(self) -> AsyncGenerator[dict[str, str], None]:
676689
while True:
677690
if self.state_ispending():
678691
try:
679-
poll_result = await client.poll_active(self.host, [self.job["jobid"]])
692+
poll_result = await client.job_info(self.host, self.job["jobId"])
680693
if poll_result[0]["state"] != "RUNNING":
681-
reason = poll_result[0]["nodelist"]
682-
message = f"Job {self.job['jobid']} is pending in queue {reason} "
694+
reason = poll_result[0]["nodes"]
695+
message = f"Job {self.job['jobId']} is pending in queue {reason} "
683696
else:
684-
message = f"Job {self.job['jobid']} is being allocated"
697+
message = f"Job {self.job['jobId']} is being allocated"
685698

686699
except:
687-
message = f"Job {self.job['jobid']} is pending in queue "
700+
message = f"Job {self.job['jobId']} is pending in queue "
688701

689702
await yield_(
690703
{
691704
"message": message,
692705
}
693706
)
694707
elif self.state_isrunning():
708+
poll_result = await client.job_metadata(self.host, self.job["jobId"])
695709
await yield_(
696710
{
697711
"message": "Cluster job running... waiting to connect. "
698712
"If the server fails to start in a few moments, "
699713
"check the log file for possible reasons: "
700-
f"{self.job['job_file_out']}",
714+
f"{poll_result[0]['standardOutput']}",
701715
}
702716
)
703717
return
@@ -788,7 +802,7 @@ async def state_gethost(self) -> str:
788802

789803
# this function is called only when the job has been allocated,
790804
# then ``nodelist`` won't be ``[]``
791-
host = hostlist.expand_hostlist(poll_result[0]["nodelist"])[0]
805+
host = hostlist.expand_hostlist(poll_result[0]["nodes"])[0]
792806
return self.node_name_template.format(host)
793807

794808

0 commit comments

Comments
 (0)