diff --git a/disnake/client.py b/disnake/client.py index d5c1d8cf33..fa25b89797 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -398,7 +398,7 @@ def __init__( connector, proxy=proxy, proxy_auth=proxy_auth, - unsync_clock=assume_unsync_clock, + # unsync_clock=assume_unsync_clock, loop=self.loop, ) @@ -432,6 +432,7 @@ def __init__( self._first_connect: asyncio.Event = asyncio.Event() self._connection._get_websocket = self._get_websocket self._connection._get_client = lambda: self + self._token: str | None = None if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False @@ -862,7 +863,9 @@ async def login(self, token: str) -> None: if not isinstance(token, str): raise TypeError(f"token must be of type str, got {type(token).__name__} instead") - data = await self.http.static_login(token.strip()) + self._token = token.strip() + + data = await self.http.static_login(f"Bot {self._token}") self._connection.user = ClientUser(state=self._connection, data=data) async def connect( diff --git a/disnake/errors.py b/disnake/errors.py index 21a1834dff..3bcf04def4 100644 --- a/disnake/errors.py +++ b/disnake/errors.py @@ -19,6 +19,7 @@ "NoMoreItems", "GatewayNotFound", "HTTPException", + "Unauthorized", "Forbidden", "NotFound", "DiscordServerError", @@ -134,6 +135,14 @@ def __init__( super().__init__(fmt.format(self.response, self.code, self.text)) +class Unauthorized(HTTPException): + """Exception that's raised for when status code 401 occurs. + Subclass of :exc:`HTTPException` + """ + + pass + + class Forbidden(HTTPException): """Exception that's raised for when status code 403 occurs. diff --git a/disnake/gateway.py b/disnake/gateway.py index 2081493509..abfc29a252 100644 --- a/disnake/gateway.py +++ b/disnake/gateway.py @@ -423,7 +423,7 @@ async def from_client( ws = cls(socket, loop=client.loop) # dynamically add attributes needed - ws.token = client.http.token # type: ignore + ws.token = client._token # type: ignore ws._connection = client._connection ws._discord_parsers = client._connection.parsers ws._dispatch = client.dispatch diff --git a/disnake/http.py b/disnake/http.py index f289620a8d..4acfac3987 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -6,8 +6,7 @@ import logging import re import sys -import weakref -from errno import ECONNRESET +from datetime import datetime, timezone from typing import ( TYPE_CHECKING, Any, @@ -23,6 +22,7 @@ Type, TypeVar, Union, + cast, ) from urllib.parse import quote as _uriquote @@ -31,12 +31,14 @@ from . import __version__, utils from .errors import ( + DiscordException, DiscordServerError, Forbidden, GatewayNotFound, HTTPException, LoginFailure, NotFound, + Unauthorized, ) from .gateway import DiscordClientWebSocketResponse from .utils import MISSING @@ -46,8 +48,6 @@ if TYPE_CHECKING: from types import TracebackType - from typing_extensions import Self - from .enums import InteractionResponseType from .file import File from .message import Attachment @@ -163,8 +163,17 @@ def to_multipart_with_attachments( return to_multipart(payload, files) +def _get_logging_auth(auth: str | None) -> str: + if auth is None: + return "None" + elif len(auth) < 12: # This shouldn't ever occur, but whatever. Better safe than sorry. + return "[redacted]" + else: + return f"{auth[:12]}[redacted]" + + class Route: - BASE: ClassVar[str] = "https://discord.com/api/v10" + BASE: ClassVar[str] = f"https://discord.com/api/v{_API_VERSION}" def __init__(self, method: str, path: str, **parameters: Any) -> None: self.path: str = path @@ -188,25 +197,303 @@ def bucket(self) -> str: return f"{self.channel_id}:{self.guild_id}:{self.path}" -class MaybeUnlock: - def __init__(self, lock: asyncio.Lock) -> None: - self.lock: asyncio.Lock = lock - self._unlock: bool = True +class RateLimitMigrating(DiscordException): + ... + + +class IncorrectBucket(DiscordException): + ... + + +class RateLimit: + """Used to time gate a large batch of requests to only occur X every Y seconds. Used via an async context manager. + + NOT THREAD SAFE. + + Parameters + ---------- + time_offset: :class:`float` + Number in seconds to increase all timers by. Used for lag compensation. + """ + + def __init__(self, time_offset: float = 0.3) -> None: + self.limit: int = 1 + """Maximum amount of requests before requests have to wait for the rate limit to reset.""" + self.remaining: int = 1 + """Remaining amount of requests before requests have to wait for the rate limit to reset.""" + self.reset: datetime | None = None + """Datetime that the bucket roughly will be reset at.""" + self.reset_after: float = 1.0 + """Amount of seconds roughly until the rate limit will be reset.""" + self.bucket: str | None = None + """Name of the bucket, if it has one.""" + + self._time_offset: float = time_offset + """Number in seconds to increase all timers by. Used for lag compensation.""" + self._first_update: bool = True + """If the next update to be ran will be the first.""" + self._reset_remaining_task: asyncio.Task | None = None + """Holds the task object for resetting the remaining count.""" + self._on_reset_event: asyncio.Event = asyncio.Event() + """Used to indicate when the rate limit is ready to be acquired.""" + self._on_reset_event.set() + self._deny: bool = False + """Set to error all acquiring requests with a 404 value error.""" + self._migrating: str | None = None + """When this RateLimit is being deprecated and acquiring requests need to migrate to a different RateLimit, this + variable should be set to the different RateLimit/buckets string name. + """ + + @property + def resetting(self) -> bool: + return self._reset_remaining_task is not None and not self._reset_remaining_task.done() + + async def update(self, response: aiohttp.ClientResponse) -> None: + """Updates the rate limit with information from the response.""" + if response.headers.get("X-RateLimit-Global") == "true": + # The response is intended for the global rate limit, not a regular rate limit. + return + + # Updates the bucket name. The bucket name not existing as fine, as ``None`` is desired for that. + # This is done immediately, so we can error out if we get an update not for this bucket. + x_bucket = response.headers.get("X-RateLimit-Bucket") + + if self.bucket == x_bucket: + pass # Don't need to set it again. + elif self.bucket is None: + self.bucket = x_bucket + else: + raise IncorrectBucket( + f"Update given for bucket {x_bucket}, but this RateLimit is for bucket {self.bucket}!" + ) + + if response.status == 404: + self._deny = True - def __enter__(self) -> Self: - return self + # Updates the limit if it exists. + x_limit = response.headers.get("X-RateLimit-Limit") + self.limit = 1 if x_limit is None else int(x_limit) - def defer(self) -> None: - self._unlock = False + # Updates the remaining left if it exists, being pessimistic. + x_remaining = response.headers.get("X-RateLimit-Remaining") - def __exit__( + if x_remaining is None: + self.remaining = 1 + elif self._first_update: + self.remaining = int(x_remaining) + else: + # If requests come back out of order, it's possible that we could get a wrong amount remaining. + # It's best to be pessimistic and assume it cannot go back up unless the reset task occurs. + self.remaining = ( + int(x_remaining) if int(x_remaining) < self.remaining else self.remaining + ) + + # Updates the datetime of the reset. + x_reset = response.headers.get("X-RateLimit-Reset") + if x_reset is not None: + # self.reset = datetime.utcfromtimestamp(float(x_reset)) + self.reset = datetime.fromtimestamp(float(x_reset), tz=timezone.utc) + + # Updates the reset-after count, being pessimistic. + x_reset_after = response.headers.get("X-RateLimit-Reset-After") + if x_reset_after is not None: + x_reset_after = float(x_reset_after) + self._time_offset + if self.reset_after is None: + self.reset_after = x_reset_after + else: + if self.reset_after < x_reset_after: + _log.debug( + "Bucket %s: Reset after time increased, adapting reset time.", self.bucket + ) + self.reset_after = x_reset_after + self.start_reset_task() + + if not self.resetting: + self.start_reset_task() + + # If for whatever reason we have requests remaining but the reset event isn't set, set it. + if 0 < self.remaining and not self._on_reset_event.is_set(): + _log.debug( + "Bucket %s: Updated with remaining %s, setting reset event.", + self.bucket, + self.remaining, + ) + self._on_reset_event.set() + + # If this is our first update, indicate that all future updates aren't the first. + if self._first_update: + self._first_update = False + + _log.debug( + "Bucket %s: Updated with limit %s, remaining %s, reset %s, and reset_after %s seconds.", + self.bucket, + self.limit, + self.remaining, + self.reset, + self.reset_after, + ) + + def start_reset_task(self) -> None: + """Starts the reset task, non-blocking.""" + if self.resetting: + _log.debug("Bucket %s: Reset task already running, cancelling.", self.bucket) + self._reset_remaining_task.cancel() # pyright: ignore [reportOptionalMemberAccess] + + loop = asyncio.get_running_loop() + _log.debug("Bucket %s: Resetting after %s seconds.", self.bucket, self.reset_after) + self._reset_remaining_task = loop.create_task(self.reset_remaining(self.reset_after)) + + async def reset_remaining(self, time: float) -> None: + """|coro| + Sleeps for the specified amount of time, then resets the remaining request count to the limit. + + Parameters + ---------- + time: :class:`float` + Amount of time to sleep until the request count is reset to the limit. ``time_offset`` is not added to + this number. + """ + await asyncio.sleep(time) + self.remaining = self.limit + self._on_reset_event.set() + _log.debug("Bucket %s: Reset, allowing requests to continue.", self.bucket) + + @property + def migrating(self) -> str | None: + """If not ``None``, this indicates what bucket acquiring requests should migrate to.""" + return self._migrating + + def migrate_to(self, bucket: str) -> None: + """Signals to acquiring requests, both present and future, that they need to migrate to a new bucket.""" + self._migrating = bucket + self.remaining = self.limit + self._on_reset_event.set() + _log.debug( + "Bucket %s: Deprecating, acquiring requests will migrate to a new bucket.", bucket + ) + + async def __aenter__(self) -> None: + await self.acquire() + return None + + async def __aexit__( self, exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - traceback: Optional[TracebackType], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: - if self._unlock: - self.lock.release() + self.release() + + def locked(self) -> bool: + return self.remaining <= 0 + + async def acquire(self) -> bool: + # If no more requests can be made but the event is set, clear it. + if self.remaining <= 0 and self._on_reset_event.is_set(): + _log.debug( + "Bucket %s: Hit the remaining request limit of %s, locking until reset.", + self.bucket, + self.limit, + ) + self._on_reset_event.clear() + if not self.resetting: + self.start_reset_task() + + # Waits in a loop for the event to be set, clearing the event as needed and looping. + while not self._on_reset_event.is_set(): + _log.debug("Bucket %s: Not set yet, waiting for it to be set.", self.bucket) + await self._on_reset_event.wait() + + if self.remaining <= 0 and self._on_reset_event.is_set(): + _log.debug( + "Bucket %s: Hit the remaining limit of %s, locking until reset.", + self.bucket, + self.limit, + ) + self._on_reset_event.clear() + if not self.resetting: + self.start_reset_task() + + if self.migrating: + raise RateLimitMigrating( + f"This RateLimit is deprecated, you need to migrate to bucket {self.migrating}" + ) + elif self._deny: + raise ValueError("This request path 404'd and is now denied.") + + _log.debug("Bucket %s: Continuing with request.", self.bucket) + self.remaining -= 1 + return True + + def release(self) -> None: + # Basically a placeholder, could probably be removed ;) + pass + + +class GlobalRateLimit(RateLimit): + """Represents the global rate limit, and thus has to have slightly modified behavior. + + Still not thread safe. + """ + + async def acquire(self) -> bool: + ret = await super().acquire() + # As updates are little weird, it's best to start the reset task as soon as the first request has acquired. + if not self.resetting: + self.start_reset_task() + + return ret + + async def update(self, response: aiohttp.ClientResponse) -> None: + if response.headers.get("X-RateLimit-Global") != "true": + # The response is intended for the regular rate limit, not a global rate limit. + return + + if response.status == 429: + # Oh dear, we hit the rate limit. + _log.warning("Global rate limit 429 encountered, setting remaining to 0.") + self.remaining = 0 + if response.headers.get("X-RateLimit-Scope") == "global": + data = await response.json() + _log.warning(data) + if (retry_after := data.get("retry_after")) or ( + retry_after := response.headers.get("Retry-After") + ): + _log.debug( + "Got global retry_after, resetting global after %s seconds", retry_after + ) + self.reset_after = float(retry_after) + self._time_offset + if self.resetting: + self._reset_remaining_task.cancel() # pyright: ignore [reportOptionalMemberAccess] + + self.start_reset_task() + + self._on_reset_event.clear() + if not self.resetting: + self.start_reset_task() + + _log.warning("Cleared global ratelimit, waiting for reset.") + + +# class MaybeUnlock: +# def __init__(self, lock: asyncio.Lock) -> None: +# self.lock: asyncio.Lock = lock +# self._unlock: bool = True +# +# def __enter__(self) -> Self: +# return self +# +# def defer(self) -> None: +# self._unlock = False +# +# def __exit__( +# self, +# exc_type: Optional[Type[BaseException]], +# exc: Optional[BaseException], +# traceback: Optional[TracebackType], +# ) -> None: +# if self._unlock: +# self.lock.release() # For some reason, the Discord voice websocket expects this header to be @@ -214,43 +501,185 @@ def __exit__( aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore +# class HTTPClient: +# """Represents an HTTP client sending HTTP requests to the Discord API.""" +# +# def __init__( +# self, +# connector: Optional[aiohttp.BaseConnector] = None, +# *, +# loop: asyncio.AbstractEventLoop, +# proxy: Optional[str] = None, +# proxy_auth: Optional[aiohttp.BasicAuth] = None, +# unsync_clock: bool = True, +# ) -> None: +# self.loop: asyncio.AbstractEventLoop = loop +# self.connector = connector +# self.__session: aiohttp.ClientSession = MISSING # filled in static_login +# self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() +# self._global_over: asyncio.Event = asyncio.Event() +# self._global_over.set() +# self.token: Optional[str] = None +# self.bot_token: bool = False +# self.proxy: Optional[str] = proxy +# self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth +# self.use_clock: bool = not unsync_clock +# +# user_agent = "DiscordBot (https://github.com/DisnakeDev/disnake {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" +# self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) + + class HTTPClient: - """Represents an HTTP client sending HTTP requests to the Discord API.""" + """Represents an HTTP client for sending HTTP requests to and handling the rate limits of the Discord API. + Also, not thread safe. + + Parameters + ---------- + connector + default_max_per_second: :class:`int` + Maximum amount of requests per second per authorization. + Discord by default only allows 50 requests per second, but if your bot has had its maximum increased, then + increase this parameter. + time_offset: :class:`float` + Amount of seconds added to all ratelimit timers for lag compensation. + Due to latency and Discord servers not perfectly time synced, having no offset can cause 429's to occur even + with us following the reported X-RateLimit-Reset-After. + Increasing will protect from erroneous 429s but will slow bucket resets, lowering max theoretical speed. + Decreasing will hasten bucket resets and increase max theoretical speed but may cause 429s. + default_auth: Optional[:class:`str`] + Default string to use in the Authorization header if it's not manually provided. + proxy + proxy_auth + loop + """ def __init__( self, connector: Optional[aiohttp.BaseConnector] = None, *, - loop: asyncio.AbstractEventLoop, + default_max_per_second: int = 50, + time_offset: float = 0.0, + default_auth: Optional[str] = None, proxy: Optional[str] = None, proxy_auth: Optional[aiohttp.BasicAuth] = None, - unsync_clock: bool = True, + loop: Optional[asyncio.AbstractEventLoop] = None, + # dispatch: Callable, ) -> None: - self.loop: asyncio.AbstractEventLoop = loop - self.connector = connector - self.__session: aiohttp.ClientSession = MISSING # filled in static_login - self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() - self._global_over: asyncio.Event = asyncio.Event() - self._global_over.set() - self.token: Optional[str] = None - self.bot_token: bool = False - self.proxy: Optional[str] = proxy - self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth - self.use_clock: bool = not unsync_clock + # TODO: Think about adding ratelimit_multiplier? Would reduce the internal RateLimit.limit by that + # float (such as 0.7) and could allow people to run multiple Disnake bots/processes on the same token while + # avoiding ratelimit issues. Could also help with replit-style scenarios. + self._session: aiohttp.ClientSession = MISSING # Set when performing a request. + self._connector = connector + self._default_max_per_second = default_max_per_second + """Maximum amount of requests per second per authorization.""" + self._time_offset = time_offset + """Amount of seconds added to all ratelimit timers for lag compensation.""" + self._default_auth = None + # For consistency with possible future changes to set_default_auth. + self._set_default_auth(default_auth) + self._proxy = proxy + self._proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth + # loop is truthy by default it seems, so this works. + self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + # self._dispatch = dispatch user_agent = "DiscordBot (https://github.com/DisnakeDev/disnake {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) + self._buckets: dict[str, RateLimit] = {} + """{"Discord bucket name": RateLimit}""" + self._global_rate_limits: dict[str | None, RateLimit] = {} + """{"Auth string": RateLimit}, None for auth-less ratelimit.""" + self._url_rate_limits: dict[tuple[str, str, str | None], RateLimit] = {} + """{("METHOD", "Route.bucket", "auth string"): RateLimit} auth string may be None to indicate auth-less.""" + + @property + def __session(self) -> aiohttp.ClientSession | None: + # TODO: Track down uses of this and remove them. + return self._session + + def _make_global_rate_limit(self, auth: str | None, max_per_second: int) -> GlobalRateLimit: + _log.debug( + "Creating global ratelimit for auth %s with max per second %s.", + _get_logging_auth(auth), + max_per_second, + ) + rate_limit = GlobalRateLimit(time_offset=self._time_offset) + rate_limit.limit = max_per_second + rate_limit.remaining = max_per_second + rate_limit.reset_after = 1 + self._time_offset + rate_limit.bucket = f"Global {_get_logging_auth(auth) if auth else 'Unauthorized'}" + + self._global_rate_limits[auth] = rate_limit + return rate_limit + + def _make_url_rate_limit(self, method: str, route: Route, auth: str | None) -> RateLimit: + _log.debug( + "Making URL rate limit for %s %s %s", method, route.bucket, _get_logging_auth(auth) + ) + ret = RateLimit(time_offset=self._time_offset) + self._url_rate_limits[(method, route.bucket, auth)] = ret + return ret + + def _set_url_rate_limit( + self, method: str, route: Route, auth: str | None, rate_limit: RateLimit + ) -> None: + self._url_rate_limits[(method, route.bucket, auth)] = rate_limit + + def _get_url_rate_limit(self, method: str, route: Route, auth: str | None) -> RateLimit | None: + return self._url_rate_limits.get((method, route.bucket, auth), None) + + def _set_default_auth(self, auth: str | None) -> None: + self._default_auth = auth + + def _make_headers( + self, + original_headers: dict[str, str], + *, + auth: str | None | MISSING = MISSING, + ) -> dict[str, str]: + """Creates a new dictionary of headers, without overwriting values from the given headers. + + Parameters + ---------- + original_headers: :class:`dict`[:class:`str`, :class:`str`] + Headers to make a shallow copy of. + auth: :class:`str` | `None` | `MISSING` + Authorization string to use. Will not auto-format given tokens. For example, a bot token must be provided + as "Bot ". If set to `None`, no authorization header will be added. If left unset, the default + auth string will be used. (if the default is not set, no auth will be used.) + + Returns + ------- + :class:`dict`[:class:`str`, :class:`str`] + Modified headers to use. + """ + ret = original_headers.copy() + + if "Authorization" not in ret: + if auth is None: + pass # We do nothing, this is here to make the logic easier to read. + elif auth is MISSING: + if self._default_auth is not None: + ret["Authorization"] = self._default_auth + else: # auth isn't None or MISSING, so it must be something. + ret["Authorization"] = auth + + if "User-Agent" not in ret and self.user_agent: + ret["User-Agent"] = self.user_agent + + return ret + def recreate(self) -> None: - if self.__session.closed: - self.__session = aiohttp.ClientSession( - connector=self.connector, ws_response_class=DiscordClientWebSocketResponse + if self._session.closed: + self._session = aiohttp.ClientSession( + connector=self._connector, ws_response_class=DiscordClientWebSocketResponse ) async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse: kwargs = { - "proxy_auth": self.proxy_auth, - "proxy": self.proxy, + "proxy_auth": self._proxy_auth, + "proxy": self._proxy, "max_msg_size": 0, "timeout": 30.0, "autoclose": False, @@ -260,7 +689,176 @@ async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebS "compress": compress, } - return await self.__session.ws_connect(url, **kwargs) + return await self._session.ws_connect(url, **kwargs) + + # async def request( + # self, + # route: Route, + # *, + # files: Optional[Sequence[File]] = None, + # form: Optional[Iterable[Dict[str, Any]]] = None, + # **kwargs: Any, + # ) -> Any: + # bucket = route.bucket + # method = route.method + # url = route.url + # + # lock = self._locks.get(bucket) + # if lock is None: + # lock = asyncio.Lock() + # if bucket is not None: + # self._locks[bucket] = lock + # + # # header creation + # headers: Dict[str, str] = { + # "User-Agent": self.user_agent, + # } + # + # if self.token is not None: + # headers["Authorization"] = "Bot " + self.token + # # some checking if it's a JSON request + # if "json" in kwargs: + # headers["Content-Type"] = "application/json" + # kwargs["data"] = utils._to_json(kwargs.pop("json")) + # + # try: + # reason = kwargs.pop("reason") + # except KeyError: + # pass + # else: + # if reason: + # headers["X-Audit-Log-Reason"] = _uriquote(reason, safe="/ ") + # + # kwargs["headers"] = headers + # + # # Proxy support + # if self.proxy is not None: + # kwargs["proxy"] = self.proxy + # if self.proxy_auth is not None: + # kwargs["proxy_auth"] = self.proxy_auth + # + # if not self._global_over.is_set(): + # # wait until the global lock is complete + # await self._global_over.wait() + # + # response: Optional[aiohttp.ClientResponse] = None + # data: Optional[Union[Dict[str, Any], str]] = None + # await lock.acquire() + # with MaybeUnlock(lock) as maybe_lock: + # for tries in range(5): + # if files: + # for f in files: + # f.reset(seek=tries) + # + # if form: + # # NOTE: for `quote_fields`, see https://github.com/aio-libs/aiohttp/issues/4012 + # form_data = aiohttp.FormData(quote_fields=False) + # for p in form: + # # manually escape chars, just in case + # name = re.sub( + # r"[^\x21\x23-\x5b\x5d-\x7e]", lambda m: f"\\{m.group(0)}", p["name"] + # ) + # form_data.add_field( + # name=name, **{k: v for k, v in p.items() if k != "name"} + # ) + # kwargs["data"] = form_data + # + # try: + # async with self.__session.request(method, url, **kwargs) as response: + # _log.debug( + # "%s %s with %s has returned %s", + # method, + # url, + # kwargs.get("data"), + # response.status, + # ) + # + # # even errors have text involved in them so this is safe to call + # data = await json_or_text(response) + # + # # check if we have rate limit header information + # remaining = response.headers.get("X-Ratelimit-Remaining") + # if remaining == "0" and response.status != 429: + # # we've depleted our current bucket + # delta = utils._parse_ratelimit_header( + # response, use_clock=self.use_clock + # ) + # _log.debug( + # "A rate limit bucket has been exhausted (bucket: %s, retry: %s).", + # bucket, + # delta, + # ) + # maybe_lock.defer() + # self.loop.call_later(delta, lock.release) + # + # # the request was successful so just return the text/json + # if 300 > response.status >= 200: + # _log.debug("%s %s has received %s", method, url, data) + # return data + # + # # we are being rate limited + # if response.status == 429: + # if not response.headers.get("Via") or isinstance(data, str): + # # Banned by Cloudflare more than likely. + # raise HTTPException(response, data) + # + # fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' + # + # # sleep a bit + # retry_after: float = data["retry_after"] + # _log.warning(fmt, retry_after, bucket) + # + # # check if it's a global rate limit + # is_global = data.get("global", False) + # if is_global: + # _log.warning( + # "Global rate limit has been hit. Retrying in %.2f seconds.", + # retry_after, + # ) + # self._global_over.clear() + # + # await asyncio.sleep(retry_after) + # _log.debug("Done sleeping for the rate limit. Retrying...") + # + # # release the global lock now that the + # # global rate limit has passed + # if is_global: + # self._global_over.set() + # _log.debug("Global rate limit is now over.") + # + # continue + # + # # we've received a 500, 502, or 504, unconditional retry + # if response.status in {500, 502, 504}: + # await asyncio.sleep(1 + tries * 2) + # continue + # + # # the usual error cases + # if response.status == 403: + # raise Forbidden(response, data) + # elif response.status == 404: + # raise NotFound(response, data) + # elif response.status >= 500: + # raise DiscordServerError(response, data) + # else: + # raise HTTPException(response, data) + # + # # This is handling exceptions from the request + # except OSError as e: + # # Connection reset by peer + # if tries < 4 and e.errno == ECONNRESET: + # await asyncio.sleep(1 + tries * 2) + # continue + # raise + # + # if response is not None: + # # We've run out of retries, raise. + # if response.status >= 500: + # raise DiscordServerError(response, data) + # + # raise HTTPException(response, data) + # + # raise RuntimeError("Unreachable code in HTTP handling") async def request( self, @@ -268,29 +866,43 @@ async def request( *, files: Optional[Sequence[File]] = None, form: Optional[Iterable[Dict[str, Any]]] = None, + auth: Optional[str] = MISSING, + retry_request: bool = True, **kwargs: Any, ) -> Any: - bucket = route.bucket - method = route.method - url = route.url - - lock = self._locks.get(bucket) - if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock - - # header creation - headers: Dict[str, str] = { - "User-Agent": self.user_agent, - } + """|coro| - if self.token is not None: - headers["Authorization"] = "Bot " + self.token - # some checking if it's a JSON request - if "json" in kwargs: - headers["Content-Type"] = "application/json" - kwargs["data"] = utils._to_json(kwargs.pop("json")) + Makes an API request to Discord, handling authorization (if needed), rate limits, and limited error handling. + + Parameters + ---------- + route: :class:`Route` + The Discord Route to make the API request for. + files: Optional[Sequence[:class:`File`]] + pass + form: Optional[Iterable[:class:`dict`[:class:`str`, `Any`]]] + pass + auth: :class:`str` | `None` | `MISSING` + Authorization string to use. Will not auto-format given tokens. For example, a bot token must be provided + as "Bot ". If set to `None`, no authorization will be used. If left unset, the default + auth will be used. (if the default is not set, no auth will be used.) + retry_request: :class:`bool` + If the request should be retried in specific cases. This mainly concerns 500 errors (Discord server issues) + or 429s. (ratelimit issues) + kwargs + This is purposefully undocumented. Behavior of extra kwargs may change in a breaking way, and extra kwargs + may not be allowed in the future. + + Returns + ------- + pass + """ + if not self._session: + self._session = aiohttp.ClientSession( + connector=self._connector, ws_response_class=DiscordClientWebSocketResponse + ) + + headers = self._make_headers(kwargs.pop("headers", {}), auth=auth) try: reason = kwargs.pop("reason") @@ -300,139 +912,265 @@ async def request( if reason: headers["X-Audit-Log-Reason"] = _uriquote(reason, safe="/ ") - kwargs["headers"] = headers - - # Proxy support - if self.proxy is not None: - kwargs["proxy"] = self.proxy - if self.proxy_auth is not None: - kwargs["proxy_auth"] = self.proxy_auth - - if not self._global_over.is_set(): - # wait until the global lock is complete - await self._global_over.wait() - - response: Optional[aiohttp.ClientResponse] = None - data: Optional[Union[Dict[str, Any], str]] = None - await lock.acquire() - with MaybeUnlock(lock) as maybe_lock: - for tries in range(5): - if files: - for f in files: - f.reset(seek=tries) - - if form: - # NOTE: for `quote_fields`, see https://github.com/aio-libs/aiohttp/issues/4012 - form_data = aiohttp.FormData(quote_fields=False) - for p in form: - # manually escape chars, just in case - name = re.sub( - r"[^\x21\x23-\x5b\x5d-\x7e]", lambda m: f"\\{m.group(0)}", p["name"] - ) - form_data.add_field( - name=name, **{k: v for k, v in p.items() if k != "name"} - ) - kwargs["data"] = form_data - - try: - async with self.__session.request(method, url, **kwargs) as response: - _log.debug( - "%s %s with %s has returned %s", - method, - url, - kwargs.get("data"), - response.status, - ) - - # even errors have text involved in them so this is safe to call - data = await json_or_text(response) - - # check if we have rate limit header information - remaining = response.headers.get("X-Ratelimit-Remaining") - if remaining == "0" and response.status != 429: - # we've depleted our current bucket - delta = utils._parse_ratelimit_header( - response, use_clock=self.use_clock - ) + auth = headers.get("Authorization") + + # If a global rate limit for this authorization doesn't exist yet, make it. + if (global_rate_limit := self._global_rate_limits.get(auth)) is None: + global_rate_limit = self._make_global_rate_limit(auth, self._default_max_per_second) + + global_rate_limit = cast(GlobalRateLimit, global_rate_limit) + + # If a rate limit for this url path doesn't exist yet, make it. + if (url_rate_limit := self._get_url_rate_limit(route.method, route, auth)) is None: + url_rate_limit = self._make_url_rate_limit(route.method, route, auth) + + # max_retry_count = 5 + rate_limit_path = ( + route.method, + route.bucket, + _get_logging_auth(auth), + ) # Only use this for logging. + ret: Any | None = None + response: aiohttp.ClientResponse | None = None + + # The loop is to allow migration to a different RateLimit if needed. + # If we hit this loop max_retry_count times, something is wrong. Either we're migrating buckets way + # too much, 429s keep getting hit, or something is internally wrong. + max_try_count = 5 + for retry_count in range(max_try_count): # To prevent infinite loops. + should_retry = False + try: + async with global_rate_limit: + async with url_rate_limit: + # This check is for asyncio.gather()'d requests where the rate limit can change. + if (temp := self._get_url_rate_limit(route.method, route, auth)) not in ( + url_rate_limit, + None, + ): + # temp = cast(RateLimit, temp) _log.debug( - "A rate limit bucket has been exhausted (bucket: %s, retry: %s).", - bucket, - delta, + "Route %s had the rate limit changed, resetting and retrying.", + rate_limit_path, ) - maybe_lock.defer() - self.loop.call_later(delta, lock.release) - - # the request was successful so just return the text/json - if 300 > response.status >= 200: - _log.debug("%s %s has received %s", method, url, data) - return data - - # we are being rate limited - if response.status == 429: - if not response.headers.get("Via") or isinstance(data, str): - # Banned by Cloudflare more than likely. - raise HTTPException(response, data) - - fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' - - # sleep a bit - retry_after: float = data["retry_after"] - _log.warning(fmt, retry_after, bucket) - - # check if it's a global rate limit - is_global = data.get("global", False) - if is_global: - _log.warning( - "Global rate limit has been hit. Retrying in %.2f seconds.", - retry_after, + url_rate_limit = temp + continue + + if files: + for f in files: + f.reset(seek=retry_count) + + if form: + # form_data = aiohttp.FormData(quote_fields=False) + # for params in form: + # form_data.add_field(**params) + # kwargs["data"] = form_data + # NOTE: for `quote_fields`, see https://github.com/aio-libs/aiohttp/issues/4012 + form_data = aiohttp.FormData(quote_fields=False) + for p in form: + # manually escape chars, just in case + name = re.sub( + r"[^\x21\x23-\x5b\x5d-\x7e]", + lambda m: f"\\{m.group(0)}", + p["name"], + ) + form_data.add_field( + name=name, **{k: v for k, v in p.items() if k != "name"} ) - self._global_over.clear() + kwargs["data"] = form_data + + async with self._session.request( + method=route.method, + url=route.url, + headers=headers, + proxy=self._proxy, + proxy_auth=self._proxy_auth, + **kwargs, + ) as response: + _log.debug( + "%s %s with %s has returned %s", + route.method, + route.url, + kwargs.get("data"), + response.status, + ) - await asyncio.sleep(retry_after) - _log.debug("Done sleeping for the rate limit. Retrying...") + await global_rate_limit.update(response) + try: + await url_rate_limit.update(response) + except IncorrectBucket as e: + # This condition can be met when doing asyncio.gather()'d requests. + if ( + temp := self._buckets.get( + # The empty string default makes pyright happy. (hopefully) + response.headers.get("X-RateLimit-Bucket", "") + ) + ) is not None: + _log.debug( + "Route %s was given a different bucket, found it.", + rate_limit_path, + ) + url_rate_limit = temp + self._set_url_rate_limit( + route.method, route, auth, url_rate_limit + ) + await url_rate_limit.update(response) + else: + _log.debug( + "Route %s was given a different bucket, making a new one: %s", + rate_limit_path, + e, + ) + url_rate_limit = self._make_url_rate_limit( + route.method, route, auth + ) + await url_rate_limit.update(response) + + if url_rate_limit.bucket is not None and self._buckets.get( + url_rate_limit.bucket + ) not in (url_rate_limit, None): + # If the current RateLimit bucket name exists, but the stored RateLimit is not the + # current RateLimit, finish up and signal that the current bucket should be migrated + # to the stored one. + _log.debug( + "Route %s with bucket %s already exists, migrating other possible requests to " + "that bucket.", + rate_limit_path, + url_rate_limit.bucket, + ) + correct_rate_limit = self._buckets[url_rate_limit.bucket] + self._set_url_rate_limit( + route.method, route, auth, correct_rate_limit + ) + if correct_rate_limit.bucket: + # Signals to all requests waiting to acquire to migrate. + url_rate_limit.migrate_to(correct_rate_limit.bucket) + else: + raise ValueError( + f"Migrating to bucket {correct_rate_limit.bucket}, but " + f"correct_rate_limit.bucket is falsey. This is likely an internal Disnake " + f"issue and should be reported." + ) + # Update the correct RateLimit object with our findings. + await correct_rate_limit.update(response) + elif url_rate_limit.bucket is not None: + self._buckets[url_rate_limit.bucket] = url_rate_limit + + # even errors have text involved in them so this is safe to call + ret = await json_or_text(response) + + if response.status >= 400: + # >= 500 was considered, but stuff like 501 and 505+ are not good to retry on. + if response.status in {500, 502, 504}: + if retry_request: + _log.info( + "Path %s encountered a Discord server issue, retrying.", + rate_limit_path, + ) + await asyncio.sleep(1 + retry_count * 2) + should_retry = True + else: + _log.info( + "Path %s encountered a Discord server issue.", + rate_limit_path, + ) + raise DiscordServerError(response, ret) + elif response.status == 401: + _log.warning( + "Path %s resulted in error 401, rejected authorization?", + rate_limit_path, + ) + raise Unauthorized(response, ret) + elif response.status == 403: + _log.warning( + "Path %s resulted in error 403, check your permissions?", + rate_limit_path, + ) + raise Forbidden(response, ret) + elif response.status == 404: + _log.warning( + "Path %s resulted in error 404, check your path?", + rate_limit_path, + ) + raise NotFound(response, ret) + elif response.status == 429: + if not response.headers.get("Via") or isinstance(ret, str): + _log.error( + "Path %s resulted in what appears to be a CloudFlare ban, either a " + "large amount of errors recently happened or Disnake has a bug." + ) + # Banned by Cloudflare more than likely. + raise HTTPException(response, ret) + + if retry_request: + # _log.warning( + # "Path %s resulted in error 429, rate limit exceeded. Retrying.", + # rate_limit_path, + # ) + + _log.warning( + "We are being rate limited on path %s. Retrying in %.2f seconds. ", + rate_limit_path, + url_rate_limit.reset_after, + ) + should_retry = True + else: + _log.warning( + "We are being rate limited on path %s.", rate_limit_path + ) + raise HTTPException( + response, ret + ) # TODO: Make actual HTTPRateLimit error? + + elif response.status >= 500: + raise DiscordServerError(response, ret) + else: + raise HTTPException(response, ret) + + # This is handling exceptions from the request + except OSError as e: + # Connection reset by peer + if retry_count < max_try_count - 1 and e.errno in (54, 10054): + await asyncio.sleep(1 + retry_count * 2) + continue + + raise + + except RateLimitMigrating: + if url_rate_limit.migrating is None: + raise ValueError( + "RateLimitMigrating raised, but RateLimit.migrating is None. This is an internal Disnake " + "error and should be reported!" + ) + else: + url_rate_limit = self._buckets.get(url_rate_limit.migrating) + if url_rate_limit is None: + # This means we have an internal issue that we need to fix. + raise ValueError( + "RateLimit said to migrate, but the RateLimit to migrate was not found? This is an " + "internal Disnake error and should be reported!" + ) - # release the global lock now that the - # global rate limit has passed - if is_global: - self._global_over.set() - _log.debug("Global rate limit is now over.") + else: + if not should_retry: + break - continue + if retry_count >= max_try_count - 1: + _log.error( + "Hit retry %s/%s on %s, either something is wrong with Discord or Disnake.", + retry_count + 1, + max_try_count, + rate_limit_path, + ) + if response is not None: + if response.status >= 500: + raise DiscordServerError(response, ret) - # we've received a 500, 502, or 504, unconditional retry - if response.status in {500, 502, 504}: - await asyncio.sleep(1 + tries * 2) - continue + raise HTTPException(response, ret) - # the usual error cases - if response.status == 403: - raise Forbidden(response, data) - elif response.status == 404: - raise NotFound(response, data) - elif response.status >= 500: - raise DiscordServerError(response, data) - else: - raise HTTPException(response, data) - - # This is handling exceptions from the request - except OSError as e: - # Connection reset by peer - if tries < 4 and e.errno == ECONNRESET: - await asyncio.sleep(1 + tries * 2) - continue - raise - - if response is not None: - # We've run out of retries, raise. - if response.status >= 500: - raise DiscordServerError(response, data) - - raise HTTPException(response, data) - - raise RuntimeError("Unreachable code in HTTP handling") + return ret async def get_from_cdn(self, url: str) -> bytes: - async with self.__session.get(url) as resp: + async with self._session.get(url) as resp: if resp.status == 200: return await resp.read() elif resp.status == 404: @@ -445,29 +1183,48 @@ async def get_from_cdn(self, url: str) -> bytes: # state management async def close(self) -> None: - if self.__session: - await self.__session.close() + if self._session: + await self._session.close() # login management - async def static_login(self, token: str) -> user.User: - # Necessary to get aiohttp to stop complaining about session creation - self.__session = aiohttp.ClientSession( - connector=self.connector, ws_response_class=DiscordClientWebSocketResponse - ) - old_token = self.token - self.token = token + async def static_login(self, auth: str) -> user.User: + # # Necessary to get aiohttp to stop complaining about session creation + # self.__session = aiohttp.ClientSession( + # connector=self.connector, ws_response_class=DiscordClientWebSocketResponse + # ) + # old_token = self.token + # self.token = token + + # TODO: Change this? This is literally just fetching /users/@me AKA "Get Current User", and is totally + # usable with OAuth2. This doesn't actually have anything to do with logging in. + self._set_default_auth(auth) try: data: user.User = await self.request(Route("GET", "/users/@me")) except HTTPException as exc: - self.token = old_token + # self.token = old_token if exc.status == 401: raise LoginFailure("Improper token has been passed.") from exc raise return data + async def exchange_access_code( + self, *, client_id: int, client_secret: str, code: str, redirect_uri: str + ): + # TODO: Look into how viable this function is here. + # This doesn't actually have hard ratelimits it seems? Not in the headers at least. The default bucket should + # keep it at 1 every 1 second. + data = { + "client_id": client_id, + "client_secret": client_secret, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + return await self.request(Route("POST", "/oauth2/token"), data=data) + def create_party( self, channel_id: Snowflake, diff --git a/disnake/webhook/async_.py b/disnake/webhook/async_.py index edd9ec3dcd..8892154381 100644 --- a/disnake/webhook/async_.py +++ b/disnake/webhook/async_.py @@ -1217,12 +1217,12 @@ def _as_follower(cls, data, *, channel, user) -> Webhook: state = channel._state session = channel._state.http._HTTPClient__session - return cls(feed, session=session, state=state, token=state.http.token) + return cls(feed, session=session, state=state, token=state.http._default_auth) @classmethod def from_state(cls, data, state) -> Webhook: session = state.http._HTTPClient__session - return cls(data, session=session, state=state, token=state.http.token) + return cls(data, session=session, state=state, token=state.http._default_auth) async def fetch(self, *, prefer_auth: bool = True) -> Webhook: """|coro|