Skip to content

Commit 7954ad2

Browse files
Closes #3: add tests for PeriodAxis
1 parent 0cde7f4 commit 7954ad2

File tree

5 files changed

+101
-32
lines changed

5 files changed

+101
-32
lines changed

src/backtest_lib/market/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from collections.abc import Iterable, Iterator, Mapping, Sequence
66
from dataclasses import replace
7-
from enum import Enum, auto
7+
from enum import Enum, StrEnum, auto
88
from typing import (
99
TYPE_CHECKING,
1010
Any,
@@ -32,6 +32,12 @@
3232
from backtest_lib.universe.vector_mapping import VectorMapping
3333

3434

35+
class Closed(StrEnum):
36+
LEFT = "left"
37+
RIGHT = "right"
38+
BOTH = "both"
39+
40+
3541
def get_pastview_from_mapping(backend: str) -> type[PastView]:
3642
if backend == "polars":
3743
from backtest_lib.market.polars_impl import PolarsPastView
@@ -121,6 +127,8 @@ def between(
121127
self,
122128
start: Index | str,
123129
end: Index | str,
130+
*,
131+
closed: Closed | str = Closed.LEFT,
124132
) -> Self:
125133
"""PLACEHOLDER"""
126134
... # will not clone data, must be contiguous, performs a binary search

src/backtest_lib/market/polars_impl/_axis.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import polars as pl
88
from numpy.typing import NDArray
99

10+
from backtest_lib.market import Closed
11+
1012

1113
@dataclass(frozen=True)
1214
class SecurityAxis:
@@ -48,11 +50,11 @@ def take(self, idxs: Sequence[int] | NDArray[np.int64]) -> PeriodAxis:
4850
new_labels = tuple(self.labels[i] for i in idxs)
4951
return PeriodAxis(
5052
dt64=self.dt64[idxs],
51-
labels=tuple(new_labels),
53+
labels=new_labels,
5254
pos={lbl: i for i, lbl in enumerate(new_labels)},
5355
)
5456

55-
def _slice(self, key: slice) -> PeriodAxis:
57+
def slice(self, key: slice) -> PeriodAxis:
5658
"""
5759
Creates a new PeriodAxis from a slice
5860
of the current PeriodAxis.
@@ -94,32 +96,10 @@ def bounds_between(
9496
start: np.datetime64,
9597
end: np.datetime64,
9698
*,
97-
closed: str = "both",
99+
closed: str | Closed = Closed.BOTH,
98100
) -> tuple[int, int]:
99-
inc_start = closed in ("both", "left")
100-
inc_end = closed in ("both", "right")
101+
inc_start = closed in (Closed.BOTH, Closed.LEFT)
102+
inc_end = closed in (Closed.BOTH, Closed.RIGHT)
101103
left, _ = self.bounds_after(start, inclusive=inc_start)
102104
_, right = self.bounds_before(end, inclusive=inc_end)
103105
return left, right
104-
105-
def after(
106-
self, start: np.datetime64, *, inclusive: bool = True
107-
) -> NDArray[np.int64]:
108-
left, right = self.bounds_after(start, inclusive=inclusive)
109-
return np.arange(left, right, dtype=np.int64)
110-
111-
def before(
112-
self, end: np.datetime64, *, inclusive: bool = False
113-
) -> NDArray[np.int64]:
114-
left, right = self.bounds_before(end, inclusive=inclusive)
115-
return np.arange(left, right, dtype=np.int64)
116-
117-
def between(
118-
self,
119-
start: np.datetime64,
120-
end: np.datetime64,
121-
*,
122-
closed: str = "left",
123-
) -> NDArray[np.int64]:
124-
left, right = self.bounds_between(start, end, closed=closed)
125-
return np.arange(left, right, dtype=np.int64)

src/backtest_lib/market/polars_impl/_past_view.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy as np
1717
import polars as pl
1818

19-
from backtest_lib.market import ByPeriod, BySecurity, PastView
19+
from backtest_lib.market import ByPeriod, BySecurity, Closed, PastView
2020
from backtest_lib.market.plotting import (
2121
ByPeriodPlotAccessor,
2222
BySecurityPlotAccessor,
@@ -589,7 +589,7 @@ def between(
589589
start: np.datetime64 | str,
590590
end: np.datetime64 | str,
591591
*,
592-
closed: str = "left",
592+
closed: Closed | str = Closed.LEFT,
593593
) -> PolarsPastView:
594594
left, right = self._period_axis.bounds_between(
595595
to_npdt64(start), to_npdt64(end), closed=closed

src/backtest_lib/market/polars_impl/_timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __getitem__(self, key: int | slice) -> T | Self:
8989
# return self._scalar_type(self._vec[key])
9090
else:
9191
return PolarsTimeseries[T](
92-
self._vec[key], self._axis._slice(key), self._name, self._scalar_type
92+
self._vec[key], self._axis.slice(key), self._name, self._scalar_type
9393
)
9494

9595
def before(self, end: np.datetime64 | str, *, inclusive=False) -> Self:

tests/market/polars_impl/test_axis.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from backtest_lib.market.polars_impl._axis import SecurityAxis
1+
import datetime
2+
3+
import numpy as np
4+
import polars as pl
5+
import pytest
6+
from polars.exceptions import InvalidOperationError
7+
8+
from backtest_lib.market.polars_impl._axis import PeriodAxis, SecurityAxis
29

310

411
def test_static_constructor_security_axis():
@@ -17,3 +24,77 @@ def test_returning_length_of_names_security_axis():
1724
def test_ability_to_handle_empty_names_security_axis():
1825
security_axis = SecurityAxis.from_names([])
1926
assert len(security_axis.names) == 0
27+
28+
29+
def test_casting_incompatible_data_to_date_should_throw_error_period_axis():
30+
s_string = pl.Series("string", ["a", "b", "c"])
31+
with pytest.raises(InvalidOperationError) as e_info:
32+
PeriodAxis.from_series(s_string)
33+
assert "conversion from `str` to `datetime[μs]` failed in column 'string'" in str(
34+
e_info.value
35+
)
36+
37+
38+
def test_constructor_casting_required_period_axis():
39+
series_date = pl.Series(
40+
"dates",
41+
[
42+
datetime.date(2023, 1, 1),
43+
datetime.date(2023, 1, 2),
44+
datetime.date(2023, 1, 3),
45+
],
46+
)
47+
period_axis = PeriodAxis.from_series(series_date)
48+
assert period_axis.labels == ("2023-01-01", "2023-01-02", "2023-01-03")
49+
assert period_axis.pos == {"2023-01-01": 0, "2023-01-02": 1, "2023-01-03": 2}
50+
expected_dt64_array = np.array(
51+
["2023-01-01", "2023-01-02", "2023-01-03"], dtype="datetime64[us]"
52+
)
53+
np.testing.assert_array_equal(period_axis.dt64, expected_dt64_array)
54+
55+
56+
def test_len_period_axis():
57+
series_date = pl.Series(
58+
"dates",
59+
[
60+
datetime.date(2023, 1, 1),
61+
datetime.date(2023, 1, 2),
62+
datetime.date(2023, 1, 3),
63+
],
64+
)
65+
period_axis = PeriodAxis.from_series(series_date)
66+
assert len(period_axis) == 3
67+
68+
69+
def test_slicing_incontiguous_sequence():
70+
series_date = pl.Series(
71+
"dates",
72+
[
73+
datetime.date(2023, 1, 1),
74+
datetime.date(2023, 1, 2),
75+
datetime.date(2023, 1, 3),
76+
],
77+
)
78+
period_axis = PeriodAxis.from_series(series_date)
79+
sliced_period_axis = period_axis.slice(slice(None, None, 2))
80+
assert sliced_period_axis.labels == ("2023-01-01", "2023-01-03")
81+
assert sliced_period_axis.pos == {"2023-01-01": 0, "2023-01-03": 1}
82+
expected_dt64_array = np.array(["2023-01-01", "2023-01-03"], dtype="datetime64[us]")
83+
np.testing.assert_array_equal(sliced_period_axis.dt64, expected_dt64_array)
84+
85+
86+
def test_slicing_contiguous_sequence():
87+
series_date = pl.Series(
88+
"dates",
89+
[
90+
datetime.date(2023, 1, 1),
91+
datetime.date(2023, 1, 2),
92+
datetime.date(2023, 1, 3),
93+
],
94+
)
95+
period_axis = PeriodAxis.from_series(series_date)
96+
sliced_period_axis = period_axis.slice(slice(1, 3))
97+
assert sliced_period_axis.labels == ("2023-01-02", "2023-01-03")
98+
assert sliced_period_axis.pos == {"2023-01-02": 0, "2023-01-03": 1}
99+
expected_dt64_array = np.array(["2023-01-02", "2023-01-03"], dtype="datetime64[us]")
100+
np.testing.assert_array_equal(sliced_period_axis.dt64, expected_dt64_array)

0 commit comments

Comments
 (0)