Skip to content

Commit 4a36300

Browse files
committed
account streamer fix, add oauth session
1 parent 0f54f33 commit 4a36300

File tree

3 files changed

+149
-10
lines changed

3 files changed

+149
-10
lines changed

tastytrade/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
BACKTEST_URL = "https://backtester.vast.tastyworks.com"
55
CERT_URL = "https://api.cert.tastyworks.com"
66
VAST_URL = "https://vast.tastyworks.com"
7-
VERSION = "9.10"
7+
VERSION = "9.11"
88

99
__version__ = VERSION
1010

tastytrade/session.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import time
12
from datetime import date, datetime
23
from typing import Any, Optional, Union
34
from typing_extensions import Self
45

56
import json
67
from httpx import AsyncClient, Client
78

8-
from tastytrade import API_URL, CERT_URL
9+
from tastytrade import API_URL, CERT_URL, logger
910
from tastytrade.utils import (
1011
TastytradeError,
1112
TastytradeJsonDataclass,
@@ -487,3 +488,137 @@ def deserialize(cls, serialized: str) -> Self:
487488
base_url=base_url, headers=headers, proxy=self.proxy
488489
)
489490
return self
491+
492+
493+
def _now_ms() -> int:
494+
return round(time.time() * 1000)
495+
496+
497+
class OAuthSession(Session): # pragma: no cover
498+
"""
499+
Contains a managed user login which can then be used to interact with the
500+
remote API.
501+
502+
Note that OAuth sessions can't be used to create streamers!
503+
504+
:param provider_secret: OAuth secret for your provider
505+
:param refresh_token: refresh token for the user
506+
:param is_test:
507+
whether to use the test API endpoints, default False
508+
:param proxy:
509+
if provided, all requests will be made through this proxy, as well as
510+
web socket connections for streamers.
511+
"""
512+
513+
def __init__(
514+
self,
515+
provider_secret: str,
516+
refresh_token: str,
517+
is_test: bool = False,
518+
proxy: Optional[str] = None,
519+
):
520+
#: Whether this is a cert or real session
521+
self.is_test = is_test
522+
#: Proxy URL to use for requests and web sockets
523+
self.proxy = proxy
524+
#: OAuth secret for your provider
525+
self.provider_secret = provider_secret
526+
#: Refresh token for the user
527+
self.refresh_token = refresh_token
528+
#: Unix timestamp for when the session token expires
529+
self.expires_at = 0
530+
# The headers to use for API requests
531+
headers = {
532+
"Accept": "application/json",
533+
"Content-Type": "application/json",
534+
}
535+
#: httpx client for sync requests
536+
self.sync_client = Client(
537+
base_url=(CERT_URL if is_test else API_URL), headers=headers, proxy=proxy
538+
)
539+
#: httpx client for async requests
540+
self.async_client = AsyncClient(
541+
base_url=self.sync_client.base_url,
542+
headers=self.sync_client.headers.copy(),
543+
proxy=proxy,
544+
)
545+
self.refresh()
546+
547+
def refresh(self) -> None:
548+
"""
549+
Refreshes the acccess token using the stored refresh token.
550+
"""
551+
response = self.sync_client.post(
552+
"/oauth/token",
553+
json={
554+
"grant_type": "refresh_token",
555+
"client_secret": self.provider_secret,
556+
"refresh_token": self.refresh_token,
557+
},
558+
)
559+
validate_response(response)
560+
data = response.json()
561+
# update the relevant tokens
562+
self.session_token = data["access_token"]
563+
token_lifetime = data.get("expires_in", 900) * 1000
564+
self.expires_at = _now_ms() + token_lifetime
565+
logger.debug(f"Refreshed token, expires in {token_lifetime}ms")
566+
auth_headers = {"Authorization": f"Bearer {self.session_token}"}
567+
# update the httpx clients with the new token
568+
self.sync_client.headers.update(auth_headers)
569+
self.async_client.headers.update(auth_headers)
570+
571+
async def a_refresh(self) -> None:
572+
"""
573+
Refreshes the acccess token using the stored refresh token.
574+
"""
575+
response = await self.async_client.post(
576+
"/oauth/token",
577+
json={
578+
"grant_type": "refresh_token",
579+
"client_secret": self.provider_secret,
580+
"refresh_token": self.refresh_token,
581+
},
582+
)
583+
validate_response(response)
584+
data = response.json()
585+
# update the relevant tokens
586+
self.session_token = data["access_token"]
587+
token_lifetime = data.get("expires_in", 900) * 1000
588+
self.expires_at = _now_ms() + token_lifetime
589+
logger.debug(f"Refreshed token, expires in {token_lifetime}ms")
590+
auth_headers = {"Authorization": f"Bearer {self.session_token}"}
591+
# update the httpx clients with the new token
592+
self.sync_client.headers.update(auth_headers)
593+
self.async_client.headers.update(auth_headers)
594+
595+
def serialize(self) -> str:
596+
"""
597+
Serializes the session to a string, useful for storing
598+
a session for later use.
599+
Could be used with pickle, Redis, etc.
600+
"""
601+
attrs = self.__dict__.copy()
602+
del attrs["async_client"]
603+
del attrs["sync_client"]
604+
return json.dumps(attrs)
605+
606+
@classmethod
607+
def deserialize(cls, serialized: str) -> Self:
608+
"""
609+
Create a new Session object from a serialized string.
610+
"""
611+
deserialized = json.loads(serialized)
612+
self = cls.__new__(cls)
613+
self.__dict__ = deserialized
614+
base_url = CERT_URL if self.is_test else API_URL
615+
headers = {
616+
"Accept": "application/json",
617+
"Content-Type": "application/json",
618+
"Authorization": self.session_token,
619+
}
620+
self.sync_client = Client(base_url=base_url, headers=headers, proxy=self.proxy)
621+
self.async_client = AsyncClient(
622+
base_url=base_url, headers=headers, proxy=self.proxy
623+
)
624+
return self

