Skip to content

Commit 8f1ede2

Browse files
authored
Merge pull request #4 from omsf/feat/auto-updating-ami
feat: add support for pulling the latest AMI
2 parents 436a717 + 3bb6930 commit 8f1ede2

File tree

5 files changed

+165
-62
lines changed

5 files changed

+165
-62
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ This repository contains the code to start a GitHub Actions runner on an AWS EC2
66
| arch | The AMI architecture | true | x64 |
77
| aws_home_dir | The AWS AMI home directory to use for your runner. Will not start if not specified. | true | |
88
| aws_iam_role | The optional AWS IAM role to assume for provisioning your runner. | false | |
9-
| aws_image_id | The machine AMI to use for your runner. This AMI can be a default but should have docker installed in the AMI. | true | |
9+
| aws_image_id | The machine AMI to use for your runner. This AMI can be a default but should have docker installed in the AMI. If set to `latest`, aws_image_name is required | true | |
10+
| aws_image_name | The name of AMI you want to use, only required if you don't specify `aws_image_id` | false | |
1011
| aws_instance_type | The type of instance to use for your runner. For example: t2.micro, t4g.nano, etc. Will not start if not specified.| true | |
1112
| aws_region_name | The AWS region name to use for your runner. Defaults to AWS_REGION | true | |
1213
| aws_root_device_size | The root device size in GB to use for your runner. | false | The AMI default root disk size |
@@ -46,7 +47,8 @@ jobs:
4647
id: aws-start
4748
uses: omsf/start-aws-gha-runner@v1.0.0
4849
with:
49-
aws_image_id: ami-0f7c4a792e3fb63c8
50+
aws_image_name: "Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 24.04)"
51+
aws_image_id: latest
5052
aws_instance_type: g4dn.xlarge
5153
aws_home_dir: /home/ubuntu
5254
env:

action.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ inputs:
1111
description: "The optional AWS IAM role to assume for provisioning your runner."
1212
required: false
1313
aws_image_id:
14-
description: "The machine AMI to use for your runner. This AMI can be a default but should have docker installed in the AMI. Will not start if not specified."
14+
description: "The machine AMI to use for your runner. This AMI can be a default but should have docker installed in the AMI. Will not start if not specified. If set to `latest`, `aws_image_name` is required"
1515
required: true
16+
aws_image_name:
17+
description: "The AMI name. Only required if `aws_image_id` is set to latest. Be as specific as possible. For example: Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 24.04)"
1618
aws_instance_type:
1719
description: "The type of instance to use for your runner. For example: t2.micro, t4g.nano, etc. Will not start if not specified."
1820
required: true

src/start_aws_gha_runner/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def main():
2020
builder = (
2121
EnvVarBuilder(env)
2222
.update_state("INPUT_AWS_IMAGE_ID", "image_id")
23+
.update_state("INPUT_AWS_IMAGE_NAME", "image_name")
2324
.update_state("INPUT_AWS_INSTANCE_TYPE", "instance_type")
2425
.update_state("INPUT_AWS_SUBNET_ID", "subnet_id")
2526
.update_state("INPUT_AWS_SECURITY_GROUP_ID", "security_group_id")

src/start_aws_gha_runner/start.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class StartAWS(CreateCloudInstance):
5252
repo: str
5353
region_name: str
5454
runner_release: str = ""
55+
image_name: str = ""
5556
tags: list[dict[str, str]] = field(default_factory=list)
5657
gh_runner_tokens: list[str] = field(default_factory=list)
5758
root_device_size: int = 0
@@ -119,8 +120,28 @@ def _build_user_data(self, **kwargs) -> str:
119120
except Exception as e:
120121
raise Exception(f"Error parsing user data template: {e}")
121122

123+
def _fetch_latest_ami(
124+
self, client, ami_name: str, owner: str = "amazon"
125+
) -> str:
126+
out = client.describe_images(
127+
Owners=[owner],
128+
Filters=[
129+
{
130+
"Name": "name",
131+
"Values": [f"{ami_name}*"],
132+
},
133+
{"Name": "state", "Values": ["available"]},
134+
],
135+
)
136+
images = out.get("Images", [])
137+
newest = sorted(images, key=lambda i: i["CreationDate"], reverse=True)[
138+
0
139+
]
140+
141+
return newest["ImageId"]
142+
122143
def _modify_root_disk_size(self, client, params: dict) -> dict:
123-
""" Modify the root disk size of the instance.
144+
"""Modify the root disk size of the instance.
124145
125146
Parameters
126147
----------
@@ -146,11 +167,15 @@ def _modify_root_disk_size(self, client, params: dict) -> dict:
146167
if "DryRunOperation" in str(e):
147168
image_options = client.describe_images(ImageIds=[self.image_id])
148169
root_device_name = image_options["Images"][0]["RootDeviceName"]
149-
block_devices = deepcopy(image_options["Images"][0]["BlockDeviceMappings"])
170+
block_devices = deepcopy(
171+
image_options["Images"][0]["BlockDeviceMappings"]
172+
)
150173
for idx, block_device in enumerate(block_devices):
151174
if block_device["DeviceName"] == root_device_name:
152175
if self.root_device_size > 0:
153-
block_devices[idx]["Ebs"]["VolumeSize"] = self.root_device_size
176+
block_devices[idx]["Ebs"]["VolumeSize"] = (
177+
self.root_device_size
178+
)
154179
params["BlockDeviceMappings"] = block_devices
155180
break
156181
else:
@@ -206,6 +231,14 @@ def create_instances(self) -> dict[str, str]:
206231
"runner_release": self.runner_release,
207232
"labels": labels,
208233
}
234+
# We need to handle the case where someone wants to always use latest
235+
if self.image_id == "latest":
236+
if not self.image_name:
237+
raise ValueError(
238+
"Looking for latest image but name not provided"
239+
)
240+
# This updates the image ID to the latest, will fail if image does not exist
241+
self.image_id = self._fetch_latest_ami(ec2, self.image_name)
209242
params = self._build_aws_params(user_data_params)
210243
if self.root_device_size > 0:
211244
params = self._modify_root_disk_size(ec2, params)

