Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/backtest_lib/market/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import replace
from enum import Enum, auto
from enum import Enum, StrEnum, auto
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -32,6 +32,12 @@
from backtest_lib.universe.vector_mapping import VectorMapping


class Closed(StrEnum):
LEFT = "left"
RIGHT = "right"
BOTH = "both"


def get_pastview_from_mapping(backend: str) -> type[PastView]:
if backend == "polars":
from backtest_lib.market.polars_impl import PolarsPastView
Expand Down Expand Up @@ -121,6 +127,8 @@ def between(
self,
start: Index | str,
end: Index | str,
*,
closed: Closed | str = Closed.LEFT,
) -> Self:
"""PLACEHOLDER"""
... # will not clone data, must be contiguous, performs a binary search
Expand Down
34 changes: 7 additions & 27 deletions src/backtest_lib/market/polars_impl/_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import polars as pl
from numpy.typing import NDArray

from backtest_lib.market import Closed


@dataclass(frozen=True)
class SecurityAxis:
Expand Down Expand Up @@ -48,11 +50,11 @@ def take(self, idxs: Sequence[int] | NDArray[np.int64]) -> PeriodAxis:
new_labels = tuple(self.labels[i] for i in idxs)
return PeriodAxis(
dt64=self.dt64[idxs],
labels=tuple(new_labels),
labels=new_labels,
pos={lbl: i for i, lbl in enumerate(new_labels)},
)