tastytrade/streamer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class SubscriptionType(str, Enum):
105105
for the alert streamer.
106106
"""
107107

108-
ACCOUNT = "account-subscribe" # may be 'connect' in the future
108+
ACCOUNT = "connect"
109109
HEARTBEAT = "heartbeat"
110110
PUBLIC_WATCHLISTS = "public-watchlists-subscribe"
111111
QUOTE_ALERTS = "quote-alerts-subscribe"
@@ -212,6 +212,8 @@ def __init__(
212212
self.reconnect_args = reconnect_args
213213
#: The proxy URL, if any, associated with the session
214214
self.proxy = session.proxy
215+
#: Counter used to track the request ID for the streamer
216+
self.request_id = 0
215217

216218
self._queues: dict[str, Queue] = defaultdict(Queue)
217219
self._websocket: Optional[ClientConnection] = None
@@ -252,11 +254,8 @@ async def _connect(self) -> None:
252254
Connect to the websocket server using the URL and authorization
253255
token provided during initialization.
254256
"""
255-
headers = {"Authorization": f"Bearer {self.token}"}
256257
reconnecting = False
257-
async for websocket in connect(
258-
self.base_url, additional_headers=headers, proxy=self.proxy
259-
):
258+
async for websocket in connect(self.base_url, proxy=self.proxy):
260259
self._websocket = websocket
261260
self._heartbeat_task = asyncio.create_task(self._heartbeat())
262261
logger.debug("Websocket connection established.")
@@ -342,8 +341,8 @@ async def _heartbeat(self) -> None:
342341
try:
343342
while True:
344343
await self._subscribe(SubscriptionType.HEARTBEAT, "")
345-
# send the heartbeat every 10 seconds
346-
await asyncio.sleep(10)
344+
# send the heartbeat every 15 seconds
345+
await asyncio.sleep(15)
347346
except asyncio.CancelledError:
348347
logger.debug("Websocket interrupted, cancelling heartbeat.")
349348
return
@@ -357,7 +356,12 @@ async def _subscribe(
357356
Subscribes to a :class:`SubscriptionType`. Depending on the kind of
358357
subscription, the value parameter may be required.
359358
"""
360-
message: dict[str, Any] = {"auth-token": self.token, "action": subscription}
359+
self.request_id += 1
360+
message: dict[str, Any] = {
361+
"auth-token": self.token,
362+
"action": subscription.value,
363+
"request-id": self.request_id,
364+
}
361365
if value:
362366
message["value"] = value
363367
logger.debug("sending alert subscription: %s", message)

0 commit comments

Comments
 (0)