tests/test_start.py

Lines changed: 121 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from moto import mock_aws
3+
from moto.ec2.models import ec2_backends
34
import boto3
45
from unittest.mock import call, patch, mock_open, Mock
56
from start_aws_gha_runner.start import StartAWS
@@ -21,6 +22,26 @@ def aws():
2122
yield StartAWS(**params)
2223

2324

25+
@pytest.fixture(scope="function")
26+
def aws_latest_ami():
27+
with mock_aws():
28+
params = {
29+
"image_id": "latest",
30+
# This comes from https://github.com/getmoto/moto/blob/master/moto/ec2/resources/amis.json
31+
# These are AMIs used to mock out data.
32+
# For more info see here:
33+
# https://docs.getmoto.org/en/latest/docs/services/ec2.html
34+
"image_name": "Ubuntu CUDA9 DLAMI",
35+
"instance_type": "t2.micro",
36+
"region_name": "us-east-1",
37+
"gh_runner_tokens": ["testing"],
38+
"home_dir": "/home/ec2-user",
39+
"runner_release": "testing",
40+
"repo": "omsf-eco-infra/awsinfratesting",
41+
}
42+
yield StartAWS(**params)
43+
44+
2445
def test_build_user_data(aws):
2546
params = {
2647
"homedir": "/home/ec2-user",
@@ -57,6 +78,7 @@ def test_build_user_data_missing_params(aws):
5778
with pytest.raises(Exception):
5879
aws._build_user_data(**params)
5980

81+
6082
@pytest.fixture(scope="function")
6183
def complete_params():
6284
params = {
@@ -74,10 +96,11 @@ def complete_params():
7496
"subnet_id": "test",
7597
"security_group_id": "test",
7698
"iam_role": "test",
77-
"root_device_size": 100
99+
"root_device_size": 100,
78100
}
79101
yield params
80102

103+
81104
def test_build_aws_params(complete_params):
82105
user_data_params = {
83106
"token": "test",
@@ -119,40 +142,37 @@ def test_build_aws_params(complete_params):
119142
],
120143
}
121144

145+
122146
def test_modify_root_disk_size(complete_params):
123147
mock_client = Mock()
124148

125149
# Mock image data with all device mappings
126150
mock_image_data = {
127-
"Images": [{
128-
"RootDeviceName": "/dev/sda1",
129-
"BlockDeviceMappings": [
130-
{
131-
"Ebs": {
132-
"DeleteOnTermination": True,
133-
"VolumeSize": 50,
134-
"VolumeType": "gp3",
135-
"Encrypted": False
151+
"Images": [
152+
{
153+
"RootDeviceName": "/dev/sda1",
154+
"BlockDeviceMappings": [
155+
{
156+
"Ebs": {
157+
"DeleteOnTermination": True,
158+
"VolumeSize": 50,
159+
"VolumeType": "gp3",
160+
"Encrypted": False,
161+
},
162+
"DeviceName": "/dev/sda1",
136163
},
137-
"DeviceName": "/dev/sda1"
138-
},
139-
{
140-
"DeviceName": "/dev/sdb",
141-
"VirtualName": "ephemeral0"
142-
},
143-
{
144-
"DeviceName": "/dev/sdc",
145-
"VirtualName": "ephemeral1"
146-
}
147-
]
148-
}]
164+
{"DeviceName": "/dev/sdb", "VirtualName": "ephemeral0"},
165+
{"DeviceName": "/dev/sdc", "VirtualName": "ephemeral1"},
166+
],
167+
}
168+
]
149169
}
150170

151171
def mock_describe_images(**kwargs):
152-
if kwargs.get('DryRun', False):
172+
if kwargs.get("DryRun", False):
153173
raise ClientError(
154174
error_response={"Error": {"Code": "DryRunOperation"}},
155-
operation_name="DescribeImages"
175+
operation_name="DescribeImages",
156176
)
157177
return mock_image_data
158178

@@ -168,65 +188,57 @@ def mock_describe_images(**kwargs):
168188
"DeleteOnTermination": True,
169189
"VolumeSize": 100,
170190
"VolumeType": "gp3",
171-
"Encrypted": False
172-
}
173-
},
174-
{
175-
"DeviceName": "/dev/sdb",
176-
"VirtualName": "ephemeral0"
191+
"Encrypted": False,
192+
},
177193
},
178-
{
179-
"DeviceName": "/dev/sdc",
180-
"VirtualName": "ephemeral1"
181-
}
194+
{"DeviceName": "/dev/sdb", "VirtualName": "ephemeral0"},
195+
{"DeviceName": "/dev/sdc", "VirtualName": "ephemeral1"},
182196
]
183197
}
184198
assert out == expected_output
185199

