Skip to content

Commit 61e832e

Browse files
authored
enforce isort, strict typing (+mypy) (#251)
* enforce isort, strict typing (+mypy) * all lints passing * fix unused import
1 parent 3160694 commit 61e832e

File tree

17 files changed

+960
-832
lines changed

17 files changed

+960
-832
lines changed

.github/workflows/python-app.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
- name: Type check with pyright
2626
run: |
2727
uv run pyright tastytrade/ tests/
28+
- name: Type check with mypy
29+
run: |
30+
uv run mypy tastytrade/
2831
- name: Test with pytest
2932
run: |
3033
uv run pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
install:
44
uv sync
5-
uv pip install -e .
65

76
lint:
7+
uv run ruff check --select I --fix
88
uv run ruff format tastytrade/ tests/
99
uv run ruff check tastytrade/ tests/
1010
uv run pyright tastytrade/ tests/
11+
uv run mypy tastytrade/
1112

1213
test:
1314
uv run pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95

pyproject.toml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,31 @@ dev = [
5959
"pytest-aio>=1.5.0",
6060
"pytest-cov>=5.0.0",
6161
"ruff>=0.6.9",
62-
"pyright>=1.1.390",
62+
"pyright>=1.1.401",
6363
"sphinx>=7.4.7",
6464
"enum-tools[sphinx]>=0.12.0",
6565
"autodoc-pydantic>=2.2.0",
6666
"proxy-py>=2.4.9",
6767
"sphinx-immaterial>=0.12.5",
68+
"mypy>=1.15.0",
6869
]
6970

7071
[tool.setuptools.package-data]
7172
"tastytrade" = ["py.typed"]
7273

7374
[tool.setuptools.packages.find]
7475
where = ["tastytrade"]
76+
77+
[tool.ruff.lint]
78+
select = ["E", "F", "I"]
79+
80+
[tool.pyright]
81+
pythonVersion = "3.9"
82+
include = ["tastytrade", "tests"]
83+
exclude = ["**/__pycache__"]
84+
typeCheckingMode = "strict"
85+
reportPrivateUsage = false
86+
87+
[tool.mypy]
88+
strict = true
89+
warn_unused_ignores = false

tastytrade/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
BACKTEST_URL = "https://backtester.vast.tastyworks.com"
55
CERT_URL = "https://api.cert.tastyworks.com"
66
VAST_URL = "https://vast.tastyworks.com"
7-
VERSION = "10.2.1"
7+
VERSION = "10.2.2"
88

99
__version__ = VERSION
10-
version_str = f"tastyware/tastytrade:v{VERSION}"
10+
version_str: str = f"tastyware/tastytrade:v{VERSION}"
1111

1212
logger = logging.getLogger(__name__)
1313
logger.setLevel(logging.DEBUG)

tastytrade/account.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import date, datetime
22
from decimal import Decimal
3-
from typing import Any, Literal, Optional, Union, overload
3+
from typing import Any, Literal, Optional, Union, cast, overload
44

55
import httpx
66
from pydantic import BaseModel, model_validator
@@ -22,8 +22,8 @@
2222
from tastytrade.session import Session
2323
from tastytrade.utils import (
2424
PriceEffect,
25-
TastytradeError,
2625
TastytradeData,
26+
TastytradeError,
2727
set_sign_for,
2828
today_in_new_york,
2929
validate_response,
@@ -94,9 +94,10 @@ class AccountBalance(TastytradeData):
9494
@classmethod
9595
def validate_price_effects(cls, data: Any) -> Any:
9696
if isinstance(data, dict):
97+
data = cast(dict[str, Any], data)
9798
key = "unsettled-cryptocurrency-fiat-amount"
98-
effect = data.get("unsettled-cryptocurrency-fiat-effect")
99-
if effect == PriceEffect.DEBIT:
99+
effect: Any = data.get("unsettled-cryptocurrency-fiat-effect")
100+
if effect == PriceEffect.DEBIT.value:
100101
data[key] = -abs(Decimal(data[key]))
101102
return set_sign_for(data, ["pending_cash", "buying_power_adjustment"])
102103

@@ -149,8 +150,9 @@ class AccountBalanceSnapshot(TastytradeData):
149150
@classmethod
150151
def validate_price_effects(cls, data: Any) -> Any:
151152
if isinstance(data, dict):
153+
data = cast(dict[str, Any], data)
152154
key = "unsettled-cryptocurrency-fiat-amount"
153-
effect = data.get("unsettled-cryptocurrency-fiat-effect")
155+
effect: Any = data.get("unsettled-cryptocurrency-fiat-effect")
154156
if effect == PriceEffect.DEBIT:
155157
data[key] = -abs(Decimal(data[key]))
156158
return set_sign_for(data, ["pending_cash"])
@@ -614,7 +616,8 @@ async def a_get_balance_snapshots(
614616
:param end_date: the ending date of the range.
615617
:param snapshot_date: the date of the snapshot to get.
616618
:param time_of_day:
617-
the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD' (Beginning Of Day).
619+
the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD'
620+
(Beginning Of Day).
618621
"""
619622
paginate = False
620623
if page_offset is None:
@@ -629,7 +632,7 @@ async def a_get_balance_snapshots(
629632
"snapshot-date": snapshot_date,
630633
"time-of-day": time_of_day,
631634
}
632-
snapshots = []
635+
snapshots: list[AccountBalanceSnapshot] = []
633636
while True:
634637
response = await session.async_client.get(
635638
f"/accounts/{self.account_number}/balance-snapshots",
@@ -677,7 +680,8 @@ def get_balance_snapshots(
677680
:param end_date: the ending date of the range.
678681
:param snapshot_date: the date of the snapshot to get.
679682
:param time_of_day:
680-
the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD' (Beginning Of Day).
683+
the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD'
684+
(Beginning Of Day).
681685
"""
682686
paginate = False
683687
if page_offset is None:
@@ -692,7 +696,7 @@ def get_balance_snapshots(
692696
"snapshot-date": snapshot_date,
693697
"time-of-day": time_of_day,
694698
}
695-
snapshots = []
699+
snapshots: list[AccountBalanceSnapshot] = []
696700
while True:
697701
response = session.sync_client.get(
698702
f"/accounts/{self.account_number}/balance-snapshots",
@@ -878,7 +882,7 @@ async def a_get_history(
878882
"end-at": end_at,
879883
}
880884
# loop through pages and get all transactions
881-
txns = []
885+
txns: list[Transaction] = []
882886
while True:
883887
response = await session.async_client.get(
884888
f"/accounts/{self.account_number}/transactions",
@@ -973,7 +977,7 @@ def get_history(
973977
"end-at": end_at,
974978
}
975979
# loop through pages and get all transactions
976-
txns = []
980+
txns: list[Transaction] = []
977981
while True:
978982
response = session.sync_client.get(
979983
f"/accounts/{self.account_number}/transactions",
@@ -1376,7 +1380,7 @@ async def a_get_order_history(
13761380
"end-at": end_at,
13771381
}
13781382
# loop through pages and get all transactions
1379-
orders = []
1383+
orders: list[PlacedOrder] = []
13801384
while True:
13811385
response = await session.async_client.get(
13821386
f"/accounts/{self.account_number}/orders",
@@ -1457,7 +1461,7 @@ def get_order_history(
14571461
"end-at": end_at,
14581462
}
14591463
# loop through pages and get all transactions
1460-
orders = []
1464+
orders: list[PlacedOrder] = []
14611465
while True:
14621466
response = session.sync_client.get(
14631467
f"/accounts/{self.account_number}/orders",
@@ -1500,11 +1504,11 @@ async def a_get_complex_order_history(
15001504
paginate = True
15011505
params = {"per-page": per_page, "page-offset": page_offset}
15021506
# loop through pages and get all transactions
1503-
orders = []
1507+
orders: list[PlacedComplexOrder] = []
15041508
while True:
15051509
response = await session.async_client.get(
15061510
f"/accounts/{self.account_number}/complex-orders",
1507-
params={k: v for k, v in params.items() if v is not None},
1511+
params=params,
15081512
)
15091513
validate_response(response)
15101514
json = response.json()
@@ -1539,11 +1543,11 @@ def get_complex_order_history(
15391543
paginate = True
15401544
params = {"per-page": per_page, "page-offset": page_offset}
15411545
# loop through pages and get all transactions
1542-
orders = []
1546+
orders: list[PlacedComplexOrder] = []
15431547
while True:
15441548
response = session.sync_client.get(
15451549
f"/accounts/{self.account_number}/complex-orders",
1546-
params={k: v for k, v in params.items() if v is not None},
1550+
params=params,
15471551
)
15481552
validate_response(response)
15491553
json = response.json()
@@ -1698,7 +1702,7 @@ async def a_get_order_chains(
16981702
response = await client.get(
16991703
f"{VAST_URL}/order-chains",
17001704
headers=headers,
1701-
params=params,
1705+
params=params, # type: ignore[arg-type]
17021706
)
17031707
validate_response(response)
17041708
chains = response.json()["data"]["items"]
@@ -1736,7 +1740,7 @@ def get_order_chains(
17361740
response = httpx.get(
17371741
f"{VAST_URL}/order-chains",
17381742
headers=headers,
1739-
params=params,
1743+
params=params, # type: ignore[arg-type]
17401744
)
17411745
validate_response(response)
17421746
chains = response.json()["data"]["items"]

tastytrade/dxfeed/candle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from decimal import Decimal
2-
from typing import Annotated, Any, Optional
2+
from typing import Annotated, Any, Optional, cast
33

44
from pydantic import (
55
ValidationError,
@@ -17,7 +17,7 @@ def zero_from_none(
1717
v: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
1818
) -> Decimal:
1919
try:
20-
return handler(v)
20+
return cast(Decimal, handler(v))
2121
except ValidationError:
2222
return ZERO
2323

tastytrade/dxfeed/event.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any
24

35
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
@@ -33,7 +35,7 @@ def change_nan_to_none(cls, v: Any) -> Any:
3335
return v
3436

3537
@classmethod
36-
def from_stream(cls, data: list) -> list["Event"]:
38+
def from_stream(cls, data: list[Any]) -> list[Event]:
3739
"""
3840
Makes a list of event objects from a list of raw trade data fetched by
3941
a :class:`~tastyworks.streamer.DXFeedStreamer`.
@@ -42,7 +44,7 @@ def from_stream(cls, data: list) -> list["Event"]:
4244
4345
:return: list of event objects from data
4446
"""
45-
objs = []
47+
objs: list[Event] = []
4648
size = len(cls.model_fields)
4749
multiples = len(data) / size
4850
if not multiples.is_integer():

tastytrade/instruments.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class NestedFutureOptionFuture(TastytradeData):
162162

163163
@field_validator("maturity_date", mode="before")
164164
@classmethod
165-
def parse_date_with_utc(cls, value: Any) -> str:
165+
def parse_date_with_utc(cls, value: Any) -> Any:
166166
if isinstance(value, str):
167167
return value.split(" ")[0]
168168
return value
@@ -327,7 +327,7 @@ async def a_get_active_equities(
327327
"lendability": lendability,
328328
}
329329
# loop through pages and get all active equities
330-
equities = []
330+
equities: list[Self] = []
331331
while True:
332332
response = await session.async_client.get(
333333
"/instruments/equities/active",
@@ -378,7 +378,7 @@ def get_active_equities(
378378
"lendability": lendability,
379379
}
380380
# loop through pages and get all active equities
381-
equities = []
381+
equities: list[Self] = []
382382
while True:
383383
response = session.sync_client.get(
384384
"/instruments/equities/active",
@@ -581,10 +581,10 @@ async def a_get(
581581
f"/instruments/equity-options/{symbol}", params=params
582582
)
583583
return cls(**data)
584-
params = {"symbol[]": symbols, "active": active, "with-expired": with_expired}
584+
_params = {"symbol[]": symbols, "active": active, "with-expired": with_expired}
585585
data = await session._a_get(
586586
"/instruments/equity-options",
587-
params={k: v for k, v in params.items() if v is not None},
587+
params={k: v for k, v in _params.items() if v is not None},
588588
)
589589
return [cls(**i) for i in data["items"]]
590590

@@ -628,10 +628,10 @@ def get(
628628
params = {"active": active} if active is not None else None
629629
data = session._get(f"/instruments/equity-options/{symbol}", params=params)
630630
return cls(**data)
631-
params = {"symbol[]": symbols, "active": active, "with-expired": with_expired}
631+
_params = {"symbol[]": symbols, "active": active, "with-expired": with_expired}
632632
data = session._get(
633633
"/instruments/equity-options",
634-
params={k: v for k, v in params.items() if v is not None},
634+
params={k: v for k, v in _params.items() if v is not None},
635635
)
636636
return [cls(**i) for i in data["items"]]
637637

@@ -649,7 +649,7 @@ def _set_streamer_symbol(self) -> None:
649649
)
650650

651651
@classmethod
652-
def streamer_symbol_to_occ(cls, streamer_symbol) -> str:
652+
def streamer_symbol_to_occ(cls, streamer_symbol: str) -> str:
653653
"""
654654
Returns the OCC 2010 symbol equivalent to the given streamer symbol.
655655
@@ -670,7 +670,7 @@ def streamer_symbol_to_occ(cls, streamer_symbol) -> str:
670670
return f"{symbol}{exp}{option_type}{strike}{decimal}"
671671

672672
@classmethod
673-
def occ_to_streamer_symbol(cls, occ) -> str:
673+
def occ_to_streamer_symbol(cls, occ: str) -> str:
674674
"""
675675
Returns the dxfeed symbol for use in the streamer from the given OCC
676676
2010 symbol.
@@ -1100,7 +1100,7 @@ class FutureOption(TradeableTastytradeData):
11001100

11011101
@field_validator("maturity_date", mode="before")
11021102
@classmethod
1103-
def parse_date_with_utc(cls, value: Any) -> str:
1103+
def parse_date_with_utc(cls, value: Any) -> Any:
11041104
if isinstance(value, str):
11051105
return value.split(" ")[0]
11061106
return value
@@ -1386,7 +1386,7 @@ async def a_get_option_chain(session: Session, symbol: str) -> dict[date, list[O
13861386
"""
13871387
symbol = symbol.replace("/", "%2F")
13881388
data = await session._a_get(f"/option-chains/{symbol}")
1389-
chain = defaultdict(list)
1389+
chain: dict[date, list[Option]] = defaultdict(list)
13901390
for i in data["items"]:
13911391
option = Option(**i)
13921392
chain[option.expiration_date].append(option)
@@ -1409,7 +1409,7 @@ def get_option_chain(session: Session, symbol: str) -> dict[date, list[Option]]:
14091409
"""
14101410
symbol = symbol.replace("/", "%2F")
14111411
data = session._get(f"/option-chains/{symbol}")
1412-
chain = defaultdict(list)
1412+
chain: dict[date, list[Option]] = defaultdict(list)
14131413
for i in data["items"]:
14141414
option = Option(**i)
14151415
chain[option.expiration_date].append(option)
@@ -1434,7 +1434,7 @@ async def a_get_future_option_chain(
14341434
"""
14351435
symbol = symbol.replace("/", "")
14361436
data = await session._a_get(f"/futures-option-chains/{symbol}")
1437-
chain = defaultdict(list)
1437+
chain: dict[date, list[FutureOption]] = defaultdict(list)
14381438
for i in data["items"]:
14391439
option = FutureOption(**i)
14401440
chain[option.expiration_date].append(option)
@@ -1459,7 +1459,7 @@ def get_future_option_chain(
14591459
"""
14601460
symbol = symbol.replace("/", "")
14611461
data = session._get(f"/futures-option-chains/{symbol}")
1462-
chain = defaultdict(list)
1462+
chain: dict[date, list[FutureOption]] = defaultdict(list)
14631463
for i in data["items"]:
14641464
option = FutureOption(**i)
14651465
chain[option.expiration_date].append(option)

tastytrade/market_sessions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class MarketSessionSnapshot(TastytradeData):
4545

4646
class MarketSession(TastytradeData):
4747
"""
48-
Dataclass representing the current session as well as the next and previous sessions.
48+
Dataclass representing the current, next, and previous sessions.
4949
"""
5050

5151
close_at: Optional[datetime] = None

0 commit comments

Comments
 (0)