Skip to content

Commit 1f5b976

Browse files
author
finanalyzer
authored
Merge pull request #6 from shugaoye/master
Enabled cache
2 parents d7f6e08 + 69fad67 commit 1f5b976

File tree

10 files changed

+1522
-381
lines changed

10 files changed

+1522
-381
lines changed

docs/akshare_test.ipynb

Lines changed: 1165 additions & 346 deletions
Large diffs are not rendered by default.

docs/models.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,24 @@ Please refer to the following table. AKShare only supports part of the Hong Kong
1111
| AKShareEquityQuoteData | x | x |
1212
| AKShareHistoricalDividendsData | x | x |
1313

14+
## BlobCache
15+
16+
```
17+
CREATE TABLE IF NOT EXISTS {self.table_name} (
18+
key TEXT PRIMARY KEY,
19+
timestamp REAL,
20+
type TEXT,
21+
data BLOB
22+
)
23+
```
24+
25+
To cache income_statement, balance_sheet and cash_flow, the table_name and fields can be defined as:
26+
27+
- `table_name`: income_statement, balance_sheet and cash_flow
28+
- `key`: {market}{symbol}{type}, such as `SH600325_annual` or `SH600325_quarter`
29+
- `type`: "annual", "quarter"
30+
- `timestamp`: the date of cache created, it is the TTL to compute when to fetch data again.
31+
1432
## EquityHistorical
1533

1634
To use cache, see the following steps of process.

docs/models.xlsx

-11 Bytes
Binary file not shown.

openbb_akshare/models/balance_sheet.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""AKShare Balance Sheet Model."""
22

33
# pylint: disable=unused-argument
4-
4+
import pandas as pd
55
from datetime import datetime
66
from typing import Any, Literal, Optional
77

@@ -35,6 +35,10 @@ class AKShareBalanceSheetQueryParams(BalanceSheetQueryParams):
3535
description=QUERY_DESCRIPTIONS.get("limit", ""),
3636
le=5,
3737
)
38+
use_cache: bool = Field(
39+
default=True,
40+
description="Whether to use a cached request. The quote is cached for one hour.",
41+
)
3842

3943

4044
class AKShareBalanceSheetData(BalanceSheetData):
@@ -43,7 +47,6 @@ class AKShareBalanceSheetData(BalanceSheetData):
4347
__alias_dict__ = {
4448
"period_ending": "REPORT_DATE",
4549
"fiscal_period": "REPORT_TYPE",
46-
"fiscal_year": "REPORT_DATE_NAME",
4750
"totalEquity": "TOTAL_EQUITY",
4851
"totalDebt": "TOTAL_LIABILITIES",
4952
"totalAssets": "TOTAL_ASSETS"
@@ -79,17 +82,9 @@ def extract_data(
7982
) -> list[dict]:
8083
"""Extract the data from the AKShare endpoints."""
8184
# pylint: disable=import-outside-toplevel
82-
import akshare as ak
83-
import pandas as pd
84-
from openbb_akshare.utils.tools import normalize_symbol
85+
em_df = get_data(query.symbol, query.period, query.use_cache)
8586

86-
symbol_b, symbol_f, market = normalize_symbol(query.symbol)
87-
symbol_em = f"SH{symbol_b}"
88-
stock_balance_sheet_by_yearly_em_df = ak.stock_balance_sheet_by_yearly_em(symbol=symbol_em)
89-
balance_sheet_em = stock_balance_sheet_by_yearly_em_df[["REPORT_DATE", "REPORT_TYPE", "REPORT_DATE_NAME", "TOTAL_ASSETS", "TOTAL_LIABILITIES", "TOTAL_EQUITY"]]
90-
balance_sheet_em['REPORT_DATE_NAME'] = pd.to_datetime(balance_sheet_em['REPORT_DATE']).dt.year.astype(int)
91-
92-
return balance_sheet_em.to_dict(orient="records")
87+
return em_df.to_dict(orient="records")
9388

9489
@staticmethod
9590
def transform_data(
@@ -99,3 +94,24 @@ def transform_data(
9994
) -> list[AKShareBalanceSheetData]:
10095
"""Transform the data."""
10196
return [AKShareBalanceSheetData.model_validate(d) for d in data]
97+
98+
def get_data(symbol: str, period: str = "annual", use_cache: bool = True) -> pd.DataFrame:
99+
if use_cache:
100+
from openbb_akshare.utils.blob_cache import BlobCache
101+
cache = BlobCache(table_name="balance_sheet")
102+
return cache.load_cached_data(symbol, period, get_ak_data)
103+
else:
104+
return get_ak_data(symbol, period)
105+
106+
def get_ak_data(symbol: str, period: str = "annual") -> pd.DataFrame:
107+
import akshare as ak
108+
from openbb_akshare.utils.helpers import normalize_symbol
109+
symbol_b, symbol_f, market = normalize_symbol(symbol)
110+
symbol_em = f"{market}{symbol_b}"
111+
112+
if period == "annual":
113+
return ak.stock_balance_sheet_by_yearly_em(symbol=symbol_em)
114+
elif period == "quarter":
115+
return ak.stock_balance_sheet_by_report_em(symbol=symbol_em)
116+
else:
117+
raise ValueError("Invalid period. Please use 'annual' or 'quarter'.")

openbb_akshare/models/cash_flow.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""AKShare Cash Flow Statement Model."""
22

