@@ -44,10 +44,20 @@ def driver(self):
4444 self ._driver = get_driver (Provider .GCE )("" , "" , project = project_id )
4545 return self ._driver
4646
47+ @property
48+ def zone (self ):
49+ if getattr (self , '_zone' , None ) is None :
50+ r = requests .get ("http://metadata.google.internal/computeMetadata/v1/instance/zone" ,
51+ headers = {"Metadata-Flavor" : "Google" })
52+ self ._zone = r .text
53+ return self ._zone
54+
4755 @property
4856 def name (self ):
4957 if getattr (self , '_name' , None ) is None :
50- self ._name = utils .call (["hostname" ])[1 ].strip ()
58+ r = requests .get ("http://metadata.google.internal/computeMetadata/v1/instance/name" ,
59+ headers = {"Metadata-Flavor" : "Google" })
60+ self ._name = r .text
5161 return self ._name
5262
5363
@@ -69,7 +79,7 @@ def name(self):
6979 @property
7080 def details (self ):
7181 _ , r , _ = utils .call (
72- ["gcloud" , "alpha" , " compute" , "tpus" , "describe" , "--zone={}" .format (self .manager .zone ), self .name ])
82+ ["gcloud" , "compute" , "tpus" , "describe" , "--zone={}" .format (self .manager .zone ), self .name ])
7383 r = r .split ("\n " )
7484 details = dict ()
7585 for line in r :
@@ -220,11 +230,10 @@ def __init__(self, instance):
220230 except :
221231 logger .warn ("Unable to determine Tensorflow version. Assuming 1.15" )
222232 self .tf_version = "1.15"
233+
223234 self .hostname = socket .gethostname ()
224- _ , r , _ = utils .call (["gcloud" , "compute" , "instances" , "list" , "--filter=\" name={}\" " .format (self .hostname )])
225- lines = r .split ("\n " )[1 :]
226- lines = list (filter (lambda l : l != "" , lines ))
227- self .zone = lines [0 ].split ()[1 ]
235+ self .zone = instance .zone .split ('/' )[- 1 ]
236+
228237 from cloud import socket_path
229238 self .lockfile = TPULockFile (os .path .join ("~" , ".tpu_registry" ))
230239 self .refresh ()
@@ -238,7 +247,7 @@ def ips(self):
238247 return [r .ip for r in self .resources ]
239248
240249 def get_all_tpu_names (self ):
241- _ , r , _ = utils .call (["gcloud" , "alpha" , " compute" , "tpus" , "list" , "--zone={}" .format (self .zone )])
250+ _ , r , _ = utils .call (["gcloud" , "compute" , "tpus" , "list" , "--zone={}" .format (self .zone )])
242251 lines = r .split ("\n " )[1 :]
243252 lines = filter (lambda l : l != "" , lines )
244253 names = [l .split ()[0 ] for l in lines ]
@@ -308,7 +317,7 @@ def get(self, preemptible=True, name=None, version='v3-8', zone=None):
308317 def _up (self , name , ip , preemptible , version , zone , background ):
309318 logger .info ("Trying to acquire TPU with name: {} ip: {}" .format (name , ip ))
310319 cmd = [
311- "gcloud" , "alpha" , " compute" , "tpus" , "create" , name , "--range=10.0.{}.0" .format (ip ),
320+ "gcloud" , "compute" , "tpus" , "create" , name , "--range=10.0.{}.0" .format (ip ),
312321 "--accelerator-type={}" .format (version ), "--version={}" .format (self .tf_version ), "--network=default"
313322 ]
314323 if zone :
0 commit comments