|
14 | 14 | # 否则米筐科技有权追究相应的知识产权侵权责任。 |
15 | 15 | # 在此前提下,对本软件的使用同样需要遵守 Apache 2.0 许可,Apache 2.0 许可与本许可冲突之处,以本许可为准。 |
16 | 16 | # 详细的授权流程,请联系 [email protected] 获取。 |
17 | | -from collections import defaultdict, ChainMap |
| 17 | +from collections import ChainMap |
18 | 18 | import os |
19 | 19 | from datetime import date, datetime, timedelta |
20 | 20 | from itertools import chain, repeat |
21 | | -from typing import DefaultDict, Dict, Iterable, List, Mapping, Optional, Sequence, Union, cast, Tuple |
| 21 | +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, cast, Tuple |
22 | 22 |
|
23 | 23 | try: |
24 | 24 | from typing import Protocol, runtime_checkable |
@@ -116,10 +116,10 @@ def _p(name): |
116 | 116 | self._ex_factor_stores: Dict[Tuple[INSTRUMENT_TYPE, MARKET], AbstractSimpleFactorStore] = {} |
117 | 117 |
|
118 | 118 | # instruments |
119 | | - self._id_instrument_map: DefaultDict[str, dict[datetime, Instrument]] = defaultdict(dict) |
120 | | - self._sym_instrument_map: DefaultDict[str, dict[datetime, Instrument]] = defaultdict(dict) |
121 | | - self._id_or_sym_instrument_map: Mapping[str, dict[datetime, Instrument]] = ChainMap(self._id_instrument_map, self._sym_instrument_map) |
122 | | - self._grouped_instruments: DefaultDict[INSTRUMENT_TYPE, dict[datetime, Instrument]] = defaultdict(dict) |
| 119 | + self._id_instrument_map: Dict[str, Dict[datetime, Instrument]] = {} |
| 120 | + self._sym_instrument_map: Dict[str, Dict[datetime, Instrument]] = {} |
| 121 | + self._id_or_sym_instrument_map: Mapping[str, Dict[datetime, Instrument]] = ChainMap(self._id_instrument_map, self._sym_instrument_map) |
| 122 | + self._grouped_instruments: Dict[INSTRUMENT_TYPE, Dict[datetime, Instrument]] = {} |
123 | 123 |
|
124 | 124 | # register instruments |
125 | 125 | self.register_instruments(load_instruments_from_pkl(_p('instruments.pk'), self._future_info_store)) |
@@ -150,9 +150,9 @@ def register_day_bar_store(self, instrument_type: INSTRUMENT_TYPE, store: Abstra |
150 | 150 |
|
151 | 151 | def register_instruments(self, instruments: Iterable[Instrument]): |
152 | 152 | for ins in instruments: |
153 | | - self._id_instrument_map[ins.order_book_id][ins.listed_date] = ins |
154 | | - self._sym_instrument_map[ins.symbol][ins.listed_date] = ins |
155 | | - self._grouped_instruments[ins.type][ins.listed_date] = ins |
| 153 | + self._id_instrument_map.setdefault(ins.order_book_id, {})[ins.listed_date] = ins |
| 154 | + self._sym_instrument_map.setdefault(ins.symbol, {})[ins.listed_date] = ins |
| 155 | + self._grouped_instruments.setdefault(ins.type, {})[ins.listed_date] = ins |
156 | 156 |
|
157 | 157 | def register_dividend_store(self, instrument_type: INSTRUMENT_TYPE, dividend_store: AbstractDividendStore, market: MARKET = MARKET.CN): |
158 | 158 | self._dividend_stores[instrument_type, market] = dividend_store |
@@ -189,8 +189,9 @@ def get_instruments(self, id_or_syms: Optional[Iterable[str]] = None, types: Opt |
189 | 189 | if id_or_syms is not None: |
190 | 190 | seen = set() |
191 | 191 | for i in id_or_syms: |
192 | | - if i in self._id_or_sym_instrument_map: |
193 | | - for ins in self._id_or_sym_instrument_map[i].values(): |
| 192 | + v = self._id_or_sym_instrument_map.get(i) |
| 193 | + if v: |
| 194 | + for ins in v.values(): |
194 | 195 | if ins not in seen: |
195 | 196 | seen.add(ins) |
196 | 197 | yield ins |
|
0 commit comments