33
# pylint: disable=unused-argument
4-
4+
import pandas as pd
55
from datetime import datetime
66
from typing import Any, Literal, Optional
77

@@ -35,6 +35,10 @@ class AKShareCashFlowStatementQueryParams(CashFlowStatementQueryParams):
3535
description=QUERY_DESCRIPTIONS.get("limit", ""),
3636
le=5,
3737
)
38+
use_cache: bool = Field(
39+
default=True,
40+
description="Whether to use a cached request. The quote is cached for one hour.",
41+
)
3842

3943

4044
class AKShareCashFlowStatementData(CashFlowStatementData):
@@ -76,19 +80,9 @@ def extract_data(
7680
) -> list[dict]:
7781
"""Extract the data from the AKShare endpoints."""
7882
# pylint: disable=import-outside-toplevel
79-
import akshare as ak
80-
import pandas as pd
81-
from openbb_akshare.utils.tools import normalize_symbol
83+
em_df = get_data(query.symbol, query.period, query.use_cache)
8284

83-
symbol_b, symbol_f, market = normalize_symbol(query.symbol)
84-
symbol_em = f"SH{symbol_b}"
85-
stock_cash_flow_sheet_by_yearly_em_df = ak.stock_cash_flow_sheet_by_yearly_em(symbol=symbol_em)
86-
cash_flow_em = stock_cash_flow_sheet_by_yearly_em_df[["REPORT_DATE", "REPORT_TYPE", "CURRENCY",
87-
"TOTAL_OPERATE_INFLOW", "TOTAL_OPERATE_OUTFLOW", "NETCASH_OPERATE",
88-
"TOTAL_INVEST_INFLOW", "TOTAL_INVEST_OUTFLOW", "NETCASH_INVEST",
89-
"TOTAL_FINANCE_INFLOW", "TOTAL_FINANCE_OUTFLOW","NETCASH_FINANCE", "NETPROFIT"]]
90-
91-
return cash_flow_em.to_dict(orient="records")
85+
return em_df.to_dict(orient="records")
9286

