Skip to content

Commit b87b886

Browse files
authored
Add fixes to remove hostname assumption and avoid using alpha in commands (#35)
* Add fixes to remove hostname assumption and avoid * don't alpha describe
1 parent de7f7cb commit b87b886

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

cloud/envs/gcp.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,20 @@ def driver(self):
4444
self._driver = get_driver(Provider.GCE)("", "", project=project_id)
4545
return self._driver
4646

47+
@property
48+
def zone(self):
49+
if getattr(self, '_zone', None) is None:
50+
r = requests.get("http://metadata.google.internal/computeMetadata/v1/instance/zone",
51+
headers={"Metadata-Flavor": "Google"})
52+
self._zone = r.text
53+
return self._zone
54+
4755
@property
4856
def name(self):
4957
if getattr(self, '_name', None) is None:
50-
self._name = utils.call(["hostname"])[1].strip()
58+
r = requests.get("http://metadata.google.internal/computeMetadata/v1/instance/name",
59+
headers={"Metadata-Flavor": "Google"})
60+
self._name = r.text
5161
return self._name
5262

5363

@@ -69,7 +79,7 @@ def name(self):
6979
@property
7080
def details(self):
7181
_, r, _ = utils.call(
72-
["gcloud", "alpha", "compute", "tpus", "describe", "--zone={}".format(self.manager.zone), self.name])
82+
["gcloud", "compute", "tpus", "describe", "--zone={}".format(self.manager.zone), self.name])
7383
r = r.split("\n")
7484
details = dict()
7585
for line in r:
@@ -220,11 +230,10 @@ def __init__(self, instance):
220230
except:
221231
logger.warn("Unable to determine Tensorflow version. Assuming 1.15")
222232
self.tf_version = "1.15"
233+
223234
self.hostname = socket.gethostname()
224-
_, r, _ = utils.call(["gcloud", "compute", "instances", "list", "--filter=\"name={}\"".format(self.hostname)])
225-
lines = r.split("\n")[1:]
226-
lines = list(filter(lambda l: l != "", lines))
227-
self.zone = lines[0].split()[1]
235+
self.zone = instance.zone.split('/')[-1]
236+
228237
from cloud import socket_path
229238
self.lockfile = TPULockFile(os.path.join("~", ".tpu_registry"))
230239
self.refresh()
@@ -238,7 +247,7 @@ def ips(self):
238247
return [r.ip for r in self.resources]
239248

240249
def get_all_tpu_names(self):
241-
_, r, _ = utils.call(["gcloud", "alpha", "compute", "tpus", "list", "--zone={}".format(self.zone)])
250+
_, r, _ = utils.call(["gcloud", "compute", "tpus", "list", "--zone={}".format(self.zone)])
242251
lines = r.split("\n")[1:]
243252
lines = filter(lambda l: l != "", lines)
244253
names = [l.split()[0] for l in lines]
@@ -308,7 +317,7 @@ def get(self, preemptible=True, name=None, version='v3-8', zone=None):
308317
def _up(self, name, ip, preemptible, version, zone, background):
309318
logger.info("Trying to acquire TPU with name: {} ip: {}".format(name, ip))
310319
cmd = [
311-
"gcloud", "alpha", "compute", "tpus", "create", name, "--range=10.0.{}.0".format(ip),
320+
"gcloud", "compute", "tpus", "create", name, "--range=10.0.{}.0".format(ip),
312321
"--accelerator-type={}".format(version), "--version={}".format(self.tf_version), "--network=default"
313322
]
314323
if zone:

0 commit comments

Comments
 (0)