Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ def extract_data(
from openbb_core.provider.utils.helpers import (
to_snake_case,
)
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

period = "yearly" if query.period == "annual" else "quarterly" # type: ignore
data = Ticker(query.symbol).get_balance_sheet(
data = Ticker(normalize_yfinance_symbol(query.symbol)).get_balance_sheet(
as_dict=False, pretty=False, freq=period
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def extract_data(
from openbb_core.provider.utils.helpers import (
to_snake_case,
)
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

period = "yearly" if query.period == "annual" else "quarterly" # type: ignore
data = Ticker(query.symbol).get_cash_flow(
data = Ticker(normalize_yfinance_symbol(query.symbol)).get_cash_flow(
as_dict=False, pretty=False, freq=period
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def aextract_data(
from warnings import warn

from openbb_core.provider.utils.errors import EmptyDataError
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

symbols = [s.strip() for s in query.symbol.split(",") if s.strip()]
Expand Down Expand Up @@ -122,10 +123,11 @@ def _normalize_news_item(item: dict, sym: str) -> dict | None:

def _fetch_news(sym: str) -> list[dict]:
"""Fetch the data in a worker thread."""
raw = Ticker(sym).get_news() or []
provider_symbol = normalize_yfinance_symbol(sym)
raw = Ticker(provider_symbol).get_news() or []
out: list[dict] = []
for item in raw:
norm = _normalize_news_item(item, sym)
norm = _normalize_news_item(item, sym.upper())
if norm:
out.append(norm)
return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@ async def aextract_data(
import asyncio # noqa
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from warnings import warn
from yfinance import Ticker

symbols = query.symbol.split(",")
symbols = [s.strip() for s in query.symbol.split(",") if s.strip()]
results = []
fields = [
"symbol",
Expand Down Expand Up @@ -158,10 +159,14 @@ async def aextract_data(

async def get_one(symbol):
"""Get the data for one ticker symbol."""
requested_symbol = symbol.upper()
provider_symbol = normalize_yfinance_symbol(symbol)
result: dict = {}
ticker: dict = {}
try:
ticker = await asyncio.to_thread(lambda: Ticker(symbol).get_info())
ticker = await asyncio.to_thread(
lambda: Ticker(provider_symbol).get_info()
)
except Exception as e:
messages.append(
f"Error getting data for {symbol} -> {e.__class__.__name__}: {e}"
Expand All @@ -175,6 +180,7 @@ async def get_one(symbol):
)
] = ticker.get(field, None)
if result:
result["symbol"] = requested_symbol
results.append(result)

tasks = [get_one(symbol) for symbol in symbols]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ async def aextract_data(
"""Extract the raw data from YFinance."""
# pylint: disable=import-outside-toplevel
import asyncio # noqa
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

symbols = [s.strip() for s in query.symbol.split(",") if s.strip()]
Expand Down Expand Up @@ -108,14 +109,18 @@ async def aextract_data(
]

async def get_one(symbol: str) -> None:
provider_symbol = normalize_yfinance_symbol(symbol)
try:
ticker = await asyncio.to_thread(lambda: Ticker(symbol).get_info())
ticker = await asyncio.to_thread(
lambda: Ticker(provider_symbol).get_info()
)
except Exception as e:
warn(f"Error getting data for {symbol}: {e}")
return

result = {f: ticker.get(f) for f in fields if f in ticker}
if result:
result["symbol"] = symbol.upper()
results.append(result)

await asyncio.gather(*(get_one(symbol) for symbol in symbols))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def extract_data(
) -> list[dict]:
"""Extract the raw data from YFinance."""
# pylint: disable=import-outside-toplevel
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

try:
ticker = Ticker(query.symbol).get_dividends()
ticker = Ticker(normalize_yfinance_symbol(query.symbol)).get_dividends()
if isinstance(ticker, list) and not ticker or ticker.empty: # type: ignore
raise OpenBBError(f"No dividend data found for {query.symbol}")
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ def extract_data(
from openbb_core.provider.utils.helpers import (
to_snake_case,
)
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

period = "yearly" if query.period == "annual" else "quarterly"
data = Ticker(query.symbol).get_income_stmt(
data = Ticker(normalize_yfinance_symbol(query.symbol)).get_income_stmt(
as_dict=False, pretty=False, freq=period
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def extract_data(
"""Extract the raw data from YFinance."""
# pylint: disable=import-outside-toplevel
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

try:
ticker = Ticker(query.symbol).get_info()
ticker = Ticker(normalize_yfinance_symbol(query.symbol)).get_info()
except Exception as e:
raise OpenBBError(
f"Error getting data for {query.symbol} -> {e.__class__.__name__}: {e}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,11 @@ async def aextract_data(
import asyncio # noqa
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from warnings import warn
from yfinance import Ticker

symbols = query.symbol.split(",")
symbols = [s.strip() for s in query.symbol.split(",") if s.strip()]
results = []
fields = [
"symbol",
Expand Down Expand Up @@ -292,10 +293,14 @@ async def aextract_data(

async def get_one(symbol):
"""Get the data for one ticker symbol."""
requested_symbol = symbol.upper()
provider_symbol = normalize_yfinance_symbol(symbol)
result: dict = {}
ticker: dict = {}
try:
ticker = await asyncio.to_thread(lambda: Ticker(symbol).get_info())
ticker = await asyncio.to_thread(
lambda: Ticker(provider_symbol).get_info()
)
except Exception as e:
messages.append(
f"Error getting data for {symbol} -> {e.__class__.__name__}: {e}"
Expand All @@ -307,6 +312,7 @@ async def get_one(symbol):
if field in ticker:
result[field] = ticker.get(field, None)
if result and result.get("52WeekChange") is not None:
result["symbol"] = requested_symbol
results.append(result)

tasks = [get_one(symbol) for symbol in symbols]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ async def aextract_data(
"""Extract the raw data from YFinance."""
# pylint: disable=import-outside-toplevel
import asyncio # noqa
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from pandas import concat
from yfinance import Ticker
from pytz import timezone

symbol = query.symbol.upper()
symbol = "^" + symbol if symbol in ["VIX", "RUT", "SPX", "NDX"] else symbol
provider_symbol = normalize_yfinance_symbol(symbol)

def _get_all_data(symbol: str):
def _get_all_data(provider_symbol: str):
"""Get all options data in a single thread-safe operation."""
t = Ticker(symbol)
t = Ticker(provider_symbol)
expirations = list(t.options)

if not expirations or len(expirations) == 0:
Expand Down Expand Up @@ -117,7 +119,7 @@ def _get_all_data(symbol: str):
return underlying, chains_output, expirations

underlying, chains_output, expirations = await asyncio.to_thread(
_get_all_data, symbol
_get_all_data, provider_symbol
)

if not expirations or len(expirations) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ async def aextract_data(
# pylint: disable=import-outside-toplevel
import asyncio # noqa
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from warnings import warn
from yfinance import Ticker

symbols = query.symbol.split(",") # type: ignore
symbols = [s.strip() for s in query.symbol.split(",") if s.strip()] # type: ignore
results = []
fields = [
"symbol",
Expand All @@ -107,10 +108,14 @@ async def aextract_data(

async def get_one(symbol):
"""Get the data for one ticker symbol."""
requested_symbol = symbol.upper()
provider_symbol = normalize_yfinance_symbol(symbol)
result: dict = {}
ticker: dict = {}
try:
ticker = await asyncio.to_thread(lambda: Ticker(symbol).get_info())
ticker = await asyncio.to_thread(
lambda: Ticker(provider_symbol).get_info()
)
except Exception as e:
messages.append(
f"Error getting data for {symbol}: {e.__class__.__name__}: {e}"
Expand All @@ -120,6 +125,7 @@ async def get_one(symbol):
if field in ticker:
result[field] = ticker.get(field, None)
if result and result.get("numberOfAnalystOpinions") is not None:
result["symbol"] = requested_symbol
results.append(result)

tasks = [get_one(symbol) for symbol in symbols]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ async def aextract_data(
import asyncio # noqa
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_yfinance.utils.helpers import normalize_yfinance_symbol
from yfinance import Ticker

symbols = query.symbol.split(",")
symbols = [s.strip() for s in query.symbol.split(",") if s.strip()]
results = []
fields = [
"symbol",
Expand All @@ -143,11 +144,13 @@ async def aextract_data(

async def get_one(symbol):
"""Get the data for one ticker symbol."""
requested_symbol = symbol.upper()
provider_symbol = normalize_yfinance_symbol(symbol)
result: dict = {}
ticker: dict = {}
try:
_ticker = await asyncio.to_thread(lambda: Ticker(symbol))
ticker = await asyncio.to_thread(lambda: _ticker.get_info())
_ticker = await asyncio.to_thread(lambda: Ticker(provider_symbol))
ticker = await asyncio.to_thread(_ticker.get_info)
major_holders = await asyncio.to_thread(
lambda: _ticker.get_major_holders(as_dict=True).get("Value")
)
Expand All @@ -162,6 +165,7 @@ async def get_one(symbol):
if field in ticker:
result[field] = ticker.get(field, None)
if result and result.get("sharesOutstanding") is not None:
result["symbol"] = requested_symbol
results.append(result)

tasks = [get_one(symbol) for symbol in symbols]
Expand Down
60 changes: 46 additions & 14 deletions openbb_platform/providers/yfinance/openbb_yfinance/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# pylint: disable=unused-argument,too-many-arguments,too-many-branches,too-many-locals,too-many-statements

import re
from typing import TYPE_CHECKING, Any, Literal, Union

from openbb_core.provider.utils.errors import EmptyDataError
Expand Down Expand Up @@ -74,6 +75,29 @@
"earnings_date",
]

# Keep this conservative; Yahoo also uses one-letter exchange suffixes like ".L".
CLASS_SHARE_SYMBOL_REGEX = re.compile(r"^[A-Z]{1,5}\.[ABC]$")


def normalize_yfinance_symbol(symbol: str) -> str:
"""Convert US class-share symbols to Yahoo's dash convention."""
clean_symbol = symbol.strip()
upper_symbol = clean_symbol.upper()

if CLASS_SHARE_SYMBOL_REGEX.fullmatch(upper_symbol):
return upper_symbol.replace(".", "-")

return clean_symbol


def normalize_yfinance_symbols(symbols: str) -> str:
"""Normalize a comma-separated symbol list for yfinance requests."""
return ",".join(
normalize_yfinance_symbol(symbol)
for symbol in symbols.split(",")
if symbol.strip()
)


async def get_custom_screener(
body: dict[str, Any],
Expand Down Expand Up @@ -220,8 +244,8 @@ async def get_defined_screener(

def get_expiration_month(symbol: str) -> str:
"""Get the expiration month for a given symbol."""
month = symbol.split(".")[0][-3]
year = "20" + symbol.split(".")[0][-2:]
month = symbol.split(".", maxsplit=1)[0][-3]
year = "20" + symbol.split(".", maxsplit=1)[0][-2:]
return f"{year}-{MONTH_MAP[month]}"


Expand Down Expand Up @@ -500,7 +524,13 @@ def yf_download( # pylint: disable=too-many-positional-arguments
from pandas import DataFrame, concat, to_datetime
import yfinance as yf

symbol = symbol.upper()
requested_tickers = [
ticker.strip().upper() for ticker in symbol.split(",") if ticker.strip()
]
provider_tickers = [
normalize_yfinance_symbol(ticker) for ticker in requested_tickers
]
provider_symbols = ",".join(provider_tickers)
_start_date = start_date
intraday = False
if interval in ["60m", "1h"]:
Expand Down Expand Up @@ -529,7 +559,7 @@ def yf_download( # pylint: disable=too-many-positional-arguments

try:
data = yf.download(
tickers=symbol,
tickers=provider_symbols,
start=_start_date,
end=None,
interval=interval,
Expand All @@ -550,22 +580,24 @@ def yf_download( # pylint: disable=too-many-positional-arguments
except ValueError as exc:
raise EmptyDataError() from exc

tickers = symbol.split(",")
if len(tickers) == 1:
if len(provider_tickers) == 1:
provider_symbol = provider_tickers[0]
if hasattr(data.columns, "levels"):
try:
if symbol in data.columns.get_level_values(0):
data = data[symbol] # type: ignore
elif symbol in data.columns.get_level_values(1):
data = data.xs(symbol, level=1, axis=1) # type: ignore
if provider_symbol in data.columns.get_level_values(0):
data = data[provider_symbol] # type: ignore
elif provider_symbol in data.columns.get_level_values(1):
data = data.xs(provider_symbol, level=1, axis=1) # type: ignore
except (KeyError, IndexError):
pass
elif len(tickers) > 1:
elif len(provider_tickers) > 1:
_data = DataFrame()
for ticker in tickers:
temp = data[ticker].copy().dropna(how="all") # type: ignore
for requested_ticker, provider_ticker in zip(
requested_tickers, provider_tickers
):
temp = data[provider_ticker].copy().dropna(how="all") # type: ignore
if len(temp) > 0:
temp["symbol"] = ticker
temp["symbol"] = requested_ticker
temp = temp.reset_index().rename(
columns={"Date": "date", "Datetime": "date", "index": "date"}
)
Expand Down
Loading