9387
@staticmethod
9488
def transform_data(
@@ -98,3 +92,24 @@ def transform_data(
9892
) -> list[AKShareCashFlowStatementData]:
9993
"""Transform the data."""
10094
return [AKShareCashFlowStatementData.model_validate(d) for d in data]
95+
96+
def get_data(symbol: str, period: str = "annual", use_cache: bool = True) -> pd.DataFrame:
97+
if use_cache:
98+
from openbb_akshare.utils.blob_cache import BlobCache
99+
cache = BlobCache(table_name="cash_flow")
100+
return cache.load_cached_data(symbol, period, get_ak_data)
101+
else:
102+
return get_ak_data(symbol, period)
103+
104+
def get_ak_data(symbol: str, period: str = "annual") -> pd.DataFrame:
105+
import akshare as ak
106+
from openbb_akshare.utils.helpers import normalize_symbol
107+
symbol_b, symbol_f, market = normalize_symbol(symbol)
108+
symbol_em = f"{market}{symbol_b}"
109+
110+
if period == "annual":
111+
return ak.stock_cash_flow_sheet_by_yearly_em(symbol=symbol_em)
112+
elif period == "quarter":
113+
return ak.stock_cash_flow_sheet_by_report_em(symbol=symbol_em)
114+
else:
115+
raise ValueError("Invalid period. Please use 'annual' or 'quarter'.")

openbb_akshare/models/income_statement.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""AKShare Income Statement Model."""
22

33
# pylint: disable=unused-argument
4+
import pandas as pd
45
from datetime import (
56
date as dateType,
67
datetime,
@@ -32,6 +33,10 @@ class AKShareIncomeStatementQueryParams(IncomeStatementQueryParams):
3233
default="annual",
3334
description=QUERY_DESCRIPTIONS.get("period", ""),
3435
)
36+
use_cache: bool = Field(
37+
default=True,
38+
description="Whether to use a cached request. The quote is cached for one hour.",
39+
)
3540

3641

3742
class AKShareIncomeStatementData(IncomeStatementData):
@@ -98,21 +103,16 @@ def transform_query(params: Dict[str, Any]) -> AKShareIncomeStatementQueryParams
98103
return AKShareIncomeStatementQueryParams(**params)
99104

100105
@staticmethod
101-
async def aextract_data(
106+
async def extract_data(
102107
query: AKShareIncomeStatementQueryParams,
103108
credentials: Optional[Dict[str, str]],
104109
**kwargs: Any,
105110
) -> List[Dict]:
106111
"""Return the raw data from the AKShare endpoint."""
107-
import akshare as ak
108-
from openbb_akshare.utils.tools import normalize_symbol
109112

110-
symbol_b, symbol_f, market = normalize_symbol(query.symbol)
111-
symbol_em = f"SH{symbol_b}"
112-
stock_profit_sheet_by_yearly_em_df = ak.stock_profit_sheet_by_yearly_em(symbol=symbol_em)
113-
#income_statement_em = stock_profit_sheet_by_yearly_em_df[["REPORT_DATE", "REPORT_TYPE", "CURRENCY", "TOTAL_OPERATE_COST", "OPERATE_INCOME", "TOTAL_PROFIT", "INCOME_TAX", "NETPROFIT", "BASIC_EPS", "DILUTED_EPS"]]
113+
em_df = get_data(query.symbol, query.period, query.use_cache)
114114

115-
return stock_profit_sheet_by_yearly_em_df.to_dict(orient="records")
115+
return em_df.to_dict(orient="records")
116116

117117
@staticmethod
118118
def transform_data(
@@ -123,3 +123,24 @@ def transform_data(
123123
result.pop("symbol", None)
124124
result.pop("cik", None)
125125
return [AKShareIncomeStatementData.model_validate(d) for d in data]
126+
127+
def get_data(symbol: str, period: str = "annual", use_cache: bool = True) -> pd.DataFrame:
128+
if use_cache:
129+
from openbb_akshare.utils.blob_cache import BlobCache
130+
cache = BlobCache(table_name="income_statement")
131+
return cache.load_cached_data(symbol, period, get_income_statement)
132+
else:
133+
return get_income_statement(symbol, period)
134+
135+
def get_income_statement(symbol: str, period: str = "annual") -> pd.DataFrame:
136+
import akshare as ak
137+
from openbb_akshare.utils.helpers import normalize_symbol
138+
symbol_b, symbol_f, market = normalize_symbol(symbol)
139+
symbol_em = f"{market}{symbol_b}"
140+
141+
if period == "annual":
142+
return ak.stock_profit_sheet_by_yearly_em(symbol=symbol_em)
143+
elif period == "quarter":
144+
return ak.stock_profit_sheet_by_report_em(symbol=symbol_em)
145+
else:
146+
raise ValueError("Invalid period. Please use 'annual' or 'quarter'.")

openbb_akshare/utils/blob_cache.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import sqlite3
3+
import pandas as pd
4+
import time
5+
import pickle
6+
import logging
7+
from typing import Optional, List, Dict, Any
8+
from datetime import datetime, timedelta
9+
from openbb_akshare.utils.tools import setup_logger
10+
11+
CACHE_TTL = 60*60 # 60 seconds
12+
setup_logger()
13+
logger = logging.getLogger(__name__)
14+
15+
# Constant TTL strategy
16+
def constant_ttl(now: datetime, ttl_seconds: int) -> datetime:
17+
return now + timedelta(seconds=ttl_seconds)
18+
19+
# Quarter-based expiry (each quarter is 3 months)
20+
def get_next_quarter_start(dt: datetime) -> datetime:
21+
month = ((dt.month - 1) // 3 + 1) * 3 + 1
22+
if month > 12:
23+
return datetime(dt.year + 1, 1, 1)
24+
return datetime(dt.year, month, 1)
25+
26+
# Year-based expiry
27+
def get_next_year_start(dt: datetime) -> datetime:
28+
return datetime(dt.year + 1, 1, 1)
29+
30+
def calculate_cache_ttl(ttl_strategy_func, *args, now=None):
31+
"""
32+
Generic function to calculate cache TTL using a strategy function.
33+
34+
Args:
35+
ttl_strategy_func: A function that calculates the TTL end time.
36+
*args: Arguments for the strategy function.
37+
now: Optional current time (for testing or simulation).
38+
39+
Returns:
40+
The calculated TTL expiry time.
41+
"""
42+
now = now or datetime.now()
43+
return ttl_strategy_func(now, *args)
44+
45+
class BlobCache:
46+
def __init__(self, table_name: Optional[str] = None, db_path: Optional[str] = None):
47+
if table_name is None:
48+
raise ValueError("Table name must be provided")
49+
50+
self.table_name = table_name
51+
self.conn = None
52+
if db_path is None:
53+
from openbb_akshare.utils import get_cache_path
54+
self.db_path = get_cache_path()
55+
else:
56+
os.makedirs(db_path, exist_ok=True)
57+
db_path = f"{db_path}/equity.db"
58+
self.db_path = db_path
59+
self._ensure_db_exists()
60+
61+
def _ensure_db_exists(self):
62+
"""Ensure the SQLite database and table exist."""
63+
with sqlite3.connect(self.db_path) as conn:
64+
cursor = conn.cursor()
65+
cursor.execute(f'''
66+
CREATE TABLE IF NOT EXISTS {self.table_name} (
67+
key TEXT PRIMARY KEY,
68+
timestamp REAL,
69+
data BLOB
70+
)
71+
''')
72+
conn.commit()
73+
74+
def load_cached_data(self, symbol:str, report_type, get_data, *args, **kwargs):
75+
"""Load cached data from SQLite cache or generate new data."""
76+
from openbb_akshare.utils.tools import normalize_symbol
77+
symbol_b, symbol_f, market = normalize_symbol(symbol)
78+
key = f"{market}{symbol_b}{report_type}"
79+
now = time.time()
80+
with sqlite3.connect(self.db_path) as conn:
81+
cursor = conn.cursor()
82+
cursor.execute(f'SELECT timestamp, data FROM {self.table_name} WHERE key=?', (key,))
83+
row = cursor.fetchone()
84+
85+
if row:
86+
timestamp, data_blob = row
87+
stored_date = datetime.fromtimestamp(timestamp)
88+
if report_type == "annual":
89+
expired_date = calculate_cache_ttl(get_next_year_start, now=stored_date)
90+
if now < expired_date.timestamp():
91+
logger.info("Loading annual data from SQLite cache...")
92+
return pickle.loads(data_blob)
93+
elif report_type == "quarter":
94+
expired_date = calculate_cache_ttl(get_next_quarter_start, now=stored_date)
95+
if now < expired_date.timestamp():
96+
logger.info("Loading quarter data from SQLite cache...")
97+
return pickle.loads(data_blob)
98+
else:
99+
if now - timestamp < CACHE_TTL:
100+
logger.info("Loading data from SQLite cache...")
101+
return pickle.loads(data_blob)
102+
103+
logger.info(f"Generating new {report_type} data...")
104+
df = get_data(symbol, report_type)
105+
106+
# 序列化 DataFrame
107+
data_blob = pickle.dumps(df)
108+
109+
# 更新或插入缓存
110+
cursor.execute(f'''
111+
INSERT OR REPLACE INTO {self.table_name} (key, timestamp, data)
112+
VALUES (?, ?, ?)
113+
''', (key, now, data_blob))
114+
115+
conn.commit()
116+
return df

openbb_akshare/utils/tools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
# support logging
77
import logging
88
from logging.handlers import RotatingFileHandler
9+
from openbb_core.app.utils import get_user_cache_directory
910

1011
# Configure logging
1112
def setup_logger():
1213
# Create logs directory if it doesn't exist
13-
log_dir = "logs"
14+
log_dir = f"{get_user_cache_directory()}/akshare/logs"
1415
log_file = os.path.join(log_dir, "openbb_akshare.log")
1516
os.makedirs(log_dir, exist_ok=True)
1617

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "openbb-akshare"
3-
version = "0.4.39"
3+
version = "0.4.43"
44
description = "AKShare extension for OpenBB"
55
authors = ["Roger Ye <shugaoye@yahoo.com>"]
66
license = "AGPL-3.0-only"

0 commit comments

Comments
 (0)