1+ import json
12import logging
3+ import os
24import random
35import re
46import socket
57import string
68import subprocess
9+ from pathlib import Path
710
811import numpy as np
12+ import psutil
913import requests
14+ from filelock import FileLock
1015from libcloud .compute .providers import get_driver
1116from libcloud .compute .types import Provider
1217
@@ -81,7 +86,7 @@ def still_exists(self):
8186
8287 @property
8388 def free (self ):
84- return not self ._in_use
89+ return not self ._in_use and not self . manager . lockfile . check_if_in_use ( self )
8590
8691 @property
8792 def usable (self ):
@@ -131,9 +136,73 @@ def delete(self, background=True):
131136
132137 def in_use (self ):
133138 self ._in_use = True
139+ self .manager .lockfile .register_in_use (self )
134140
135141 def release (self ):
142+ assert self ._in_use
136143 self ._in_use = False
144+ self .manager .lockfile .register_free (self )
145+
146+
147+ class TPULockFile :
148+
149+ def __init__ (self , filepath ):
150+ self .filepath = Path (filepath ).expanduser ()
151+ self .lockpath = Path (filepath + ".lock" ).expanduser ()
152+ self .filelock = FileLock (self .lockpath )
153+
154+ if not self .filepath .exists ():
155+ self .filepath .touch ()
156+ if not self .lockpath .exists ():
157+ self .lockpath .touch ()
158+
159+ def _write_registry (self , registry ):
160+ f = open (self .filepath , "w" )
161+ f .write (json .dumps (registry ))
162+ f .close ()
163+
164+ def register_free (self , tpu ):
165+ with self .filelock :
166+ f = open (self .filepath , "r" )
167+ f_raw = f .read ()
168+ tpu_registry = json .loads (f_raw ) if f_raw else {}
169+ if tpu .name not in tpu_registry :
170+ return
171+
172+ del tpu_registry [tpu .name ]
173+ f .close ()
174+ self ._write_registry (tpu_registry )
175+
176+ def register_in_use (self , tpu ):
177+ with self .filelock :
178+ f = open (self .filepath , "r" )
179+ f_raw = f .read ()
180+ tpu_registry = json .loads (f_raw ) if f_raw else {}
181+ if tpu .name in tpu_registry :
182+ if os .getpid () == tpu_registry [tpu .name ]:
183+ pass
184+ elif psutil .pid_exists (tpu_registry [tpu .name ]):
185+ raise Exception ("TPU is already registered" )
186+ else :
187+ logger .warn (f"Forcefully acquiring TPU { tpu .name } from dead pid { tpu_registry [tpu .name ]} ." )
188+ tpu_registry [tpu .name ] = os .getpid ()
189+ f .close ()
190+ self ._write_registry (tpu_registry )
191+
192+ def check_if_in_use (self , tpu ):
193+ with self .filelock :
194+ f = open (self .filepath , "r" )
195+ f_raw = f .read ()
196+ tpu_registry = json .loads (f_raw ) if f_raw else {}
197+ if tpu .name in tpu_registry :
198+ if psutil .pid_exists (tpu_registry [tpu .name ]):
199+ return True
200+ else :
201+ logger .warn (f"Removing TPU { tpu .name } from dead pid { tpu_registry [tpu .name ]} ." )
202+ del tpu_registry [tpu .name ]
203+ self ._write_registry (tpu_registry )
204+
205+ return False
137206
138207
139208class TPUManager (env .ResourceManager ):
@@ -156,6 +225,8 @@ def __init__(self, instance):
156225 lines = r .split ("\n " )[1 :]
157226 lines = list (filter (lambda l : l != "" , lines ))
158227 self .zone = lines [0 ].split ()[1 ]
228+ from cloud import socket_path
229+ self .lockfile = TPULockFile (os .path .join ("~" , ".tpu_registry" ))
159230 self .refresh ()
160231
161232 @property
0 commit comments