Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
GITHUB_REPO_URL = "https://github.com/WATonomous/infra-config"
ALLOCATE_RUNNER_SCRIPT_PATH = "apptainer.sh" # relative path from '/allocation_script'

# Timeout configurations
NETWORK_TIMEOUT = 30 # seconds for HTTP requests (GitHub API calls)
SLURM_COMMAND_TIMEOUT = 60 # seconds for SLURM commands (sbatch, sacct, etc.)
THREAD_SLEEP_TIMEOUT = 5 # seconds between polling cycles for threads

REPOS_TO_MONITOR = [
{
Expand Down
180 changes: 133 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from config import ALLOCATE_RUNNER_SCRIPT_PATH, REPOS_TO_MONITOR
from KubernetesLogFormatter import KubernetesLogFormatter
from runner_size_config import get_runner_resources
from config import NETWORK_TIMEOUT, SLURM_COMMAND_TIMEOUT, THREAD_SLEEP_TIMEOUT
from RunningJob import RunningJob

logger = logging.getLogger()
Expand Down Expand Up @@ -64,7 +65,7 @@ def get_gh_api(url, token, etag=None):
if etag:
headers["If-None-Match"] = etag

response = requests.get(url, headers=headers)
response = requests.get(url, headers=headers, timeout=NETWORK_TIMEOUT)

response.raise_for_status()

Expand Down Expand Up @@ -92,18 +93,28 @@ def get_gh_api(url, token, etag=None):
else:
logger.error(f"Unexpected status code: {response.status_code}")
return None, etag
except requests.exceptions.Timeout:
logger.error(
f"GitHub API request timed out after {NETWORK_TIMEOUT} seconds: {url}"
)
return None, etag
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error while calling GitHub API: {e}")
return None, etag
except requests.exceptions.RequestException as e:
logger.error(f"Exception occurred while calling GitHub API: {e}")
return None, etag


def poll_github_actions_and_allocate_runners(token, sleep_time=5):
def poll_github_actions_and_allocate_runners(token, sleep_time=THREAD_SLEEP_TIMEOUT):
"""
Polls each repository in REPOS_TO_MONITOR for queued workflows, then tries
to allocate ephemeral runners.
"""
global POLLED_WITHOUT_ALLOCATING

logger.info(f"Starting GitHub Actions polling thread with {sleep_time}s intervals")

while True:
try:
something_allocated = False
Expand Down Expand Up @@ -219,19 +230,44 @@ def allocate_actions_runner(job_id, token, repo_api_base_url, repo_url, repo_nam
reg_url = f"{repo_api_base_url}/actions/runners/registration-token"
remove_url = f"{repo_api_base_url}/actions/runners/remove-token"

reg_resp = requests.post(reg_url, headers=headers)
reg_resp.raise_for_status()
reg_data = reg_resp.json()
registration_token = reg_data["token"]
try:
reg_resp = requests.post(reg_url, headers=headers, timeout=NETWORK_TIMEOUT)
reg_resp.raise_for_status()
reg_data = reg_resp.json()
registration_token = reg_data["token"]
logger.debug("Successfully obtained registration token")
except requests.exceptions.Timeout:
logger.error(
f"Registration token request timed out after {NETWORK_TIMEOUT} seconds"
)
del allocated_jobs[(repo_name, job_id)]
return False
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get registration token: {e}")
del allocated_jobs[(repo_name, job_id)]
return False

# recommended small delay https://docs.github.com/en/rest/using-the-rest-api/best-practices-for-using-the-rest-api?apiVersion=2022-11-28#pause-between-mutative-requests
time.sleep(1)

# Get removal token
remove_resp = requests.post(remove_url, headers=headers)
remove_resp.raise_for_status()
remove_data = remove_resp.json()
removal_token = remove_data["token"]
try:
remove_resp = requests.post(
remove_url, headers=headers, timeout=NETWORK_TIMEOUT
)
remove_resp.raise_for_status()
remove_data = remove_resp.json()
removal_token = remove_data["token"]
except requests.exceptions.Timeout:
logger.error(
f"Removal token request timed out after {NETWORK_TIMEOUT} seconds"
)
del allocated_jobs[(repo_name, job_id)]
return False
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get removal token: {e}")
del allocated_jobs[(repo_name, job_id)]
return False

# Get job details to see labels
job_api_url = f"{repo_api_base_url}/actions/jobs/{job_id}"
Expand Down Expand Up @@ -287,40 +323,55 @@ def allocate_actions_runner(job_id, token, repo_api_base_url, repo_url, repo_nam
]

logger.info(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
output = result.stdout.strip()
error_output = result.stderr.strip()
logger.info(f"Command stdout: {output}")
if error_output:
logger.error(f"Command stderr: {error_output}")

# Attempt to parse the SLURM job ID from output (e.g. "Submitted batch job 3828")
if result.returncode == 0:
try:
slurm_job_id = int(output.split()[-1])
# Store the SLURM job ID in allocated_jobs
allocated_jobs[(repo_name, job_id)] = RunningJob(
repo=repo_name,
job_id=job_id,
slurm_job_id=slurm_job_id,
workflow_name=job_data["workflow_name"],
job_name=job_data["name"],
labels=labels,
)
logger.info(
f"Allocated runner for job {job_id} in {repo_name} with SLURM job ID {slurm_job_id}."
)
return True
except (IndexError, ValueError) as parse_err:
try:
result = subprocess.run(
command, capture_output=True, text=True, timeout=SLURM_COMMAND_TIMEOUT
)
output = result.stdout.strip()
error_output = result.stderr.strip()
logger.info(f"Command stdout: {output}")
if error_output:
logger.error(f"Command stderr: {error_output}")

# Attempt to parse the SLURM job ID from output (e.g. "Submitted batch job 3828")
if result.returncode == 0:
try:
slurm_job_id = int(output.split()[-1])
# Store the SLURM job ID in allocated_jobs
allocated_jobs[(repo_name, job_id)] = RunningJob(
repo=repo_name,
job_id=job_id,
slurm_job_id=slurm_job_id,
workflow_name=job_data["workflow_name"],
job_name=job_data["name"],
labels=labels,
)
logger.info(
f"Allocated runner for job {job_id} in {repo_name} with SLURM job ID {slurm_job_id}."
)
return True
except (IndexError, ValueError) as parse_err:
logger.error(
f"Failed to parse SLURM job ID from: {output}. Error: {parse_err}"
)
else:
logger.error(
f"Failed to parse SLURM job ID from: {output}. Error: {parse_err}"
f"sbatch command failed with return code {result.returncode}"
)
else:
logger.error(f"sbatch command failed with return code {result.returncode}")

# If we get here, something failed, so remove from tracking and consider retry
del allocated_jobs[(repo_name, job_id)]
return False
# If we get here, something failed, so remove from tracking and consider retry
del allocated_jobs[(repo_name, job_id)]
return False
except subprocess.TimeoutExpired:
logger.error(
f"SLURM command timed out after {SLURM_COMMAND_TIMEOUT} seconds: {' '.join(command)}"
)
del allocated_jobs[(repo_name, job_id)]
return False
except subprocess.SubprocessError as e:
logger.error(f"Subprocess error running SLURM command: {e}")
del allocated_jobs[(repo_name, job_id)]
return False

except Exception as e:
logger.error(f"Exception in allocate_actions_runner for job_id {job_id}: {e}")
Expand Down Expand Up @@ -353,7 +404,12 @@ def check_slurm_status():
]

try:
sacct_result = subprocess.run(sacct_cmd, capture_output=True, text=True)
logger.debug(
f"Checking SLURM job status for job ID: {running_job.slurm_job_id}"
)
sacct_result = subprocess.run(
sacct_cmd, capture_output=True, text=True, timeout=SLURM_COMMAND_TIMEOUT
)
if sacct_result.returncode != 0:
logger.error(
f"sacct command failed with return code {sacct_result.returncode}"
Expand Down Expand Up @@ -408,6 +464,14 @@ def check_slurm_status():
)
to_remove.append(job_id)

except subprocess.TimeoutExpired:
logger.error(
f"SLURM status check timed out after {SLURM_COMMAND_TIMEOUT} seconds for job ID {running_job.slurm_job_id}"
)
except subprocess.SubprocessError as e:
logger.error(
f"Subprocess error checking SLURM job status for job ID {running_job.slurm_job_id}: {e}"
)
except Exception as e:
logger.error(
f"Error querying SLURM job details for job ID {running_job.slurm_job_id}: {e}"
Expand All @@ -418,10 +482,12 @@ def check_slurm_status():
del allocated_jobs[key]


def poll_slurm_statuses(sleep_time=5):
def poll_slurm_statuses(sleep_time=THREAD_SLEEP_TIMEOUT):
"""
Wrapper function to poll check_slurm_status.
"""
logger.info(f"Starting SLURM status polling thread with {sleep_time}s intervals")

while True:
try:
check_slurm_status()
Expand All @@ -431,18 +497,38 @@ def poll_slurm_statuses(sleep_time=5):


if __name__ == "__main__":
logger.info("Starting SLURM GitHub Actions runner with timeout configurations:")
logger.info(f" Network timeout: {NETWORK_TIMEOUT}s")
logger.info(f" SLURM command timeout: {SLURM_COMMAND_TIMEOUT}s")
logger.info(f" Thread sleep timeout: {THREAD_SLEEP_TIMEOUT}s")

# Thread to poll GitHub for new queued workflows
github_thread = threading.Thread(
target=poll_github_actions_and_allocate_runners, args=(GITHUB_ACCESS_TOKEN, 2)
target=poll_github_actions_and_allocate_runners,
args=(GITHUB_ACCESS_TOKEN, 2),
name="GitHub-Poller",
)

# Thread to poll SLURM job statuses
slurm_thread = threading.Thread(
target=poll_slurm_statuses, kwargs={"sleep_time": 5}
target=poll_slurm_statuses,
kwargs={"sleep_time": THREAD_SLEEP_TIMEOUT},
name="SLURM-Status-Poller",
)

# Set threads as daemon so they exit when main thread exits
github_thread.daemon = True
slurm_thread.daemon = True

github_thread.start()
slurm_thread.start()

github_thread.join()
slurm_thread.join()
try:
github_thread.join()
slurm_thread.join()
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down gracefully...")
except Exception as e:
logger.error(f"Unexpected error in main thread: {e}")
finally:
logger.info("SLURM GitHub Actions runner shutdown complete.")