Skip to content

使用本地数据库导入数据源时,env.set_data_source报错 #940

@1995zhanbudu0115

Description

@1995zhanbudu0115

提 ISSUE 须知

请先阅读文档 rqalpha文档

如果仍有问题的话请在 issue列表 中寻找是否有相关问题的解决方案

如果没有的话 麻烦开一个issue 描述以下问题:

rqalpha 5.6.4 版本

python 3.11

window 系统

class FactorDataSource(BaseDataSource):
def init(self, mysql_config, start_date,end_date):
self.mysql_config = mysql_config
self._conn = pymysql.connect(**mysql_config)
self._data_cache = {}
self._instruments = []
self._date_range = (None, None)

    self._preload_all_data(start_date,end_date)

    # super().__init__(path=".",custom_future_info={})

def _preload_all_data(self,start_date,end_date):
    """预加载所有行情数据和因子数据到内存"""
    print("🚀 预加载行情数据和因子数据...")
    
    # 加载股票列表
    instruments_sql = f"SELECT DISTINCT etfcode instrument FROM daily_etf_bar where tradedate between '{start_date}' and '{end_date}'"
    instruments_df = pd.read_sql(instruments_sql, self._conn)
    self._instruments = instruments_df['instrument'].tolist()
    
    # 批量加载所有行情数据
    all_data_sql = f"""
    SELECT etfcode instrument, tradedate date, open, high, low, close, volume, amount
    FROM daily_etf_bar where tradedate between '{start_date}' and '{end_date}'
    """
    
    all_data = pd.read_sql(all_data_sql, self._conn)
    all_data['date'] = pd.to_datetime(all_data['date'])
    all_data.set_index(['date','instrument'],inplace=True)
    all_data.sort_index(inplace=True)
    # 按股票代码分组存储
    for code, group in all_data.groupby('instrument'):
        self._data_cache[code] = group
    
    print(f"✅ 加载完成!共 {len(self._data_cache)} 只股票")
    self._conn.close()

def history_bars(self, instrument, bar_count, frequency, fields, dt, skip_suspended=True):
    if instrument not in self._data_cache:
        return np.array([])
    
    stock_data = self._data_cache[instrument]
    available_data = stock_data[stock_data.index <= pd.Timestamp(dt)]
    
    if len(available_data) == 0:
        return np.array([])
    
    recent_data = available_data.tail(bar_count)
    
    # field_map = {
    #     'open': 'open', 'high': 'high', 'low': 'low', 
    #     'close': 'close', 'volume': 'volume', 'total_turnover': 'amount'
    # }
    
    # mysql_fields = [field_map.get(f, f) for f in fields]
    return recent_data[fields].values

def get_all_instruments(self):
    return [{'order_book_id': code, 'type': 'CS'} for code in self._instruments]

def available_data_range(self, frequency):
    return (pd.Timestamp('2015-01-01'), pd.Timestamp('2023-12-31'))

def get_bar(self, instrument, dt, frequency):
    """获取单个bar数据"""
    if instrument not in self._data_cache:
        return None
    
    stock_data = self._data_cache[instrument]
    dt_timestamp = pd.Timestamp(dt)
    
    if dt_timestamp in stock_data.index:
        bar_data = stock_data.loc[dt_timestamp]
        return {
            'open': float(bar_data['open']),
            'high': float(bar_data['high']),
            'low': float(bar_data['low']),
            'close': float(bar_data['close']),
            'volume': float(bar_data['volume']),
            'total_turnover': float(bar_data.get('amount', 0))
        }
    return None
    # 重写可能引发数据包检查的方法
def get_fund_info(self, fund, dt):
    """重写基金信息获取,返回空数据"""
    return {}

def get_yield_curve(self, start_date, end_date, tenor=None):
    """重写收益率曲线获取,返回空数据"""
    return pd.DataFrame()

def get_share_transformation(self, order_book_id, dt):
    """重写股本变动信息"""
    return {}

def get_dividend(self, order_book_id, dt):
    """重写分红信息"""
    return {}

def _get_preds_his():
    pass
def get_preds():
    pass

def advanced_multi_factor_strategy(context):
"""
进阶多因子策略:包含行业中性化和风险控制
"""
print(context.trading_dt)
context.history_bars('000300.XSHG', 60, '1d', ['close'])

def run_multi_factor_backtest():
"""运行多因子选股回测"""

mysql_config = {
}

# 创建数据源
data_source = FactorDataSource(mysql_config,start_date='2015-01-01',end_date='2015-01-31')

# 获取所有股票作为股票池
all_instruments = [inst['order_book_id'] for inst in data_source.get_all_instruments()]
# 可以选择部分股票,比如市值前300
universe = all_instruments[:300] if len(all_instruments) > 300 else all_instruments

config = {
    "base": {
        "start_date": "2015-01-01",
        "end_date": "2015-01-31",
        "accounts": {"stock": 1000000},
        "frequency": "1d",
        "matching_type": "next_bar",
        "data_source": data_source,
        "benchmark": None,
        "universe": universe,  # 设置股票池
    },
    "extra": {
        "log_level": "info",
    },
    "mod": {
        "sys_analyser": {
            "enabled": True,
            "plot": True,
            "output_file": "./results/multi_factor_result.pkl",
            "report_save_path": "./results/"
        },
        "sys_simulation": {
            "enabled": True,
            "slippage": 0.0001,
            "commission_multiplier": 0,
        }
    },
     # 在新版中,策略在配置中指定
    "strategy": advanced_multi_factor_strategy
}

print("开始多因子选股回测...")
result = run(config)

# 保存因子分析结果
_save_factor_analysis(result)

return result

def _save_factor_analysis(backtest_result):
"""保存因子分析结果"""
try:
# 这里可以添加因子绩效分析
analysis_df = pd.DataFrame({
'description': ['多因子选股策略回测结果'],
'total_return': [backtest_result['total_returns']],
'annual_return': [backtest_result['annualized_returns']],
'max_drawdown': [backtest_result['max_drawdown']],
'sharpe_ratio': [backtest_result['sharpe']]
})

    analysis_df.to_csv('./results/factor_analysis.csv', index=False)
    print("因子分析结果已保存")
except Exception as e:
    print(f"保存分析结果失败: {e}")

if name == "main":
import os
os.makedirs('./results', exist_ok=True)
run_multi_factor_backtest()

在执行 run时报错,File "D:\tools\Anaconda\envs\new\Lib\site-packages\rqalpha\main.py", line 150, in run
env.set_data_source(BaseDataSource(
^^^^^^^^^^^^^^^
怎么跳过bundle数据检查,使用自有行情数据来进行回测策略?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions