diff --git a/opensearchpy/_async/http_aiohttp.py b/opensearchpy/_async/http_aiohttp.py index a3c98c1a5..d5d5a152d 100644 --- a/opensearchpy/_async/http_aiohttp.py +++ b/opensearchpy/_async/http_aiohttp.py @@ -29,7 +29,7 @@ import os import ssl import warnings -from typing import Any, Collection, Mapping, Optional, Union +from typing import Any, Callable, Collection, Mapping, Optional, Union import urllib3 @@ -146,9 +146,14 @@ def __init__( ) if http_auth is not None: - if isinstance(http_auth, (tuple, list)): - http_auth = ":".join(http_auth) - self.headers.update(urllib3.make_headers(basic_auth=http_auth)) + if isinstance(http_auth, Callable): # type: ignore + pass + elif isinstance(http_auth, (tuple, list)): + self.headers.update( + urllib3.make_headers(basic_auth=":".join(http_auth)) + ) + else: + self.headers.update(urllib3.make_headers(basic_auth=http_auth)) # if providing an SSL context, raise error if any other SSL related flag is used if ssl_context and ( @@ -285,6 +290,9 @@ async def perform_request( if headers: req_headers.update(headers) + if isinstance(self._http_auth, Callable): # type: ignore + req_headers.update(self._http_auth(method, str(url), body)) + if self.http_compress and body: body = self._gzip_compress(body) req_headers["content-encoding"] = "gzip"