11# Copyright (c) Alibaba, Inc. and its affiliates.
22import threading
3- import time
43from contextlib import contextmanager
54from copy import deepcopy
65from typing import Any , Dict , List , Optional
76
8- import baostock as bs
97import pandas as pd
108from ms_agent .tools .findata .data_source_base import (DataSourceError ,
119 FinancialDataSource ,
1210 NoDataFoundError )
1311from ms_agent .utils import get_logger
12+ from ms_agent .utils .utils import install_package
1413
1514logger = get_logger ()
1615
@@ -40,7 +39,7 @@ def _cancel_logout(self):
4039 def _force_logout (self ):
4140 with self ._session_lock :
4241 if self ._login_count == 0 and self ._is_logged_in :
43- bs .logout ()
42+ baostock .logout ()
4443 self ._is_logged_in = False
4544 logger .debug ('BaoStock session closed by idle-timeout' )
4645 self ._timer = None
@@ -56,7 +55,7 @@ def ensure_login(self):
5655 """Ensure BaoStock is logged in (thread-safe)"""
5756 with self ._session_lock :
5857 if not self ._is_logged_in :
59- lg = bs .login ()
58+ lg = baostock .login ()
6059 if lg .error_code != '0' :
6160 raise DataSourceError (
6261 f'BaoStock login failed: { lg .error_msg } ' )
@@ -108,6 +107,15 @@ class BaoStockDataSource(FinancialDataSource):
108107 ]
109108
110109 def __init__ (self ):
110+ logger .info ('Installing BaoStock package...' )
111+ try :
112+ install_package (package_name = 'baostock' )
113+ except Exception as e :
114+ raise e
115+
116+ global baostock
117+ import baostock
118+
111119 logger .info ('Initializing BaoStock data source' )
112120 # Test connection
113121 with baostock_session ():
@@ -146,7 +154,7 @@ def get_historical_k_data(
146154 logger .info (f'Fetching K-data for { code } ({ start_date } to { end_date } )' )
147155
148156 with baostock_session ():
149- rs = bs .query_history_k_data_plus (
157+ rs = baostock .query_history_k_data_plus (
150158 code = code ,
151159 fields = fields_str ,
152160 start_date = start_date ,
@@ -160,7 +168,7 @@ def get_stock_basic_info(self, code: str) -> pd.DataFrame:
160168 logger .info (f'Fetching basic info for { code } ' )
161169
162170 with baostock_session ():
163- rs = bs .query_stock_basic (code = code )
171+ rs = baostock .query_stock_basic (code = code )
164172 return self ._query_to_dataframe (rs , f'basic info for { code } ' )
165173
166174 def get_dividend_data (self ,
@@ -171,7 +179,7 @@ def get_dividend_data(self,
171179 logger .info (f'Fetching dividend data for { code } ({ year } { year_type } )' )
172180
173181 with baostock_session ():
174- rs = bs .query_dividend_data (
182+ rs = baostock .query_dividend_data (
175183 code = code , year = year , yearType = year_type )
176184 return self ._query_to_dataframe (
177185 rs , f'dividend data for { code } ({ year } { year_type } )' )
@@ -184,7 +192,7 @@ def get_adjust_factor_data(self, code: str, start_date: str,
184192 )
185193
186194 with baostock_session ():
187- rs = bs .query_adjust_factor (
195+ rs = baostock .query_adjust_factor (
188196 code = code , start_date = start_date , end_date = end_date )
189197 return self ._query_to_dataframe (
190198 rs ,
@@ -205,17 +213,17 @@ def get_financial_data(self, code: str, year: str, quarter: int,
205213 with baostock_session ():
206214 for data_type in data_types :
207215 if data_type == 'profit' :
208- query_func = bs .query_profit_data
216+ query_func = baostock .query_profit_data
209217 elif data_type == 'operation' :
210- query_func = bs .query_operation_data
218+ query_func = baostock .query_operation_data
211219 elif data_type == 'growth' :
212- query_func = bs .query_growth_data
220+ query_func = baostock .query_growth_data
213221 elif data_type == 'balance' :
214- query_func = bs .query_balance_data
222+ query_func = baostock .query_balance_data
215223 elif data_type == 'cash_flow' :
216- query_func = bs .query_cash_flow_data
224+ query_func = baostock .query_cash_flow_data
217225 elif data_type == 'dupont' :
218- query_func = bs .query_dupont_data
226+ query_func = baostock .query_dupont_data
219227 else :
220228 raise ValueError (f'Invalid data type: { data_type } ' )
221229
@@ -252,10 +260,10 @@ def get_report(self,
252260
253261 with baostock_session ():
254262 if report_type == 'performance_express' :
255- rs = bs .query_performance_express_report (
263+ rs = baostock .query_performance_express_report (
256264 code = code , start_date = start_date , end_date = end_date )
257265 elif report_type == 'performance_forecast' :
258- rs = bs .query_forecast_report (
266+ rs = baostock .query_forecast_report (
259267 code = code , start_date = start_date , end_date = end_date )
260268 else :
261269 raise ValueError (f'Invalid report type: { report_type } ' )
@@ -274,7 +282,7 @@ def get_stock_industry(self,
274282 )
275283
276284 with baostock_session ():
277- rs = bs .query_stock_industry (code = code , date = date )
285+ rs = baostock .query_stock_industry (code = code , date = date )
278286 return self ._query_to_dataframe (
279287 rs , f'stock industry for { code or "all" } ({ date or "latest" } )' )
280288
@@ -288,13 +296,13 @@ def get_stock_list(self,
288296
289297 with baostock_session ():
290298 if data_type == 'sse50' :
291- rs = bs .query_sz50_stocks (date = date )
299+ rs = baostock .query_sz50_stocks (date = date )
292300 elif data_type == 'hs300' :
293- rs = bs .query_hs300_stocks (date = date )
301+ rs = baostock .query_hs300_stocks (date = date )
294302 elif data_type == 'zz500' :
295- rs = bs .query_zz500_stocks (date = date )
303+ rs = baostock .query_zz500_stocks (date = date )
296304 elif data_type == 'all_a_share' :
297- rs = bs .query_all_stock (day = date )
305+ rs = baostock .query_all_stock (day = date )
298306 else :
299307 raise ValueError (f'Invalid data type: { data_type } ' )
300308
@@ -313,7 +321,8 @@ def get_trade_dates(self,
313321 )
314322
315323 with baostock_session ():
316- rs = bs .query_trade_dates (start_date = start_date , end_date = end_date )
324+ rs = baostock .query_trade_dates (
325+ start_date = start_date , end_date = end_date )
317326 return self ._query_to_dataframe (rs , 'trade dates' )
318327
319328 def get_macro_data (
@@ -341,27 +350,27 @@ def get_macro_data(
341350 parsed_end_date = end_date
342351
343352 if data_type == 'deposit_rate' :
344- query_func = bs .query_deposit_rate_data
353+ query_func = baostock .query_deposit_rate_data
345354
346355 elif data_type == 'loan_rate' :
347- query_func = bs .query_loan_rate_data
356+ query_func = baostock .query_loan_rate_data
348357
349358 elif data_type == 'required_reserve_ratio' :
350- query_func = bs .query_required_reserve_ratio_data
359+ query_func = baostock .query_required_reserve_ratio_data
351360 if extra_kwargs :
352361 parsed_extra_kwargs .update (extra_kwargs )
353362 if 'yearType' not in parsed_extra_kwargs :
354363 parsed_extra_kwargs ['yearType' ] = '0'
355364
356365 elif data_type == 'money_supply_month' :
357- query_func = bs .query_money_supply_data_month
366+ query_func = baostock .query_money_supply_data_month
358367 parsed_start_date = pd .to_datetime (
359368 start_date ).strftime ('%Y-%m' )
360369 parsed_end_date = pd .to_datetime (end_date ).strftime (
361370 '%Y-%m' )
362371
363372 elif data_type == 'money_supply_year' :
364- query_func = bs .query_money_supply_data_year
373+ query_func = baostock .query_money_supply_data_year
365374 parsed_start_date = pd .to_datetime (
366375 start_date ).strftime ('%Y' )
367376 parsed_end_date = pd .to_datetime (end_date ).strftime (
0 commit comments