diff --git a/funcat/api.py b/funcat/api.py index 92259e9..dae564d 100644 --- a/funcat/api.py +++ b/funcat/api.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np -from .time_series import MarketDataSeries +from .time_series import MarketDataSeries, NumericSeries from .func import ( SumSeries, AbsSeries, @@ -19,7 +19,7 @@ llv, Ref, iif, -) + barslast) from .context import ( symbol, set_current_security, @@ -59,6 +59,7 @@ HHV = hhv LLV = llv IF = IIF = iif +BARSLAST=barslast S = set_current_security T = set_current_date @@ -90,6 +91,7 @@ "HHV", "LLV", "IF", "IIF", + "BARSLAST", "S", "T", @@ -101,4 +103,6 @@ "set_start_date", "set_data_backend", "set_current_freq", + + "NumericSeries", ] diff --git a/funcat/func.py b/funcat/func.py index aca5846..81a0e6b 100644 --- a/funcat/func.py +++ b/funcat/func.py @@ -19,6 +19,17 @@ ) +# delete nan of series for error made by some operator +def filter_begin_nan(series): + i = 0 + for x in series: + if np.isnan(x): + i += 1 + else: + break + return series[i:] + + class OneArgumentSeries(NumericSeries): func = talib.MA @@ -29,6 +40,7 @@ def __init__(self, series, arg): try: series[series == np.inf] = np.nan series = self.func(series, arg) + series = filter_begin_nan(series) except Exception as e: raise FormulaException(e) super(OneArgumentSeries, self).__init__(series) @@ -64,6 +76,7 @@ def __init__(self, series, arg1, arg2): try: series[series == np.inf] = np.nan series = self.func(series, arg1, arg2) + series = filter_begin_nan(series) except Exception as e: raise FormulaException(e) super(TwoArgumentSeries, self).__init__(series) @@ -129,6 +142,8 @@ def CrossOver(s1, s2): def Ref(s1, n): + if isinstance(n, NumericSeries): + return s1[int(n.value)] return s1[n] @@ -213,3 +228,24 @@ def iif(condition, true_statement, false_statement): series[cond_series] = series1[cond_series] return NumericSeries(series) + + +@handle_numpy_warning +def barslast(statement): + series = get_series(statement) + size = len(series) + end = size + begin = size - 1 + + try: + result = np.full(size, 1e16, dtype=np.int64) + except ValueError as e: + raise FormulaException(e) + + for s in series[::-1]: + if s: + result[begin:end] = range(0, end - begin) + end = begin + begin -= 1 + + return NumericSeries(result) diff --git a/funcat/time_series.py b/funcat/time_series.py index e791bb7..e090d9a 100644 --- a/funcat/time_series.py +++ b/funcat/time_series.py @@ -204,6 +204,9 @@ def __invert__(self): def __repr__(self): return str(self.value) + def __int__(self): + return int(self.value) + class NumericSeries(TimeSeries): def __init__(self, series=[]): @@ -216,7 +219,13 @@ def series(self): return self._series def __getitem__(self, index): - assert isinstance(index, int) and index >= 0 + assert (isinstance(index, int) and index >= 0) \ + or (isinstance(index, NumericSeries)) + + if isinstance(index, NumericSeries): + index = int(index.value) + assert index >= 0 + return self.__class__(series=self.series[:len(self.series) - index], **self.extra_create_kwargs) @@ -255,6 +264,10 @@ def __getitem__(self, index): if isinstance(index, int): assert index >= 0 + if isinstance(index, NumericSeries): + index = int(index.value) + assert index >= 0 + if isinstance(index, six.string_types): unit = index[-1] period = int(index[:-1]) diff --git a/requirements.txt b/requirements.txt index 4afefcc..70d5247 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ lxml requests TA-Lib cached-property +bs4 +tushare \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 8015417..ca1a83c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -14,6 +14,7 @@ def test_000001(): T("20161216") S("000001.XSHG") + assert np.equal(REF(C, BARSLAST(C > 0)).value, C.value) assert np.equal(round(CLOSE.value, 2), 3122.98) assert np.equal(round(OPEN[2].value, 2), 3149.38) assert np.equal(round((CLOSE - OPEN).value, 2), 11.47)