Skip to content

Commit 30fb5dc

Browse files
committed
Enhance instrument handling by adding InstrumentNotFound exception, updating assure_instrument and related methods to improve order book ID validation, and refining argument checking for instrument types.
1 parent d926588 commit 30fb5dc

File tree

8 files changed

+125
-105
lines changed

8 files changed

+125
-105
lines changed

rqalpha/apis/api_base.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from rqalpha.environment import Environment
2929
from rqalpha.core.execution_context import ExecutionContext
3030
from rqalpha.utils import is_valid_price
31-
from rqalpha.utils.exception import RQInvalidArgument
31+
from rqalpha.utils.exception import RQInvalidArgument, InstrumentNotFound
3232
from rqalpha.utils.i18n import gettext as _
3333
from rqalpha.utils.arg_checker import apply_rules, verify_that
3434
from rqalpha.api import export_as_api
@@ -60,22 +60,35 @@
6060
export_as_api(EVENT, name='EVENT')
6161

6262

63-
def assure_instrument(id_or_ins) -> Instrument:
63+
def assure_instrument(id_or_ins, verify_listing: bool) -> Instrument:
64+
# verify_listing: 是否验证合约是否上市,注意,港股存在复用代码的情况,因此该参数应尽量设为 True 以避免搜索 instrument 失败
6465
if isinstance(id_or_ins, Instrument):
6566
return id_or_ins
6667
elif isinstance(id_or_ins, six.string_types):
67-
ins = Environment.get_instance().data_proxy.instrument(id_or_ins)
68-
if not ins:
69-
raise RQInvalidArgument(_(
70-
"invalid argument, expected order_book_ids or Instrument objects, got {} (type: {})"
71-
).format(id_or_ins, type(id_or_ins)))
68+
env = Environment.get_instance()
69+
try:
70+
ins = env.data_proxy.instrument_not_none(id_or_ins, env.trading_dt if verify_listing else None)
71+
except InstrumentNotFound as e:
72+
if verify_listing:
73+
raise RQInvalidArgument(_(
74+
"instrument {} not found or not listed at {}"
75+
).format(id_or_ins, env.trading_dt.date()))
76+
else:
77+
raise RQInvalidArgument(_(
78+
"instrument {} not found"
79+
).format(id_or_ins))
7280
return ins
7381
else:
7482
raise RQInvalidArgument(_(u"unsupported order_book_id type"))
7583

7684

7785
def assure_order_book_id(id_or_ins):
78-
return assure_instrument(id_or_ins).order_book_id
86+
if isinstance(id_or_ins, Instrument):
87+
return id_or_ins.order_book_id
88+
try:
89+
return Environment.get_instance().data_proxy.assure_order_book_id(id_or_ins)
90+
except InstrumentNotFound as e:
91+
raise RQInvalidArgument(_("instrument {} not found").format(id_or_ins))
7992

8093

8194
def cal_style(price, style, price_or_style=None):
@@ -134,7 +147,7 @@ def get_open_orders():
134147

