Skip to content

Commit 153f902

Browse files
xiaotongchang-creatorjathoms
authored andcommitted
Closes #3: add tests for PeriodAxis
1 parent f2ce5dc commit 153f902

File tree

5 files changed

+169
-32
lines changed

5 files changed

+169
-32
lines changed

src/backtest_lib/market/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from collections.abc import Iterable, Iterator, Mapping, Sequence
66
from dataclasses import replace
7-
from enum import Enum, auto
7+
from enum import Enum, StrEnum, auto
88
from typing import (
99
TYPE_CHECKING,
1010
Any,
@@ -32,6 +32,12 @@
3232
from backtest_lib.universe.vector_mapping import VectorMapping
3333

3434

35+
class Closed(StrEnum):
36+
LEFT = "left"
37+
RIGHT = "right"
38+
BOTH = "both"
39+
40+
3541
def get_pastview_from_mapping(backend: str) -> type[PastView]:
3642
if backend == "polars":
3743
from backtest_lib.market.polars_impl import PolarsPastView
@@ -121,6 +127,8 @@ def between(
121127
self,
122128
start: Index | str,
123129
end: Index | str,
130+
*,
131+
closed: Closed | str = Closed.LEFT,
124132
) -> Self:
125133
"""PLACEHOLDER"""
126134
... # will not clone data, must be contiguous, performs a binary search

src/backtest_lib/market/polars_impl/_axis.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import polars as pl
88
from numpy.typing import NDArray
99

10+
from backtest_lib.market import Closed
11+
1012

1113
@dataclass(frozen=True)
1214
class SecurityAxis:
@@ -48,11 +50,11 @@ def take(self, idxs: Sequence[int] | NDArray[np.int64]) -> PeriodAxis:
4850
new_labels = tuple(self.labels[i] for i in idxs)
4951
return PeriodAxis(
5052
dt64=self.dt64[idxs],
51-
labels=tuple(new_labels),
53+
labels=new_labels,
5254
pos={lbl: i for i, lbl in enumerate(new_labels)},
5355
)
5456

55-
def _slice(self, key: slice) -> PeriodAxis:
57+
def slice(self, key: slice) -> PeriodAxis:
5658
"""
5759
Creates a new PeriodAxis from a slice
5860
of the current PeriodAxis.
@@ -94,32 +96,10 @@ def bounds_between(
9496
start: np.datetime64,
9597
end: np.datetime64,
9698
*,
97-
closed: str = "both",
99+
closed: str | Closed = Closed.BOTH,
98100
) -> tuple[int, int]:
99-
inc_start = closed in ("both", "left")
100-
inc_end = closed in ("both", "right")
101+
inc_start = closed in (Closed.BOTH, Closed.LEFT)
102+
inc_end = closed in (Closed.BOTH, Closed.RIGHT)
101103
left, _ = self.bounds_after(start, inclusive=inc_start)
102104
_, right = self.bounds_before(end, inclusive=inc_end)
103105
return left, right
104-
105-
def after(
106-
self, start: np.datetime64, *, inclusive: bool = True
107-
) -> NDArray[np.int64]:
108-
left, right = self.bounds_after(start, inclusive=inclusive)
109-
return np.arange(left, right, dtype=np.int64)
110-
111-
def before(
112-
self, end: np.datetime64, *, inclusive: bool = False
113-
) -> NDArray[np.int64]:
114-
left, right = self.bounds_before(end, inclusive=inclusive)
115-
return np.arange(left, right, dtype=np.int64)
116-
117-
def between(
118-
self,
119-
start: np.datetime64,
120-
end: np.datetime64,
121-
*,
122-
closed: str = "left",
123-
) -> NDArray[np.int64]:
124-
left, right = self.bounds_between(start, end, closed=closed)
125-
return np.arange(left, right, dtype=np.int64)

src/backtest_lib/market/polars_impl/_past_view.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy as np
1717
import polars as pl
1818

19-
from backtest_lib.market import ByPeriod, BySecurity, PastView
19+
from backtest_lib.market import ByPeriod, BySecurity, Closed, PastView
2020
from backtest_lib.market.plotting import (
2121
ByPeriodPlotAccessor,
2222
BySecurityPlotAccessor,
@@ -589,7 +589,7 @@ def between(
589589
start: np.datetime64 | str,
590590
end: np.datetime64 | str,
591591
*,
592-
closed: str = "left",
592+
closed: Closed | str = Closed.LEFT,
593593
) -> PolarsPastView:
594594
left, right = self._period_axis.bounds_between(
595595
to_npdt64(start), to_npdt64(end), closed=closed

src/backtest_lib/market/polars_impl/_timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __getitem__(self, key: int | slice) -> T | Self:
8989
# return self._scalar_type(self._vec[key])
9090
else:
9191
return PolarsTimeseries[T](
92-
self._vec[key], self._axis._slice(key), self._name, self._scalar_type
92+
self._vec[key], self._axis.slice(key), self._name, self._scalar_type
9393
)
9494

9595
def before(self, end: np.datetime64 | str, *, inclusive=False) -> Self:

tests/market/polars_impl/test_axis.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from backtest_lib.market.polars_impl._axis import SecurityAxis
1+
import datetime
2+
3+
import numpy as np
4+
import polars as pl
5+
import pytest
6+
from polars.exceptions import InvalidOperationError
7+
8+
from backtest_lib.market import Closed
9+
from backtest_lib.market.polars_impl._axis import PeriodAxis, SecurityAxis
210

311

412
def test_static_constructor_security_axis():
@@ -17,3 +25,144 @@ def test_returning_length_of_names_security_axis():
1725
def test_ability_to_handle_empty_names_security_axis():
1826
security_axis = SecurityAxis.from_names([])
1927
assert len(security_axis.names) == 0
28+
29+
30+
def test_casting_incompatible_data_to_date_should_throw_error_period_axis():
31+
s_string = pl.Series("string", ["a", "b", "c"])
32+
with pytest.raises(InvalidOperationError) as e_info:
33+
PeriodAxis.from_series(s_string)
34+
assert "conversion from `str` to `datetime[μs]` failed in column 'string'" in str(
35+
e_info.value
36+
)
37+
38+
39+
def test_constructor_casting_required_period_axis():
40+
series_date = pl.Series(
41+
"dates",
42+
[
43+
datetime.date(2023, 1, 1),
44+
datetime.date(2023, 1, 2),
45+
datetime.date(2023, 1, 3),
46+
],
47+
)
48+
period_axis = PeriodAxis.from_series(series_date)
49+
assert period_axis.labels == ("2023-01-01", "2023-01-02", "2023-01-03")
50+
assert period_axis.pos == {"2023-01-01": 0, "2023-01-02": 1, "2023-01-03": 2}
51+
expected_dt64_array = np.array(
52+
["2023-01-01", "2023-01-02", "2023-01-03"], dtype="datetime64[us]"
53+
)
54+
np.testing.assert_array_equal(period_axis.dt64, expected_dt64_array)
55+
56+
57+
def test_len_period_axis():
58+
series_date = pl.Series(
59+
"dates",
60+
[
61+
datetime.date(2023, 1, 1),
62+
datetime.date(2023, 1, 2),
63+
datetime.date(2023, 1, 3),
64+
],
65+
)
66+
period_axis = PeriodAxis.from_series(series_date)
67+
assert len(period_axis) == 3
68+
69+
70+
def test_slicing_incontiguous_sequence():
71+
series_date = pl.Series(
72+
"dates",
73+
[
74+
datetime.date(2023, 1, 1),
75+
datetime.date(2023, 1, 2),
76+
datetime.date(2023, 1, 3),
77+
],
78+
)
79+
period_axis = PeriodAxis.from_series(series_date)
80+
sliced_period_axis = period_axis.slice(slice(None, None, 2))
81+
assert sliced_period_axis.labels == ("2023-01-01", "2023-01-03")
82+
assert sliced_period_axis.pos == {"2023-01-01": 0, "2023-01-03": 1}
83+
expected_dt64_array = np.array(["2023-01-01", "2023-01-03"], dtype="datetime64[us]")
84+
np.testing.assert_array_equal(sliced_period_axis.dt64, expected_dt64_array)
85+
86+
87+
def test_slicing_contiguous_sequence():
88+
series_date = pl.Series(
89+
"dates",
90+
[
91+
datetime.date(2023, 1, 1),
92+
datetime.date(2023, 1, 2),
93+
datetime.date(2023, 1, 3),
94+
],
95+
)
96+
period_axis = PeriodAxis.from_series(series_date)
97+
sliced_period_axis = period_axis.slice(slice(1, 3))
98+
assert sliced_period_axis.labels == ("2023-01-02", "2023-01-03")
99+
assert sliced_period_axis.pos == {"2023-01-02": 0, "2023-01-03": 1}
100+
expected_dt64_array = np.array(["2023-01-02", "2023-01-03"], dtype="datetime64[us]")
101+
np.testing.assert_array_equal(sliced_period_axis.dt64, expected_dt64_array)
102+
103+
104+
def test_bounds_after():
105+
series_date = pl.Series(
106+
"dates",
107+
[
108+
datetime.date(2023, 1, 1),
109+
datetime.date(2023, 1, 2),
110+
datetime.date(2023, 1, 3),
111+
],
112+
)
113+
period_axis = PeriodAxis.from_series(series_date)
114+
assert period_axis.bounds_after(np.datetime64("2023-01-02"), inclusive=True) == (
115+
1,
116+
3,
117+
)
118+
assert period_axis.bounds_after(np.datetime64("2023-01-02"), inclusive=False) == (
119+
2,
120+
3,
121+
)
122+
123+
124+
def test_bounds_before():
125+
series_date = pl.Series(
126+
"dates",
127+
[
128+
datetime.date(2023, 1, 1),
129+
datetime.date(2023, 1, 2),
130+
datetime.date(2023, 1, 3),
131+
],
132+
)
133+
period_axis = PeriodAxis.from_series(series_date)
134+
assert period_axis.bounds_before(np.datetime64("2023-01-02"), inclusive=True) == (
135+
0,
136+
2,
137+
)
138+
assert period_axis.bounds_before(np.datetime64("2023-01-02"), inclusive=False) == (
139+
0,
140+
1,
141+
)
142+
143+
144+
def test_bounds_between():
145+
series_date = pl.Series(
146+
"dates",
147+
[
148+
datetime.date(2023, 1, 1),
149+
datetime.date(2023, 1, 2),
150+
datetime.date(2023, 1, 3),
151+
],
152+
)
153+
period_axis = PeriodAxis.from_series(series_date)
154+
assert period_axis.bounds_between(
155+
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed=Closed.LEFT
156+
) == (0, 2)
157+
assert period_axis.bounds_between(
158+
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed="left"
159+
) == (0, 2)
160+
assert period_axis.bounds_between(
161+
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed=Closed.RIGHT
162+
) == (1, 3)
163+
assert period_axis.bounds_between(
164+
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed=Closed.BOTH
165+
) == (0, 3)
166+
assert period_axis.bounds_between(
167+
np.datetime64("2023-01-02"), np.datetime64("2023-01-03"), closed=Closed.BOTH
168+
) == (1, 3)

0 commit comments

Comments
 (0)