Skip to content

Commit 70752ca

Browse files
authored
Merge pull request #971 from ricequant/develop
fix use symbol to get instrument error
2 parents ec72594 + 6c73533 commit 70752ca

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

rqalpha/data/base_data_source/data_source.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
# 否则米筐科技有权追究相应的知识产权侵权责任。
1515
# 在此前提下,对本软件的使用同样需要遵守 Apache 2.0 许可,Apache 2.0 许可与本许可冲突之处,以本许可为准。
1616
# 详细的授权流程,请联系 [email protected] 获取。
17-
from collections import defaultdict, ChainMap
17+
from collections import ChainMap
1818
import os
1919
from datetime import date, datetime, timedelta
2020
from itertools import chain, repeat
21-
from typing import DefaultDict, Dict, Iterable, List, Mapping, Optional, Sequence, Union, cast, Tuple
21+
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, cast, Tuple
2222

2323
try:
2424
from typing import Protocol, runtime_checkable
@@ -116,10 +116,10 @@ def _p(name):
116116
self._ex_factor_stores: Dict[Tuple[INSTRUMENT_TYPE, MARKET], AbstractSimpleFactorStore] = {}
117117

118118
# instruments
119-
self._id_instrument_map: DefaultDict[str, dict[datetime, Instrument]] = defaultdict(dict)
120-
self._sym_instrument_map: DefaultDict[str, dict[datetime, Instrument]] = defaultdict(dict)
121-
self._id_or_sym_instrument_map: Mapping[str, dict[datetime, Instrument]] = ChainMap(self._id_instrument_map, self._sym_instrument_map)
122-
self._grouped_instruments: DefaultDict[INSTRUMENT_TYPE, dict[datetime, Instrument]] = defaultdict(dict)
119+
self._id_instrument_map: Dict[str, Dict[datetime, Instrument]] = {}
120+
self._sym_instrument_map: Dict[str, Dict[datetime, Instrument]] = {}
121+
self._id_or_sym_instrument_map: Mapping[str, Dict[datetime, Instrument]] = ChainMap(self._id_instrument_map, self._sym_instrument_map)
122+
self._grouped_instruments: Dict[INSTRUMENT_TYPE, Dict[datetime, Instrument]] = {}
123123

124124
# register instruments
125125
self.register_instruments(load_instruments_from_pkl(_p('instruments.pk'), self._future_info_store))
@@ -150,9 +150,9 @@ def register_day_bar_store(self, instrument_type: INSTRUMENT_TYPE, store: Abstra
150150

151151
def register_instruments(self, instruments: Iterable[Instrument]):
152152
for ins in instruments:
153-
self._id_instrument_map[ins.order_book_id][ins.listed_date] = ins
154-
self._sym_instrument_map[ins.symbol][ins.listed_date] = ins
155-
self._grouped_instruments[ins.type][ins.listed_date] = ins
153+
self._id_instrument_map.setdefault(ins.order_book_id, {})[ins.listed_date] = ins
154+
self._sym_instrument_map.setdefault(ins.symbol, {})[ins.listed_date] = ins
155+
self._grouped_instruments.setdefault(ins.type, {})[ins.listed_date] = ins
156156

157157
def register_dividend_store(self, instrument_type: INSTRUMENT_TYPE, dividend_store: AbstractDividendStore, market: MARKET = MARKET.CN):
158158
self._dividend_stores[instrument_type, market] = dividend_store
@@ -189,8 +189,9 @@ def get_instruments(self, id_or_syms: Optional[Iterable[str]] = None, types: Opt
189189
if id_or_syms is not None:
190190
seen = set()
191191
for i in id_or_syms:
192-
if i in self._id_or_sym_instrument_map:
193-
for ins in self._id_or_sym_instrument_map[i].values():
192+
v = self._id_or_sym_instrument_map.get(i)
193+
if v:
194+
for ins in v.values():
194195
if ins not in seen:
195196
seen.add(ins)
196197
yield ins

0 commit comments

Comments
 (0)