Skip to content

Commit 23ab831

Browse files
ok Merge branch 'main' of github.com:modelscope/ms-agent into release/1.5
2 parents 4ad0d5a + 2b863ce commit 23ab831

3 files changed

Lines changed: 68 additions & 50 deletions

File tree

ms_agent/tools/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from .code import CodeExecutionTool, SandboxManagerFactory
33
from .filesystem_tool import FileSystemTool
4-
from .findata import FinancialDataFetcher
54
from .mcp_client import MCPClient
65
from .split_task import SplitTask
76
from .tool_manager import ToolManager

ms_agent/tools/findata/akshare_source.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import re
33
from typing import Any, Dict, List, Optional
44

5-
import akshare as ak
65
import pandas as pd
76
from ms_agent.tools.findata.data_source_base import (DataSourceError,
87
FinancialDataSource,
98
NoDataFoundError)
109
from ms_agent.utils import get_logger
10+
from ms_agent.utils.utils import install_package
1111

1212
logger = get_logger()
1313

@@ -25,10 +25,19 @@ class AKShareDataSource(FinancialDataSource):
2525
"""
2626

2727
def __init__(self):
28+
logger.info('Installing AKShare package...')
29+
try:
30+
install_package(package_name='akshare')
31+
except Exception as e:
32+
raise e
33+
34+
global akshare
35+
import akshare
36+
2837
logger.info('Initializing AKShare data source')
2938
try:
3039
# Test AKShare availability
31-
ak.tool_trade_date_hist_sina()
40+
akshare.tool_trade_date_hist_sina()
3241
logger.info('AKShare initialized successfully')
3342
except Exception as e:
3443
raise DataSourceError(f'Failed to initialize AKShare: {e}')
@@ -141,23 +150,23 @@ def get_historical_k_data(
141150
if code.startswith('sh.') or code.startswith(
142151
'sz.') or code.startswith('bj'):
143152
clean_code = self._convert_code(code, market='A')
144-
df = ak.stock_zh_a_hist(
153+
df = akshare.stock_zh_a_hist(
145154
symbol=clean_code,
146155
period=period,
147156
start_date=st_date,
148157
end_date=ed_date,
149158
adjust=adjust)
150159
elif code.startswith('hk'):
151160
clean_code = self._convert_code(code, market='HK')
152-
df = ak.stock_hk_hist(
161+
df = akshare.stock_hk_hist(
153162
symbol=clean_code,
154163
period=period,
155164
start_date=st_date,
156165
end_date=ed_date,
157166
adjust=adjust)
158167
else:
159168
clean_code = self._convert_code(code, market='US')
160-
df = ak.stock_us_hist(
169+
df = akshare.stock_us_hist(
161170
symbol=clean_code,
162171
period=period,
163172
start_date=st_date,
@@ -204,7 +213,7 @@ def _get_hk_basic_info(self, code: str) -> pd.DataFrame:
204213

205214
# Try to fetch base info
206215
try:
207-
df_base_info = ak.stock_hk_spot_em()
216+
df_base_info = akshare.stock_hk_spot_em()
208217
stock_info = df_base_info[df_base_info['代码'] == clean_code]
209218
if not stock_info.empty:
210219
df_stock_info = pd.DataFrame({
@@ -220,7 +229,7 @@ def _get_hk_basic_info(self, code: str) -> pd.DataFrame:
220229

221230
# Try to fetch business info
222231
try:
223-
df_business_info = ak.stock_zyjs_ths(symbol=clean_code)
232+
df_business_info = akshare.stock_zyjs_ths(symbol=clean_code)
224233
if not df_business_info.empty:
225234
df_business_info = df_business_info.rename(
226235
columns={
@@ -256,7 +265,7 @@ def _get_us_basic_info(self, code: str) -> pd.DataFrame:
256265
symbol = self._convert_code(code, 'US')
257266

258267
try:
259-
df = ak.stock_us_spot_em()
268+
df = akshare.stock_us_spot_em()
260269
stock_info = df[df['代码'] == symbol]
261270

262271
if stock_info.empty:
@@ -283,7 +292,7 @@ def _get_a_share_basic_info(self, code: str) -> pd.DataFrame:
283292
clean_code = self._convert_code(code, 'A')
284293

285294
try:
286-
df_base_info = ak.stock_individual_info_em(symbol=clean_code)
295+
df_base_info = akshare.stock_individual_info_em(symbol=clean_code)
287296

288297
if df_base_info.empty:
289298
raise NoDataFoundError(f'No basic info found for {code}')
@@ -307,7 +316,7 @@ def _get_a_share_basic_info(self, code: str) -> pd.DataFrame:
307316
'status': ['1']
308317
})
309318

310-
df_business_info = ak.stock_zyjs_ths(symbol=clean_code)
319+
df_business_info = akshare.stock_zyjs_ths(symbol=clean_code)
311320
if df_business_info.empty:
312321
raise NoDataFoundError(f'No business info found for {code}')
313322

@@ -479,8 +488,9 @@ def _filter_columns(row_df: pd.DataFrame,
479488
ind_df = pd.DataFrame()
480489
if code.startswith(('hk.', 'us.')):
481490
try:
482-
ind_df = ak.stock_financial_hk_analysis_indicator_em(symbol=clean_code) if code.startswith('hk.') else \
483-
ak.stock_financial_us_analysis_indicator_em(symbol=clean_code)
491+
ind_df = akshare.stock_financial_hk_analysis_indicator_em(
492+
symbol=clean_code) if code.startswith('hk.') else \
493+
akshare.stock_financial_us_analysis_indicator_em(symbol=clean_code)
484494
ind_df = _select_row_by_report(ind_df)
485495
except Exception as e:
486496
logger.warning(
@@ -495,7 +505,7 @@ def _filter_columns(row_df: pd.DataFrame,
495505

496506
if needs_indicator:
497507
try:
498-
ind_df = ak.stock_financial_analysis_indicator(
508+
ind_df = akshare.stock_financial_analysis_indicator(
499509
symbol=clean_code)
500510
ind_df = _select_row_by_report(ind_df)
501511
except Exception as e:
@@ -517,14 +527,14 @@ def _filter_columns(row_df: pd.DataFrame,
517527
continue
518528

519529
elif data_type == 'balance':
520-
df = ak.stock_balance_sheet_by_report_em(
530+
df = akshare.stock_balance_sheet_by_report_em(
521531
symbol=code.replace('.', '').upper())
522532
row = _select_row_by_report(df)
523533
if not row.empty:
524534
result[data_type] = row
525535

526536
elif data_type == 'cash_flow':
527-
df = ak.stock_cash_flow_sheet_by_report_em(
537+
df = akshare.stock_cash_flow_sheet_by_report_em(
528538
symbol=code.replace('.', '').upper())
529539
row = _select_row_by_report(df)
530540
if not row.empty:
@@ -570,13 +580,13 @@ def get_stock_list(self,
570580

571581
try:
572582
if data_type == 'sse50':
573-
df = ak.index_stock_cons(symbol='000016')
583+
df = akshare.index_stock_cons(symbol='000016')
574584
elif data_type == 'hs300':
575-
df = ak.index_stock_cons(symbol='000300')
585+
df = akshare.index_stock_cons(symbol='000300')
576586
elif data_type == 'zz500':
577-
df = ak.index_stock_cons(symbol='000905')
587+
df = akshare.index_stock_cons(symbol='000905')
578588
elif data_type == 'all_a_share':
579-
df_a_share = ak.stock_zh_a_spot_em()
589+
df_a_share = akshare.stock_zh_a_spot_em()
580590
df_a_share['market'] = 'A'
581591
df = df_a_share[['代码', '名称', 'market']].copy()
582592
df = df.rename(columns={'代码': 'code', '名称': 'code_name'})
@@ -593,7 +603,7 @@ def get_trade_dates(self,
593603
logger.info(f'Fetching trade dates ({start_date} to {end_date})')
594604

595605
try:
596-
df = ak.tool_trade_date_hist_sina()
606+
df = akshare.tool_trade_date_hist_sina()
597607

598608
# Ensure trade_date is string for comparison
599609
if 'trade_date' in df.columns:
@@ -633,7 +643,7 @@ def get_macro_data(
633643
for data_type in data_types:
634644
try:
635645
if data_type in ('deposit_rate', 'loan_rate'):
636-
result[data_type] = ak.rate_interbank()
646+
result[data_type] = akshare.rate_interbank()
637647
elif data_type in ('required_reserve_ratio'):
638648
raise DataSourceError(
639649
'Required reserve ratio is not supported by AKShare')
@@ -661,7 +671,7 @@ def _get_money_supply_data_month(
661671
start_date: Optional[str] = None,
662672
end_date: Optional[str] = None) -> pd.DataFrame:
663673
try:
664-
df = ak.macro_china_money_supply() # from 2008-01 to now
674+
df = akshare.macro_china_money_supply() # from 2008-01 to now
665675
df['月份'] = pd.to_datetime(df['月份'].str.replace('月份',
666676
'').str.replace(
667677
'年', '-'))

ms_agent/tools/findata/baostock_source.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import threading
3-
import time
43
from contextlib import contextmanager
54
from copy import deepcopy
65
from typing import Any, Dict, List, Optional
76

8-
import baostock as bs
97
import pandas as pd
108
from ms_agent.tools.findata.data_source_base import (DataSourceError,
119
FinancialDataSource,
1210
NoDataFoundError)
1311
from ms_agent.utils import get_logger
12+
from ms_agent.utils.utils import install_package
1413

1514
logger = get_logger()
1615

@@ -40,7 +39,7 @@ def _cancel_logout(self):
4039
def _force_logout(self):
4140
with self._session_lock:
4241
if self._login_count == 0 and self._is_logged_in:
43-
bs.logout()
42+
baostock.logout()
4443
self._is_logged_in = False
4544
logger.debug('BaoStock session closed by idle-timeout')
4645
self._timer = None
@@ -56,7 +55,7 @@ def ensure_login(self):
5655
"""Ensure BaoStock is logged in (thread-safe)"""
5756
with self._session_lock:
5857
if not self._is_logged_in:
59-
lg = bs.login()
58+
lg = baostock.login()
6059
if lg.error_code != '0':
6160
raise DataSourceError(
6261
f'BaoStock login failed: {lg.error_msg}')
@@ -108,6 +107,15 @@ class BaoStockDataSource(FinancialDataSource):
108107
]
109108

110109
def __init__(self):
110+
logger.info('Installing BaoStock package...')
111+
try:
112+
install_package(package_name='baostock')
113+
except Exception as e:
114+
raise e
115+
116+
global baostock
117+
import baostock
118+
111119
logger.info('Initializing BaoStock data source')
112120
# Test connection
113121
with baostock_session():
@@ -146,7 +154,7 @@ def get_historical_k_data(
146154
logger.info(f'Fetching K-data for {code} ({start_date} to {end_date})')
147155

148156
with baostock_session():
149-
rs = bs.query_history_k_data_plus(
157+
rs = baostock.query_history_k_data_plus(
150158
code=code,
151159
fields=fields_str,
152160
start_date=start_date,
@@ -160,7 +168,7 @@ def get_stock_basic_info(self, code: str) -> pd.DataFrame:
160168
logger.info(f'Fetching basic info for {code}')
161169

162170
with baostock_session():
163-
rs = bs.query_stock_basic(code=code)
171+
rs = baostock.query_stock_basic(code=code)
164172
return self._query_to_dataframe(rs, f'basic info for {code}')
165173

166174
def get_dividend_data(self,
@@ -171,7 +179,7 @@ def get_dividend_data(self,
171179
logger.info(f'Fetching dividend data for {code} ({year} {year_type})')
172180

173181
with baostock_session():
174-
rs = bs.query_dividend_data(
182+
rs = baostock.query_dividend_data(
175183
code=code, year=year, yearType=year_type)
176184
return self._query_to_dataframe(
177185
rs, f'dividend data for {code} ({year} {year_type})')
@@ -184,7 +192,7 @@ def get_adjust_factor_data(self, code: str, start_date: str,
184192
)
185193

186194
with baostock_session():
187-
rs = bs.query_adjust_factor(
195+
rs = baostock.query_adjust_factor(
188196
code=code, start_date=start_date, end_date=end_date)
189197
return self._query_to_dataframe(
190198
rs,
@@ -205,17 +213,17 @@ def get_financial_data(self, code: str, year: str, quarter: int,
205213
with baostock_session():
206214
for data_type in data_types:
207215
if data_type == 'profit':
208-
query_func = bs.query_profit_data
216+
query_func = baostock.query_profit_data
209217
elif data_type == 'operation':
210-
query_func = bs.query_operation_data
218+
query_func = baostock.query_operation_data
211219
elif data_type == 'growth':
212-
query_func = bs.query_growth_data
220+
query_func = baostock.query_growth_data
213221
elif data_type == 'balance':
214-
query_func = bs.query_balance_data
222+
query_func = baostock.query_balance_data
215223
elif data_type == 'cash_flow':
216-
query_func = bs.query_cash_flow_data
224+
query_func = baostock.query_cash_flow_data
217225
elif data_type == 'dupont':
218-
query_func = bs.query_dupont_data
226+
query_func = baostock.query_dupont_data
219227
else:
220228
raise ValueError(f'Invalid data type: {data_type}')
221229

@@ -252,10 +260,10 @@ def get_report(self,
252260

253261
with baostock_session():
254262
if report_type == 'performance_express':
255-
rs = bs.query_performance_express_report(
263+
rs = baostock.query_performance_express_report(
256264
code=code, start_date=start_date, end_date=end_date)
257265
elif report_type == 'performance_forecast':
258-
rs = bs.query_forecast_report(
266+
rs = baostock.query_forecast_report(
259267
code=code, start_date=start_date, end_date=end_date)
260268
else:
261269
raise ValueError(f'Invalid report type: {report_type}')
@@ -274,7 +282,7 @@ def get_stock_industry(self,
274282
)
275283

276284
with baostock_session():
277-
rs = bs.query_stock_industry(code=code, date=date)
285+
rs = baostock.query_stock_industry(code=code, date=date)
278286
return self._query_to_dataframe(
279287
rs, f'stock industry for {code or "all"} ({date or "latest"})')
280288

@@ -288,13 +296,13 @@ def get_stock_list(self,
288296

289297
with baostock_session():
290298
if data_type == 'sse50':
291-
rs = bs.query_sz50_stocks(date=date)
299+
rs = baostock.query_sz50_stocks(date=date)
292300
elif data_type == 'hs300':
293-
rs = bs.query_hs300_stocks(date=date)
301+
rs = baostock.query_hs300_stocks(date=date)
294302
elif data_type == 'zz500':
295-
rs = bs.query_zz500_stocks(date=date)
303+
rs = baostock.query_zz500_stocks(date=date)
296304
elif data_type == 'all_a_share':
297-
rs = bs.query_all_stock(day=date)
305+
rs = baostock.query_all_stock(day=date)
298306
else:
299307
raise ValueError(f'Invalid data type: {data_type}')
300308

@@ -313,7 +321,8 @@ def get_trade_dates(self,
313321
)
314322

315323
with baostock_session():
316-
rs = bs.query_trade_dates(start_date=start_date, end_date=end_date)
324+
rs = baostock.query_trade_dates(
325+
start_date=start_date, end_date=end_date)
317326
return self._query_to_dataframe(rs, 'trade dates')
318327

319328
def get_macro_data(
@@ -341,27 +350,27 @@ def get_macro_data(
341350
parsed_end_date = end_date
342351

343352
if data_type == 'deposit_rate':
344-
query_func = bs.query_deposit_rate_data
353+
query_func = baostock.query_deposit_rate_data
345354

346355
elif data_type == 'loan_rate':
347-
query_func = bs.query_loan_rate_data
356+
query_func = baostock.query_loan_rate_data
348357

349358
elif data_type == 'required_reserve_ratio':
350-
query_func = bs.query_required_reserve_ratio_data
359+
query_func = baostock.query_required_reserve_ratio_data
351360
if extra_kwargs:
352361
parsed_extra_kwargs.update(extra_kwargs)
353362
if 'yearType' not in parsed_extra_kwargs:
354363
parsed_extra_kwargs['yearType'] = '0'
355364

356365
elif data_type == 'money_supply_month':
357-
query_func = bs.query_money_supply_data_month
366+
query_func = baostock.query_money_supply_data_month
358367
parsed_start_date = pd.to_datetime(
359368
start_date).strftime('%Y-%m')
360369
parsed_end_date = pd.to_datetime(end_date).strftime(
361370
'%Y-%m')
362371

363372
elif data_type == 'money_supply_year':
364-
query_func = bs.query_money_supply_data_year
373+
query_func = baostock.query_money_supply_data_year
365374
parsed_start_date = pd.to_datetime(
366375
start_date).strftime('%Y')
367376
parsed_end_date = pd.to_datetime(end_date).strftime(

0 commit comments

Comments
 (0)