Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ scipy>=1.6.3
curl_cffi>=0.7
protobuf>=5.29.0,<6
websockets>=11.0
Better-Holidays>=0.1.1
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
'platformdirs>=2.0.0', 'pytz>=2022.5',
'frozendict>=2.3.4', 'peewee>=3.16.2',
'beautifulsoup4>=4.11.1', 'curl_cffi>=0.7',
'protobuf>=5.29.0,<6', 'websockets>=11.0'],
'protobuf>=5.29.0,<6', 'websockets>=11.0',
"Better-Holidays>=0.1.1"],
extras_require={
'nospam': ['requests_cache>=1.0', 'requests_ratelimiter>=0.3.1'],
'repair': ['scipy>=1.6.3'],
Expand Down
5 changes: 3 additions & 2 deletions yfinance/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from . import utils, cache
from .data import YfData
from .exceptions import YFEarningsDateMissing, YFRateLimitError
from .live import WebSocket
from .scrapers.analysis import Analysis
from .scrapers.fundamentals import Fundamentals
from .scrapers.holders import Holders
Expand Down Expand Up @@ -99,7 +98,7 @@ def history(self, *args, **kwargs) -> pd.DataFrame:

def _lazy_load_price_history(self):
if self._price_history is None:
self._price_history = PriceHistory(self._data, self.ticker, self._get_ticker_tz(timeout=10))
self._price_history = PriceHistory(self._data, self, self._get_ticker_tz(timeout=10))
return self._price_history

def _get_ticker_tz(self, timeout):
Expand Down Expand Up @@ -799,6 +798,8 @@ def get_funds_data(self, proxy=_SENTINEL_) -> Optional[FundsData]:
return self._funds_data

def live(self, message_handler=None, verbose=True):
from .live import WebSocket

self._message_handler = message_handler

self.ws = WebSocket(verbose=verbose)
Expand Down
3 changes: 3 additions & 0 deletions yfinance/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class YFTzMissingError(YFTickerMissingError):
def __init__(self, ticker):
super().__init__(ticker, "no timezone found")

class YFMarketHoliday(YFTickerMissingError):
def __init__(self, ticker, holiday):
super().__init__(ticker, f"holiday {holiday} is on market holiday")

class YFPricesMissingError(YFTickerMissingError):
def __init__(self, ticker, debug_info):
Expand Down
92 changes: 71 additions & 21 deletions yfinance/live.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import asyncio
import base64
import json
from typing import List, Optional, Callable
from typing import List, Optional, Callable, TypedDict
import datetime as _datetime

import BetterHolidays as bh
from BetterHolidays.days import Holiday
from .ticker import Ticker
from .exceptions import YFMarketHoliday

from websockets.sync.client import connect as sync_connect
from websockets.asyncio.client import connect as async_connect
Expand All @@ -10,32 +16,45 @@
from yfinance.pricing_pb2 import PricingData
from google.protobuf.json_format import MessageToDict

DATA = TypedDict("DATA", {"type": str, "data": dict})

class BaseWebSocket:
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
def __init__(self, url: 'str' = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
self.url = url
self.verbose = verbose
self.logger = utils.get_yf_logger()
self._ws = None
self._subscriptions = set()
self._subscription_interval = 15 # seconds

def _decode_message(self, base64_message: str) -> dict:
def _decode_message(self, base64_message: 'str') -> 'DATA':
try:
decoded_bytes = base64.b64decode(base64_message)
pricing_data = PricingData()
pricing_data.ParseFromString(decoded_bytes)
return MessageToDict(pricing_data, preserving_proto_field_name=True)
return self._encode_message("received", MessageToDict(pricing_data, preserving_proto_field_name=True))
except Exception as e:
self.logger.error("Failed to decode message: %s", e, exc_info=True)
if self.verbose:
print("Failed to decode message: %s", e)
return {
'error': str(e),
'raw_base64': base64_message
}


return self._encode_message("error", self._encode_error(e, base64_message))

def _encode_message(self, type:'str', data:'dict') -> 'DATA':
if type == "received":
return {"type": "received", "data": data}

elif type == "error":
return {"type": "error", "data": data}

else:
return {"type": "unknown", "data": data}

def _encode_error(self, error:'Exception', raw) -> 'dict[str, str]':
return {
"error": str(error),
"type": str(type(error)),
"raw_base64": raw
}
class AsyncWebSocket(BaseWebSocket):
"""
Asynchronous WebSocket client for streaming real time pricing data.
Expand All @@ -52,6 +71,7 @@ def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", ver
super().__init__(url, verbose)
self._message_handler = None # Callable to handle messages
self._heartbeat_task = None # Task to send heartbeat subscribe
self.messages = []

async def _connect(self):
try:
Expand Down Expand Up @@ -84,17 +104,24 @@ async def _periodic_subscribe(self):
print(f"Error in heartbeat subscription: {e}")
break

async def subscribe(self, symbols: str | List[str]):
async def subscribe(self, symbols: 'str | list[str]'):
"""
Subscribe to a stock symbol or a list of stock symbols.

Args:
symbols (str | List[str]): Stock symbol(s) to subscribe to.
symbols (str | list[str]): Stock symbol(s) to subscribe to.
"""
await self._connect()

if isinstance(symbols, str):
symbols = [symbols]
symbols = symbols.replace(',', ' ').split()

tickers = [Ticker(symbol) for symbol in symbols]

for ticker in tickers:
market = bh.get_market(ticker.history_metadata["fullExchangeName"], None)
if market and isinstance(day := market.day(_datetime.date.today()), Holiday):
self.messages.append(self._encode_message("error", self._encode_error(YFMarketHoliday(ticker, day), None)))

self._subscriptions.update(symbols)

Expand All @@ -109,17 +136,17 @@ async def subscribe(self, symbols: str | List[str]):
if self.verbose:
print(f"Subscribed to symbols: {symbols}")

async def unsubscribe(self, symbols: str | List[str]):
async def unsubscribe(self, symbols: 'str | list[str]'):
"""
Unsubscribe from a stock symbol or a list of stock symbols.

Args:
symbols (str | List[str]): Stock symbol(s) to unsubscribe from.
symbols (str | list[str]): Stock symbol(s) to unsubscribe from.
"""
await self._connect()

if isinstance(symbols, str):
symbols = [symbols]
symbols = symbols.replace(',', ' ').split()

self._subscriptions.difference_update(symbols)

Expand Down Expand Up @@ -150,6 +177,16 @@ async def listen(self, message_handler=None):

while True:
try:
while len(self.messages) > 0:
msg = self.messages.pop(0)
if self._message_handler is not None:
if asyncio.iscoroutinefunction(self._message_handler):
await self._message_handler(msg)
else:
self._message_handler(msg)
else:
print(msg)

async for message in self._ws:
message_json = json.loads(message)
encoded_data = message_json.get("message", "")
Expand Down Expand Up @@ -220,6 +257,7 @@ def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", ver
verbose (bool): Flag to enable or disable print statements. Defaults to True.
"""
super().__init__(url, verbose)
self.messages = []

def _connect(self):
try:
Expand All @@ -245,7 +283,13 @@ def subscribe(self, symbols: str | List[str]):
self._connect()

if isinstance(symbols, str):
symbols = [symbols]
symbols = symbols.replace(',', ' ').split()

tickers = [Ticker(symbol) for symbol in symbols]
for ticker in tickers:
market = bh.get_market(ticker.history_metadata["fullExchangeName"], None)
if market and isinstance(day := market.day(_datetime.date.today()), Holiday):
self.messages.append(self._encode_message("error", self._encode_error(YFMarketHoliday(ticker, day), None)))

self._subscriptions.update(symbols)

Expand All @@ -256,17 +300,17 @@ def subscribe(self, symbols: str | List[str]):
if self.verbose:
print(f"Subscribed to symbols: {symbols}")

def unsubscribe(self, symbols: str | List[str]):
def unsubscribe(self, symbols: 'str | list[str]'):
"""
Unsubscribe from a stock symbol or a list of stock symbols.

Args:
symbols (str | List[str]): Stock symbol(s) to unsubscribe from.
symbols (str | list[str]): Stock symbol(s) to unsubscribe from.
"""
self._connect()

if isinstance(symbols, str):
symbols = [symbols]
symbols = symbols.replace(',', ' ').split()

self._subscriptions.difference_update(symbols)

Expand All @@ -277,7 +321,7 @@ def unsubscribe(self, symbols: str | List[str]):
if self.verbose:
print(f"Unsubscribed from symbols: {symbols}")

def listen(self, message_handler: Optional[Callable[[dict], None]] = None):
def listen(self, message_handler: 'Optional[Callable[[dict], None]]' = None):
"""
Start listening to messages from the WebSocket server.

Expand All @@ -292,6 +336,12 @@ def listen(self, message_handler: Optional[Callable[[dict], None]] = None):

while True:
try:
for msg in self.messages:
if message_handler:
message_handler(msg)
else:
print(msg)

message = self._ws.recv()
message_json = json.loads(message)
encoded_data = message_json.get("message", "")
Expand Down
Loading
Loading