135148
@export_as_api
136149
@apply_rules(
137-
verify_that("id_or_ins").is_valid_instrument(),
150+
verify_that("id_or_ins").is_valid_order_book_id(),
138151
verify_that("amount").is_number().is_greater_than(0),
139152
verify_that("side").is_in([SIDE.BUY, SIDE.SELL]),
140153
)
@@ -755,7 +768,7 @@ def get_next_trading_date(date, n=1):
755768
EXECUTION_PHASE.AFTER_TRADING,
756769
EXECUTION_PHASE.SCHEDULED,
757770
)
758-
@apply_rules(verify_that("id_or_symbol").is_valid_instrument())
771+
@apply_rules(verify_that("id_or_symbol").is_valid_order_book_id())
759772
def current_snapshot(id_or_symbol):
760773
# type: (Union[str, Instrument]) -> Optional[TickObject]
761774
"""

rqalpha/apis/api_rqdatac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def get_stock_connect(order_book_ids, count=1, fields=None, expect_df=False):
827827

828828

829829
@export_as_api
830-
@apply_rules(verify_that('order_book_id').is_valid_instrument(),
830+
@apply_rules(verify_that('order_book_id').is_valid_order_book_id(),
831831
verify_that('quarter').is_valid_quarter(),
832832
verify_that('fields').are_valid_fields(VALID_CURRENT_PERFORMANCE_FIELDS, ignore_none=True))
833833
def current_performance(order_book_id, info_date=None, quarter=None, interval='1q', fields=None):

rqalpha/data/base_data_source/data_source.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ def get_dividend(self, instrument):
183183
def get_trading_minutes_for(self, instrument, trading_dt):
184184
raise NotImplementedError
185185

186-
def get_trading_calendars(self):
187-
# type: () -> Dict[TRADING_CALENDAR_TYPE, pd.DatetimeIndex]
186+
def get_trading_calendars(self) -> Dict[TRADING_CALENDAR_TYPE, pd.DatetimeIndex]:
188187
return {t: store.get_trading_calendar() for t, store in self._calendar_stores.items()}
189188

190189
def get_instruments(self, id_or_syms: Optional[Iterable[str]] = None, types: Optional[Iterable[INSTRUMENT_TYPE]] = None) -> Iterable[Instrument]:

rqalpha/data/data_proxy.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,26 @@
1616
# 详细的授权流程,请联系 [email protected] 获取。
1717

1818
from datetime import datetime, date
19-
from typing import Union, List, Sequence, Optional, Tuple, Iterable, Dict
19+
from typing import Union, List, Sequence, Optional, Tuple, Iterable, Dict, Callable
2020

2121
import numpy as np
2222
import pandas as pd
2323

24-
from rqalpha.const import INSTRUMENT_TYPE, TRADING_CALENDAR_TYPE, EXECUTION_PHASE, MARKET
24+
from rqalpha.const import INSTRUMENT_TYPE, EXECUTION_PHASE, MARKET
2525
from rqalpha.utils import risk_free_helper, TimeRange, merge_trading_period
2626
from rqalpha.data.trading_dates_mixin import TradingDatesMixin
2727
from rqalpha.model.bar import BarObject, NANDict, PartialBarObject
2828
from rqalpha.model.tick import TickObject
2929
from rqalpha.model.instrument import Instrument
3030
from rqalpha.model.order import ALGO_ORDER_STYLES
3131
from rqalpha.utils.functools import lru_cache
32-
from rqalpha.utils.datetime_func import convert_int_to_datetime, convert_date_to_int
32+
from rqalpha.utils.datetime_func import convert_int_to_datetime
3333
from rqalpha.utils.typing import DateLike, StrOrIter
3434
from rqalpha.utils.i18n import gettext as _
3535
from rqalpha.interface import AbstractDataSource, AbstractPriceBoard, ExchangeRate
3636
from rqalpha.core.execution_context import ExecutionContext
3737
from rqalpha.utils.typing import DateLike
38+
from rqalpha.utils.exception import InstrumentNotFound
3839

3940

4041
class DataProxy(TradingDatesMixin):
@@ -285,8 +286,7 @@ def get_last_price(self, order_book_id):
285286
# type: (str) -> float
286287
return float(self._price_board.get_last_price(order_book_id))
287288

288-
def all_instruments(self, types, dt=None):
289-
# type: (List[INSTRUMENT_TYPE], Optional[datetime]) -> List[Instrument]
289+
def all_instruments(self, types: List[INSTRUMENT_TYPE], dt: Optional[datetime] = None) -> List[Instrument]:
290290
li = []
291291
for i in self._data_source.get_instruments(types=types):
292292
if dt is None or i.listing_at(dt):
@@ -295,20 +295,42 @@ def all_instruments(self, types, dt=None):
295295
# return [i for i in self._data_source.get_instruments(types=types) if dt is None or i.listing_at(dt)]
296296

297297
@lru_cache(2048)
298-
def instrument(self, sym_or_id):
299-
return next(iter(self._data_source.get_instruments(id_or_syms=[sym_or_id])), None)
298+
def instrument_not_none(self, id_or_sym: str, dt: Optional[datetime] = None) -> Instrument:
299+
"""
300+
根据合约代码获取唯一的 Instrument 对象
300301
301-
@lru_cache(2048)
302-
def instrument_not_none(self, sym_or_id) -> Instrument:
303-
try:
304-
return next(iter(self._data_source.get_instruments(id_or_syms=[sym_or_id])))
305-
except StopIteration:
306-
raise LookupError(_("Instrument not found: {}").format(sym_or_id))
302+
:param str id_or_sym: 合约代码或合约简称
303+
:param datetime dt: 可选,指定查询的时间点。若提供此参数,则仅返回在该时间点处于上市状态的合约
304+
注意,对于港股等可能出现复用代码情况等品种,请一律指定 dt 参数
305+
:return: 匹配的 Instrument 对象
306+
"""
307+
candidates = []
308+
for instrument in self._data_source.get_instruments(id_or_syms=[id_or_sym]):
309+
if dt is None or instrument.listing_at(dt):
310+
candidates.append(instrument)
311+
if not candidates:
312+
raise InstrumentNotFound(_("No instrument found at {dt}: {id_or_sym}").format(dt=dt, id_or_sym=id_or_sym))
313+
if len(candidates) > 1:
314+
raise InstrumentNotFound(_("Multiple instruments found at {dt}: {id_or_sym}").format(dt=dt, id_or_sym=id_or_sym))
315+
return candidates[0]
307316

308317
def multi_instruments(self, order_book_ids: Iterable[str]) -> Dict[str, Instrument]:
309318
return {i.order_book_id: i for i in self._data_source.get_instruments(id_or_syms=order_book_ids)}
310319

320+
def assure_order_book_id(self, order_book_id: str, expected_type: Optional[INSTRUMENT_TYPE] = None) -> str:
321+
for instrument in self._data_source.get_instruments(id_or_syms=[order_book_id]):
322+
if expected_type is not None and instrument.type != expected_type:
323+
continue
324+
return instrument.order_book_id
325+
raise InstrumentNotFound(_("No instrument found: {}").format(order_book_id))
326+
327+
@lru_cache(2048)
328+
def instrument(self, sym_or_id):
329+
# deprecated
330+
return next(iter(self._data_source.get_instruments(id_or_syms=[sym_or_id])), None)
331+
311332
def instruments(self, sym_or_ids):
333+
# deprecated
312334
# type: (StrOrIter) -> Union[None, Instrument, List[Instrument]]
313335
if isinstance(sym_or_ids, str):
314336
return next(iter(self._data_source.get_instruments(id_or_syms=[sym_or_ids])), None)

rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def stock_order_to(order_book_id, quantity, price_or_style=None, price=None, sty
268268
EXECUTION_PHASE.SCHEDULED,
269269
EXECUTION_PHASE.GLOBAL
270270
)
271-
@apply_rules(verify_that('id_or_ins').is_valid_stock(), verify_that('amount').is_number(), *common_rules)
271+
@apply_rules(
272+
verify_that('id_or_ins').is_valid_order_book_id(expected_type=INSTRUMENT_TYPE.CS),
273+
verify_that('amount').is_number(),
274+
*common_rules
275+
)
272276
def order_lots(id_or_ins, amount, price_or_style=None, price=None, style=None):
273277
# type: (Union[str, Instrument], int, PRICE_OR_STYLE_TYPE, Optional[float], Optional[OrderStyle]) -> Optional[Order]
274278
"""
@@ -462,7 +466,7 @@ def order_target_portfolio(
462466
EXECUTION_PHASE.ON_TICK,
463467
EXECUTION_PHASE.AFTER_TRADING,
464468
EXECUTION_PHASE.SCHEDULED)
465-
@apply_rules(verify_that('order_book_id').is_valid_instrument(),
469+
@apply_rules(verify_that('order_book_id').is_valid_order_book_id(),
466470
verify_that('count').is_greater_than(0))
467471
def is_suspended(order_book_id, count=1):
468472
# type: (str, Optional[int]) -> Union[bool, pd.DataFrame]
@@ -486,7 +490,7 @@ def is_suspended(order_book_id, count=1):
486490
EXECUTION_PHASE.ON_TICK,
487491
EXECUTION_PHASE.AFTER_TRADING,
488492
EXECUTION_PHASE.SCHEDULED)
489-
@apply_rules(verify_that('order_book_id').is_valid_instrument())
493+
@apply_rules(verify_that('order_book_id').is_valid_order_book_id())
490494
def is_st_stock(order_book_id, count=1):
491495
# type: (str, Optional[int]) -> Union[bool, pd.DataFrame]
492496
"""
@@ -697,7 +701,7 @@ def init(context):
697701

698702
@export_as_api
699703
@apply_rules(
700-
verify_that("order_book_id").is_valid_instrument(),
704+
verify_that("order_book_id").is_valid_order_book_id(),
701705
verify_that("start_date").is_valid_date(ignore_none=False),
702706
)
703707
def get_dividend(order_book_id, start_date):

rqalpha/model/instrument.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
import re
1919
import copy
20-
import datetime
21-
import inspect
20+
from datetime import datetime, time, date
2221
from typing import Dict, Callable, Optional
2322
from methodtools import lru_cache
2423

@@ -34,17 +33,17 @@
3433

3534
# TODO:改为 namedtuple,提升性能
3635
class Instrument(metaclass=PropertyReprMeta):
37-
DEFAULT_DE_LISTED_DATE = datetime.datetime(2999, 12, 31)
36+
DEFAULT_DE_LISTED_DATE = datetime(2999, 12, 31)
3837

3938
@staticmethod
40-
def _fix_date(ds, dflt=None) -> Optional[datetime.datetime]:
41-
if isinstance(ds, datetime.datetime):
39+
def _fix_date(ds, dflt=None) -> Optional[datetime]:
40+
if isinstance(ds, datetime):
4241
return ds
4342
if ds == '0000-00-00' or ds is None:
4443
return dflt
4544
try:
4645
year, month, day = ds.split('-')
47-
return datetime.datetime(int(year), int(month), int(day))
46+
return datetime(int(year), int(month), int(day))
4847
except:
4948
return parse(ds)
5049

@@ -92,15 +91,14 @@ def round_lot(self) -> int:
9291
return int(self._dict["round_lot"])
9392

9493
@cached_property
95-
def listed_date(self) -> Optional[datetime.datetime]:
94+
def listed_date(self) -> datetime:
9695
"""
9796
[datetime] 股票:该证券上市日期。期货:期货的上市日期,主力连续合约与指数连续合约都为 datetime(1990, 1, 1)。
9897
"""
9998
return self._dict["listed_date"]
10099

101100
@cached_property
102-
def de_listed_date(self):
103-
# type: () -> datetime.datetime
101+
def de_listed_date(self) -> datetime:
104102
"""
105103
[datetime] 股票:退市日期。期货:交割日期。
106104
"""
@@ -262,7 +260,7 @@ def underlying_symbol(self):
262260

263261
@cached_property
264262
def maturity_date(self):
265-
# type: () -> datetime.datetime
263+
# type: () -> datetime
266264
"""
267265
[datetime] 到期日
268266
"""
@@ -316,36 +314,30 @@ def account_type(self):
316314
else:
317315
raise NotImplementedError
318316

319-
def listing_at(self, dt):
317+
def listing_at(self, dt: datetime) -> bool:
320318
"""
321319
该合约在指定日期是否在交易
322-
:param dt: datetime.datetime
323-
:return: bool
324320
"""
325321
return self.listed_at(dt) and not self.de_listed_at(dt)
326322

327-
def listed_at(self, dt):
323+
def listed_at(self, dt: datetime) -> bool:
328324
"""
329325
该合约在指定日期是否已上市
330-
:param dt: datetime.datetime
331-
:return: bool
332326
"""
333-
return self.listed_date and self.listed_date <= dt
327+
return self.listed_date <= dt
334328

335-
def de_listed_at(self, dt):
329+
def de_listed_at(self, dt: datetime) -> bool:
336330
"""
337331
该合约在指定日期是否已退市
338-
:param dt: datetime.datetime
339-
:return: bool
340332
"""
341333
if self.type in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION):
342334
return dt.date() > self.de_listed_date.date()
343335
else:
344336
return dt >= self.de_listed_date
345337

346338
STOCK_TRADING_PERIOD = [
347-
TimeRange(start=datetime.time(9, 31), end=datetime.time(11, 30)),
348-
TimeRange(start=datetime.time(13, 1), end=datetime.time(15, 0)),
339+
TimeRange(start=time(9, 31), end=time(11, 30)),
340+
TimeRange(start=time(13, 1), end=time(15, 0)),
349341
]
350342

351343
@cached_property
@@ -361,16 +353,16 @@ def trading_hours(self):
361353
trading_hours = trading_hours.replace("-", ":")
362354
for time_range_str in trading_hours.split(","):
363355
start_h, start_m, end_h, end_m = (int(i) for i in time_range_str.split(":"))
364-
start, end = datetime.time(start_h, start_m), datetime.time(end_h, end_m)
356+
start, end = time(start_h, start_m), time(end_h, end_m)
365357
if start > end:
366-
trading_period.append(TimeRange(start, datetime.time(23, 59)))
367-
trading_period.append(TimeRange(datetime.time(0, 0), end))
358+
trading_period.append(TimeRange(start, time(23, 59)))
359+
trading_period.append(TimeRange(time(0, 0), end))
368360
else:
369361
trading_period.append(TimeRange(start, end))
370362
return trading_period
371363

372364
def during_continuous_auction(self, time):
373-
# type: (datetime.time) -> bool
365+
# type: (time) -> bool
374366
""" 是否处于连续竞价时间段内 """
375367
for time_range in self.trading_hours:
376368
if time_range.start <= time <= time_range.end:
@@ -389,7 +381,7 @@ def trading_code(self):
389381

390382
@cached_property
391383
def trade_at_night(self):
392-
return any(r.start <= datetime.time(4, 0) or r.end >= datetime.time(19, 0) for r in (self.trading_hours or []))
384+
return any(r.start <= time(4, 0) or r.end >= time(19, 0) for r in (self.trading_hours or []))
393385

394386
@cached_property
395387
def min_order_quantity(self):
@@ -455,21 +447,21 @@ def tick_size(self):
455447
raise NotImplementedError
456448

457449
@lru_cache(8)
458-
def get_long_margin_ratio(self, dt: datetime.date) -> float:
450+
def get_long_margin_ratio(self, dt: date) -> float:
459451
"""
460452
获取多头保证金率(期货专用)
461453
"""
462454
return Environment.get_instance().data_proxy.get_futures_trading_parameters(self.order_book_id, dt).long_margin_ratio
463455

464456
@lru_cache(8)
465-
def get_short_margin_ratio(self, dt: datetime.date) -> float:
457+
def get_short_margin_ratio(self, dt: date) -> float:
466458
"""
467459
获取空头保证金率(期货专用)
468460
"""
469461
return Environment.get_instance().data_proxy.get_futures_trading_parameters(self.order_book_id, dt).short_margin_ratio
470462

471463
def calc_cash_occupation(self, price, quantity, direction, dt):
472-
# type: (float, int, POSITION_DIRECTION, datetime.date) -> float
464+
# type: (float, int, POSITION_DIRECTION, date) -> float
473465
if self.market != MARKET.CN:
474466
exchagne_rate = Environment.get_instance().data_proxy.get_exchange_rate(dt, self.market)
475467
price = price * exchagne_rate.ask_reference

0 commit comments

Comments
 (0)