Open
Description
Here is the code
class ConnectionPool(object):
def __init__(self, size, hosts, **kwargs):
if not isinstance(size, int):
raise TypeError("Pool 'size' arg must be an integer")
if not size > 0:
raise ValueError("Pool 'size' arg must be greater than zero")
logger.debug(
"Initializing connection pool with %d connections", size)
self._lock = threading.Lock()
self._queue = queue.LifoQueue(maxsize=size)
self._thread_connections = threading.local()
self._hosts = hosts
self.connection_kwargs = kwargs
self.connection_kwargs = kwargs
self.connection_kwargs['autoconnect'] = False
# add connection
host_size = len(hosts)
for i in range(size):
self.connection_kwargs['host'] = hosts[i % host_size]
connection = Connection(**self.connection_kwargs)
self._queue.put(connection)
def _acquire_connection(self, timeout=None):
try:
return self._queue.get(True, timeout)
except queue.Empty:
raise NoConnectionsAvailable(
"No connection available from pool within specified "
"timeout")
def _return_connection(self, connection):
self._queue.put(connection)
@contextlib.contextmanager
def connection(self, timeout=None):
connection = getattr(self._thread_connections, 'current', None)
return_after_use = False
if connection is None:
return_after_use = True
connection = self._acquire_connection(timeout)
with self._lock:
self._thread_connections.current = connection
try:
yield connection
except (TException, socket.error) as e:
logger.info("Replacing tainted pool connection")
# add new connection
host = connection.host
_connection_kwargs = self.connection_kwargs
for _host in self._hosts:
if host != _host:
_connection_kwargs['host'] = _host
break
connection = Connection(**_connection_kwargs)
with self._lock:
self._thread_connections.current = connection
raise
finally:
if return_after_use:
connection = self._thread_connections.current
del self._thread_connections.current
connection.close()
self._return_connection(connection)
Metadata
Metadata
Assignees
Labels
No labels