def _slice(self, key: slice) -> PeriodAxis:
def slice(self, key: slice) -> PeriodAxis:
"""
Creates a new PeriodAxis from a slice
of the current PeriodAxis.
Expand Down Expand Up @@ -94,32 +96,10 @@ def bounds_between(
start: np.datetime64,
end: np.datetime64,
*,
closed: str = "both",
closed: str | Closed = Closed.BOTH,
) -> tuple[int, int]:
inc_start = closed in ("both", "left")
inc_end = closed in ("both", "right")
inc_start = closed in (Closed.BOTH, Closed.LEFT)
inc_end = closed in (Closed.BOTH, Closed.RIGHT)
left, _ = self.bounds_after(start, inclusive=inc_start)
_, right = self.bounds_before(end, inclusive=inc_end)
return left, right

def after(
self, start: np.datetime64, *, inclusive: bool = True
) -> NDArray[np.int64]:
left, right = self.bounds_after(start, inclusive=inclusive)
return np.arange(left, right, dtype=np.int64)

def before(
self, end: np.datetime64, *, inclusive: bool = False
) -> NDArray[np.int64]:
left, right = self.bounds_before(end, inclusive=inclusive)
return np.arange(left, right, dtype=np.int64)

def between(
self,
start: np.datetime64,
end: np.datetime64,
*,
closed: str = "left",
) -> NDArray[np.int64]:
left, right = self.bounds_between(start, end, closed=closed)
return np.arange(left, right, dtype=np.int64)
4 changes: 2 additions & 2 deletions src/backtest_lib/market/polars_impl/_past_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import polars as pl

from backtest_lib.market import ByPeriod, BySecurity, PastView
from backtest_lib.market import ByPeriod, BySecurity, Closed, PastView
from backtest_lib.market.plotting import (
ByPeriodPlotAccessor,
BySecurityPlotAccessor,
Expand Down Expand Up @@ -589,7 +589,7 @@ def between(
start: np.datetime64 | str,
end: np.datetime64 | str,
*,
closed: str = "left",
closed: Closed | str = Closed.LEFT,
) -> PolarsPastView:
left, right = self._period_axis.bounds_between(
to_npdt64(start), to_npdt64(end), closed=closed
Expand Down
2 changes: 1 addition & 1 deletion src/backtest_lib/market/polars_impl/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __getitem__(self, key: int | slice) -> T | Self:
# return self._scalar_type(self._vec[key])
else:
return PolarsTimeseries[T](
self._vec[key], self._axis._slice(key), self._name, self._scalar_type
self._vec[key], self._axis.slice(key), self._name, self._scalar_type
)

def before(self, end: np.datetime64 | str, *, inclusive=False) -> Self:
Expand Down
151 changes: 150 additions & 1 deletion tests/market/polars_impl/test_axis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from backtest_lib.market.polars_impl._axis import SecurityAxis
import datetime

import numpy as np
import polars as pl
import pytest
from polars.exceptions import InvalidOperationError

from backtest_lib.market import Closed
from backtest_lib.market.polars_impl._axis import PeriodAxis, SecurityAxis


def test_static_constructor_security_axis():
Expand All @@ -17,3 +25,144 @@ def test_returning_length_of_names_security_axis():
def test_ability_to_handle_empty_names_security_axis():
security_axis = SecurityAxis.from_names([])
assert len(security_axis.names) == 0


def test_casting_incompatible_data_to_date_should_throw_error_period_axis():
s_string = pl.Series("string", ["a", "b", "c"])
with pytest.raises(InvalidOperationError) as e_info:
PeriodAxis.from_series(s_string)
assert "conversion from `str` to `datetime[μs]` failed in column 'string'" in str(
e_info.value
)


def test_constructor_casting_required_period_axis():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
assert period_axis.labels == ("2023-01-01", "2023-01-02", "2023-01-03")
assert period_axis.pos == {"2023-01-01": 0, "2023-01-02": 1, "2023-01-03": 2}
expected_dt64_array = np.array(
["2023-01-01", "2023-01-02", "2023-01-03"], dtype="datetime64[us]"
)
np.testing.assert_array_equal(period_axis.dt64, expected_dt64_array)


def test_len_period_axis():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
assert len(period_axis) == 3


def test_slicing_incontiguous_sequence():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
sliced_period_axis = period_axis.slice(slice(None, None, 2))
assert sliced_period_axis.labels == ("2023-01-01", "2023-01-03")
assert sliced_period_axis.pos == {"2023-01-01": 0, "2023-01-03": 1}
expected_dt64_array = np.array(["2023-01-01", "2023-01-03"], dtype="datetime64[us]")
np.testing.assert_array_equal(sliced_period_axis.dt64, expected_dt64_array)


def test_slicing_contiguous_sequence():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
sliced_period_axis = period_axis.slice(slice(1, 3))
assert sliced_period_axis.labels == ("2023-01-02", "2023-01-03")
assert sliced_period_axis.pos == {"2023-01-02": 0, "2023-01-03": 1}
expected_dt64_array = np.array(["2023-01-02", "2023-01-03"], dtype="datetime64[us]")
np.testing.assert_array_equal(sliced_period_axis.dt64, expected_dt64_array)


def test_bounds_after():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
assert period_axis.bounds_after(np.datetime64("2023-01-02"), inclusive=True) == (
1,
3,
)
assert period_axis.bounds_after(np.datetime64("2023-01-02"), inclusive=False) == (
2,
3,
)


def test_bounds_before():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
assert period_axis.bounds_before(np.datetime64("2023-01-02"), inclusive=True) == (
0,
2,
)
assert period_axis.bounds_before(np.datetime64("2023-01-02"), inclusive=False) == (
0,
1,
)


def test_bounds_between():
series_date = pl.Series(
"dates",
[
datetime.date(2023, 1, 1),
datetime.date(2023, 1, 2),
datetime.date(2023, 1, 3),
],
)
period_axis = PeriodAxis.from_series(series_date)
assert period_axis.bounds_between(
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed=Closed.LEFT
) == (0, 2)
assert period_axis.bounds_between(
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed="left"
) == (0, 2)
assert period_axis.bounds_between(
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed=Closed.RIGHT
) == (1, 3)
assert period_axis.bounds_between(
np.datetime64("2023-01-01"), np.datetime64("2023-01-03"), closed=Closed.BOTH
) == (0, 3)
assert period_axis.bounds_between(
np.datetime64("2023-01-02"), np.datetime64("2023-01-03"), closed=Closed.BOTH
) == (1, 3)