Skip to content

Include VertexAI cluster environment for Fabric #19911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/source-fabric/api/environments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ Environments
~slurm.SLURMEnvironment
~torchelastic.TorchElasticEnvironment
~xla.XLAEnvironment
~vertexai.VertexAIEnvironment
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ You can define the interface of your own cluster environment based on the requir
SLURMEnvironment
TorchElasticEnvironment
XLAEnvironment
VertexAIEnvironment
6 changes: 6 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unReleased] - 2024-05-27

### Added
- `VertexAIEnvironment` to run DDP on [vertex AI custom training jobs](https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-spec-format)


## [unReleased] - 2024-MM-DD

### Added
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from lightning.fabric.plugins.environments.mpi import MPIEnvironment # noqa: F401
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment # noqa: F401
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment # noqa: F401
from lightning.fabric.plugins.environments.vertexai import VertexAIEnvironment # noqa: F401
from lightning.fabric.plugins.environments.xla import XLAEnvironment # noqa: F401
54 changes: 54 additions & 0 deletions src/lightning/fabric/plugins/environments/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
import json

from lightning.fabric.plugins.environments.lightning import LightningEnvironment


class VertexAIEnvironment(LightningEnvironment):
"""
Configures distributed training on a vertex ai custom training job,
More information:
https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-spec-format

Example:
Consider a cluster with 3 nodes, each composed of 2 gpus

The "cluster" key in CLUSTER_SPEC will be:
{
'workerpool0': ['cmle-training-workerpool0-d604929a6a-0:2222'],
'workerpool1': [
'cmle-training-workerpool1-d604929a6a-0:2222',
'cmle-training-workerpool1-d604929a6a-1:2222'
]
}

and each process scheduled will be under the "task" key, following the same example
the three tasks will look like this:
task0 ("chief" spawn process) -> node 0:
{'type': 'workerpool0', 'index': 0}
task 1 (on first node on workerpool1) -> node 1:
{'type': 'workerpool1', 'index': 0}
task 2 (on second node on workerpool1) -> node 2:
{'type': 'workerpool1', 'index': 1}
"""

def __init__(self):
super().__init__()
assert "CLUSTER_SPEC" in os.environ
self.cluster_spec = json.loads(os.environ["CLUSTER_SPEC"])

@property
def main_address(self) -> str:
return self.cluster_spec["cluster"]["workerpool0"][0].split(':')[0]

@property
def main_port(self) -> int:
"""Set common fixed MASTER_PORT port across processes."""
return int(self.cluster_spec["cluster"]["workerpool0"][0].split(':')[1])

def node_rank(self) -> int:
task = self.cluster_spec["task"]
if task["type"] == "workerpool0":
return 0
else:
return task["index"] + 1
1 change: 1 addition & 0 deletions src/lightning/pytorch/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
MPIEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
VertexAIEnvironment,
XLAEnvironment,
)
Loading