200+
186201
def test_modify_root_disk_size_permission_error(complete_params):
187202
mock_client = Mock()
188203

189204
# Mock permission denied error
190205
mock_client.describe_images.side_effect = ClientError(
191-
error_response={'Error': {'Code': 'AccessDenied'}},
192-
operation_name='DescribeImages'
206+
error_response={"Error": {"Code": "AccessDenied"}},
207+
operation_name="DescribeImages",
193208
)
194209

195210
aws = StartAWS(**complete_params)
196211

197212
with pytest.raises(ClientError) as exc_info:
198213
aws._modify_root_disk_size(mock_client, {})
199214

200-
assert 'AccessDenied' in str(exc_info.value)
215+
assert "AccessDenied" in str(exc_info.value)
216+
201217

202218
def test_modify_root_disk_size_no_change(complete_params):
203219
mock_client = Mock()
204220
complete_params["root_device_size"] = 0
205221

206222
mock_image_data = {
207-
"Images": [{
208-
"RootDeviceName": "/dev/sda1",
209-
"BlockDeviceMappings": [
210-
{
211-
"DeviceName": "/dev/sda1",
212-
"Ebs": {
213-
"VolumeSize": 50,
214-
"VolumeType": "gp3"
215-
}
216-
},
217-
{
218-
"DeviceName": "/dev/sdb",
219-
"VirtualName": "ephemeral0"
220-
}
221-
]
222-
}]
223+
"Images": [
224+
{
225+
"RootDeviceName": "/dev/sda1",
226+
"BlockDeviceMappings": [
227+
{
228+
"DeviceName": "/dev/sda1",
229+
"Ebs": {"VolumeSize": 50, "VolumeType": "gp3"},
230+
},
231+
{"DeviceName": "/dev/sdb", "VirtualName": "ephemeral0"},
232+
],
233+
}
234+
]
223235
}
224236

225237
def mock_describe_images(**kwargs):
226-
if kwargs.get('DryRun', False):
238+
if kwargs.get("DryRun", False):
227239
raise ClientError(
228-
error_response={'Error': {'Code': 'DryRunOperation'}},
229-
operation_name='DescribeImages'
240+
error_response={"Error": {"Code": "DryRunOperation"}},
241+
operation_name="DescribeImages",
230242
)
231243
return mock_image_data
232244

@@ -239,6 +251,59 @@ def mock_describe_images(**kwargs):
239251
# With root_device_size = 0, no modifications should be made
240252
assert result == input_params
241253

254+
255+
@pytest.fixture(scope="function")
256+
def complete_params_latest():
257+
params = {
258+
"image_name": "Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04)",
259+
"image_id": "latest",
260+
"instance_type": "t2.micro",
261+
"tags": [
262+
{"Key": "Name", "Value": "test"},
263+
{"Key": "Owner", "Value": "test"},
264+
],
265+
"region_name": "us-east-1",
266+
"gh_runner_tokens": ["testing"],
267+
"home_dir": "/home/ec2-user",
268+
"runner_release": "testing",
269+
"repo": "omsf-eco-infra/awsinfratesting",
270+
"subnet_id": "test",
271+
"security_group_id": "test",
272+
"iam_role": "test",
273+
"root_device_size": 100,
274+
}
275+
yield params
276+
277+
278+
def test_fetch_latest_ami(complete_params_latest):
279+
mock_client = Mock()
280+
281+
mock_image_data = {
282+
"Images": [
283+
{"CreationDate": "2025-08-03", "ImageId": "ami-12345678"},
284+
{"CreationDate": "2025-08-05", "ImageId": "ami-89123456"},
285+
{"CreationDate": "2025-09-05", "ImageId": "ami-89121111"},
286+
]
287+
}
288+
mock_client.describe_images.return_value = mock_image_data
289+
aws = StartAWS(**complete_params_latest)
290+
result = aws._fetch_latest_ami(mock_client, "Test")
291+
assert result == "ami-89121111"
292+
293+
294+
def test_create_instances_latest(aws_latest_ami):
295+
ids = aws_latest_ami.create_instances()
296+
assert len(ids) == 1
297+
298+
299+
def test_create_instatnces_latest_no_name(aws_latest_ami):
300+
aws_latest_ami.image_name = ""
301+
with pytest.raises(
302+
ValueError, match="Looking for latest image but name not provided"
303+
):
304+
aws_latest_ami.create_instances()
305+
306+
242307
def test_create_instance_with_labels(aws):
243308
aws.labels = "test"
244309
ids = aws.create_instances()

0 commit comments

Comments
 (0)