diff --git a/.claude/rules/bitfinex-api.md b/.claude/rules/bitfinex-api.md new file mode 120000 index 00000000..067632c3 --- /dev/null +++ b/.claude/rules/bitfinex-api.md @@ -0,0 +1 @@ +../../../.claude/rules/bitfinex-api.md \ No newline at end of file diff --git a/.claude/rules/blazor-server.md b/.claude/rules/blazor-server.md new file mode 120000 index 00000000..7f5e15e4 --- /dev/null +++ b/.claude/rules/blazor-server.md @@ -0,0 +1 @@ +../../../.claude/rules/blazor-server.md \ No newline at end of file diff --git a/.claude/rules/ci-deployment.md b/.claude/rules/ci-deployment.md new file mode 120000 index 00000000..b6b9a444 --- /dev/null +++ b/.claude/rules/ci-deployment.md @@ -0,0 +1 @@ +../../../.claude/rules/ci-deployment.md \ No newline at end of file diff --git a/.claude/rules/cloudflare-worker.md b/.claude/rules/cloudflare-worker.md new file mode 120000 index 00000000..1191d0ae --- /dev/null +++ b/.claude/rules/cloudflare-worker.md @@ -0,0 +1 @@ +../../../.claude/rules/cloudflare-worker.md \ No newline at end of file diff --git a/.claude/rules/dotnet-auth.md b/.claude/rules/dotnet-auth.md new file mode 120000 index 00000000..906599b2 --- /dev/null +++ b/.claude/rules/dotnet-auth.md @@ -0,0 +1 @@ +../../../.claude/rules/dotnet-auth.md \ No newline at end of file diff --git a/.claude/rules/git-hooks.md b/.claude/rules/git-hooks.md new file mode 120000 index 00000000..9bf13972 --- /dev/null +++ b/.claude/rules/git-hooks.md @@ -0,0 +1 @@ +../../../.claude/rules/git-hooks.md \ No newline at end of file diff --git a/.claude/rules/hono-cloudflare.md b/.claude/rules/hono-cloudflare.md new file mode 120000 index 00000000..756c743f --- /dev/null +++ b/.claude/rules/hono-cloudflare.md @@ -0,0 +1 @@ +../../../.claude/rules/hono-cloudflare.md \ No newline at end of file diff --git a/.claude/rules/microsoft-graph.md b/.claude/rules/microsoft-graph.md new file mode 120000 index 00000000..cdd9bcbd --- /dev/null +++ b/.claude/rules/microsoft-graph.md @@ -0,0 +1 @@ +../../../.claude/rules/microsoft-graph.md \ No newline at end of file diff --git a/.claude/rules/shared-packages.md b/.claude/rules/shared-packages.md new file mode 120000 index 00000000..1ac8e5bd --- /dev/null +++ b/.claude/rules/shared-packages.md @@ -0,0 +1 @@ +../../../.claude/rules/shared-packages.md \ No newline at end of file diff --git a/.claude/rules/testing.md b/.claude/rules/testing.md new file mode 120000 index 00000000..e3c732c2 --- /dev/null +++ b/.claude/rules/testing.md @@ -0,0 +1 @@ +../../../.claude/rules/testing.md \ No newline at end of file diff --git a/.commitlintrc.json b/.commitlintrc.json new file mode 100644 index 00000000..73bb18bf --- /dev/null +++ b/.commitlintrc.json @@ -0,0 +1,13 @@ +{ + "extends": ["@commitlint/config-conventional"], + "rules": { + "type-enum": [ + 2, + "always", + ["feat", "fix", "docs", "style", "refactor", "perf", "test", "chore", "build", "ci", "revert"] + ], + "subject-case": [0, "always"], + "header-max-length": [1, "always", 100] + } +} + diff --git a/.flake8 b/.flake8 deleted file mode 100644 index f7aee5b3..00000000 --- a/.flake8 +++ /dev/null @@ -1,13 +0,0 @@ -[flake8] -max-line-length = 80 -extend-select = B950 -extend-ignore = E203,E501,E701 - -exclude = - __pycache__ - build - dist - venv - -per-file-ignores = - */__init__.py:F401 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..fadb4e3a --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,14 @@ +# Require owner review for CI/CD and security-sensitive files. +# Prevents unauthorized workflow changes (supply chain attack vector). + +# Default owner — all files +* @JCBauza + +# GitHub Actions workflows +.github/workflows/ @JCBauza + +# Dependabot / security config +dependabot.yml @JCBauza + +# This file itself +CODEOWNERS @JCBauza diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000..03c16862 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,21 @@ +# Project Guidelines + +> **Primary reference:** Read the repo's `CLAUDE.md` first — it is the single source of truth. This file provides supplementary Copilot-specific context. + + +## Build and Test + +- Use Python 3.12+. Install with `pip install -e ".[dev]"`. Test with `pytest` (coverage threshold 80%). +- Lint/format with Ruff. Type check with mypy (strict mode). + +## Architecture + +- CloudIngenium fork of official Bitfinex Python API client (v2). +- Published as `bitfinex-api-py` v6.0.0. Core library for BfxLendingBot. +- REST + WebSocket communication with Bitfinex exchange. + +## Conventions + +- Build system: Hatchling. Pre-commit hooks via Ruff. +- Always use `Decimal` for financial values — never `float`. +- WebSocket-first, REST as fallback. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..86288e85 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,24 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + groups: + all-security: + applies-to: security-updates + patterns: ["*"] + all-dependencies: + patterns: ["*"] + runs-on: "ubuntu-latest" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + groups: + all-security: + applies-to: security-updates + patterns: ["*"] + all-actions: + patterns: ["*"] + runs-on: "ubuntu-latest" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 59d76c63..8ba6c4a8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,20 +7,39 @@ on: pull_request: branches: - master + workflow_dispatch: jobs: - build: - runs-on: ubuntu-latest + lint: + runs-on: [self-hosted, Linux, Build] + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Set up Python 3.13 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.13" + - name: Install dependencies + run: pip install -e ".[dev]" + - name: Run ruff linter + run: ruff check bfxapi/ + - name: Run ruff formatter check + run: ruff format --check bfxapi/ + - name: Run mypy type checking + run: mypy bfxapi/ - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + test: + runs-on: [self-hosted, Linux, Build] + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Set up Python 3.13 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: - python-version: '3.8' - - name: Install bitfinex-api-py's dependencies - run: python -m pip install -r dev-requirements.txt - - name: Run pre-commit hooks (see .pre-commit-config.yaml) - uses: pre-commit/action@v3.0.1 - - name: Run mypy to ensure correct type hinting - run: python -m mypy bfxapi + python-version: "3.13" + - name: Install dependencies + run: pip install -e ".[dev]" + - name: Verify import + run: python -c "import bfxapi; print('Import successful')" + - name: Run tests with coverage + run: python -m pytest tests/ -v + - name: Run mypy + run: mypy bfxapi/ diff --git a/.github/workflows/commitlint.yml b/.github/workflows/commitlint.yml new file mode 100644 index 00000000..2f8dc7f1 --- /dev/null +++ b/.github/workflows/commitlint.yml @@ -0,0 +1,20 @@ +name: Commit Lint +on: + pull_request: + branches: [main, master] +permissions: + contents: read +jobs: + commitlint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Lint commits in PR + uses: wagoid/commitlint-github-action@v6 + with: + configFile: ".commitlintrc.json" + failOnWarnings: false + firstParent: false + diff --git a/.github/workflows/copilot-reviewer.yml b/.github/workflows/copilot-reviewer.yml new file mode 100644 index 00000000..745d07b5 --- /dev/null +++ b/.github/workflows/copilot-reviewer.yml @@ -0,0 +1,12 @@ +name: Copilot Reviewer +on: + pull_request: + types: [opened, reopened, ready_for_review] + +permissions: + pull-requests: write + +jobs: + request: + uses: CloudIngenium/.github/.github/workflows/request-copilot-review.yml@main + secrets: inherit diff --git a/.gitignore b/.gitignore index f0c1f9b1..700645c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,21 @@ .venv +venv/ .DS_Store +Thumbs.db .vscode +.idea +.vim +*.swp +*.swo .python-version __pycache__ +# Environment +.env +.env.local +.env.*.local + +# Build / packaging bitfinex_api_py.egg-info bitfinex_api_py.dist-info build/ @@ -11,6 +23,11 @@ dist/ pip-wheel-metadata/ .eggs -.idea +# Test / lint caches +.coverage +htmlcov/ +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ -venv/ +*.log diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 39b73a71..00000000 --- a/.isort.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[settings] -profile = black diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9397acfa..7742b283 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,7 @@ repos: - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.1 hooks: - - id: isort - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.2.0 - hooks: - - id: black - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - - additional_dependencies: [ - flake8-bugbear - ] + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..fd322e50 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,77 @@ +# CLAUDE.md — bitfinex-api-py + +## Project Purpose + +CloudIngenium fork of the official Bitfinex Python API client (v2). Published as `bitfinex-api-py` v6.0.0. This is the core library used by BfxLendingBot for all Bitfinex REST and WebSocket communication. + +## Stack + +- **Language:** Python 3.12+ (targets 3.13) +- **Build system:** Hatchling +- **Lint/Format:** Ruff (line-length 80, pycodestyle + pyflakes + bugbear + pyupgrade + isort) +- **Type checking:** mypy (strict mode) +- **Test:** pytest + pytest-asyncio + pytest-cov (coverage threshold: 80%) +- **Pre-commit:** Ruff linting and formatting hooks + +## Build & Dev + +```bash +pip install -e ".[dev]" # Install in editable mode with dev deps +pytest # Run tests with coverage +mypy bfxapi/ # Type checking (strict) +ruff check bfxapi/ # Lint +ruff format bfxapi/ # Format +pre-commit run --all-files # Run all pre-commit hooks +``` + +## Structure + +``` +bfxapi/ + __init__.py # Package entry — exports Client, host constants + _client.py # Main Client class (REST + WebSocket) + _version.py # Version string + _utils/ # Internal utilities + exceptions.py # Package-level exceptions + rest/ + __init__.py + _bfx_rest_interface.py # REST interface base + _interface/ + interface.py # Low-level HTTP interface + middleware.py # Request middleware (auth, etc.) + _interfaces/ + rest_auth_endpoints.py # Authenticated REST endpoints + rest_public_endpoints.py # Public REST endpoints + exceptions.py # REST-specific exceptions + types/ + __init__.py + dataclasses.py # All Bitfinex data types (Order, Trade, Candle, etc.) + labeler.py # Field labeling utilities + notification.py # Notification type + serializers.py # Type serialization + websocket/ + __init__.py + _connection.py # WebSocket connection management + subscriptions.py # Subscription types and management + exceptions.py # WebSocket-specific exceptions + _client/ + bfx_websocket_client.py # Main WebSocket client + bfx_websocket_bucket.py # Connection bucketing + bfx_websocket_inputs.py # WebSocket input operations + _event_emitter/ + bfx_event_emitter.py # Event system + _handlers/ + auth_events_handler.py # Authenticated event handling + public_channels_handler.py # Public channel handling +tests/ # pytest test suites +examples/ # Usage examples (REST + WebSocket) +``` + +## Conventions + +- Import types from `bfxapi.types.dataclasses` (never `bfxapi.models` — removed in v4) +- Use `Decimal` for monetary values, never `float` +- WebSocket preferred over REST for real-time data (no rate limits) +- REST rate limit: 90 req/5min on private endpoints +- Never hardcode API keys — use environment variables +- asyncio-based WebSocket client; use `bfx.wss.run()` or `await bfx.wss.start()` diff --git a/README.md b/README.md index 9a0604aa..b8c2ed66 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # bitfinex-api-py [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/bitfinex-api-py)](https://pypi.org/project/bitfinex-api-py/) -[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -![GitHub Action](https://github.com/bitfinexcom/bitfinex-api-py/actions/workflows/build.yml/badge.svg) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +![GitHub Action](https://github.com/JCBauza/bitfinex-api-py/actions/workflows/build.yml/badge.svg) -Official implementation of the [Bitfinex APIs (V2)](https://docs.bitfinex.com/docs) for `Python 3.8+`. +Official implementation of the [Bitfinex APIs (V2)](https://docs.bitfinex.com/docs) for `Python 3.12+`. ### Features @@ -302,7 +302,7 @@ All contributions are welcome! :D A guide on how to install and set up `bitfinex-api-py`'s source code can be found [here](#installation-and-setup). \ Before opening any pull requests, please have a look at [Before Opening a PR](#before-opening-a-pr). \ -Contributors must uphold the [Contributor Covenant code of conduct](https://github.com/bitfinexcom/bitfinex-api-py/blob/master/CODE_OF_CONDUCT.md). +Contributors must uphold the [Contributor Covenant code of conduct](https://github.com/JCBauza/bitfinex-api-py/blob/master/CODE_OF_CONDUCT.md). ### Index @@ -316,41 +316,36 @@ Contributors must uphold the [Contributor Covenant code of conduct](https://gith ## Installation and setup -A brief guide on how to install and set up the project in your Python 3.8+ environment. +A brief guide on how to install and set up the project in your Python 3.12+ environment. ### Cloning the repository ```console -git clone https://github.com/bitfinexcom/bitfinex-api-py.git +git clone https://github.com/JCBauza/bitfinex-api-py.git ``` ### Installing the dependencies ```console -python3 -m pip install -r dev-requirements.txt +python3 -m pip install -e ".[dev]" ``` -Make sure to install `dev-requirements.txt` (and not `requirements.txt`!). \ -`dev-requirements.txt` will install all dependencies in `requirements.txt` plus any development dependency. \ -dev-requirements includes [mypy](https://github.com/python/mypy), [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort), [flake8](https://github.com/PyCQA/flake8), and [pre-commit](https://github.com/pre-commit/pre-commit) (more on these tools in later chapters). +This installs the package in editable mode along with all development dependencies, including [mypy](https://github.com/python/mypy), [ruff](https://github.com/astral-sh/ruff), [pytest](https://github.com/pytest-dev/pytest), and [pre-commit](https://github.com/pre-commit/pre-commit). -All done, your Python 3.8+ environment should now be able to run `bitfinex-api-py`'s source code. +All done, your Python 3.12+ environment should now be able to run `bitfinex-api-py`'s source code. ### Set up the pre-commit hooks (optional) **Do not skip this paragraph if you intend to contribute to the project.** -This repository includes a pre-commit configuration file that defines the following hooks: -1. [isort](https://github.com/PyCQA/isort) -2. [black](https://github.com/psf/black) -3. [flake8](https://github.com/PyCQA/flake8) +This repository includes a pre-commit configuration file that runs [ruff](https://github.com/astral-sh/ruff) for both linting and formatting. To set up pre-commit use: ```console python3 -m pre-commit install ``` -These will ensure that isort, black and flake8 are run on each git commit. +This will ensure that ruff is run on each git commit. [Visit this page to learn more about git hooks and pre-commit.](https://pre-commit.com/#introduction) @@ -367,7 +362,7 @@ python3 -m pre-commit run --all-files Wheter you're submitting a bug fix, a new feature or a documentation change, you should first discuss it in an issue. -You must be able to check off all tasks listed in [PULL_REQUEST_TEMPLATE](https://raw.githubusercontent.com/bitfinexcom/bitfinex-api-py/master/.github/PULL_REQUEST_TEMPLATE.md) before opening a pull request. +You must be able to check off all tasks listed in [PULL_REQUEST_TEMPLATE](https://raw.githubusercontent.com/JCBauza/bitfinex-api-py/master/.github/PULL_REQUEST_TEMPLATE.md) before opening a pull request. ### Tip diff --git a/bfxapi/_client.py b/bfxapi/_client.py index f8fc67c7..962ecdd0 100644 --- a/bfxapi/_client.py +++ b/bfxapi/_client.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING from bfxapi._utils.logging import ColorLogger from bfxapi.exceptions import IncompleteCredentialError from bfxapi.rest import BfxRestInterface +from bfxapi.types.labeler import set_decimal_mode from bfxapi.websocket import BfxWebSocketClient if TYPE_CHECKING: @@ -18,16 +19,18 @@ class Client: def __init__( self, - api_key: Optional[str] = None, - api_secret: Optional[str] = None, + api_key: str | None = None, + api_secret: str | None = None, *, rest_host: str = REST_HOST, wss_host: str = WSS_HOST, - filters: Optional[List[str]] = None, - timeout: Optional[int] = 60 * 15, - log_filename: Optional[str] = None, + filters: list[str] | None = None, + timeout: int | None = 60 * 15, + log_filename: str | None = None, + decimal_mode: bool = False, ) -> None: - credentials: Optional["_Credentials"] = None + set_decimal_mode(decimal_mode) + credentials: _Credentials | None = None if api_key and api_secret: credentials = { diff --git a/bfxapi/_utils/json_decoder.py b/bfxapi/_utils/json_decoder.py index 3597de90..89f3469d 100644 --- a/bfxapi/_utils/json_decoder.py +++ b/bfxapi/_utils/json_decoder.py @@ -1,16 +1,20 @@ import json import re -from typing import Any, Dict +from typing import Any def _to_snake_case(string: str) -> str: return re.sub(r"(? Any: +def _object_hook(data: dict[str, Any]) -> Any: return {_to_snake_case(key): value for key, value in data.items()} class JSONDecoder(json.JSONDecoder): def __init__(self, *args: Any, **kwargs: Any) -> None: + # requests uses simplejson as `complexjson` when installed; simplejson + # passes an obsolete `encoding` kwarg that stdlib json.JSONDecoder + # rejects on Python 3.9+ (confirmed failure on 3.12). + kwargs.pop("encoding", None) super().__init__(*args, **kwargs, object_hook=_object_hook) diff --git a/bfxapi/_utils/json_encoder.py b/bfxapi/_utils/json_encoder.py index 0d0d9e35..e316bc1e 100644 --- a/bfxapi/_utils/json_encoder.py +++ b/bfxapi/_utils/json_encoder.py @@ -1,16 +1,25 @@ import json from decimal import Decimal -from typing import Any, Dict, List, Union - -_ExtJSON = Union[ - Dict[str, "_ExtJSON"], List["_ExtJSON"], bool, int, float, str, Decimal, None -] - -_StrictJSON = Union[Dict[str, "_StrictJSON"], List["_StrictJSON"], int, str, None] - - -def _clear(dictionary: Dict[str, Any]) -> Dict[str, Any]: - return {key: value for key, value in dictionary.items() if value is not None} +from typing import Any + +_ExtJSON = ( + dict[str, "_ExtJSON"] + | list["_ExtJSON"] + | bool + | int + | float + | str + | Decimal + | None +) + +_StrictJSON = dict[str, "_StrictJSON"] | list["_StrictJSON"] | int | str | None + + +def _clear(dictionary: dict[str, Any]) -> dict[str, Any]: + return { + key: value for key, value in dictionary.items() if value is not None + } def _adapter(data: _ExtJSON) -> _StrictJSON: diff --git a/bfxapi/_utils/logging.py b/bfxapi/_utils/logging.py index 9eca09c4..2979668e 100644 --- a/bfxapi/_utils/logging.py +++ b/bfxapi/_utils/logging.py @@ -1,7 +1,7 @@ import sys from copy import copy from logging import FileHandler, Formatter, Logger, LogRecord, StreamHandler -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: _Level = Literal["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] @@ -40,7 +40,7 @@ def format(self, record: LogRecord) -> str: return super().format(_record) - def formatTime(self, record: LogRecord, datefmt: Optional[str] = None) -> str: + def formatTime(self, record: LogRecord, datefmt: str | None = None) -> str: return _GREEN + super().formatTime(record, datefmt) + _NC @staticmethod diff --git a/bfxapi/_version.py b/bfxapi/_version.py index e94f36fe..79a961b4 100644 --- a/bfxapi/_version.py +++ b/bfxapi/_version.py @@ -1 +1 @@ -__version__ = "3.0.5" +__version__ = "6.0.1" diff --git a/bfxapi/rest/__init__.py b/bfxapi/rest/__init__.py index 850f4727..cd0d67a3 100644 --- a/bfxapi/rest/__init__.py +++ b/bfxapi/rest/__init__.py @@ -1 +1,3 @@ from ._bfx_rest_interface import BfxRestInterface + +__all__ = ["BfxRestInterface"] diff --git a/bfxapi/rest/_bfx_rest_interface.py b/bfxapi/rest/_bfx_rest_interface.py index 94dec41d..2b6ca584 100644 --- a/bfxapi/rest/_bfx_rest_interface.py +++ b/bfxapi/rest/_bfx_rest_interface.py @@ -1,12 +1,21 @@ -from typing import Optional - +from bfxapi.rest._interface.middleware import RateLimitInfo from bfxapi.rest._interfaces import RestAuthEndpoints, RestPublicEndpoints class BfxRestInterface: def __init__( - self, host: str, api_key: Optional[str] = None, api_secret: Optional[str] = None + self, + host: str, + api_key: str | None = None, + api_secret: str | None = None, ): - self.auth = RestAuthEndpoints(host=host, api_key=api_key, api_secret=api_secret) + self.auth = RestAuthEndpoints( + host=host, api_key=api_key, api_secret=api_secret + ) self.public = RestPublicEndpoints(host=host) + + @property + def last_rate_limit(self) -> RateLimitInfo: + """Rate limit info from the most recent REST call (auth or public).""" + return self.auth._m.last_rate_limit diff --git a/bfxapi/rest/_interface/__init__.py b/bfxapi/rest/_interface/__init__.py index a45df51f..49e1ea6a 100644 --- a/bfxapi/rest/_interface/__init__.py +++ b/bfxapi/rest/_interface/__init__.py @@ -1 +1,3 @@ from .interface import Interface + +__all__ = ["Interface"] diff --git a/bfxapi/rest/_interface/interface.py b/bfxapi/rest/_interface/interface.py index 2975eb48..904cdec7 100644 --- a/bfxapi/rest/_interface/interface.py +++ b/bfxapi/rest/_interface/interface.py @@ -1,10 +1,11 @@ -from typing import Optional - from .middleware import Middleware class Interface: def __init__( - self, host: str, api_key: Optional[str] = None, api_secret: Optional[str] = None + self, + host: str, + api_key: str | None = None, + api_secret: str | None = None, ): self._m = Middleware(host, api_key, api_secret) diff --git a/bfxapi/rest/_interface/middleware.py b/bfxapi/rest/_interface/middleware.py index 92967c17..4d85f350 100644 --- a/bfxapi/rest/_interface/middleware.py +++ b/bfxapi/rest/_interface/middleware.py @@ -1,24 +1,70 @@ import hashlib import hmac import json +import threading +import time +from dataclasses import dataclass from datetime import datetime from enum import IntEnum -from typing import TYPE_CHECKING, Any, List, NoReturn, Optional +from typing import TYPE_CHECKING, Any, NoReturn + +_NONCE_LOCK = threading.Lock() +_NONCE_LAST = 0 + + +def _next_nonce() -> str: + """Monotonic microsecond nonce. Guards against duplicates when two calls + hit the same microsecond tick under concurrent use (H8 fix).""" + global _NONCE_LAST + with _NONCE_LOCK: + candidate = time.time_ns() // 1_000 + if candidate <= _NONCE_LAST: + candidate = _NONCE_LAST + 1 + _NONCE_LAST = candidate + return str(candidate) import requests from bfxapi._utils.json_decoder import JSONDecoder from bfxapi._utils.json_encoder import JSONEncoder from bfxapi.exceptions import InvalidCredentialError -from bfxapi.rest.exceptions import GenericError, RequestParameterError +from bfxapi.rest.exceptions import ( + GenericError, + InsufficientFundsError, + NetworkError, + RateLimitError, + RequestParameterError, +) if TYPE_CHECKING: from requests.sessions import _Params +@dataclass +class RateLimitInfo: + """Rate limit information from the last API response.""" + + remaining: int | None = None + limit: int | None = None + reset: int | None = None + + @classmethod + def from_headers(cls, headers: dict[str, str]) -> "RateLimitInfo": + def _int(key: str) -> int | None: + val = headers.get(key) + return int(val) if val is not None else None + + return cls( + remaining=_int("x-ratelimit-remaining"), + limit=_int("x-ratelimit-limit"), + reset=_int("x-ratelimit-reset"), + ) + + class _Error(IntEnum): ERR_UNK = 10000 ERR_GENERIC = 10001 + ERR_RATE_LIMIT = 10010 ERR_PARAMS = 10020 ERR_AUTH_FAIL = 10100 @@ -27,7 +73,10 @@ class Middleware: __TIMEOUT = 30 def __init__( - self, host: str, api_key: Optional[str] = None, api_secret: Optional[str] = None + self, + host: str, + api_key: str | None = None, + api_secret: str | None = None, ): self.__host = host @@ -35,35 +84,40 @@ def __init__( self.__api_secret = api_secret - def get(self, endpoint: str, params: Optional["_Params"] = None) -> Any: + self.last_rate_limit: RateLimitInfo = RateLimitInfo() + + def get(self, endpoint: str, params: "_Params | None" = None) -> Any: headers = {"Accept": "application/json"} if self.__api_key and self.__api_secret: headers = {**headers, **self.__get_authentication_headers(endpoint)} - request = requests.get( - url=f"{self.__host}/{endpoint}", - params=params, - headers=headers, - timeout=Middleware.__TIMEOUT, - ) - - data = request.json(cls=JSONDecoder) - - if isinstance(data, list) and len(data) > 0 and data[0] == "error": - self.__handle_error(data) + try: + response = requests.get( + url=f"{self.__host}/{endpoint}", + params=params, + headers=headers, + timeout=Middleware.__TIMEOUT, + ) + except requests.ConnectionError as e: + raise NetworkError(f"Connection error: {e}") from e + except requests.Timeout as e: + raise NetworkError(f"Request timeout: {e}") from e - return data + return self.__process_response(response) def post( self, endpoint: str, - body: Optional[Any] = None, - params: Optional["_Params"] = None, + body: Any | None = None, + params: "_Params | None" = None, ) -> Any: _body = body and json.dumps(body, cls=JSONEncoder) or None - headers = {"Accept": "application/json", "Content-Type": "application/json"} + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } if self.__api_key and self.__api_secret: headers = { @@ -71,47 +125,94 @@ def post( **self.__get_authentication_headers(endpoint, _body), } - request = requests.post( - url=f"{self.__host}/{endpoint}", - data=_body, - params=params, - headers=headers, - timeout=Middleware.__TIMEOUT, + try: + response = requests.post( + url=f"{self.__host}/{endpoint}", + data=_body, + params=params, + headers=headers, + timeout=Middleware.__TIMEOUT, + ) + except requests.ConnectionError as e: + raise NetworkError(f"Connection error: {e}") from e + except requests.Timeout as e: + raise NetworkError(f"Request timeout: {e}") from e + + return self.__process_response(response) + + def __process_response(self, response: requests.Response) -> Any: + self.last_rate_limit = RateLimitInfo.from_headers( + dict(response.headers) ) - data = request.json(cls=JSONDecoder) + if response.status_code == 429: + reset = self.last_rate_limit.reset + retry_ms = ( + (reset - int(datetime.now().timestamp())) * 1000 + if reset + else 60_000 + ) + raise RateLimitError( + "Rate limit exceeded (HTTP 429)", + retry_after_ms=max(retry_ms, 1000), + ) + + data = response.json(cls=JSONDecoder) if isinstance(data, list) and len(data) > 0 and data[0] == "error": self.__handle_error(data) return data - def __handle_error(self, error: List[Any]) -> NoReturn: - if error[1] == _Error.ERR_PARAMS: + def __handle_error(self, error: list[Any]) -> NoReturn: + code = error[1] + message = error[2] if len(error) > 2 else str(error) + + if code == _Error.ERR_RATE_LIMIT: + reset = self.last_rate_limit.reset + retry_ms = ( + (reset - int(datetime.now().timestamp())) * 1000 + if reset + else 60_000 + ) + raise RateLimitError( + f"Rate limit exceeded: <{message}>", + retry_after_ms=max(retry_ms, 1000), + ) + + if code == _Error.ERR_PARAMS: raise RequestParameterError( "The request was rejected with the following parameter " - f"error: <{error[2]}>." + f"error: <{message}>." ) - if error[1] == _Error.ERR_AUTH_FAIL: + if code == _Error.ERR_AUTH_FAIL: raise InvalidCredentialError( "Can't authenticate with given API-KEY and API-SECRET." ) - if not error[1] or error[1] == _Error.ERR_UNK or error[1] == _Error.ERR_GENERIC: + # Insufficient funds — check both error code and message + if code == _Error.ERR_GENERIC and isinstance(message, str): + msg_lower = message.lower() + if "insufficient" in msg_lower or "not enough" in msg_lower: + raise InsufficientFundsError(f"Insufficient funds: <{message}>") + + if not code or code == _Error.ERR_UNK or code == _Error.ERR_GENERIC: raise GenericError( "The request was rejected with the following generic " - f"error: <{error[2]}>." + f"error: <{message}>." ) raise RuntimeError( f"The request was rejected with an unexpected error: <{error}>." ) - def __get_authentication_headers(self, endpoint: str, data: Optional[str] = None): + def __get_authentication_headers( + self, endpoint: str, data: str | None = None + ) -> dict[str, str]: assert self.__api_key and self.__api_secret - nonce = str(round(datetime.now().timestamp() * 1_000_000)) + nonce = _next_nonce() if not data: message = f"/api/v2/{endpoint}{nonce}" diff --git a/bfxapi/rest/_interfaces/__init__.py b/bfxapi/rest/_interfaces/__init__.py index 34e1dde0..aee45fed 100644 --- a/bfxapi/rest/_interfaces/__init__.py +++ b/bfxapi/rest/_interfaces/__init__.py @@ -1,2 +1,4 @@ from .rest_auth_endpoints import RestAuthEndpoints from .rest_public_endpoints import RestPublicEndpoints + +__all__ = ["RestAuthEndpoints", "RestPublicEndpoints"] diff --git a/bfxapi/rest/_interfaces/rest_auth_endpoints.py b/bfxapi/rest/_interfaces/rest_auth_endpoints.py index ad3b8066..e79b3813 100644 --- a/bfxapi/rest/_interfaces/rest_auth_endpoints.py +++ b/bfxapi/rest/_interfaces/rest_auth_endpoints.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal from bfxapi.rest._interface import Interface from bfxapi.types import ( @@ -36,14 +36,14 @@ Withdrawal, serializers, ) -from bfxapi.types.serializers import _Notification +from bfxapi.types.notification import _Notification class RestAuthEndpoints(Interface): def get_user_info(self) -> UserInfo: return serializers.UserInfo.parse(*self._m.post("auth/r/info/user")) - def get_login_history(self) -> List[LoginHistory]: + def get_login_history(self) -> list[LoginHistory]: return [ serializers.LoginHistory.parse(*sub_data) for sub_data in self._m.post("auth/r/logins/hist") @@ -54,25 +54,31 @@ def get_balance_available_for_orders_or_offers( symbol: str, type: str, *, - dir: Optional[int] = None, - rate: Optional[str] = None, - lev: Optional[str] = None, + dir: int | None = None, + rate: str | None = None, + lev: str | None = None, ) -> BalanceAvailable: - body = {"symbol": symbol, "type": type, "dir": dir, "rate": rate, "lev": lev} + body = { + "symbol": symbol, + "type": type, + "dir": dir, + "rate": rate, + "lev": lev, + } return serializers.BalanceAvailable.parse( *self._m.post("auth/calc/order/avail", body=body) ) - def get_wallets(self) -> List[Wallet]: + def get_wallets(self) -> list[Wallet]: return [ serializers.Wallet.parse(*sub_data) for sub_data in self._m.post("auth/r/wallets") ] def get_orders( - self, *, symbol: Optional[str] = None, ids: Optional[List[str]] = None - ) -> List[Order]: + self, *, symbol: str | None = None, ids: list[str] | None = None + ) -> list[Order]: if symbol is None: endpoint = "auth/r/orders" else: @@ -87,18 +93,18 @@ def submit_order( self, type: str, symbol: str, - amount: Union[str, float, Decimal], - price: Union[str, float, Decimal], + amount: str | float | Decimal, + price: str | float | Decimal, *, - lev: Optional[int] = None, - price_trailing: Optional[Union[str, float, Decimal]] = None, - price_aux_limit: Optional[Union[str, float, Decimal]] = None, - price_oco_stop: Optional[Union[str, float, Decimal]] = None, - gid: Optional[int] = None, - cid: Optional[int] = None, - flags: Optional[int] = None, - tif: Optional[str] = None, - meta: Optional[Dict[str, Any]] = None, + lev: int | None = None, + price_trailing: str | float | Decimal | None = None, + price_aux_limit: str | float | Decimal | None = None, + price_oco_stop: str | float | Decimal | None = None, + gid: int | None = None, + cid: int | None = None, + flags: int | None = None, + tif: str | None = None, + meta: dict[str, Any] | None = None, ) -> Notification[Order]: body = { "type": type, @@ -124,17 +130,17 @@ def update_order( self, id: int, *, - amount: Optional[Union[str, float, Decimal]] = None, - price: Optional[Union[str, float, Decimal]] = None, - cid: Optional[int] = None, - cid_date: Optional[str] = None, - gid: Optional[int] = None, - flags: Optional[int] = None, - lev: Optional[int] = None, - delta: Optional[Union[str, float, Decimal]] = None, - price_aux_limit: Optional[Union[str, float, Decimal]] = None, - price_trailing: Optional[Union[str, float, Decimal]] = None, - tif: Optional[str] = None, + amount: str | float | Decimal | None = None, + price: str | float | Decimal | None = None, + cid: int | None = None, + cid_date: str | None = None, + gid: int | None = None, + flags: int | None = None, + lev: int | None = None, + delta: str | float | Decimal | None = None, + price_aux_limit: str | float | Decimal | None = None, + price_trailing: str | float | Decimal | None = None, + tif: str | None = None, ) -> Notification[Order]: body = { "id": id, @@ -158,39 +164,40 @@ def update_order( def cancel_order( self, *, - id: Optional[int] = None, - cid: Optional[int] = None, - cid_date: Optional[str] = None, + id: int | None = None, + cid: int | None = None, + cid_date: str | None = None, ) -> Notification[Order]: return _Notification[Order](serializers.Order).parse( *self._m.post( - "auth/w/order/cancel", body={"id": id, "cid": cid, "cid_date": cid_date} + "auth/w/order/cancel", + body={"id": id, "cid": cid, "cid_date": cid_date}, ) ) def cancel_order_multi( self, *, - id: Optional[List[int]] = None, - cid: Optional[List[Tuple[int, str]]] = None, - gid: Optional[List[int]] = None, - all: Optional[bool] = None, - ) -> Notification[List[Order]]: + id: list[int] | None = None, + cid: list[tuple[int, str]] | None = None, + gid: list[int] | None = None, + all: bool | None = None, + ) -> Notification[list[Order]]: body = {"id": id, "cid": cid, "gid": gid, "all": all} - return _Notification[List[Order]](serializers.Order, is_iterable=True).parse( - *self._m.post("auth/w/order/cancel/multi", body=body) - ) + return _Notification[list[Order]]( + serializers.Order, is_iterable=True + ).parse(*self._m.post("auth/w/order/cancel/multi", body=body)) def get_orders_history( self, *, - symbol: Optional[str] = None, - ids: Optional[List[int]] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Order]: + symbol: str | None = None, + ids: list[int] | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Order]: if symbol is None: endpoint = "auth/r/orders/hist" else: @@ -203,7 +210,7 @@ def get_orders_history( for sub_data in self._m.post(endpoint, body=body) ] - def get_order_trades(self, symbol: str, id: int) -> List[OrderTrade]: + def get_order_trades(self, symbol: str, id: int) -> list[OrderTrade]: return [ serializers.OrderTrade.parse(*sub_data) for sub_data in self._m.post(f"auth/r/order/{symbol}:{id}/trades") @@ -212,12 +219,12 @@ def get_order_trades(self, symbol: str, id: int) -> List[OrderTrade]: def get_trades_history( self, *, - symbol: Optional[str] = None, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Trade]: + symbol: str | None = None, + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Trade]: if symbol is None: endpoint = "auth/r/trades/hist" else: @@ -232,19 +239,24 @@ def get_trades_history( def get_ledgers( self, - currency: Optional[str] = None, + currency: str | None = None, *, - category: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Ledger]: + category: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Ledger]: if currency is None: endpoint = "auth/r/ledgers/hist" else: endpoint = f"auth/r/ledgers/{currency}/hist" - body = {"category": category, "start": start, "end": end, "limit": limit} + body = { + "category": category, + "start": start, + "end": end, + "limit": limit, + } return [ serializers.Ledger.parse(*sub_data) @@ -261,36 +273,41 @@ def get_symbol_margin_info(self, symbol: str) -> SymbolMarginInfo: *self._m.post(f"auth/r/info/margin/{symbol}") ) - def get_all_symbols_margin_info(self) -> List[SymbolMarginInfo]: + def get_all_symbols_margin_info(self) -> list[SymbolMarginInfo]: return [ serializers.SymbolMarginInfo.parse(*sub_data) for sub_data in self._m.post("auth/r/info/margin/sym_all") ] - def get_positions(self) -> List[Position]: + def get_positions(self) -> list[Position]: return [ serializers.Position.parse(*sub_data) for sub_data in self._m.post("auth/r/positions") ] def claim_position( - self, id: int, *, amount: Optional[Union[str, float, Decimal]] = None + self, id: int, *, amount: str | float | Decimal | None = None ) -> Notification[PositionClaim]: return _Notification[PositionClaim](serializers.PositionClaim).parse( - *self._m.post("auth/w/position/claim", body={"id": id, "amount": amount}) + *self._m.post( + "auth/w/position/claim", body={"id": id, "amount": amount} + ) ) def increase_position( - self, symbol: str, amount: Union[str, float, Decimal] + self, symbol: str, amount: str | float | Decimal ) -> Notification[PositionIncrease]: - return _Notification[PositionIncrease](serializers.PositionIncrease).parse( + return _Notification[PositionIncrease]( + serializers.PositionIncrease + ).parse( *self._m.post( - "auth/w/position/increase", body={"symbol": symbol, "amount": amount} + "auth/w/position/increase", + body={"symbol": symbol, "amount": amount}, ) ) def get_increase_position_info( - self, symbol: str, amount: Union[str, float, Decimal] + self, symbol: str, amount: str | float | Decimal ) -> PositionIncreaseInfo: return serializers.PositionIncreaseInfo.parse( *self._m.post( @@ -302,10 +319,10 @@ def get_increase_position_info( def get_positions_history( self, *, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[PositionHistory]: + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[PositionHistory]: return [ serializers.PositionHistory.parse(*sub_data) for sub_data in self._m.post( @@ -317,10 +334,10 @@ def get_positions_history( def get_positions_snapshot( self, *, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[PositionSnapshot]: + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[PositionSnapshot]: return [ serializers.PositionSnapshot.parse(*sub_data) for sub_data in self._m.post( @@ -332,11 +349,11 @@ def get_positions_snapshot( def get_positions_audit( self, *, - ids: Optional[List[int]] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[PositionAudit]: + ids: list[int] | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[PositionAudit]: body = {"ids": ids, "start": start, "end": end, "limit": limit} return [ @@ -345,7 +362,7 @@ def get_positions_audit( ] def set_derivative_position_collateral( - self, symbol: str, collateral: Union[str, float, Decimal] + self, symbol: str, collateral: str | float | Decimal ) -> DerivativePositionCollateral: return serializers.DerivativePositionCollateral.parse( *( @@ -360,10 +377,14 @@ def get_derivative_position_collateral_limits( self, symbol: str ) -> DerivativePositionCollateralLimits: return serializers.DerivativePositionCollateralLimits.parse( - *self._m.post("auth/calc/deriv/collateral/limit", body={"symbol": symbol}) + *self._m.post( + "auth/calc/deriv/collateral/limit", body={"symbol": symbol} + ) ) - def get_funding_offers(self, *, symbol: Optional[str] = None) -> List[FundingOffer]: + def get_funding_offers( + self, *, symbol: str | None = None + ) -> list[FundingOffer]: if symbol is None: endpoint = "auth/r/funding/offers" else: @@ -378,11 +399,11 @@ def submit_funding_offer( self, type: str, symbol: str, - amount: Union[str, float, Decimal], - rate: Union[str, float, Decimal], + amount: str | float | Decimal, + rate: str | float | Decimal, period: int, *, - flags: Optional[int] = None, + flags: int | None = None, ) -> Notification[FundingOffer]: body = { "type": type, @@ -402,7 +423,9 @@ def cancel_funding_offer(self, id: int) -> Notification[FundingOffer]: *self._m.post("auth/w/funding/offer/cancel", body={"id": id}) ) - def cancel_all_funding_offers(self, currency: str) -> Notification[Literal[None]]: + def cancel_all_funding_offers( + self, currency: str + ) -> Notification[Literal[None]]: return _Notification[Literal[None]](None).parse( *self._m.post( "auth/w/funding/offer/cancel/all", body={"currency": currency} @@ -419,9 +442,9 @@ def toggle_auto_renew( status: bool, currency: str, *, - amount: Optional[str] = None, - rate: Optional[int] = None, - period: Optional[int] = None, + amount: str | None = None, + rate: int | None = None, + period: int | None = None, ) -> Notification[FundingAutoRenew]: body = { "status": status, @@ -431,16 +454,16 @@ def toggle_auto_renew( "period": period, } - return _Notification[FundingAutoRenew](serializers.FundingAutoRenew).parse( - *self._m.post("auth/w/funding/auto", body=body) - ) + return _Notification[FundingAutoRenew]( + serializers.FundingAutoRenew + ).parse(*self._m.post("auth/w/funding/auto", body=body)) def toggle_keep_funding( self, type: Literal["credit", "loan"], *, - ids: Optional[List[int]] = None, - changes: Optional[Dict[int, Literal[1, 2]]] = None, + ids: list[int] | None = None, + changes: dict[int, Literal[1, 2]] | None = None, ) -> Notification[Literal[None]]: return _Notification[Literal[None]](None).parse( *self._m.post( @@ -452,11 +475,11 @@ def toggle_keep_funding( def get_funding_offers_history( self, *, - symbol: Optional[str] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[FundingOffer]: + symbol: str | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[FundingOffer]: if symbol is None: endpoint = "auth/r/funding/offers/hist" else: @@ -469,7 +492,9 @@ def get_funding_offers_history( ) ] - def get_funding_loans(self, *, symbol: Optional[str] = None) -> List[FundingLoan]: + def get_funding_loans( + self, *, symbol: str | None = None + ) -> list[FundingLoan]: if symbol is None: endpoint = "auth/r/funding/loans" else: @@ -483,11 +508,11 @@ def get_funding_loans(self, *, symbol: Optional[str] = None) -> List[FundingLoan def get_funding_loans_history( self, *, - symbol: Optional[str] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[FundingLoan]: + symbol: str | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[FundingLoan]: if symbol is None: endpoint = "auth/r/funding/loans/hist" else: @@ -501,8 +526,8 @@ def get_funding_loans_history( ] def get_funding_credits( - self, *, symbol: Optional[str] = None - ) -> List[FundingCredit]: + self, *, symbol: str | None = None + ) -> list[FundingCredit]: if symbol is None: endpoint = "auth/r/funding/credits" else: @@ -516,11 +541,11 @@ def get_funding_credits( def get_funding_credits_history( self, *, - symbol: Optional[str] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[FundingCredit]: + symbol: str | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[FundingCredit]: if symbol is None: endpoint = "auth/r/funding/credits/hist" else: @@ -536,12 +561,12 @@ def get_funding_credits_history( def get_funding_trades_history( self, *, - symbol: Optional[str] = None, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[FundingTrade]: + symbol: str | None = None, + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[FundingTrade]: if symbol is None: endpoint = "auth/r/funding/trades/hist" else: @@ -565,7 +590,7 @@ def transfer_between_wallets( to_wallet: str, currency: str, currency_to: str, - amount: Union[str, float, Decimal], + amount: str | float | Decimal, ) -> Notification[Transfer]: body = { "from": from_wallet, @@ -580,7 +605,11 @@ def transfer_between_wallets( ) def submit_wallet_withdrawal( - self, wallet: str, method: str, address: str, amount: Union[str, float, Decimal] + self, + wallet: str, + method: str, + address: str, + amount: str | float | Decimal, ) -> Notification[Withdrawal]: body = { "wallet": wallet, @@ -604,7 +633,7 @@ def get_deposit_address( ) def generate_deposit_invoice( - self, wallet: str, currency: str, amount: Union[str, float, Decimal] + self, wallet: str, currency: str, amount: str | float | Decimal ) -> LightningNetworkInvoice: return serializers.LightningNetworkInvoice.parse( *self._m.post( @@ -616,11 +645,11 @@ def generate_deposit_invoice( def get_movements( self, *, - currency: Optional[str] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Movement]: + currency: str | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Movement]: if currency is None: endpoint = "auth/r/movements/hist" else: diff --git a/bfxapi/rest/_interfaces/rest_public_endpoints.py b/bfxapi/rest/_interfaces/rest_public_endpoints.py index 2c2a540a..3d75da74 100644 --- a/bfxapi/rest/_interfaces/rest_public_endpoints.py +++ b/bfxapi/rest/_interfaces/rest_public_endpoints.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, cast from bfxapi.rest._interface import Interface from bfxapi.types import ( @@ -34,8 +34,8 @@ def get_platform_status(self) -> PlatformStatus: return serializers.PlatformStatus.parse(*self._m.get("platform/status")) def get_tickers( - self, symbols: List[str] - ) -> Dict[str, Union[TradingPairTicker, FundingCurrencyTicker]]: + self, symbols: list[str] + ) -> dict[str, TradingPairTicker | FundingCurrencyTicker]: data = self._m.get("tickers", params={"symbols": ",".join(symbols)}) parsers = { @@ -45,7 +45,7 @@ def get_tickers( return { symbol: cast( - Union[TradingPairTicker, FundingCurrencyTicker], + TradingPairTicker | FundingCurrencyTicker, parsers[symbol[0]](*sub_data), ) for sub_data in data @@ -53,8 +53,8 @@ def get_tickers( } def get_t_tickers( - self, symbols: Union[List[str], Literal["ALL"]] - ) -> Dict[str, TradingPairTicker]: + self, symbols: list[str] | Literal["ALL"] + ) -> dict[str, TradingPairTicker]: if isinstance(symbols, str) and symbols == "ALL": return { symbol: cast(TradingPairTicker, sub_data) @@ -64,11 +64,11 @@ def get_t_tickers( data = self.get_tickers(list(symbols)) - return cast(Dict[str, TradingPairTicker], data) + return cast(dict[str, TradingPairTicker], data) def get_f_tickers( - self, symbols: Union[List[str], Literal["ALL"]] - ) -> Dict[str, FundingCurrencyTicker]: + self, symbols: list[str] | Literal["ALL"] + ) -> dict[str, FundingCurrencyTicker]: if isinstance(symbols, str) and symbols == "ALL": return { symbol: cast(FundingCurrencyTicker, sub_data) @@ -78,22 +78,26 @@ def get_f_tickers( data = self.get_tickers(list(symbols)) - return cast(Dict[str, FundingCurrencyTicker], data) + return cast(dict[str, FundingCurrencyTicker], data) def get_t_ticker(self, symbol: str) -> TradingPairTicker: - return serializers.TradingPairTicker.parse(*self._m.get(f"ticker/{symbol}")) + return serializers.TradingPairTicker.parse( + *self._m.get(f"ticker/{symbol}") + ) def get_f_ticker(self, symbol: str) -> FundingCurrencyTicker: - return serializers.FundingCurrencyTicker.parse(*self._m.get(f"ticker/{symbol}")) + return serializers.FundingCurrencyTicker.parse( + *self._m.get(f"ticker/{symbol}") + ) def get_tickers_history( self, - symbols: List[str], + symbols: list[str], *, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[TickersHistory]: + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[TickersHistory]: return [ serializers.TickersHistory.parse(*sub_data) for sub_data in self._m.get( @@ -111,38 +115,45 @@ def get_t_trades( self, pair: str, *, - limit: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - sort: Optional[int] = None, - ) -> List[TradingPairTrade]: + limit: int | None = None, + start: str | None = None, + end: str | None = None, + sort: int | None = None, + ) -> list[TradingPairTrade]: params = {"limit": limit, "start": start, "end": end, "sort": sort} data = self._m.get(f"trades/{pair}/hist", params=params) - return [serializers.TradingPairTrade.parse(*sub_data) for sub_data in data] + return [ + serializers.TradingPairTrade.parse(*sub_data) for sub_data in data + ] def get_f_trades( self, currency: str, *, - limit: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - sort: Optional[int] = None, - ) -> List[FundingCurrencyTrade]: + limit: int | None = None, + start: str | None = None, + end: str | None = None, + sort: int | None = None, + ) -> list[FundingCurrencyTrade]: params = {"limit": limit, "start": start, "end": end, "sort": sort} data = self._m.get(f"trades/{currency}/hist", params=params) - return [serializers.FundingCurrencyTrade.parse(*sub_data) for sub_data in data] + return [ + serializers.FundingCurrencyTrade.parse(*sub_data) + for sub_data in data + ] def get_t_book( self, pair: str, precision: Literal["P0", "P1", "P2", "P3", "P4"], *, - len: Optional[Literal[1, 25, 100]] = None, - ) -> List[TradingPairBook]: + len: Literal[1, 25, 100] | None = None, + ) -> list[TradingPairBook]: return [ serializers.TradingPairBook.parse(*sub_data) - for sub_data in self._m.get(f"book/{pair}/{precision}", params={"len": len}) + for sub_data in self._m.get( + f"book/{pair}/{precision}", params={"len": len} + ) ] def get_f_book( @@ -150,8 +161,8 @@ def get_f_book( currency: str, precision: Literal["P0", "P1", "P2", "P3", "P4"], *, - len: Optional[Literal[1, 25, 100]] = None, - ) -> List[FundingCurrencyBook]: + len: Literal[1, 25, 100] | None = None, + ) -> list[FundingCurrencyBook]: return [ serializers.FundingCurrencyBook.parse(*sub_data) for sub_data in self._m.get( @@ -160,30 +171,32 @@ def get_f_book( ] def get_t_raw_book( - self, pair: str, *, len: Optional[Literal[1, 25, 100]] = None - ) -> List[TradingPairRawBook]: + self, pair: str, *, len: Literal[1, 25, 100] | None = None + ) -> list[TradingPairRawBook]: return [ serializers.TradingPairRawBook.parse(*sub_data) for sub_data in self._m.get(f"book/{pair}/R0", params={"len": len}) ] def get_f_raw_book( - self, currency: str, *, len: Optional[Literal[1, 25, 100]] = None - ) -> List[FundingCurrencyRawBook]: + self, currency: str, *, len: Literal[1, 25, 100] | None = None + ) -> list[FundingCurrencyRawBook]: return [ serializers.FundingCurrencyRawBook.parse(*sub_data) - for sub_data in self._m.get(f"book/{currency}/R0", params={"len": len}) + for sub_data in self._m.get( + f"book/{currency}/R0", params={"len": len} + ) ] def get_stats_hist( self, resource: str, *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Statistic]: + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Statistic]: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"stats1/{resource}/hist", params=params) return [serializers.Statistic.parse(*sub_data) for sub_data in data] @@ -192,10 +205,10 @@ def get_stats_last( self, resource: str, *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, ) -> Statistic: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"stats1/{resource}/last", params=params) @@ -206,11 +219,11 @@ def get_candles_hist( symbol: str, tf: str = "1m", *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Candle]: + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Candle]: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"candles/trade:{tf}:{symbol}/hist", params=params) return [serializers.Candle.parse(*sub_data) for sub_data in data] @@ -220,18 +233,18 @@ def get_candles_last( symbol: str, tf: str = "1m", *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, ) -> Candle: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"candles/trade:{tf}:{symbol}/last", params=params) return serializers.Candle.parse(*data) def get_derivatives_status( - self, keys: Union[List[str], Literal["ALL"]] - ) -> Dict[str, DerivativesStatus]: + self, keys: list[str] | Literal["ALL"] + ) -> dict[str, DerivativesStatus]: if keys == "ALL": params = {"keys": "ALL"} else: @@ -249,50 +262,54 @@ def get_derivatives_status_history( self, key: str, *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[DerivativesStatus]: + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[DerivativesStatus]: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"status/deriv/{key}/hist", params=params) - return [serializers.DerivativesStatus.parse(*sub_data) for sub_data in data] + return [ + serializers.DerivativesStatus.parse(*sub_data) for sub_data in data + ] def get_liquidations( self, *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Liquidation]: + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Liquidation]: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get("liquidations/hist", params=params) - return [serializers.Liquidation.parse(*sub_data[0]) for sub_data in data] + return [ + serializers.Liquidation.parse(*sub_data[0]) for sub_data in data + ] def get_seed_candles( self, symbol: str, tf: str = "1m", *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Candle]: - params = {"sort": sort, "start": start, "end": end, "limit": limit} - data = self._m.get(f"candles/trade:{tf}:{symbol}/hist", params=params) - return [serializers.Candle.parse(*sub_data) for sub_data in data] + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Candle]: + return self.get_candles_hist( + symbol, tf, sort=sort, start=start, end=end, limit=limit + ) def get_leaderboards_hist( self, resource: str, *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Leaderboard]: + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[Leaderboard]: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"rankings/{resource}/hist", params=params) return [serializers.Leaderboard.parse(*sub_data) for sub_data in data] @@ -301,10 +318,10 @@ def get_leaderboards_last( self, resource: str, *, - sort: Optional[int] = None, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, + sort: int | None = None, + start: str | None = None, + end: str | None = None, + limit: int | None = None, ) -> Leaderboard: params = {"sort": sort, "start": start, "end": end, "limit": limit} data = self._m.get(f"rankings/{resource}/last", params=params) @@ -314,35 +331,41 @@ def get_funding_stats( self, symbol: str, *, - start: Optional[str] = None, - end: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[FundingStatistic]: + start: str | None = None, + end: str | None = None, + limit: int | None = None, + ) -> list[FundingStatistic]: params = {"start": start, "end": end, "limit": limit} data = self._m.get(f"funding/stats/{symbol}/hist", params=params) - return [serializers.FundingStatistic.parse(*sub_data) for sub_data in data] + return [ + serializers.FundingStatistic.parse(*sub_data) for sub_data in data + ] def get_trading_market_average_price( self, symbol: str, - amount: Union[str, float, Decimal], + amount: str | float | Decimal, *, - price_limit: Optional[Union[str, float, Decimal]] = None, + price_limit: str | float | Decimal | None = None, ) -> TradingMarketAveragePrice: return serializers.TradingMarketAveragePrice.parse( *self._m.post( "calc/trade/avg", - body={"symbol": symbol, "amount": amount, "price_limit": price_limit}, + body={ + "symbol": symbol, + "amount": amount, + "price_limit": price_limit, + }, ) ) def get_funding_market_average_price( self, symbol: str, - amount: Union[str, float, Decimal], + amount: str | float | Decimal, period: int, *, - rate_limit: Optional[Union[str, float, Decimal]] = None, + rate_limit: str | float | Decimal | None = None, ) -> FundingMarketAveragePrice: return serializers.FundingMarketAveragePrice.parse( *self._m.post( diff --git a/bfxapi/rest/exceptions.py b/bfxapi/rest/exceptions.py index 7bc3c671..e2eae4aa 100644 --- a/bfxapi/rest/exceptions.py +++ b/bfxapi/rest/exceptions.py @@ -7,3 +7,29 @@ class RequestParameterError(BfxBaseException): class GenericError(BfxBaseException): pass + + +class RateLimitError(BfxBaseException): + """Raised when Bitfinex returns HTTP 429 or error code 10010. + + Attributes: + retry_after_ms: Suggested wait time in milliseconds. + """ + + def __init__(self, message: str, retry_after_ms: int = 60_000): + super().__init__(message) + self.retry_after_ms = retry_after_ms + + +class InsufficientFundsError(BfxBaseException): + """Raised when Bitfinex returns error code 10001.""" + + pass + + +class NetworkError(BfxBaseException): + """Raised on connection errors, timeouts, DNS failures.""" + + def __init__(self, message: str, retryable: bool = True): + super().__init__(message) + self.retryable = retryable diff --git a/bfxapi/rest/retry.py b/bfxapi/rest/retry.py new file mode 100644 index 00000000..f50a7387 --- /dev/null +++ b/bfxapi/rest/retry.py @@ -0,0 +1,177 @@ +"""Retry with exponential backoff for Bitfinex REST API calls. + +Understands Bitfinex-specific error patterns and applies appropriate +backoff strategies. + +Example:: + + from bfxapi.rest.retry import retry_with_backoff + + result = await retry_with_backoff( + lambda: client.rest.get_wallets(), + max_attempts=5, + ) +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable +from typing import TypeVar + +from bfxapi.exceptions import InvalidCredentialError +from bfxapi.rest.exceptions import ( + GenericError, + InsufficientFundsError, + NetworkError, + RateLimitError, +) + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +def is_retryable(error: BaseException) -> bool: + """Check if an error is worth retrying.""" + if isinstance(error, RateLimitError): + return True + if isinstance(error, NetworkError): + return error.retryable + if isinstance(error, InvalidCredentialError): + return False + if isinstance(error, InsufficientFundsError): + return False + if isinstance(error, GenericError): + msg = str(error).lower() + # Nonce errors are transient + if "nonce" in msg: + return True + return False + # Network-level errors from requests + msg = str(error).lower() + return any( + kw in msg for kw in ("timeout", "connection", "reset", "refused") + ) + + +def get_backoff_delay( + error: BaseException, + attempt: int, + base_delay: float, + max_delay: float, +) -> float: + """Calculate backoff delay for a given error and attempt.""" + if isinstance(error, RateLimitError): + return min(error.retry_after_ms / 1000.0, max_delay) + + msg = str(error).lower() + if "nonce" in msg: + return 1.0 + + if isinstance(error, NetworkError): + return min(base_delay * (attempt + 1), 60.0) + + # Default: exponential backoff + delay: float = base_delay * (2**attempt) + return min(delay, max_delay) + + +def retry_with_backoff( + fn: Callable[[], T], + max_attempts: int = 5, + base_delay: float = 1.0, + max_delay: float = 300.0, +) -> T: + """Retry a synchronous function with exponential backoff. + + Args: + fn: Function to retry. + max_attempts: Maximum number of attempts. + base_delay: Base delay in seconds for backoff. + max_delay: Maximum delay cap in seconds. + + Returns: + Result of the function. + + Raises: + The last exception if all attempts fail. + """ + import time + + last_error: BaseException | None = None + + for attempt in range(max_attempts): + try: + return fn() + except Exception as e: + last_error = e + + if attempt == max_attempts - 1: + raise + + if not is_retryable(e): + raise + + delay = get_backoff_delay(e, attempt, base_delay, max_delay) + logger.warning( + "Attempt %d/%d failed (%s), retrying in %.1fs", + attempt + 1, + max_attempts, + type(e).__name__, + delay, + ) + time.sleep(delay) + + raise last_error or RuntimeError("Max retry attempts exceeded") + + +async def async_retry_with_backoff( + fn: Callable[[], T], + max_attempts: int = 5, + base_delay: float = 1.0, + max_delay: float = 300.0, +) -> T: + """Retry a synchronous function with async sleep between attempts. + + Useful when the function itself is sync (REST calls) but you want + non-blocking sleep in an async context. + + Args: + fn: Synchronous function to retry. + max_attempts: Maximum number of attempts. + base_delay: Base delay in seconds for backoff. + max_delay: Maximum delay cap in seconds. + + Returns: + Result of the function. + + Raises: + The last exception if all attempts fail. + """ + last_error: BaseException | None = None + + for attempt in range(max_attempts): + try: + return fn() + except Exception as e: + last_error = e + + if attempt == max_attempts - 1: + raise + + if not is_retryable(e): + raise + + delay = get_backoff_delay(e, attempt, base_delay, max_delay) + logger.warning( + "Attempt %d/%d failed (%s), retrying in %.1fs", + attempt + 1, + max_attempts, + type(e).__name__, + delay, + ) + await asyncio.sleep(delay) + + raise last_error or RuntimeError("Max retry attempts exceeded") diff --git a/bfxapi/types/__init__.py b/bfxapi/types/__init__.py index ce15a180..6941a5f5 100644 --- a/bfxapi/types/__init__.py +++ b/bfxapi/types/__init__.py @@ -50,4 +50,60 @@ Wallet, Withdrawal, ) -from .notification import Notification +from .notification import Notification, _Notification + +__all__ = [ + "BalanceAvailable", + "BalanceInfo", + "BaseMarginInfo", + "Candle", + "DepositAddress", + "DerivativePositionCollateral", + "DerivativePositionCollateralLimits", + "DerivativesStatus", + "FundingAutoRenew", + "FundingCredit", + "FundingCurrencyBook", + "FundingCurrencyRawBook", + "FundingCurrencyTicker", + "FundingCurrencyTrade", + "FundingInfo", + "FundingLoan", + "FundingMarketAveragePrice", + "FundingOffer", + "FundingStatistic", + "FundingTrade", + "FxRate", + "Leaderboard", + "Ledger", + "LightningNetworkInvoice", + "Liquidation", + "LoginHistory", + "Movement", + "Notification", + "Order", + "OrderTrade", + "PlatformStatus", + "Position", + "PositionAudit", + "PositionClaim", + "PositionHistory", + "PositionIncrease", + "PositionIncreaseInfo", + "PositionSnapshot", + "Statistic", + "SymbolMarginInfo", + "TickersHistory", + "Trade", + "TradingMarketAveragePrice", + "TradingPairBook", + "TradingPairRawBook", + "TradingPairTicker", + "TradingPairTrade", + "Transfer", + "UserInfo", + "Wallet", + "Withdrawal", + "_Notification", + "serializers", +] diff --git a/bfxapi/types/dataclasses.py b/bfxapi/types/dataclasses.py index 99aaa6ee..4a8fdf80 100644 --- a/bfxapi/types/dataclasses.py +++ b/bfxapi/types/dataclasses.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any from .labeler import _Type @@ -148,7 +148,7 @@ class Leaderboard(_Type): username: str ranking: int value: float - twitter_handle: Optional[str] + twitter_handle: str | None @dataclass @@ -204,14 +204,14 @@ class UserInfo(_Type): ppt_enabled: int merchant_enabled: int competition_enabled: int - two_factors_authentication_modes: List[str] + two_factors_authentication_modes: list[str] is_securities_master: int securities_enabled: int allow_disable_ctxswitch: int time_last_login: int ctxtswitch_disabled: int - comp_countries: List[str] - compl_countries_resid: List[str] + comp_countries: list[str] + compl_countries_resid: list[str] is_merchant_enterprise: int @@ -220,7 +220,7 @@ class LoginHistory(_Type): id: int time: int ip: str - extra_info: Dict[str, Any] + extra_info: dict[str, Any] @dataclass @@ -251,7 +251,7 @@ class Order(_Type): hidden: int placed_id: int routing: str - meta: Dict[str, Any] + meta: dict[str, Any] @dataclass @@ -272,7 +272,7 @@ class Position(_Type): type: int collateral: float collateral_min: float - meta: Dict[str, Any] + meta: dict[str, Any] @dataclass @@ -412,7 +412,7 @@ class Wallet(_Type): unsettled_interest: float available_balance: float last_change: str - trade_details: Dict[str, Any] + trade_details: dict[str, Any] @dataclass @@ -497,7 +497,7 @@ class PositionClaim(_Type): pos_type: int collateral: str min_collateral: str - meta: Dict[str, Any] + meta: dict[str, Any] @dataclass @@ -563,7 +563,7 @@ class PositionAudit(_Type): type: int collateral: float collateral_min: float - meta: Dict[str, Any] + meta: dict[str, Any] @dataclass diff --git a/bfxapi/types/labeler.py b/bfxapi/types/labeler.py index 9c966ed5..6c4e0502 100644 --- a/bfxapi/types/labeler.py +++ b/bfxapi/types/labeler.py @@ -1,10 +1,109 @@ -from typing import Any, Dict, Generic, Iterable, List, Tuple, Type, TypeVar, cast +from collections.abc import Callable, Iterable +from decimal import Decimal +from typing import Any, Generic, TypeVar, cast T = TypeVar("T", bound="_Type") - -def compose(*decorators): - def wrapper(function): +# Field names that represent monetary/financial values. +# When decimal_mode is enabled, these are converted from float to Decimal. +MONETARY_FIELDS: frozenset[str] = frozenset( + { + "amount", + "amount_orig", + "ask", + "ask_size", + "available_balance", + "balance", + "base_price", + "bid", + "bid_size", + "buy", + "collateral", + "collateral_min", + "current_pos", + "daily_change", + "daily_change_relative", + "deriv_price", + "exec_amount", + "exec_price", + "fee", + "fees", + "frr", + "frr_amount_available", + "funding", + "funding_amount", + "funding_amount_used", + "funding_avail", + "funding_below_threshold", + "funding_required", + "funding_required_currency", + "funding_value", + "funding_value_currency", + "gross_balance", + "high", + "insurance_fund_balance", + "last_price", + "leverage", + "low", + "margin_balance", + "margin_funding", + "margin_net", + "margin_min", + "mark_price", + "max_pos", + "min_collateral", + "max_collateral", + "next_funding_accrued", + "next_funding_step", + "open_interest", + "order_price", + "pl", + "pl_perc", + "price", + "price_avg", + "price_liq", + "price_trailing", + "price_aux_limit", + "rate", + "rate_avg", + "sell", + "spot_price", + "tradable_balance", + "tradable_balance_base_currency", + "tradable_balance_base_total", + "tradable_balance_quote_currency", + "tradable_balance_quote_total", + "unsettled_interest", + "user_pl", + "user_swaps", + "value", + "volume", + "withdrawal_fee", + "yield_lend", + "yield_loan", + "aum", + "aum_net", + "current_funding", + "base_currency_balance", + "clamp_min", + "clamp_max", + } +) + +# Module-level flag — set by Client when decimal_mode=True +_decimal_mode: bool = False + + +def set_decimal_mode(enabled: bool) -> None: + """Enable or disable Decimal conversion for monetary fields.""" + global _decimal_mode + _decimal_mode = enabled + + +def compose( + *decorators: Callable[[type[Any]], type[Any]], +) -> Callable[[type[Any]], type[Any]]: + def wrapper(function: type[Any]) -> type[Any]: for decorator in reversed(decorators): function = decorator(function) return function @@ -12,8 +111,8 @@ def wrapper(function): return wrapper -def partial(cls): - def __init__(self, **kwargs): +def partial(cls: type[Any]) -> type[Any]: + def __init__(self: Any, **kwargs: Any) -> None: for annotation in self.__annotations__.keys(): if annotation not in kwargs: self.__setattr__(annotation, None) @@ -41,11 +140,21 @@ class _Type: class _Serializer(Generic[T]): def __init__( - self, name: str, klass: Type[_Type], labels: List[str], *, flat: bool = False - ): - self.name, self.klass, self.__labels, self.__flat = name, klass, labels, flat - - def _serialize(self, *args: Any) -> Iterable[Tuple[str, Any]]: + self, + name: str, + klass: type[_Type], + labels: list[str], + *, + flat: bool = False, + ) -> None: + self.name, self.klass, self.__labels, self.__flat = ( + name, + klass, + labels, + flat, + ) + + def _serialize(self, *args: Any) -> Iterable[tuple[str, Any]]: if self.__flat: args = tuple(_Serializer.__flatten(list(args))) @@ -57,16 +166,23 @@ def _serialize(self, *args: Any) -> Iterable[Tuple[str, Any]]: for index, label in enumerate(self.__labels): if label != "_PLACEHOLDER": - yield label, args[index] + value = args[index] + if ( + _decimal_mode + and label in MONETARY_FIELDS + and isinstance(value, (int, float)) + ): + value = Decimal(str(value)) + yield label, value def parse(self, *values: Any) -> T: return cast(T, self.klass(**dict(self._serialize(*values)))) - def get_labels(self) -> List[str]: + def get_labels(self) -> list[str]: return [label for label in self.__labels if label != "_PLACEHOLDER"] @classmethod - def __flatten(cls, array: List[Any]) -> List[Any]: + def __flatten(cls, array: list[Any]) -> list[Any]: if len(array) == 0: return array @@ -76,16 +192,16 @@ def __flatten(cls, array: List[Any]) -> List[Any]: return array[:1] + cls.__flatten(array[1:]) -class _RecursiveSerializer(_Serializer, Generic[T]): +class _RecursiveSerializer(_Serializer[T], Generic[T]): def __init__( self, name: str, - klass: Type[_Type], - labels: List[str], + klass: type[_Type], + labels: list[str], *, - serializers: Dict[str, _Serializer[Any]], + serializers: dict[str, _Serializer[Any]], flat: bool = False, - ): + ) -> None: super().__init__(name, klass, labels, flat=flat) self.serializers = serializers @@ -95,23 +211,25 @@ def parse(self, *values: Any) -> T: for key in serialization: if key in self.serializers.keys(): - serialization[key] = self.serializers[key].parse(*serialization[key]) + serialization[key] = self.serializers[key].parse( + *serialization[key] + ) return cast(T, self.klass(**serialization)) def generate_labeler_serializer( - name: str, klass: Type[T], labels: List[str], *, flat: bool = False + name: str, klass: type[T], labels: list[str], *, flat: bool = False ) -> _Serializer[T]: return _Serializer[T](name, klass, labels, flat=flat) def generate_recursive_serializer( name: str, - klass: Type[T], - labels: List[str], + klass: type[T], + labels: list[str], *, - serializers: Dict[str, _Serializer[Any]], + serializers: dict[str, _Serializer[Any]], flat: bool = False, ) -> _RecursiveSerializer[T]: return _RecursiveSerializer[T]( diff --git a/bfxapi/types/notification.py b/bfxapi/types/notification.py index 2b7d3753..6562e372 100644 --- a/bfxapi/types/notification.py +++ b/bfxapi/types/notification.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Generic, List, Optional, TypeVar, cast +from typing import Any, Generic, TypeVar, cast from .labeler import _Serializer, _Type @@ -10,14 +10,14 @@ class Notification(_Type, Generic[T]): mts: int type: str - message_id: Optional[int] + message_id: int | None data: T - code: Optional[int] + code: int | None status: str text: str -class _Notification(_Serializer, Generic[T]): +class _Notification(_Serializer[Notification[T]], Generic[T]): __LABELS = [ "mts", "type", @@ -30,8 +30,10 @@ class _Notification(_Serializer, Generic[T]): ] def __init__( - self, serializer: Optional[_Serializer] = None, is_iterable: bool = False - ): + self, + serializer: _Serializer[Any] | None = None, + is_iterable: bool = False, + ) -> None: super().__init__("Notification", Notification, _Notification.__LABELS) self.serializer, self.is_iterable = serializer, is_iterable @@ -42,7 +44,7 @@ def parse(self, *values: Any) -> Notification[T]: ) if isinstance(self.serializer, _Serializer): - data = cast(List[Any], notification.data) + data = cast(list[Any], notification.data) if not self.is_iterable: if len(data) == 1 and isinstance(data[0], list): diff --git a/bfxapi/types/serializers.py b/bfxapi/types/serializers.py index 3881fa5b..091a1ed1 100644 --- a/bfxapi/types/serializers.py +++ b/bfxapi/types/serializers.py @@ -353,7 +353,9 @@ ) BalanceAvailable = generate_labeler_serializer( - name="BalanceAvailable", klass=dataclasses.BalanceAvailable, labels=["amount"] + name="BalanceAvailable", + klass=dataclasses.BalanceAvailable, + labels=["amount"], ) Order = generate_labeler_serializer( @@ -444,7 +446,15 @@ FundingTrade = generate_labeler_serializer( name="FundingTrade", klass=dataclasses.FundingTrade, - labels=["id", "currency", "mts_create", "offer_id", "amount", "rate", "period"], + labels=[ + "id", + "currency", + "mts_create", + "offer_id", + "amount", + "rate", + "period", + ], ) OrderTrade = generate_labeler_serializer( @@ -648,7 +658,13 @@ LightningNetworkInvoice = generate_labeler_serializer( name="LightningNetworkInvoice", klass=dataclasses.LightningNetworkInvoice, - labels=["invoice_hash", "invoice", "_PLACEHOLDER", "_PLACEHOLDER", "amount"], + labels=[ + "invoice_hash", + "invoice", + "_PLACEHOLDER", + "_PLACEHOLDER", + "amount", + ], ) Movement = generate_labeler_serializer( diff --git a/bfxapi/websocket/__init__.py b/bfxapi/websocket/__init__.py index ced8300f..2616aa97 100644 --- a/bfxapi/websocket/__init__.py +++ b/bfxapi/websocket/__init__.py @@ -1 +1,3 @@ from ._client import BfxWebSocketClient + +__all__ = ["BfxWebSocketClient"] diff --git a/bfxapi/websocket/_client/__init__.py b/bfxapi/websocket/_client/__init__.py index ebbd6d2c..07b60461 100644 --- a/bfxapi/websocket/_client/__init__.py +++ b/bfxapi/websocket/_client/__init__.py @@ -1 +1,3 @@ from .bfx_websocket_client import BfxWebSocketClient + +__all__ = ["BfxWebSocketClient"] diff --git a/bfxapi/websocket/_client/bfx_websocket_bucket.py b/bfxapi/websocket/_client/bfx_websocket_bucket.py index fa6262fb..a6a0c78f 100644 --- a/bfxapi/websocket/_client/bfx_websocket_bucket.py +++ b/bfxapi/websocket/_client/bfx_websocket_bucket.py @@ -1,9 +1,9 @@ import asyncio import json import uuid -from typing import Any, Dict, List, Optional, cast +from typing import Any, cast -import websockets.client +import websockets.asyncio.client from pyee import EventEmitter from bfxapi._utils.json_decoder import JSONDecoder @@ -14,7 +14,7 @@ _CHECKSUM_FLAG_VALUE = 131_072 -def _strip(message: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: +def _strip(message: dict[str, Any], keys: list[str]) -> dict[str, Any]: return {key: value for key, value in message.items() if key not in keys} @@ -25,12 +25,14 @@ def __init__(self, host: str, event_emitter: EventEmitter) -> None: super().__init__(host) self.__event_emitter = event_emitter - self.__pendings: List[Dict[str, Any]] = [] - self.__subscriptions: Dict[int, Subscription] = {} + self.__pendings: list[dict[str, Any]] = [] + self.__subscriptions: dict[int, Subscription] = {} self.__condition = asyncio.locks.Condition() - self.__handler = PublicChannelsHandler(event_emitter=self.__event_emitter) + self.__handler = PublicChannelsHandler( + event_emitter=self.__event_emitter + ) @property def count(self) -> int: @@ -41,13 +43,14 @@ def is_full(self) -> bool: return self.count == BfxWebSocketBucket.__MAXIMUM_SUBSCRIPTIONS_AMOUNT @property - def ids(self) -> List[str]: + def ids(self) -> list[str]: return [pending["subId"] for pending in self.__pendings] + [ - subscription["sub_id"] for subscription in self.__subscriptions.values() + subscription["sub_id"] + for subscription in self.__subscriptions.values() ] async def start(self) -> None: - async with websockets.client.connect(self._host) as websocket: + async with websockets.asyncio.client.connect(self._host) as websocket: self._websocket = websocket await self.__recover_state() @@ -70,11 +73,12 @@ async def start(self) -> None: ): self.__handler.handle(subscription, message[1:]) - def __on_subscribed(self, message: Dict[str, Any]) -> None: + def __on_subscribed(self, message: dict[str, Any]) -> None: chan_id = cast(int, message["chan_id"]) subscription = cast( - Subscription, _strip(message, keys=["chan_id", "event", "pair", "currency"]) + Subscription, + _strip(message, keys=["chan_id", "event", "pair", "currency"]), ) self.__pendings = [ @@ -98,14 +102,16 @@ async def __recover_state(self) -> None: await self.__set_config([_CHECKSUM_FLAG_VALUE]) - async def __set_config(self, flags: List[int]) -> None: - await self._websocket.send(json.dumps({"event": "conf", "flags": sum(flags)})) + async def __set_config(self, flags: list[int]) -> None: + await self._websocket.send( + json.dumps({"event": "conf", "flags": sum(flags)}) + ) - @Connection._require_websocket_connection + @Connection._require_websocket_connection # type: ignore[arg-type] async def subscribe( - self, channel: str, sub_id: Optional[str] = None, **kwargs: Any + self, channel: str, sub_id: str | None = None, **kwargs: Any ) -> None: - subscription: Dict[str, Any] = { + subscription: dict[str, Any] = { **kwargs, "event": "subscribe", "channel": channel, diff --git a/bfxapi/websocket/_client/bfx_websocket_client.py b/bfxapi/websocket/_client/bfx_websocket_client.py index ffae0adf..8cbd7153 100644 --- a/bfxapi/websocket/_client/bfx_websocket_client.py +++ b/bfxapi/websocket/_client/bfx_websocket_client.py @@ -3,14 +3,16 @@ import random import traceback from asyncio import Task +from collections.abc import Callable from datetime import datetime from logging import Logger from socket import gaierror -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, TypedDict -import websockets -import websockets.client -from websockets.exceptions import ConnectionClosedError, InvalidStatusCode +import websockets.asyncio.client +import websockets.frames +from pyee import Handler +from websockets.exceptions import ConnectionClosedError, InvalidStatus from bfxapi._utils.json_encoder import JSONEncoder from bfxapi.exceptions import InvalidCredentialError @@ -28,13 +30,18 @@ from .bfx_websocket_bucket import BfxWebSocketBucket from .bfx_websocket_inputs import BfxWebSocketInputs -_Credentials = TypedDict( - "_Credentials", {"api_key": str, "api_secret": str, "filters": Optional[List[str]]} -) -_Reconnection = TypedDict( - "_Reconnection", {"attempts": int, "reason": str, "timestamp": datetime} -) +class _Credentials(TypedDict): + api_key: str + api_secret: str + filters: list[str] | None + + +class _Reconnection(TypedDict): + attempts: int + reason: str + timestamp: datetime + _DEFAULT_LOGGER = Logger("bfxapi.websocket._client", level=0) @@ -72,17 +79,21 @@ def __init__( self, host: str, *, - credentials: Optional[_Credentials] = None, - timeout: Optional[int] = 60 * 15, + credentials: _Credentials | None = None, + timeout: int | None = 60 * 15, logger: Logger = _DEFAULT_LOGGER, ) -> None: super().__init__(host) - self.__credentials, self.__timeout, self.__logger = credentials, timeout, logger + self.__credentials, self.__timeout, self.__logger = ( + credentials, + timeout, + logger, + ) - self.__buckets: Dict[BfxWebSocketBucket, Optional[Task]] = {} + self.__buckets: dict[BfxWebSocketBucket, Task[None] | None] = {} - self.__reconnection: Optional[_Reconnection] = None + self.__reconnection: _Reconnection | None = None self.__event_emitter = BfxEventEmitter(loop=None) @@ -100,7 +111,7 @@ def error(exception: Exception) -> None: type(exception), exception, exception.__traceback__ ) - self.__logger.critical(f"{header}\n" + str().join(stack_trace)[:-1]) + self.__logger.critical(f"{header}\n" + "".join(stack_trace)[:-1]) @property def inputs(self) -> BfxWebSocketInputs: @@ -112,9 +123,9 @@ def run(self) -> None: async def start(self) -> None: _delay = _Delay(backoff_factor=1.618) - _sleep: Optional[Task] = None + _sleep: Task[None] | None = None - def _on_timeout(): + def _on_timeout() -> None: if not self.open: if _sleep: _sleep.cancel() @@ -133,14 +144,22 @@ def _on_timeout(): try: await self.__connect() - except (ConnectionClosedError, InvalidStatusCode, gaierror) as error: + except ( + ConnectionClosedError, + InvalidStatus, + gaierror, + ) as error: - async def _cancel(task: Task) -> None: + async def _cancel(task: Task[None]) -> None: task.cancel() try: await task - except (ConnectionClosedError, InvalidStatusCode, gaierror) as _e: + except ( + ConnectionClosedError, + InvalidStatus, + gaierror, + ) as _e: nonlocal error if type(error) is not type(_e) or error.args != _e.args: @@ -154,25 +173,34 @@ async def _cancel(task: Task) -> None: await _cancel(task) - if isinstance(error, ConnectionClosedError) and error.code in ( - 1006, - 1012, + if ( + isinstance(error, ConnectionClosedError) + and error.rcvd + and error.rcvd.code + in ( + 1006, + 1012, + ) ): - if error.code == 1006: - self.__logger.error("Connection lost: trying to reconnect...") + if error.rcvd.code == 1006: + self.__logger.error( + "Connection lost: trying to reconnect..." + ) - if error.code == 1012: + if error.rcvd.code == 1012: self.__logger.warning( "WSS server is restarting: all " "clients need to reconnect (server sent 20051)." ) if self.__timeout: - asyncio.get_event_loop().call_later(self.__timeout, _on_timeout) + asyncio.get_event_loop().call_later( + self.__timeout, _on_timeout + ) self.__reconnection = { "attempts": 1, - "reason": error.reason, + "reason": error.rcvd.reason, "timestamp": datetime.now(), } @@ -180,7 +208,10 @@ async def _cancel(task: Task) -> None: _delay.reset() elif ( - (isinstance(error, InvalidStatusCode) and error.status_code == 408) + ( + isinstance(error, InvalidStatus) + and error.response.status_code == 408 + ) or isinstance(error, gaierror) ) and self.__reconnection: self.__logger.warning( @@ -199,16 +230,23 @@ async def _cancel(task: Task) -> None: raise error if not self.__reconnection: + close_code = None + close_reason = "" + if self._websocket.close_code is not None: + close_code = self._websocket.close_code + if self._websocket.close_reason is not None: + close_reason = self._websocket.close_reason + self.__event_emitter.emit( "disconnected", - self._websocket.close_code, - self._websocket.close_reason, + close_code, + close_reason, ) break async def __connect(self) -> None: - async with websockets.client.connect(self._host) as websocket: + async with websockets.asyncio.client.connect(self._host) as websocket: if self.__reconnection: self.__logger.warning( "Reconnection attempt successful (no." @@ -224,7 +262,9 @@ async def __connect(self) -> None: self.__buckets[bucket] = asyncio.create_task(bucket.start()) if len(self.__buckets) == 0 or ( - await asyncio.gather(*[bucket.wait() for bucket in self.__buckets]) + await asyncio.gather( + *[bucket.wait() for bucket in self.__buckets] + ) ): self.__event_emitter.emit("open") @@ -247,9 +287,12 @@ async def __connect(self) -> None: f"to resolve this error (client version: 2, server " f"version: {message['version']})." ) - elif message["event"] == "info" and message["code"] == 20051: + elif ( + message["event"] == "info" and message["code"] == 20051 + ): rcvd = websockets.frames.Close( - 1012, "Stop/Restart WebSocket Server (please reconnect)." + 1012, + "Stop/Restart WebSocket Server (please reconnect).", ) raise ConnectionClosedError(rcvd=rcvd, sent=None) @@ -279,9 +322,9 @@ async def __new_bucket(self) -> BfxWebSocketBucket: return bucket - @Connection._require_websocket_connection + @Connection._require_websocket_connection # type: ignore[arg-type] async def subscribe( - self, channel: str, sub_id: Optional[str] = None, **kwargs: Any + self, channel: str, sub_id: str | None = None, **kwargs: Any ) -> None: if channel not in ["ticker", "trades", "book", "candles", "status"]: raise UnknownChannelError( @@ -330,22 +373,33 @@ async def close(self, code: int = 1000, reason: str = "") -> None: for bucket in self.__buckets: await bucket.close(code=code, reason=reason) - if self._websocket.open: + if self.open: await self._websocket.close(code=code, reason=reason) - @Connection._require_websocket_authentication + @Connection._require_websocket_authentication # type: ignore[arg-type] async def notify( - self, info: Any, message_id: Optional[int] = None, **kwargs: Any + self, info: Any, message_id: int | None = None, **kwargs: Any ) -> None: await self._websocket.send( json.dumps( - [0, "n", message_id, {"type": "ucm-test", "info": info, **kwargs}] + [ + 0, + "n", + message_id, + {"type": "ucm-test", "info": info, **kwargs}, + ] ) ) @Connection._require_websocket_authentication async def __handle_websocket_input(self, event: str, data: Any) -> None: - await self._websocket.send(json.dumps([0, event, None, data], cls=JSONEncoder)) + await self._websocket.send( + json.dumps([0, event, None, data], cls=JSONEncoder) + ) - def on(self, event, callback=None): + def on( + self, event: str, callback: Handler | None = None + ) -> Handler | Callable[[Handler], Handler]: + if callback is None: + return self.__event_emitter.on(event) return self.__event_emitter.on(event, callback) diff --git a/bfxapi/websocket/_client/bfx_websocket_inputs.py b/bfxapi/websocket/_client/bfx_websocket_inputs.py index 48a39d10..3c6029bf 100644 --- a/bfxapi/websocket/_client/bfx_websocket_inputs.py +++ b/bfxapi/websocket/_client/bfx_websocket_inputs.py @@ -1,5 +1,6 @@ +from collections.abc import Awaitable, Callable from decimal import Decimal -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union +from typing import Any _Handler = Callable[[str, Any], Awaitable[None]] @@ -12,18 +13,18 @@ async def submit_order( self, type: str, symbol: str, - amount: Union[str, float, Decimal], - price: Union[str, float, Decimal], + amount: str | float | Decimal, + price: str | float | Decimal, *, - lev: Optional[int] = None, - price_trailing: Optional[Union[str, float, Decimal]] = None, - price_aux_limit: Optional[Union[str, float, Decimal]] = None, - price_oco_stop: Optional[Union[str, float, Decimal]] = None, - gid: Optional[int] = None, - cid: Optional[int] = None, - flags: Optional[int] = None, - tif: Optional[str] = None, - meta: Optional[Dict[str, Any]] = None, + lev: int | None = None, + price_trailing: str | float | Decimal | None = None, + price_aux_limit: str | float | Decimal | None = None, + price_oco_stop: str | float | Decimal | None = None, + gid: int | None = None, + cid: int | None = None, + flags: int | None = None, + tif: str | None = None, + meta: dict[str, Any] | None = None, ) -> None: await self.__handle_websocket_input( "on", @@ -48,17 +49,17 @@ async def update_order( self, id: int, *, - amount: Optional[Union[str, float, Decimal]] = None, - price: Optional[Union[str, float, Decimal]] = None, - cid: Optional[int] = None, - cid_date: Optional[str] = None, - gid: Optional[int] = None, - flags: Optional[int] = None, - lev: Optional[int] = None, - delta: Optional[Union[str, float, Decimal]] = None, - price_aux_limit: Optional[Union[str, float, Decimal]] = None, - price_trailing: Optional[Union[str, float, Decimal]] = None, - tif: Optional[str] = None, + amount: str | float | Decimal | None = None, + price: str | float | Decimal | None = None, + cid: int | None = None, + cid_date: str | None = None, + gid: int | None = None, + flags: int | None = None, + lev: int | None = None, + delta: str | float | Decimal | None = None, + price_aux_limit: str | float | Decimal | None = None, + price_trailing: str | float | Decimal | None = None, + tif: str | None = None, ) -> None: await self.__handle_websocket_input( "ou", @@ -81,9 +82,9 @@ async def update_order( async def cancel_order( self, *, - id: Optional[int] = None, - cid: Optional[int] = None, - cid_date: Optional[str] = None, + id: int | None = None, + cid: int | None = None, + cid_date: str | None = None, ) -> None: await self.__handle_websocket_input( "oc", {"id": id, "cid": cid, "cid_date": cid_date} @@ -92,10 +93,10 @@ async def cancel_order( async def cancel_order_multi( self, *, - id: Optional[List[int]] = None, - cid: Optional[List[Tuple[int, str]]] = None, - gid: Optional[List[int]] = None, - all: Optional[bool] = None, + id: list[int] | None = None, + cid: list[tuple[int, str]] | None = None, + gid: list[int] | None = None, + all: bool | None = None, ) -> None: await self.__handle_websocket_input( "oc_multi", {"id": id, "cid": cid, "gid": gid, "all": all} @@ -105,11 +106,11 @@ async def submit_funding_offer( self, type: str, symbol: str, - amount: Union[str, float, Decimal], - rate: Union[str, float, Decimal], + amount: str | float | Decimal, + rate: str | float | Decimal, period: int, *, - flags: Optional[int] = None, + flags: int | None = None, ) -> None: await self.__handle_websocket_input( "fon", @@ -127,4 +128,6 @@ async def cancel_funding_offer(self, id: int) -> None: await self.__handle_websocket_input("foc", {"id": id}) async def calc(self, *args: str) -> None: - await self.__handle_websocket_input("calc", list(map(lambda arg: [arg], args))) + await self.__handle_websocket_input( + "calc", list(map(lambda arg: [arg], args)) + ) diff --git a/bfxapi/websocket/_connection.py b/bfxapi/websocket/_connection.py index 10579d94..f1e10af2 100644 --- a/bfxapi/websocket/_connection.py +++ b/bfxapi/websocket/_connection.py @@ -1,15 +1,23 @@ import hashlib import hmac import json +import time from abc import ABC, abstractmethod -from datetime import datetime -from functools import wraps -from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, cast - -from typing_extensions import Concatenate, ParamSpec -from websockets.client import WebSocketClientProtocol - -from bfxapi.websocket.exceptions import ActionRequiresAuthentication, ConnectionNotOpen +from collections.abc import Awaitable, Callable +from typing import ( + Any, + Concatenate, + ParamSpec, + TypeVar, + cast, +) + +from websockets.asyncio.client import ClientConnection + +from bfxapi.websocket.exceptions import ( + ActionRequiresAuthentication, + ConnectionNotOpen, +) _S = TypeVar("_S", bound="Connection") @@ -26,22 +34,24 @@ def __init__(self, host: str) -> None: self._authentication: bool = False - self.__protocol: Optional[WebSocketClientProtocol] = None + self.__protocol: ClientConnection | None = None @property def open(self) -> bool: - return self.__protocol is not None and self.__protocol.open + return ( + self.__protocol is not None and self.__protocol.state.name == "OPEN" + ) @property def authentication(self) -> bool: return self._authentication @property - def _websocket(self) -> WebSocketClientProtocol: - return cast(WebSocketClientProtocol, self.__protocol) + def _websocket(self) -> ClientConnection: + return cast(ClientConnection, self.__protocol) @_websocket.setter - def _websocket(self, protocol: WebSocketClientProtocol) -> None: + def _websocket(self, protocol: ClientConnection) -> None: self.__protocol = protocol @abstractmethod @@ -51,8 +61,9 @@ async def start(self) -> None: ... def _require_websocket_connection( function: Callable[Concatenate[_S, _P], Awaitable[_R]], ) -> Callable[Concatenate[_S, _P], Awaitable[_R]]: - @wraps(function) - async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R: + async def wrapper( + self: _S, /, *args: _P.args, **kwargs: _P.kwargs + ) -> _R: if self.open: return await function(self, *args, **kwargs) @@ -64,8 +75,9 @@ async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R: def _require_websocket_authentication( function: Callable[Concatenate[_S, _P], Awaitable[_R]], ) -> Callable[Concatenate[_S, _P], Awaitable[_R]]: - @wraps(function) - async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R: + async def wrapper( + self: _S, /, *args: _P.args, **kwargs: _P.kwargs + ) -> _R: if not self.authentication: raise ActionRequiresAuthentication( "To perform this action you need to " @@ -80,15 +92,15 @@ async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R: @staticmethod def _get_authentication_message( - api_key: str, api_secret: str, filters: Optional[List[str]] = None + api_key: str, api_secret: str, filters: list[str] | None = None ) -> str: - message: Dict[str, Any] = { + message: dict[str, Any] = { "event": "auth", "filter": filters, "apiKey": api_key, } - message["authNonce"] = round(datetime.now().timestamp() * 1_000_000) + message["authNonce"] = time.time_ns() // 1_000 message["authPayload"] = f"AUTH{message['authNonce']}" diff --git a/bfxapi/websocket/_event_emitter/__init__.py b/bfxapi/websocket/_event_emitter/__init__.py index 66f58aee..8425750e 100644 --- a/bfxapi/websocket/_event_emitter/__init__.py +++ b/bfxapi/websocket/_event_emitter/__init__.py @@ -1 +1,3 @@ from .bfx_event_emitter import BfxEventEmitter + +__all__ = ["BfxEventEmitter"] diff --git a/bfxapi/websocket/_event_emitter/bfx_event_emitter.py b/bfxapi/websocket/_event_emitter/bfx_event_emitter.py index 21bbfd63..720e318c 100644 --- a/bfxapi/websocket/_event_emitter/bfx_event_emitter.py +++ b/bfxapi/websocket/_event_emitter/bfx_event_emitter.py @@ -1,13 +1,13 @@ from asyncio import AbstractEventLoop from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from collections.abc import Callable +from typing import Any, overload +from pyee import Handler from pyee.asyncio import AsyncIOEventEmitter from bfxapi.websocket.exceptions import UnknownEventError -_Handler = TypeVar("_Handler", bound=Callable[..., None]) - _ONCE_PER_CONNECTION = [ "open", "authenticated", @@ -80,12 +80,12 @@ class BfxEventEmitter(AsyncIOEventEmitter): _EVENTS = _ONCE_PER_CONNECTION + _ONCE_PER_SUBSCRIPTION + _COMMON - def __init__(self, loop: Optional[AbstractEventLoop] = None) -> None: + def __init__(self, loop: AbstractEventLoop | None = None) -> None: super().__init__(loop) - self._connection: List[str] = [] + self._connection: list[str] = [] - self._subscriptions: Dict[str, List[str]] = defaultdict(lambda: []) + self._subscriptions: dict[str, list[str]] = defaultdict(lambda: []) def emit(self, event: str, *args: Any, **kwargs: Any) -> bool: if event in _ONCE_PER_CONNECTION: @@ -104,15 +104,24 @@ def emit(self, event: str, *args: Any, **kwargs: Any) -> bool: return super().emit(event, *args, **kwargs) + @overload + def on(self, event: str) -> Callable[[Handler], Handler]: ... + + @overload + def on(self, event: str, f: Handler) -> Handler: ... + def on( - self, event: str, f: Optional[_Handler] = None - ) -> Union[_Handler, Callable[[_Handler], _Handler]]: + self, event: str, f: Handler | None = None + ) -> Handler | Callable[[Handler], Handler]: if event not in BfxEventEmitter._EVENTS: raise UnknownEventError( f"Can't register to unknown event: <{event}> (to get a full " "list of available events see https://docs.bitfinex.com/)." ) + if f is None: + return super().on(event) + return super().on(event, f) def _has_listeners(self, event: str) -> bool: diff --git a/bfxapi/websocket/_handlers/__init__.py b/bfxapi/websocket/_handlers/__init__.py index 3fd99dbe..699d7d9b 100644 --- a/bfxapi/websocket/_handlers/__init__.py +++ b/bfxapi/websocket/_handlers/__init__.py @@ -1,2 +1,4 @@ from .auth_events_handler import AuthEventsHandler from .public_channels_handler import PublicChannelsHandler + +__all__ = ["AuthEventsHandler", "PublicChannelsHandler"] diff --git a/bfxapi/websocket/_handlers/auth_events_handler.py b/bfxapi/websocket/_handlers/auth_events_handler.py index 486e9f7a..8bd3bc60 100644 --- a/bfxapi/websocket/_handlers/auth_events_handler.py +++ b/bfxapi/websocket/_handlers/auth_events_handler.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, Tuple +from typing import Any from pyee.base import EventEmitter from bfxapi.types import serializers from bfxapi.types.dataclasses import FundingOffer, Order -from bfxapi.types.serializers import _Notification +from bfxapi.types.labeler import _Serializer +from bfxapi.types.notification import _Notification class AuthEventsHandler: @@ -37,7 +38,8 @@ class AuthEventsHandler: "bu": "balance_update", } - __SERIALIZERS: Dict[Tuple[str, ...], serializers._Serializer] = { + # Flattened for O(1) lookup per message (was O(n) iterating grouped tuples) + _GROUPED: dict[tuple[str, ...], _Serializer[Any]] = { ("os", "on", "ou", "oc"): serializers.Order, ("ps", "pn", "pu", "pc"): serializers.Position, ("te", "tu"): serializers.Trade, @@ -49,6 +51,12 @@ class AuthEventsHandler: ("bu",): serializers.BalanceInfo, } + __SERIALIZERS: dict[str, _Serializer[Any]] = { + abbr: ser + for abbrs, ser in _GROUPED.items() + for abbr in abbrs + } + def __init__(self, event_emitter: EventEmitter) -> None: self.__event_emitter = event_emitter @@ -58,37 +66,48 @@ def handle(self, abbrevation: str, stream: Any) -> None: elif abbrevation == "miu": if stream[0] == "base": self.__event_emitter.emit( - "base_margin_info", serializers.BaseMarginInfo.parse(*stream) + "base_margin_info", + serializers.BaseMarginInfo.parse(*stream), ) elif stream[0] == "sym": self.__event_emitter.emit( - "symbol_margin_info", serializers.SymbolMarginInfo.parse(*stream) + "symbol_margin_info", + serializers.SymbolMarginInfo.parse(*stream), ) else: - for abbrevations, serializer in AuthEventsHandler.__SERIALIZERS.items(): - if abbrevation in abbrevations: - event = AuthEventsHandler.__ABBREVIATIONS[abbrevation] + serializer = AuthEventsHandler.__SERIALIZERS.get(abbrevation) + if serializer is not None: + event = AuthEventsHandler.__ABBREVIATIONS[abbrevation] - if all(isinstance(sub_stream, list) for sub_stream in stream): - data = [serializer.parse(*sub_stream) for sub_stream in stream] - else: - data = serializer.parse(*stream) + if all( + isinstance(sub_stream, list) for sub_stream in stream + ): + data: Any = [ + serializer.parse(*sub_stream) + for sub_stream in stream + ] + else: + data = serializer.parse(*stream) - self.__event_emitter.emit(event, data) + self.__event_emitter.emit(event, data) def __notification(self, stream: Any) -> None: event: str = "notification" - serializer: _Notification = _Notification[None](serializer=None) + serializer: _Notification[Any] = _Notification[None](serializer=None) if stream[1] in ("on-req", "ou-req", "oc-req"): - event, serializer = f"{stream[1]}-notification", _Notification[Order]( - serializer=serializers.Order + event, serializer = ( + f"{stream[1]}-notification", + _Notification[Order](serializer=serializers.Order), ) if stream[1] in ("fon-req", "foc-req"): - event, serializer = f"{stream[1]}-notification", _Notification[ - FundingOffer - ](serializer=serializers.FundingOffer) + event, serializer = ( + f"{stream[1]}-notification", + _Notification[FundingOffer]( + serializer=serializers.FundingOffer + ), + ) self.__event_emitter.emit(event, serializer.parse(*stream)) diff --git a/bfxapi/websocket/_handlers/public_channels_handler.py b/bfxapi/websocket/_handlers/public_channels_handler.py index 98da88c3..639de8ff 100644 --- a/bfxapi/websocket/_handlers/public_channels_handler.py +++ b/bfxapi/websocket/_handlers/public_channels_handler.py @@ -1,4 +1,4 @@ -from typing import Any, List, cast +from typing import Any from pyee.base import EventEmitter @@ -14,19 +14,24 @@ _CHECKSUM = "cs" +_TRADE_EVENTS: dict[str, str] = { + "te": "t_trade_execution", + "tu": "t_trade_execution_update", + "fte": "f_trade_execution", + "ftu": "f_trade_execution_update", +} + class PublicChannelsHandler: def __init__(self, event_emitter: EventEmitter) -> None: self.__event_emitter = event_emitter - def handle(self, subscription: Subscription, stream: List[Any]) -> None: + def handle(self, subscription: Subscription, stream: list[Any]) -> None: if subscription["channel"] == "ticker": - self.__ticker_channel_handler(cast(Ticker, subscription), stream) + self.__ticker_channel_handler(subscription, stream) elif subscription["channel"] == "trades": - self.__trades_channel_handler(cast(Trades, subscription), stream) + self.__trades_channel_handler(subscription, stream) elif subscription["channel"] == "book": - subscription = cast(Book, subscription) - if stream[0] == _CHECKSUM: self.__checksum_handler(subscription, stream[1]) else: @@ -35,50 +40,51 @@ def handle(self, subscription: Subscription, stream: List[Any]) -> None: else: self.__raw_book_channel_handler(subscription, stream) elif subscription["channel"] == "candles": - self.__candles_channel_handler(cast(Candles, subscription), stream) + self.__candles_channel_handler(subscription, stream) elif subscription["channel"] == "status": - self.__status_channel_handler(cast(Status, subscription), stream) + self.__status_channel_handler(subscription, stream) - def __ticker_channel_handler(self, subscription: Ticker, stream: List[Any]): + def __ticker_channel_handler( + self, subscription: Ticker, stream: list[Any] + ) -> None: if subscription["symbol"].startswith("t"): - return self.__event_emitter.emit( + self.__event_emitter.emit( "t_ticker_update", subscription, serializers.TradingPairTicker.parse(*stream[0]), ) + return if subscription["symbol"].startswith("f"): - return self.__event_emitter.emit( + self.__event_emitter.emit( "f_ticker_update", subscription, serializers.FundingCurrencyTicker.parse(*stream[0]), ) + return - def __trades_channel_handler(self, subscription: Trades, stream: List[Any]): - if (event := stream[0]) and event in ["te", "tu", "fte", "ftu"]: - events = { - "te": "t_trade_execution", - "tu": "t_trade_execution_update", - "fte": "f_trade_execution", - "ftu": "f_trade_execution_update", - } - + def __trades_channel_handler( + self, subscription: Trades, stream: list[Any] + ) -> None: + if isinstance(stream[0], str) and (event := stream[0]) in _TRADE_EVENTS: if subscription["symbol"].startswith("t"): - return self.__event_emitter.emit( - events[event], + self.__event_emitter.emit( + _TRADE_EVENTS[event], subscription, serializers.TradingPairTrade.parse(*stream[1]), ) + return if subscription["symbol"].startswith("f"): - return self.__event_emitter.emit( - events[event], + self.__event_emitter.emit( + _TRADE_EVENTS[event], subscription, serializers.FundingCurrencyTrade.parse(*stream[1]), ) + return if subscription["symbol"].startswith("t"): - return self.__event_emitter.emit( + self.__event_emitter.emit( "t_trades_snapshot", subscription, [ @@ -86,9 +92,10 @@ def __trades_channel_handler(self, subscription: Trades, stream: List[Any]): for sub_stream in stream[0] ], ) + return if subscription["symbol"].startswith("f"): - return self.__event_emitter.emit( + self.__event_emitter.emit( "f_trades_snapshot", subscription, [ @@ -96,11 +103,14 @@ def __trades_channel_handler(self, subscription: Trades, stream: List[Any]): for sub_stream in stream[0] ], ) + return - def __book_channel_handler(self, subscription: Book, stream: List[Any]): + def __book_channel_handler( + self, subscription: Book, stream: list[Any] + ) -> None: if subscription["symbol"].startswith("t"): if all(isinstance(sub_stream, list) for sub_stream in stream[0]): - return self.__event_emitter.emit( + self.__event_emitter.emit( "t_book_snapshot", subscription, [ @@ -108,16 +118,18 @@ def __book_channel_handler(self, subscription: Book, stream: List[Any]): for sub_stream in stream[0] ], ) + return - return self.__event_emitter.emit( + self.__event_emitter.emit( "t_book_update", subscription, serializers.TradingPairBook.parse(*stream[0]), ) + return if subscription["symbol"].startswith("f"): if all(isinstance(sub_stream, list) for sub_stream in stream[0]): - return self.__event_emitter.emit( + self.__event_emitter.emit( "f_book_snapshot", subscription, [ @@ -125,17 +137,21 @@ def __book_channel_handler(self, subscription: Book, stream: List[Any]): for sub_stream in stream[0] ], ) + return - return self.__event_emitter.emit( + self.__event_emitter.emit( "f_book_update", subscription, serializers.FundingCurrencyBook.parse(*stream[0]), ) + return - def __raw_book_channel_handler(self, subscription: Book, stream: List[Any]): + def __raw_book_channel_handler( + self, subscription: Book, stream: list[Any] + ) -> None: if subscription["symbol"].startswith("t"): if all(isinstance(sub_stream, list) for sub_stream in stream[0]): - return self.__event_emitter.emit( + self.__event_emitter.emit( "t_raw_book_snapshot", subscription, [ @@ -143,16 +159,18 @@ def __raw_book_channel_handler(self, subscription: Book, stream: List[Any]): for sub_stream in stream[0] ], ) + return - return self.__event_emitter.emit( + self.__event_emitter.emit( "t_raw_book_update", subscription, serializers.TradingPairRawBook.parse(*stream[0]), ) + return if subscription["symbol"].startswith("f"): if all(isinstance(sub_stream, list) for sub_stream in stream[0]): - return self.__event_emitter.emit( + self.__event_emitter.emit( "f_raw_book_snapshot", subscription, [ @@ -160,39 +178,51 @@ def __raw_book_channel_handler(self, subscription: Book, stream: List[Any]): for sub_stream in stream[0] ], ) + return - return self.__event_emitter.emit( + self.__event_emitter.emit( "f_raw_book_update", subscription, serializers.FundingCurrencyRawBook.parse(*stream[0]), ) + return - def __candles_channel_handler(self, subscription: Candles, stream: List[Any]): + def __candles_channel_handler( + self, subscription: Candles, stream: list[Any] + ) -> None: if all(isinstance(sub_stream, list) for sub_stream in stream[0]): - return self.__event_emitter.emit( + self.__event_emitter.emit( "candles_snapshot", subscription, - [serializers.Candle.parse(*sub_stream) for sub_stream in stream[0]], + [ + serializers.Candle.parse(*sub_stream) + for sub_stream in stream[0] + ], ) + return - return self.__event_emitter.emit( + self.__event_emitter.emit( "candles_update", subscription, serializers.Candle.parse(*stream[0]) ) - def __status_channel_handler(self, subscription: Status, stream: List[Any]): + def __status_channel_handler( + self, subscription: Status, stream: list[Any] + ) -> None: if subscription["key"].startswith("deriv:"): - return self.__event_emitter.emit( + self.__event_emitter.emit( "derivatives_status_update", subscription, serializers.DerivativesStatus.parse(*stream[0]), ) + return if subscription["key"].startswith("liq:"): - return self.__event_emitter.emit( + self.__event_emitter.emit( "liquidation_feed_update", subscription, serializers.Liquidation.parse(*stream[0][0]), ) + return - def __checksum_handler(self, subscription: Book, value: int): - return self.__event_emitter.emit("checksum", subscription, value & 0xFFFFFFFF) + def __checksum_handler(self, subscription: Book, value: int) -> None: + self.__event_emitter.emit("checksum", subscription, value & 0xFFFFFFFF) diff --git a/bfxapi/websocket/subscriptions.py b/bfxapi/websocket/subscriptions.py index 9f4f61f3..43bbf2a8 100644 --- a/bfxapi/websocket/subscriptions.py +++ b/bfxapi/websocket/subscriptions.py @@ -1,6 +1,4 @@ -from typing import Literal, TypedDict, Union - -Subscription = Union["Ticker", "Trades", "Book", "Candles", "Status"] +from typing import Literal, TypedDict Channel = Literal["ticker", "trades", "book", "candles", "status"] @@ -36,3 +34,6 @@ class Status(TypedDict): channel: Literal["status"] sub_id: str key: str + + +Subscription = Ticker | Trades | Book | Candles | Status diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index cd669e3a..00000000 Binary files a/dev-requirements.txt and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index 1833e97b..e2e700cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,82 @@ -[tool.black] -target-version = ["py38", "py39", "py310", "py311"] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "bitfinex-api-py" +version = "6.0.1" +description = "Official Bitfinex Python API" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.12" +authors = [ + { name = "Bitfinex", email = "support@bitfinex.com" }, +] +keywords = ["bitfinex", "api", "trading"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "pyee~=13.0", + "websockets~=16.0", + "requests>=2.32.3", +] + +[project.optional-dependencies] +typing = [ + "types-requests~=2.32.4", +] +dev = [ + "mypy~=1.19.0", + "types-requests~=2.32.4", + "ruff~=0.15.0", + "pre-commit~=4.5.0", + "pytest~=9.0", + "pytest-asyncio~=1.3.0", + "pytest-mock~=3.15.0", + "pytest-cov~=6.0", +] + +[project.urls] +Homepage = "https://github.com/CloudIngenium/bitfinex-api-py" +"Bug Reports" = "https://github.com/CloudIngenium/bitfinex-api-py/issues" +Source = "https://github.com/CloudIngenium/bitfinex-api-py" + +[tool.hatch.build.targets.wheel] +packages = ["bfxapi"] + +[tool.ruff] +target-version = "py312" +line-length = 80 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "B", # flake8-bugbear + "UP", # pyupgrade + "I", # isort +] +ignore = ["E203", "E501", "E701", "UP046", "UP047"] + +[tool.ruff.lint.per-file-ignores] +"*/__init__.py" = ["F401"] + +[tool.ruff.lint.isort] +known-first-party = ["bfxapi"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +addopts = "--cov=bfxapi --cov-report=term-missing --cov-fail-under=80" + +[tool.mypy] +python_version = "3.13" +strict = true +warn_unused_ignores = true +disallow_any_unimported = true diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 40ba2e47..00000000 Binary files a/requirements.txt and /dev/null differ diff --git a/setup.py b/setup.py deleted file mode 100644 index ddeb4b20..00000000 --- a/setup.py +++ /dev/null @@ -1,57 +0,0 @@ -from distutils.core import setup - -setup( - name="bitfinex-api-py", - version="4.0.0", - description="Official Bitfinex Python API", - long_description=( - "A Python reference implementation of the Bitfinex API " - "for both REST and websocket interaction." - ), - long_description_content_type="text/markdown", - url="https://github.com/bitfinexcom/bitfinex-api-py", - author="Bitfinex", - author_email="support@bitfinex.com", - license="Apache-2.0", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Topic :: Software Development :: Build Tools", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - ], - keywords="bitfinex,api,trading", - project_urls={ - "Bug Reports": "https://github.com/bitfinexcom/bitfinex-api-py/issues", - "Source": "https://github.com/bitfinexcom/bitfinex-api-py", - }, - packages=[ - "bfxapi", - "bfxapi._utils", - "bfxapi.types", - "bfxapi.websocket", - "bfxapi.websocket._client", - "bfxapi.websocket._handlers", - "bfxapi.websocket._event_emitter", - "bfxapi.rest", - "bfxapi.rest._interface", - "bfxapi.rest._interfaces", - ], - install_requires=[ - "pyee~=11.1.0", - "websockets~=12.0", - "requests~=2.32.3", - ], - extras_require={ - "typing": [ - "types-requests~=2.32.0.20241016", - ] - }, - python_requires=">=3.8", - package_data={"bfxapi": ["py.typed"]}, -) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_event_emitter.py b/tests/test_event_emitter.py new file mode 100644 index 00000000..02b4b4f6 --- /dev/null +++ b/tests/test_event_emitter.py @@ -0,0 +1,101 @@ +import pytest + +from bfxapi.websocket._event_emitter import BfxEventEmitter +from bfxapi.websocket.exceptions import UnknownEventError + + +class TestBfxEventEmitter: + def test_register_known_event(self): + ee = BfxEventEmitter(loop=None) + + @ee.on("t_ticker_update") + def handler(*args): + pass + + assert ee._has_listeners("t_ticker_update") + + def test_register_unknown_event_raises(self): + ee = BfxEventEmitter(loop=None) + + with pytest.raises(UnknownEventError, match="unknown event"): + + @ee.on("totally_fake_event") + def handler(*args): + pass + + def test_has_listeners_false_when_no_listeners(self): + ee = BfxEventEmitter(loop=None) + assert ee._has_listeners("t_ticker_update") is False + + def test_once_per_connection_events(self): + ee = BfxEventEmitter(loop=None) + calls = [] + + @ee.on("open") + def handler(): + calls.append(1) + + # First emit should trigger + ee.emit("open") + # Second emit should be suppressed (once per connection) + ee.emit("open") + + # Only called once because second emit is suppressed + assert len(calls) == 1 + + def test_once_per_subscription_events(self): + ee = BfxEventEmitter(loop=None) + calls = [] + + @ee.on("subscribed") + def handler(subscription): + calls.append(subscription) + + sub1 = {"sub_id": "abc", "channel": "ticker"} + sub2 = {"sub_id": "def", "channel": "trades"} + + # First emit for sub1 + ee.emit("subscribed", sub1) + # Second emit for sub1 should be suppressed + ee.emit("subscribed", sub1) + # First emit for sub2 should go through + ee.emit("subscribed", sub2) + + assert len(calls) == 2 + assert calls[0]["sub_id"] == "abc" + assert calls[1]["sub_id"] == "def" + + def test_common_events_always_emit(self): + ee = BfxEventEmitter(loop=None) + calls = [] + + @ee.on("t_ticker_update") + def handler(sub, data): + calls.append(data) + + sub = {"sub_id": "abc"} + ee.emit("t_ticker_update", sub, "data1") + ee.emit("t_ticker_update", sub, "data2") + ee.emit("t_ticker_update", sub, "data3") + + assert len(calls) == 3 + + def test_all_known_events_can_be_registered(self): + ee = BfxEventEmitter(loop=None) + for event in BfxEventEmitter._EVENTS: + + @ee.on(event) + def handler(*args, **kwargs): + pass + + def test_events_list_not_empty(self): + assert len(BfxEventEmitter._EVENTS) > 0 + + def test_on_returns_handler(self): + ee = BfxEventEmitter(loop=None) + + def my_handler(*args): + pass + + result = ee.on("order_new", my_handler) + assert result is my_handler diff --git a/tests/test_json_decoder.py b/tests/test_json_decoder.py new file mode 100644 index 00000000..2504bd82 --- /dev/null +++ b/tests/test_json_decoder.py @@ -0,0 +1,57 @@ +import json + +from bfxapi._utils.json_decoder import JSONDecoder, _to_snake_case + + +class TestToSnakeCase: + def test_camel_case(self): + assert _to_snake_case("camelCase") == "camel_case" + + def test_pascal_case(self): + # regex (? tuple[RestPublicEndpoints, MagicMock]: + ep = RestPublicEndpoints("https://api.example.com") + ep._m = MagicMock() + return ep, ep._m + + +class TestPlatformStatus: + def test_get_platform_status(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [1] + result = ep.get_platform_status() + mock_m.get.assert_called_once_with("platform/status") + assert isinstance(result, PlatformStatus) + assert result.status == 1 + + +class TestConf: + def test_conf(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [["BTC", "ETH"]] + result = ep.conf("pub:list:currency") + mock_m.get.assert_called_once_with("conf/pub:list:currency") + assert result == ["BTC", "ETH"] + + +class TestTickers: + def test_get_tickers_trading(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [ + "tBTCUSD", + 10000, + 1.5, + 10001, + 2.0, + 100, + 0.01, + 10000, + 50000, + 10500, + 9500, + ] + ] + result = ep.get_tickers(["tBTCUSD"]) + mock_m.get.assert_called_once_with( + "tickers", params={"symbols": "tBTCUSD"} + ) + assert "tBTCUSD" in result + assert isinstance(result["tBTCUSD"], TradingPairTicker) + + def test_get_tickers_funding(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [ + "fUSD", + 0.0002, + 0.00025, + 30, + 1000000, + 0.0002, + 2, + 500000, + 0.0001, + 0.0003, + 100000, + 50000, + 0.001, + 0.0, + None, + None, + 500000, + ] + ] + result = ep.get_tickers(["fUSD"]) + assert "fUSD" in result + assert isinstance(result["fUSD"], FundingCurrencyTicker) + + def test_get_t_ticker(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + 10000, + 1.5, + 10001, + 2.0, + 100, + 0.01, + 10000, + 50000, + 10500, + 9500, + ] + result = ep.get_t_ticker("tBTCUSD") + mock_m.get.assert_called_once_with("ticker/tBTCUSD") + assert isinstance(result, TradingPairTicker) + + def test_get_f_ticker(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + 0.0002, + 0.00025, + 30, + 1000000, + 0.0002, + 2, + 500000, + 0.0001, + 0.0003, + 100000, + 50000, + 0.001, + 0.0, + None, + None, + 500000, + ] + result = ep.get_f_ticker("fUSD") + mock_m.get.assert_called_once_with("ticker/fUSD") + assert isinstance(result, FundingCurrencyTicker) + + def test_get_t_tickers_with_list(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [ + "tBTCUSD", + 10000, + 1.5, + 10001, + 2.0, + 100, + 0.01, + 10000, + 50000, + 10500, + 9500, + ] + ] + result = ep.get_t_tickers(["tBTCUSD"]) + assert "tBTCUSD" in result + + def test_get_f_tickers_with_list(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [ + "fUSD", + 0.0002, + 0.00025, + 30, + 1000000, + 0.0002, + 2, + 500000, + 0.0001, + 0.0003, + 100000, + 50000, + 0.001, + 0.0, + None, + None, + 500000, + ] + ] + result = ep.get_f_tickers(["fUSD"]) + assert "fUSD" in result + + +class TestTickersHistory: + def test_get_tickers_history(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + ["tBTCUSD", 10000, None, 10001, None, None, None, None, None, None, None, None, 1609459200000] + ] + result = ep.get_tickers_history(["tBTCUSD"]) + mock_m.get.assert_called_once_with( + "tickers/hist", + params={ + "symbols": "tBTCUSD", + "start": None, + "end": None, + "limit": None, + }, + ) + assert len(result) == 1 + assert isinstance(result[0], TickersHistory) + + +class TestTrades: + def test_get_t_trades(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [123456, 1609459200000, 0.5, 10000] + ] + result = ep.get_t_trades("tBTCUSD") + mock_m.get.assert_called_once_with( + "trades/tBTCUSD/hist", + params={"limit": None, "start": None, "end": None, "sort": None}, + ) + assert len(result) == 1 + assert isinstance(result[0], TradingPairTrade) + + def test_get_f_trades(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [123456, 1609459200000, 1000, 0.0002, 30] + ] + result = ep.get_f_trades("fUSD") + mock_m.get.assert_called_once_with( + "trades/fUSD/hist", + params={"limit": None, "start": None, "end": None, "sort": None}, + ) + assert len(result) == 1 + assert isinstance(result[0], FundingCurrencyTrade) + + def test_get_t_trades_with_params(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [] + result = ep.get_t_trades("tBTCUSD", limit=10, sort=-1) + call_kwargs = mock_m.get.call_args + assert call_kwargs.kwargs["params"]["limit"] == 10 + assert call_kwargs.kwargs["params"]["sort"] == -1 + + +class TestBook: + def test_get_t_book(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [10000, 2, 1.5] + ] + result = ep.get_t_book("tBTCUSD", "P0") + mock_m.get.assert_called_once_with( + "book/tBTCUSD/P0", params={"len": None} + ) + assert len(result) == 1 + assert isinstance(result[0], TradingPairBook) + + def test_get_f_book(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [0.0002, 30, 2, 1000] + ] + result = ep.get_f_book("fUSD", "P0") + mock_m.get.assert_called_once_with( + "book/fUSD/P0", params={"len": None} + ) + assert len(result) == 1 + assert isinstance(result[0], FundingCurrencyBook) + + def test_get_t_raw_book(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [12345, 10000, 1.5] + ] + result = ep.get_t_raw_book("tBTCUSD") + mock_m.get.assert_called_once_with( + "book/tBTCUSD/R0", params={"len": None} + ) + assert len(result) == 1 + assert isinstance(result[0], TradingPairRawBook) + + def test_get_f_raw_book(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [12345, 30, 0.0002, 1000] + ] + result = ep.get_f_raw_book("fUSD") + mock_m.get.assert_called_once_with( + "book/fUSD/R0", params={"len": None} + ) + assert len(result) == 1 + assert isinstance(result[0], FundingCurrencyRawBook) + + def test_get_t_book_with_len(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [] + ep.get_t_book("tBTCUSD", "P1", len=25) + mock_m.get.assert_called_once_with( + "book/tBTCUSD/P1", params={"len": 25} + ) + + +class TestStats: + def test_get_stats_hist(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [1609459200000, 100] + ] + result = ep.get_stats_hist("pos.size:1m:tBTCUSD:long") + mock_m.get.assert_called_once() + assert len(result) == 1 + assert isinstance(result[0], Statistic) + + def test_get_stats_last(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [1609459200000, 100] + result = ep.get_stats_last("pos.size:1m:tBTCUSD:long") + assert isinstance(result, Statistic) + + +class TestCandles: + def test_get_candles_hist(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [1609459200000, 10000, 10100, 10200, 9900, 500] + ] + result = ep.get_candles_hist("tBTCUSD") + mock_m.get.assert_called_once() + assert len(result) == 1 + assert isinstance(result[0], Candle) + + def test_get_candles_last(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [1609459200000, 10000, 10100, 10200, 9900, 500] + result = ep.get_candles_last("tBTCUSD") + assert isinstance(result, Candle) + + def test_get_candles_hist_with_params(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [] + ep.get_candles_hist("tBTCUSD", tf="1h", limit=100, sort=-1) + call_args = mock_m.get.call_args + assert "candles/trade:1h:tBTCUSD/hist" in call_args.args + + def test_get_seed_candles(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [1609459200000, 10000, 10100, 10200, 9900, 500] + ] + result = ep.get_seed_candles("tBTCUSD", tf="5m") + assert len(result) == 1 + assert isinstance(result[0], Candle) + + +class TestDerivatives: + def test_get_derivatives_status(self): + ep, mock_m = _make_endpoint() + # 23 labels: key + mts, _PH, deriv_price, spot_price, _PH, insurance, + # _PH, next_funding_evt_mts, next_funding_accrued, next_funding_step, + # _PH, current_funding, _PH, _PH, mark_price, _PH, _PH, open_interest, + # _PH, _PH, _PH, clamp_min, clamp_max + mock_m.get.return_value = [ + ["tBTCF0:USTF0"] + [None] * 2 + [10000, 10100] + [None] * 2 + + [None, 0.0001, 0.0002, None, None, 0.001] + + [None] * 2 + [10050] + [None] * 2 + [100000] + + [None] * 2 + [None, -0.001, 0.001] + ] + result = ep.get_derivatives_status(["tBTCF0:USTF0"]) + mock_m.get.assert_called_once_with( + "status/deriv", params={"keys": "tBTCF0:USTF0"} + ) + assert "tBTCF0:USTF0" in result + assert isinstance(result["tBTCF0:USTF0"], DerivativesStatus) + + def test_get_derivatives_status_all(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [] + ep.get_derivatives_status("ALL") + mock_m.get.assert_called_once_with( + "status/deriv", params={"keys": "ALL"} + ) + + def test_get_derivatives_status_history(self): + ep, mock_m = _make_endpoint() + # 23 labels (no key prefix for history) + mock_m.get.return_value = [ + [1609459200000, None, 10000, 10100, None, None, None, + 0.0001, 0.0002, None, None, 0.001, None, None, 10050, + None, None, 100000, None, None, None, -0.001, 0.001] + ] + result = ep.get_derivatives_status_history("tBTCF0:USTF0") + assert len(result) == 1 + assert isinstance(result[0], DerivativesStatus) + + +class TestLiquidations: + def test_get_liquidations(self): + ep, mock_m = _make_endpoint() + # 12 labels: _PH, pos_id, mts, _PH, symbol, amount, base_price, + # _PH, is_match, is_market_sold, _PH, liquidation_price + mock_m.get.return_value = [ + [[None, 12345, 1609459200000, None, "tBTCUSD", 0.5, 10000, + None, 1, 0, None, 9800]] + ] + result = ep.get_liquidations() + mock_m.get.assert_called_once() + assert len(result) == 1 + assert isinstance(result[0], Liquidation) + + +class TestLeaderboards: + def test_get_leaderboards_hist(self): + ep, mock_m = _make_endpoint() + # 10 labels: mts, _PH, username, ranking, _PH, _PH, value, _PH, _PH, twitter_handle + mock_m.get.return_value = [ + [1609459200000, None, "username", 1, None, None, 100000, None, None, "@user"] + ] + result = ep.get_leaderboards_hist("plu_diff:1M:tGLOBAL:USD") + assert len(result) == 1 + assert isinstance(result[0], Leaderboard) + + def test_get_leaderboards_last(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + 1609459200000, None, "username", 1, None, None, 100000, None, None, "@user" + ] + result = ep.get_leaderboards_last("plu_diff:1M:tGLOBAL:USD") + assert isinstance(result, Leaderboard) + + +class TestFundingStats: + def test_get_funding_stats(self): + ep, mock_m = _make_endpoint() + mock_m.get.return_value = [ + [1609459200000, None, None, 0.0002, 0.00025, None, None, None, None, None, None, None] + ] + result = ep.get_funding_stats("fUSD") + mock_m.get.assert_called_once() + assert len(result) == 1 + assert isinstance(result[0], FundingStatistic) + + +class TestMarketAveragePrice: + def test_get_trading_market_average_price(self): + ep, mock_m = _make_endpoint() + mock_m.post.return_value = [10050.5, 0.5] + result = ep.get_trading_market_average_price("tBTCUSD", "0.5") + mock_m.post.assert_called_once_with( + "calc/trade/avg", + body={"symbol": "tBTCUSD", "amount": "0.5", "price_limit": None}, + ) + assert isinstance(result, TradingMarketAveragePrice) + + def test_get_funding_market_average_price(self): + ep, mock_m = _make_endpoint() + mock_m.post.return_value = [0.0002, 1000] + result = ep.get_funding_market_average_price("fUSD", "1000", 30) + mock_m.post.assert_called_once_with( + "calc/trade/avg", + body={ + "symbol": "fUSD", + "amount": "1000", + "period": 30, + "rate_limit": None, + }, + ) + assert isinstance(result, FundingMarketAveragePrice) + + +class TestFxRate: + def test_get_fx_rate(self): + ep, mock_m = _make_endpoint() + mock_m.post.return_value = [1.085] + result = ep.get_fx_rate("EUR", "USD") + mock_m.post.assert_called_once_with( + "calc/fx", body={"ccy1": "EUR", "ccy2": "USD"} + ) + assert isinstance(result, FxRate) diff --git a/tests/test_rest_retry.py b/tests/test_rest_retry.py new file mode 100644 index 00000000..50c6eb5f --- /dev/null +++ b/tests/test_rest_retry.py @@ -0,0 +1,110 @@ +import pytest + +from bfxapi.exceptions import InvalidCredentialError +from bfxapi.rest.exceptions import ( + GenericError, + InsufficientFundsError, + NetworkError, + RateLimitError, +) +from bfxapi.rest.retry import ( + get_backoff_delay, + is_retryable, + retry_with_backoff, +) + + +class TestIsRetryable: + def test_rate_limit_is_retryable(self): + assert is_retryable(RateLimitError("limit")) + + def test_network_error_retryable(self): + assert is_retryable(NetworkError("timeout", retryable=True)) + + def test_network_error_not_retryable(self): + assert not is_retryable(NetworkError("bad host", retryable=False)) + + def test_invalid_credentials_not_retryable(self): + assert not is_retryable(InvalidCredentialError("bad key")) + + def test_insufficient_funds_not_retryable(self): + assert not is_retryable(InsufficientFundsError("no funds")) + + def test_generic_error_not_retryable(self): + assert not is_retryable(GenericError("something")) + + def test_nonce_error_retryable(self): + assert is_retryable(GenericError("nonce too small")) + + def test_connection_string_retryable(self): + assert is_retryable(Exception("connection refused")) + + def test_unknown_error_not_retryable(self): + assert not is_retryable(Exception("something else")) + + +class TestGetBackoffDelay: + def test_rate_limit_uses_retry_after(self): + err = RateLimitError("limit", retry_after_ms=30000) + assert get_backoff_delay(err, 0, 1.0, 300.0) == 30.0 + + def test_nonce_short_delay(self): + err = GenericError("nonce too small") + assert get_backoff_delay(err, 0, 1.0, 300.0) == 1.0 + + def test_network_linear_backoff(self): + err = NetworkError("timeout") + assert get_backoff_delay(err, 0, 1.0, 300.0) == 1.0 + assert get_backoff_delay(err, 1, 1.0, 300.0) == 2.0 + assert get_backoff_delay(err, 2, 1.0, 300.0) == 3.0 + + def test_default_exponential(self): + err = Exception("something") + assert get_backoff_delay(err, 0, 1.0, 300.0) == 1.0 + assert get_backoff_delay(err, 1, 1.0, 300.0) == 2.0 + assert get_backoff_delay(err, 2, 1.0, 300.0) == 4.0 + + def test_caps_at_max(self): + err = Exception("something") + assert get_backoff_delay(err, 20, 1.0, 5.0) == 5.0 + + +class TestRetryWithBackoff: + def test_returns_on_first_success(self): + assert retry_with_backoff(lambda: 42) == 42 + + def test_retries_on_retryable_error(self): + calls = 0 + + def fn(): + nonlocal calls + calls += 1 + if calls < 3: + raise RateLimitError("limit", retry_after_ms=10) + return "ok" + + result = retry_with_backoff(fn, max_attempts=5, base_delay=0.01) + assert result == "ok" + assert calls == 3 + + def test_raises_after_max_attempts(self): + with pytest.raises(RateLimitError): + retry_with_backoff( + lambda: (_ for _ in ()).throw( + RateLimitError("limit", retry_after_ms=10) + ), + max_attempts=2, + base_delay=0.01, + ) + + def test_fails_fast_on_non_retryable(self): + calls = 0 + + def fn(): + nonlocal calls + calls += 1 + raise InsufficientFundsError("no funds") + + with pytest.raises(InsufficientFundsError): + retry_with_backoff(fn, max_attempts=5, base_delay=0.01) + assert calls == 1 diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 00000000..86157148 --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,743 @@ +import pytest + +from bfxapi.types import dataclasses, serializers +from bfxapi.types.labeler import _Serializer, generate_labeler_serializer +from bfxapi.types.notification import Notification, _Notification + + +class TestSerializerCore: + """Tests for the _Serializer base class.""" + + def test_simple_parse(self): + s = generate_labeler_serializer( + name="PlatformStatus", + klass=dataclasses.PlatformStatus, + labels=["status"], + ) + result = s.parse(1) + assert isinstance(result, dataclasses.PlatformStatus) + assert result.status == 1 + + def test_placeholder_filtering(self): + s = generate_labeler_serializer( + name="TickersHistory", + klass=dataclasses.TickersHistory, + labels=[ + "symbol", + "bid", + "_PLACEHOLDER", + "ask", + "_PLACEHOLDER", + "_PLACEHOLDER", + "_PLACEHOLDER", + "_PLACEHOLDER", + "_PLACEHOLDER", + "_PLACEHOLDER", + "_PLACEHOLDER", + "_PLACEHOLDER", + "mts", + ], + ) + result = s.parse( + "tBTCUSD", 50000.0, None, 50001.0, *([None] * 8), 1234567890 + ) + assert result.symbol == "tBTCUSD" + assert result.bid == 50000.0 + assert result.ask == 50001.0 + assert result.mts == 1234567890 + + def test_get_labels_excludes_placeholders(self): + s = generate_labeler_serializer( + name="Test", + klass=dataclasses.FxRate, + labels=["_PLACEHOLDER", "current_rate"], + ) + assert s.get_labels() == ["current_rate"] + + def test_mismatched_args_raises(self): + s = generate_labeler_serializer( + name="Test", + klass=dataclasses.TradingPairBook, + labels=["price", "count", "amount"], + ) + with pytest.raises(AssertionError, match=" and <\\*args>"): + s.parse(1.0, 2) # Missing one arg + + def test_flat_serializer(self): + result = serializers.SymbolMarginInfo.parse( + "sym", "tBTCUSD", 1000.0, 2000.0, 500.0, 600.0 + ) + assert isinstance(result, dataclasses.SymbolMarginInfo) + assert result.symbol == "tBTCUSD" + assert result.tradable_balance == 1000.0 + assert result.gross_balance == 2000.0 + assert result.buy == 500.0 + assert result.sell == 600.0 + + def test_base_margin_info_flat(self): + result = serializers.BaseMarginInfo.parse( + "base", 100.0, 200.0, 300.0, 400.0, 500.0 + ) + assert isinstance(result, dataclasses.BaseMarginInfo) + assert result.user_pl == 100.0 + assert result.margin_net == 400.0 + + +class TestPublicSerializers: + """Tests for public endpoint serializers.""" + + def test_trading_pair_ticker(self): + result = serializers.TradingPairTicker.parse( + 50000.0, + 1.5, + 50001.0, + 2.0, + 100.0, + 0.002, + 50000.5, + 10000.0, + 51000.0, + 49000.0, + ) + assert isinstance(result, dataclasses.TradingPairTicker) + assert result.bid == 50000.0 + assert result.bid_size == 1.5 + assert result.ask == 50001.0 + assert result.ask_size == 2.0 + assert result.daily_change == 100.0 + assert result.daily_change_relative == 0.002 + assert result.last_price == 50000.5 + assert result.volume == 10000.0 + assert result.high == 51000.0 + assert result.low == 49000.0 + + def test_funding_currency_ticker(self): + result = serializers.FundingCurrencyTicker.parse( + 0.0001, + 0.00009, + 2, + 100.0, + 0.00011, + 30, + 200.0, + 0.00001, + 0.01, + 0.0001, + 5000.0, + 0.00012, + 0.00008, + None, + None, + 10000.0, + ) + assert isinstance(result, dataclasses.FundingCurrencyTicker) + assert result.frr == 0.0001 + assert result.bid_period == 2 + assert result.frr_amount_available == 10000.0 + + def test_trading_pair_trade(self): + result = serializers.TradingPairTrade.parse( + 12345, 1000000, 0.5, 50000.0 + ) + assert isinstance(result, dataclasses.TradingPairTrade) + assert result.id == 12345 + assert result.mts == 1000000 + assert result.amount == 0.5 + assert result.price == 50000.0 + + def test_funding_currency_trade(self): + result = serializers.FundingCurrencyTrade.parse( + 67890, 1000000, 100.0, 0.0001, 30 + ) + assert isinstance(result, dataclasses.FundingCurrencyTrade) + assert result.id == 67890 + assert result.rate == 0.0001 + assert result.period == 30 + + def test_trading_pair_book(self): + result = serializers.TradingPairBook.parse(50000.0, 3, 1.5) + assert isinstance(result, dataclasses.TradingPairBook) + assert result.price == 50000.0 + assert result.count == 3 + assert result.amount == 1.5 + + def test_funding_currency_book(self): + result = serializers.FundingCurrencyBook.parse(0.0001, 30, 5, 1000.0) + assert isinstance(result, dataclasses.FundingCurrencyBook) + assert result.rate == 0.0001 + assert result.period == 30 + + def test_trading_pair_raw_book(self): + result = serializers.TradingPairRawBook.parse(111, 50000.0, 0.5) + assert isinstance(result, dataclasses.TradingPairRawBook) + assert result.order_id == 111 + + def test_funding_currency_raw_book(self): + result = serializers.FundingCurrencyRawBook.parse( + 222, 30, 0.0001, 500.0 + ) + assert isinstance(result, dataclasses.FundingCurrencyRawBook) + assert result.offer_id == 222 + + def test_candle(self): + result = serializers.Candle.parse( + 1000000, 50000, 50100, 51000, 49000, 1234.5 + ) + assert isinstance(result, dataclasses.Candle) + assert result.mts == 1000000 + assert result.open == 50000 + assert result.close == 50100 + assert result.high == 51000 + assert result.low == 49000 + assert result.volume == 1234.5 + + def test_statistic(self): + result = serializers.Statistic.parse(1000000, 42.5) + assert isinstance(result, dataclasses.Statistic) + assert result.mts == 1000000 + assert result.value == 42.5 + + def test_fx_rate(self): + result = serializers.FxRate.parse(1.12) + assert isinstance(result, dataclasses.FxRate) + assert result.current_rate == 1.12 + + def test_derivatives_status(self): + result = serializers.DerivativesStatus.parse( + 1000000, + None, + 50000.0, + 49900.0, + None, + 1000000.0, + None, + 2000000, + 0.0001, + 100, + None, + 0.001, + None, + None, + 50050.0, + None, + None, + 5000.0, + None, + None, + None, + -0.5, + 0.5, + ) + assert isinstance(result, dataclasses.DerivativesStatus) + assert result.deriv_price == 50000.0 + assert result.spot_price == 49900.0 + assert result.mark_price == 50050.0 + assert result.open_interest == 5000.0 + + def test_liquidation(self): + result = serializers.Liquidation.parse( + None, + 111, + 1000000, + None, + "tBTCUSD", + 0.5, + 50000.0, + None, + 1, + 0, + None, + 49000.0, + ) + assert isinstance(result, dataclasses.Liquidation) + assert result.pos_id == 111 + assert result.symbol == "tBTCUSD" + assert result.liquidation_price == 49000.0 + + def test_leaderboard(self): + result = serializers.Leaderboard.parse( + 1000000, + None, + "trader1", + 1, + None, + None, + 99.5, + None, + None, + "@trader1", + ) + assert isinstance(result, dataclasses.Leaderboard) + assert result.username == "trader1" + assert result.ranking == 1 + assert result.twitter_handle == "@trader1" + + def test_funding_statistic(self): + result = serializers.FundingStatistic.parse( + 1000000, + None, + None, + 0.0001, + 30.0, + None, + None, + 5000000.0, + 3000000.0, + None, + None, + 1000000.0, + ) + assert isinstance(result, dataclasses.FundingStatistic) + assert result.frr == 0.0001 + assert result.avg_period == 30.0 + assert result.funding_amount == 5000000.0 + + def test_trading_market_average_price(self): + result = serializers.TradingMarketAveragePrice.parse(50000.0, 1.5) + assert isinstance(result, dataclasses.TradingMarketAveragePrice) + assert result.price_avg == 50000.0 + assert result.amount == 1.5 + + def test_funding_market_average_price(self): + result = serializers.FundingMarketAveragePrice.parse(0.0001, 1000.0) + assert isinstance(result, dataclasses.FundingMarketAveragePrice) + assert result.rate_avg == 0.0001 + + +class TestAuthSerializers: + """Tests for authenticated endpoint serializers.""" + + def test_order(self): + result = serializers.Order.parse( + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.5, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "ACTIVE", + None, + None, + 50000.0, + 50000.0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ) + assert isinstance(result, dataclasses.Order) + assert result.id == 1001 + assert result.symbol == "tBTCUSD" + assert result.amount == 0.5 + assert result.order_type == "EXCHANGE LIMIT" + assert result.order_status == "ACTIVE" + assert result.price == 50000.0 + + def test_position(self): + result = serializers.Position.parse( + "tBTCUSD", + "ACTIVE", + 0.5, + 50000.0, + 0.0, + 0, + 100.0, + 0.002, + 45000.0, + 2.0, + None, + 9999, + 1000000, + 1000001, + None, + 0, + None, + 1000.0, + 500.0, + {}, + ) + assert isinstance(result, dataclasses.Position) + assert result.symbol == "tBTCUSD" + assert result.status == "ACTIVE" + assert result.leverage == 2.0 + assert result.position_id == 9999 + + def test_trade(self): + result = serializers.Trade.parse( + 5001, + "tBTCUSD", + 1000000, + 1001, + 0.5, + 50000.0, + "EXCHANGE LIMIT", + 50000.0, + 1, + -0.001, + "USD", + 123, + ) + assert isinstance(result, dataclasses.Trade) + assert result.id == 5001 + assert result.exec_amount == 0.5 + assert result.fee == -0.001 + + def test_wallet(self): + result = serializers.Wallet.parse( + "exchange", "BTC", 1.5, 0.0, 1.5, "2024-01-01", {} + ) + assert isinstance(result, dataclasses.Wallet) + assert result.wallet_type == "exchange" + assert result.currency == "BTC" + assert result.balance == 1.5 + + def test_funding_offer(self): + result = serializers.FundingOffer.parse( + 2001, + "fUSD", + 1000000, + 1000001, + 1000.0, + 1000.0, + "LIMIT", + None, + None, + 0, + "ACTIVE", + None, + None, + None, + 0.0001, + 30, + 0, + 0, + None, + 0, + None, + ) + assert isinstance(result, dataclasses.FundingOffer) + assert result.id == 2001 + assert result.symbol == "fUSD" + assert result.rate == 0.0001 + assert result.period == 30 + + def test_funding_credit(self): + result = serializers.FundingCredit.parse( + 3001, + "fUSD", + 1, + 1000000, + 1000001, + 500.0, + 0, + "ACTIVE", + "FIXED", + None, + None, + 0.0001, + 30, + 1000000, + 1000000, + 0, + 0, + None, + 0, + None, + 0, + "tBTCUSD", + ) + assert isinstance(result, dataclasses.FundingCredit) + assert result.id == 3001 + assert result.position_pair == "tBTCUSD" + + def test_funding_loan(self): + result = serializers.FundingLoan.parse( + 4001, + "fUSD", + 1, + 1000000, + 1000001, + 500.0, + 0, + "ACTIVE", + "FIXED", + None, + None, + 0.0001, + 30, + 1000000, + 1000000, + 0, + 0, + None, + 0, + None, + 0, + ) + assert isinstance(result, dataclasses.FundingLoan) + assert result.id == 4001 + + def test_funding_auto_renew(self): + result = serializers.FundingAutoRenew.parse("USD", 30, 0.0001, 500.0) + assert isinstance(result, dataclasses.FundingAutoRenew) + assert result.currency == "USD" + assert result.period == 30 + + def test_funding_info_flat(self): + result = serializers.FundingInfo.parse( + "sym", "fUSD", 0.001, 0.002, 10.0, 15.0 + ) + assert isinstance(result, dataclasses.FundingInfo) + assert result.symbol == "fUSD" + assert result.yield_loan == 0.001 + assert result.yield_lend == 0.002 + + def test_ledger(self): + result = serializers.Ledger.parse( + 7001, "BTC", None, 1000000, None, 0.5, 1.5, None, "Trade" + ) + assert isinstance(result, dataclasses.Ledger) + assert result.id == 7001 + assert result.currency == "BTC" + assert result.description == "Trade" + + def test_transfer(self): + result = serializers.Transfer.parse( + 1000000, "exchange", "margin", None, "BTC", "BTC", None, 0.5 + ) + assert isinstance(result, dataclasses.Transfer) + assert result.wallet_from == "exchange" + assert result.wallet_to == "margin" + + def test_withdrawal(self): + result = serializers.Withdrawal.parse( + 8001, + None, + "bitcoin", + "addr123", + "exchange", + 0.5, + None, + None, + 0.0001, + ) + assert isinstance(result, dataclasses.Withdrawal) + assert result.withdrawal_id == 8001 + assert result.withdrawal_fee == 0.0001 + + def test_deposit_address(self): + result = serializers.DepositAddress.parse( + None, "bitcoin", "BTC", None, "1A2B3C4D", "pool_addr" + ) + assert isinstance(result, dataclasses.DepositAddress) + assert result.method == "bitcoin" + assert result.address == "1A2B3C4D" + + def test_movement(self): + result = serializers.Movement.parse( + "9001", + "BTC", + "Bitcoin", + None, + None, + 1000000, + 1000001, + None, + None, + "COMPLETED", + None, + None, + 1, + 0, + None, + None, + "1A2B3C", + None, + None, + None, + "tx123", + "note", + ) + assert isinstance(result, dataclasses.Movement) + assert result.id == "9001" + assert result.status == "COMPLETED" + assert result.transaction_id == "tx123" + + def test_balance_info(self): + result = serializers.BalanceInfo.parse(100000.0, 95000.0) + assert isinstance(result, dataclasses.BalanceInfo) + assert result.aum == 100000.0 + assert result.aum_net == 95000.0 + + def test_derivative_position_collateral(self): + result = serializers.DerivativePositionCollateral.parse(1) + assert isinstance(result, dataclasses.DerivativePositionCollateral) + assert result.status == 1 + + def test_derivative_position_collateral_limits(self): + result = serializers.DerivativePositionCollateralLimits.parse( + 100.0, 50000.0 + ) + assert isinstance( + result, dataclasses.DerivativePositionCollateralLimits + ) + assert result.min_collateral == 100.0 + assert result.max_collateral == 50000.0 + + def test_position_increase_info_flat(self): + result = serializers.PositionIncreaseInfo.parse( + 10, + 5.0, + 1000.0, + 2000.0, + 3000.0, + 4000.0, + 5000.0, + None, + None, + None, + None, + 6000.0, + None, + None, + 7000.0, + 8000.0, + "USD", + "BTC", + ) + assert isinstance(result, dataclasses.PositionIncreaseInfo) + assert result.max_pos == 10 + assert result.funding_avail == 6000.0 + assert result.funding_value_currency == "USD" + + +class TestNotificationSerializer: + """Tests for the _Notification serializer.""" + + def test_plain_notification(self): + s = _Notification(serializer=None) + result = s.parse( + 1000000, "info", None, None, "data", 0, "SUCCESS", "ok" + ) + assert isinstance(result, Notification) + assert result.mts == 1000000 + assert result.type == "info" + assert result.status == "SUCCESS" + assert result.text == "ok" + assert result.data == "data" + + def test_notification_with_order_serializer(self): + s = _Notification(serializer=serializers.Order) + order_data = [ + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.5, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "ACTIVE", + None, + None, + 50000.0, + 50000.0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ] + result = s.parse( + 1000000, + "on-req", + None, + None, + [order_data], + 0, + "SUCCESS", + "Submitted", + ) + assert isinstance(result, Notification) + assert isinstance(result.data, dataclasses.Order) + assert result.data.symbol == "tBTCUSD" + + def test_notification_with_funding_offer_serializer(self): + s = _Notification(serializer=serializers.FundingOffer) + offer_data = [ + 2001, + "fUSD", + 1000000, + 1000001, + 1000.0, + 1000.0, + "LIMIT", + None, + None, + 0, + "ACTIVE", + None, + None, + None, + 0.0001, + 30, + 0, + 0, + None, + 0, + None, + ] + result = s.parse( + 1000000, + "fon-req", + None, + None, + [offer_data], + 0, + "SUCCESS", + "Submitted", + ) + assert isinstance(result, Notification) + assert isinstance(result.data, dataclasses.FundingOffer) + + +class TestSerializerAllDefined: + """Verify all declared serializers exist and are functional.""" + + def test_all_serializers_in_list_are_accessible(self): + for name in serializers.__serializers__: + s = getattr(serializers, name) + assert isinstance(s, _Serializer), f"{name} is not a _Serializer" + labels = s.get_labels() + assert len(labels) > 0, f"{name} has no labels" diff --git a/tests/test_websocket_client.py b/tests/test_websocket_client.py new file mode 100644 index 00000000..ac3b9490 --- /dev/null +++ b/tests/test_websocket_client.py @@ -0,0 +1,67 @@ +import random + +from bfxapi.websocket._client.bfx_websocket_client import _Delay + + +class TestDelay: + """Tests for the exponential backoff _Delay class.""" + + def test_initial_delay_is_random(self): + random.seed(42) + d = _Delay(backoff_factor=1.618) + first = d.next() + assert 1.0 <= first <= 5.0 + + def test_backoff_increases(self): + random.seed(42) + d = _Delay(backoff_factor=2.0) + d.next() # initial random delay + second = d.next() # should be 1.92 * 2.0 = 3.84 + third = d.next() # should be 3.84 * 2.0 = 7.68 + + # After initial, delays should grow + assert second < third + + def test_backoff_max_cap(self): + d = _Delay(backoff_factor=100.0) + # Force past initial + d.next() # initial + # Each subsequent call multiplies by 100, but max is 60 + for _ in range(10): + val = d.next() + assert val <= 60.0 + + def test_peek_does_not_advance(self): + random.seed(42) + d = _Delay(backoff_factor=1.618) + peek1 = d.peek() + peek2 = d.peek() + assert peek1 == peek2 + + def test_next_advances_past_peek(self): + random.seed(42) + d = _Delay(backoff_factor=1.618) + peek_val = d.peek() + next_val = d.next() + assert next_val == peek_val + # After next, peek should return a different (larger) value + assert d.peek() != peek_val + + def test_reset(self): + random.seed(42) + d = _Delay(backoff_factor=2.0) + d.next() # initial + d.next() # advance + d.next() # advance more + + d.reset() + # After reset, peek returns initial delay again + peek = d.peek() + assert 1.0 <= peek <= 5.0 + + def test_backoff_factor_applied(self): + d = _Delay(backoff_factor=2.0) + d.next() # initial random delay + val1 = d.next() # 1.92 * 2 = 3.84 + val2 = d.next() # 3.84 * 2 = 7.68 + assert abs(val2 / val1 - 2.0) < 0.01 diff --git a/tests/test_websocket_connection.py b/tests/test_websocket_connection.py new file mode 100644 index 00000000..59afddb2 --- /dev/null +++ b/tests/test_websocket_connection.py @@ -0,0 +1,171 @@ +import hashlib +import hmac +import json +from unittest.mock import MagicMock, PropertyMock + +import pytest + +from bfxapi.websocket._connection import Connection +from bfxapi.websocket.exceptions import ( + ActionRequiresAuthentication, + ConnectionNotOpen, +) + + +class ConcreteConnection(Connection): + """Concrete subclass for testing the abstract Connection.""" + + async def start(self) -> None: + pass + + +class TestConnectionProperties: + def test_open_initially_false(self): + conn = ConcreteConnection("wss://example.com") + assert conn.open is False + + def test_authentication_initially_false(self): + conn = ConcreteConnection("wss://example.com") + assert conn.authentication is False + + def test_set_authentication(self): + conn = ConcreteConnection("wss://example.com") + conn._authentication = True + assert conn.authentication is True + + +class TestRequireWebsocketConnection: + @pytest.mark.asyncio + async def test_raises_when_not_connected(self): + conn = ConcreteConnection("wss://example.com") + + @Connection._require_websocket_connection + async def some_action(self): + return "success" + + with pytest.raises(ConnectionNotOpen, match="No open connection"): + await some_action(conn) + + @pytest.mark.asyncio + async def test_passes_when_connected(self): + conn = ConcreteConnection("wss://example.com") + mock_ws = MagicMock() + mock_state = MagicMock() + mock_state.name = "OPEN" + type(mock_ws).state = PropertyMock(return_value=mock_state) + conn._websocket = mock_ws + + @Connection._require_websocket_connection + async def some_action(self): + return "success" + + result = await some_action(conn) + assert result == "success" + + +class TestRequireWebsocketAuthentication: + @pytest.mark.asyncio + async def test_raises_when_not_authenticated(self): + conn = ConcreteConnection("wss://example.com") + + @Connection._require_websocket_authentication + async def auth_action(self): + return "success" + + with pytest.raises( + ActionRequiresAuthentication, + match="authenticate using your API_KEY", + ): + await auth_action(conn) + + @pytest.mark.asyncio + async def test_raises_when_authenticated_but_not_connected(self): + conn = ConcreteConnection("wss://example.com") + conn._authentication = True + + @Connection._require_websocket_authentication + async def auth_action(self): + return "success" + + with pytest.raises(ConnectionNotOpen): + await auth_action(conn) + + @pytest.mark.asyncio + async def test_passes_when_authenticated_and_connected(self): + conn = ConcreteConnection("wss://example.com") + conn._authentication = True + mock_ws = MagicMock() + mock_state = MagicMock() + mock_state.name = "OPEN" + type(mock_ws).state = PropertyMock(return_value=mock_state) + conn._websocket = mock_ws + + @Connection._require_websocket_authentication + async def auth_action(self): + return "success" + + result = await auth_action(conn) + assert result == "success" + + +class TestGetAuthenticationMessage: + def test_message_structure(self): + msg = json.loads( + Connection._get_authentication_message( + api_key="test_key", api_secret="test_secret" + ) + ) + assert msg["event"] == "auth" + assert msg["apiKey"] == "test_key" + assert "authNonce" in msg + assert "authPayload" in msg + assert "authSig" in msg + + def test_payload_format(self): + msg = json.loads( + Connection._get_authentication_message( + api_key="key", api_secret="secret" + ) + ) + assert msg["authPayload"] == f"AUTH{msg['authNonce']}" + + def test_signature_correctness(self): + api_secret = "my_secret" + msg = json.loads( + Connection._get_authentication_message( + api_key="key", api_secret=api_secret + ) + ) + expected_sig = hmac.new( + key=api_secret.encode("utf8"), + msg=msg["authPayload"].encode("utf8"), + digestmod=hashlib.sha384, + ).hexdigest() + assert msg["authSig"] == expected_sig + + def test_filters_included(self): + msg = json.loads( + Connection._get_authentication_message( + api_key="key", + api_secret="secret", + filters=["trading", "funding"], + ) + ) + assert msg["filter"] == ["trading", "funding"] + + def test_filters_none_by_default(self): + msg = json.loads( + Connection._get_authentication_message( + api_key="key", api_secret="secret" + ) + ) + assert msg["filter"] is None + + def test_nonce_is_integer(self): + msg = json.loads( + Connection._get_authentication_message( + api_key="key", api_secret="secret" + ) + ) + assert isinstance(msg["authNonce"], int) + assert msg["authNonce"] > 0 diff --git a/tests/test_websocket_handlers.py b/tests/test_websocket_handlers.py new file mode 100644 index 00000000..d5a98eed --- /dev/null +++ b/tests/test_websocket_handlers.py @@ -0,0 +1,863 @@ +from unittest.mock import MagicMock + +from bfxapi.types import dataclasses +from bfxapi.websocket._handlers.auth_events_handler import AuthEventsHandler +from bfxapi.websocket._handlers.public_channels_handler import ( + PublicChannelsHandler, +) + + +class TestPublicChannelsHandler: + def setup_method(self): + self.ee = MagicMock() + self.handler = PublicChannelsHandler(event_emitter=self.ee) + + def test_trading_ticker(self): + subscription = { + "channel": "ticker", + "sub_id": "abc", + "symbol": "tBTCUSD", + } + stream = [ + [ + 50000.0, + 1.5, + 50001.0, + 2.0, + 100.0, + 0.002, + 50000.5, + 10000.0, + 51000.0, + 49000.0, + ] + ] + self.handler.handle(subscription, stream) + + self.ee.emit.assert_called_once() + args = self.ee.emit.call_args + assert args[0][0] == "t_ticker_update" + assert args[0][1] == subscription + assert isinstance(args[0][2], dataclasses.TradingPairTicker) + assert args[0][2].bid == 50000.0 + + def test_funding_ticker(self): + subscription = {"channel": "ticker", "sub_id": "abc", "symbol": "fUSD"} + stream = [ + [ + 0.0001, + 0.00009, + 2, + 100.0, + 0.00011, + 30, + 200.0, + 0.00001, + 0.01, + 0.0001, + 5000.0, + 0.00012, + 0.00008, + None, + None, + 10000.0, + ] + ] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_ticker_update" + assert isinstance(args[0][2], dataclasses.FundingCurrencyTicker) + + def test_trading_trade_execution(self): + subscription = { + "channel": "trades", + "sub_id": "abc", + "symbol": "tBTCUSD", + } + stream = ["te", [12345, 1000000, 0.5, 50000.0]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_trade_execution" + assert isinstance(args[0][2], dataclasses.TradingPairTrade) + + def test_trading_trade_execution_update(self): + subscription = { + "channel": "trades", + "sub_id": "abc", + "symbol": "tBTCUSD", + } + stream = ["tu", [12345, 1000000, 0.5, 50000.0]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_trade_execution_update" + + def test_funding_trade_execution(self): + subscription = {"channel": "trades", "sub_id": "abc", "symbol": "fUSD"} + stream = ["fte", [67890, 1000000, 100.0, 0.0001, 30]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_trade_execution" + assert isinstance(args[0][2], dataclasses.FundingCurrencyTrade) + + def test_funding_trade_execution_update(self): + subscription = {"channel": "trades", "sub_id": "abc", "symbol": "fUSD"} + stream = ["ftu", [67890, 1000000, 100.0, 0.0001, 30]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_trade_execution_update" + + def test_trading_trades_snapshot(self): + subscription = { + "channel": "trades", + "sub_id": "abc", + "symbol": "tBTCUSD", + } + stream = [ + [[12345, 1000000, 0.5, 50000.0], [12346, 1000001, -0.3, 49999.0]] + ] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_trades_snapshot" + assert len(args[0][2]) == 2 + assert all( + isinstance(t, dataclasses.TradingPairTrade) for t in args[0][2] + ) + + def test_funding_trades_snapshot(self): + subscription = {"channel": "trades", "sub_id": "abc", "symbol": "fUSD"} + stream = [ + [ + [67890, 1000000, 100.0, 0.0001, 30], + [67891, 1000001, 200.0, 0.0002, 15], + ] + ] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_trades_snapshot" + assert len(args[0][2]) == 2 + + def test_trading_book_snapshot(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "tBTCUSD", + "prec": "P0", + "freq": "F0", + "len": "25", + } + stream = [[[50000.0, 3, 1.5], [49999.0, 2, 0.8]]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_book_snapshot" + assert len(args[0][2]) == 2 + assert all( + isinstance(b, dataclasses.TradingPairBook) for b in args[0][2] + ) + + def test_trading_book_update(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "tBTCUSD", + "prec": "P0", + "freq": "F0", + "len": "25", + } + stream = [[50000.0, 3, 1.5]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_book_update" + assert isinstance(args[0][2], dataclasses.TradingPairBook) + + def test_funding_book_snapshot(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "fUSD", + "prec": "P0", + "freq": "F0", + "len": "25", + } + stream = [[[0.0001, 30, 5, 1000.0], [0.0002, 15, 3, 500.0]]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_book_snapshot" + assert len(args[0][2]) == 2 + + def test_funding_book_update(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "fUSD", + "prec": "P0", + "freq": "F0", + "len": "25", + } + stream = [[0.0001, 30, 5, 1000.0]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_book_update" + assert isinstance(args[0][2], dataclasses.FundingCurrencyBook) + + def test_trading_raw_book_snapshot(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "tBTCUSD", + "prec": "R0", + "freq": "F0", + "len": "25", + } + stream = [[[111, 50000.0, 0.5], [222, 49999.0, -0.3]]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_raw_book_snapshot" + assert all( + isinstance(b, dataclasses.TradingPairRawBook) for b in args[0][2] + ) + + def test_trading_raw_book_update(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "tBTCUSD", + "prec": "R0", + "freq": "F0", + "len": "25", + } + stream = [[111, 50000.0, 0.5]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "t_raw_book_update" + + def test_funding_raw_book_snapshot(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "fUSD", + "prec": "R0", + "freq": "F0", + "len": "25", + } + stream = [[[222, 30, 0.0001, 500.0], [333, 15, 0.0002, 300.0]]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_raw_book_snapshot" + + def test_funding_raw_book_update(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "fUSD", + "prec": "R0", + "freq": "F0", + "len": "25", + } + stream = [[222, 30, 0.0001, 500.0]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "f_raw_book_update" + + def test_candles_snapshot(self): + subscription = { + "channel": "candles", + "sub_id": "abc", + "key": "trade:1m:tBTCUSD", + } + stream = [ + [ + [1000000, 50000, 50100, 51000, 49000, 1234.5], + [1000060, 50100, 50200, 51100, 49100, 1235.5], + ] + ] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "candles_snapshot" + assert len(args[0][2]) == 2 + assert all(isinstance(c, dataclasses.Candle) for c in args[0][2]) + + def test_candles_update(self): + subscription = { + "channel": "candles", + "sub_id": "abc", + "key": "trade:1m:tBTCUSD", + } + stream = [[1000000, 50000, 50100, 51000, 49000, 1234.5]] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "candles_update" + assert isinstance(args[0][2], dataclasses.Candle) + + def test_derivatives_status(self): + subscription = { + "channel": "status", + "sub_id": "abc", + "key": "deriv:tBTCF0:USTF0", + } + stream = [ + [ + 1000000, + None, + 50000.0, + 49900.0, + None, + 1000000.0, + None, + 2000000, + 0.0001, + 100, + None, + 0.001, + None, + None, + 50050.0, + None, + None, + 5000.0, + None, + None, + None, + -0.5, + 0.5, + ] + ] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "derivatives_status_update" + assert isinstance(args[0][2], dataclasses.DerivativesStatus) + + def test_liquidation_feed(self): + subscription = { + "channel": "status", + "sub_id": "abc", + "key": "liq:global", + } + stream = [ + [ + [ + None, + 111, + 1000000, + None, + "tBTCUSD", + 0.5, + 50000.0, + None, + 1, + 0, + None, + 49000.0, + ] + ] + ] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "liquidation_feed_update" + assert isinstance(args[0][2], dataclasses.Liquidation) + + def test_checksum(self): + subscription = { + "channel": "book", + "sub_id": "abc", + "symbol": "tBTCUSD", + "prec": "P0", + "freq": "F0", + "len": "25", + } + stream = ["cs", 123456789] + self.handler.handle(subscription, stream) + + args = self.ee.emit.call_args + assert args[0][0] == "checksum" + assert args[0][2] == 123456789 & 0xFFFFFFFF + + +class TestAuthEventsHandler: + def setup_method(self): + self.ee = MagicMock() + self.handler = AuthEventsHandler(event_emitter=self.ee) + + def test_order_snapshot(self): + orders = [ + [ + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.5, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "ACTIVE", + None, + None, + 50000.0, + 50000.0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ], + ] + self.handler.handle("os", orders) + + args = self.ee.emit.call_args + assert args[0][0] == "order_snapshot" + assert isinstance(args[0][1], list) + assert isinstance(args[0][1][0], dataclasses.Order) + + def test_order_new(self): + order = [ + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.5, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "ACTIVE", + None, + None, + 50000.0, + 50000.0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ] + self.handler.handle("on", order) + + args = self.ee.emit.call_args + assert args[0][0] == "order_new" + assert isinstance(args[0][1], dataclasses.Order) + + def test_order_update(self): + order = [ + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.3, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "PARTIALLY FILLED", + None, + None, + 50000.0, + 50000.0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ] + self.handler.handle("ou", order) + + args = self.ee.emit.call_args + assert args[0][0] == "order_update" + + def test_order_cancel(self): + order = [ + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.5, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "CANCELED", + None, + None, + 50000.0, + 0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ] + self.handler.handle("oc", order) + + args = self.ee.emit.call_args + assert args[0][0] == "order_cancel" + + def test_position_snapshot(self): + positions = [ + [ + "tBTCUSD", + "ACTIVE", + 0.5, + 50000.0, + 0.0, + 0, + 100.0, + 0.002, + 45000.0, + 2.0, + None, + 9999, + 1000000, + 1000001, + None, + 0, + None, + 1000.0, + 500.0, + {}, + ], + ] + self.handler.handle("ps", positions) + + args = self.ee.emit.call_args + assert args[0][0] == "position_snapshot" + assert isinstance(args[0][1], list) + assert isinstance(args[0][1][0], dataclasses.Position) + + def test_wallet_snapshot(self): + wallets = [ + ["exchange", "BTC", 1.5, 0.0, 1.5, "2024-01-01", {}], + ["exchange", "USD", 50000.0, 0.0, 50000.0, "2024-01-01", {}], + ] + self.handler.handle("ws", wallets) + + args = self.ee.emit.call_args + assert args[0][0] == "wallet_snapshot" + assert len(args[0][1]) == 2 + assert all(isinstance(w, dataclasses.Wallet) for w in args[0][1]) + + def test_wallet_update(self): + wallet = ["exchange", "BTC", 1.6, 0.0, 1.6, "2024-01-02", {}] + self.handler.handle("wu", wallet) + + args = self.ee.emit.call_args + assert args[0][0] == "wallet_update" + assert isinstance(args[0][1], dataclasses.Wallet) + + def test_trade_execution(self): + trade = [ + 5001, + "tBTCUSD", + 1000000, + 1001, + 0.5, + 50000.0, + "EXCHANGE LIMIT", + 50000.0, + 1, + -0.001, + "USD", + 123, + ] + self.handler.handle("te", trade) + + args = self.ee.emit.call_args + assert args[0][0] == "trade_execution" + assert isinstance(args[0][1], dataclasses.Trade) + + def test_funding_offer_snapshot(self): + offers = [ + [ + 2001, + "fUSD", + 1000000, + 1000001, + 1000.0, + 1000.0, + "LIMIT", + None, + None, + 0, + "ACTIVE", + None, + None, + None, + 0.0001, + 30, + 0, + 0, + None, + 0, + None, + ], + ] + self.handler.handle("fos", offers) + + args = self.ee.emit.call_args + assert args[0][0] == "funding_offer_snapshot" + assert isinstance(args[0][1][0], dataclasses.FundingOffer) + + def test_funding_credit_new(self): + credit = [ + 3001, + "fUSD", + 1, + 1000000, + 1000001, + 500.0, + 0, + "ACTIVE", + "FIXED", + None, + None, + 0.0001, + 30, + 1000000, + 1000000, + 0, + 0, + None, + 0, + None, + 0, + "tBTCUSD", + ] + self.handler.handle("fcn", credit) + + args = self.ee.emit.call_args + assert args[0][0] == "funding_credit_new" + assert isinstance(args[0][1], dataclasses.FundingCredit) + + def test_funding_loan_update(self): + loan = [ + 4001, + "fUSD", + 1, + 1000000, + 1000001, + 500.0, + 0, + "ACTIVE", + "FIXED", + None, + None, + 0.0001, + 30, + 1000000, + 1000000, + 0, + 0, + None, + 0, + None, + 0, + ] + self.handler.handle("flu", loan) + + args = self.ee.emit.call_args + assert args[0][0] == "funding_loan_update" + assert isinstance(args[0][1], dataclasses.FundingLoan) + + def test_balance_update(self): + balance = [100000.0, 95000.0] + self.handler.handle("bu", balance) + + args = self.ee.emit.call_args + assert args[0][0] == "balance_update" + assert isinstance(args[0][1], dataclasses.BalanceInfo) + + def test_base_margin_info(self): + stream = ["base", 100.0, 200.0, 300.0, 400.0, 500.0] + self.handler.handle("miu", stream) + + args = self.ee.emit.call_args + assert args[0][0] == "base_margin_info" + assert isinstance(args[0][1], dataclasses.BaseMarginInfo) + + def test_symbol_margin_info(self): + stream = ["sym", "tBTCUSD", 1000.0, 2000.0, 500.0, 600.0] + self.handler.handle("miu", stream) + + args = self.ee.emit.call_args + assert args[0][0] == "symbol_margin_info" + assert isinstance(args[0][1], dataclasses.SymbolMarginInfo) + + def test_notification_plain(self): + stream = [1000000, "info", None, None, "data", 0, "SUCCESS", "ok"] + self.handler.handle("n", stream) + + args = self.ee.emit.call_args + assert args[0][0] == "notification" + + def test_notification_on_req(self): + order_data = [ + 1001, + 0, + 123, + "tBTCUSD", + 1000000, + 1000001, + 0.5, + 1.0, + "EXCHANGE LIMIT", + "LIMIT", + 0, + None, + 0, + "ACTIVE", + None, + None, + 50000.0, + 50000.0, + 0, + 0, + None, + None, + None, + 0, + 0, + 0, + None, + None, + "API>2", + None, + None, + {}, + ] + stream = [ + 1000000, + "on-req", + None, + None, + [order_data], + 0, + "SUCCESS", + "Submitted", + ] + self.handler.handle("n", stream) + + args = self.ee.emit.call_args + assert args[0][0] == "on-req-notification" + + def test_notification_fon_req(self): + offer_data = [ + 2001, + "fUSD", + 1000000, + 1000001, + 1000.0, + 1000.0, + "LIMIT", + None, + None, + 0, + "ACTIVE", + None, + None, + None, + 0.0001, + 30, + 0, + 0, + None, + 0, + None, + ] + stream = [ + 1000000, + "fon-req", + None, + None, + [offer_data], + 0, + "SUCCESS", + "Submitted", + ] + self.handler.handle("n", stream) + + args = self.ee.emit.call_args + assert args[0][0] == "fon-req-notification" + + +class TestAuthEventsHandlerAbbreviations: + """Verify all abbreviation mappings are complete.""" + + def test_all_abbreviations_have_serializers(self): + abbrevs = AuthEventsHandler._AuthEventsHandler__ABBREVIATIONS + serializer_abbrevs = set( + AuthEventsHandler._AuthEventsHandler__SERIALIZERS.keys() + ) + + for abbrev in abbrevs: + assert abbrev in serializer_abbrevs, ( + f"Abbreviation '{abbrev}' has no serializer mapping" + )