Skip to content
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Changed

- Add support for launching jobs on instances with GPUs

## [0.14.0] - 2023-12-05

### Authors
Expand Down
25 changes: 24 additions & 1 deletion covalent_gcpbatch_plugin/gcpbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class ExecutorPluginDefaults(BaseModel):
project_id: str = ""
region: str = ""
vcpus: int = 2
num_gpus: int = 0
gpu_type: str = "nvidia-tesla-t4"
memory: int = 512
time_limit: float = 300
poll_freq: int = 5
Expand All @@ -58,6 +60,8 @@ class ExecutorInfraDefaults(BaseModel):
project_id: str = "covalenttesting"
access_token: str = ""
vcpus: Optional[int] = 2
num_gpus: Optional[int] = 0
gpu_type: Optional[str] = ""
memory: Optional[float] = 512
time_limit: Optional[int] = 300
poll_freq: Optional[int] = 5
Expand Down Expand Up @@ -87,6 +91,8 @@ class GCPBatchExecutor(RemoteExecutor):
project_id: Google project ID
region: Google region
vcpus: Number of virtual CPU cores needed by the job
num_gpus: Number of GPUs available to a task
gpu_type: Type of GPU to allocate (see `gcloud compute accelerator-types list`)
memory: Memory requirement for the job in (MB)a
time_limit: Number of seconds to wait before the job is considered to have failed
poll_freq: Frequency with which the poll the bucket and job for results
Expand All @@ -106,6 +112,8 @@ def __init__(
project_id: Optional[str] = None,
region: Optional[str] = None,
vcpus: Optional[int] = None,
num_gpus: Optional[int] = None,
gpu_type: Optional[str] = None,
memory: Optional[int] = None,
time_limit: Optional[int] = None,
poll_freq: Optional[int] = None,
Expand All @@ -122,6 +130,8 @@ def __init__(
"executors.gcpbatch.service_account_email"
)
self.vcpus = vcpus or int(get_config("executors.gcpbatch.vcpus"))
self.num_gpus = num_gpus or int(get_config("executors.gcpbatch.num_gpus"))
self.gpu_type = gpu_type or get_config("executors.gcpbatch.gpu_type")
self.memory = memory or int(get_config("executors.gcpbatch.memory"))
self.time_limit = time_limit or int(get_config("executors.gcpbatch.time_limit"))
self.poll_freq = poll_freq or int(get_config("executors.gcpbatch.poll_freq"))
Expand Down Expand Up @@ -296,9 +306,22 @@ def _create_batch_job_sync(
# Create task group
task_group = batch_v1.TaskGroup(task_count=1, task_spec=task_spec)

# Create an InstancePolicyOrTemplate
if self.num_gpus > 0:
accelerators = [
batch_v1.AllocationPolicy.Accelerator(type_=self.gpu_type, count=self.num_gpus)
]
else:
accelerators = []

instance = batch_v1.AllocationPolicy.InstancePolicyOrTemplate(
install_gpu_drivers=self.num_gpus > 0,
policy=batch_v1.AllocationPolicy.InstancePolicy(accelerators=accelerators),
)

# Set job's allocation policies
alloc_policy = batch_v1.AllocationPolicy(
service_account={"email": self.service_account_email}
instances=[instance], service_account={"email": self.service_account_email}
)

# Set the cloud logging policy on the job
Expand Down
9 changes: 7 additions & 2 deletions tests/gcpbatch_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def test_executor_explicit_constructor(mocker):
service_account_email="test-email",
region="test-region",
vcpus=2,
num_gpus=1,
gpu_type="nvidia-tesla-a100",
memory=256,
time_limit=300,
poll_freq=2,
Expand All @@ -66,6 +68,8 @@ def test_executor_explicit_constructor(mocker):
assert test_executor.service_account_email == "test-email"
assert test_executor.region == "test-region"
assert test_executor.vcpus == 2
assert test_executor.num_gpus == 1
assert test_executor.gpu_type == "nvidia-tesla-a100"
assert test_executor.memory == 256
assert test_executor.time_limit == 300
assert test_executor.poll_freq == 2
Expand All @@ -80,7 +84,7 @@ def test_executor_default_constructor(mocker):
"""
mock_get_config = mocker.patch("covalent_gcpbatch_plugin.gcpbatch.get_config")
GCPBatchExecutor()
assert mock_get_config.call_count == 11
assert mock_get_config.call_count == 13


def test_get_batch_service_client(gcpbatch_executor, mocker):
Expand Down Expand Up @@ -277,7 +281,8 @@ async def test_create_batch_job(gcpbatch_executor, mocker):
task_count=1, task_spec=mock_batch_v1.TaskSpec.return_value
)
mock_batch_v1.AllocationPolicy.assert_called_once_with(
service_account={"email": gcpbatch_executor.service_account_email}
instances=[mock_batch_v1.AllocationPolicy.InstancePolicyOrTemplate.return_value],
service_account={"email": gcpbatch_executor.service_account_email},
)
mock_batch_v1.LogsPolicy.assert_called_once_with(destination="CLOUD_LOGGING")
mock_batch_v1.Job.assert_called_once_with(
Expand Down