Skip to content

Commit 9db3fda

Browse files
Merge pull request #36 from UnravelSports/feat/unravel_graph
fix padding datatypes
2 parents 7df593a + fbcd405 commit 9db3fda

File tree

6 files changed

+32
-7
lines changed

6 files changed

+32
-7
lines changed

tests/files/plot/test-1.mp4

2.88 KB
Binary file not shown.
16.3 KB
Loading

tests/files/plot/test-png.png

16.3 KB
Loading

tests/test_kloppy_polars.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,21 @@ def custom_edge_feature(**kwargs):
350350
verbose=False,
351351
)
352352

353+
# def test_skillcorner_warning(
354+
# self, kloppy_dataset: TrackingDataset
355+
# ) -> KloppyPolarsDataset:
356+
# with pytest.warns(UserWarning):
357+
# dataset = KloppyPolarsDataset(
358+
# kloppy_dataset=kloppy_dataset,
359+
# ball_carrier_threshold=25.0,
360+
# max_player_speed=12.0,
361+
# max_player_acceleration=12.0,
362+
# max_ball_speed=13.5,
363+
# max_ball_acceleration=100,
364+
# )
365+
# dataset.add_dummy_labels(by=["game_id", "frame_id"], random_seed=42)
366+
# dataset.add_graph_ids(by=["game_id", "frame_id"])
367+
353368
def test_incorrect_custom_features_no_decorator(
354369
self, kloppy_polars_dataset: KloppyPolarsDataset
355370
) -> SoccerGraphConverterPolars:

unravel/soccer/dataset/kloppy_polars.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import polars as pl
2727

28+
import warnings
29+
2830

2931
DEFAULT_PLAYER_SMOOTHING_PARAMS = {"window_length": 7, "polyorder": 1}
3032
DEFAULT_BALL_SMOOTHING_PARAMS = {"window_length": 3, "polyorder": 1}
@@ -688,13 +690,16 @@ def load(
688690
df = df.drop(["dx", "dy", "dz", "dt", "dvx", "dvy", "dvz"])
689691
df = df.filter(~(pl.col(Column.X).is_null() & pl.col(Column.Y).is_null()))
690692

691-
if (
692-
df[Column.BALL_OWNING_TEAM_ID].is_null().all()
693-
and self.ball_carrier_threshold is None
694-
):
695-
raise ValueError(
696-
f"This dataset requires us to infer the {Column.BALL_OWNING_TEAM_ID}, please specifiy a ball_carrier_threshold (float) to do so."
697-
)
693+
if df[Column.BALL_OWNING_TEAM_ID].is_null().all():
694+
if self.ball_carrier_threshold is None:
695+
raise ValueError(
696+
f"This dataset requires us to infer the {Column.BALL_OWNING_TEAM_ID}, please specifiy a ball_carrier_threshold (float) to do so."
697+
)
698+
else:
699+
warnings.warn(
700+
"This dataset does not come with 'ball owning team' information. As a result we infer this using distance to ball using the 'ball_carrier_threshold'. Please note this might cause unexpected results.",
701+
UserWarning,
702+
)
698703

699704
df = self.__infer_ball_carrier(df)
700705

unravel/soccer/graphs/graph_converter_pl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ def _apply_padding(self) -> pl.DataFrame:
260260
for col in user_defined_columns
261261
]
262262
)
263+
264+
padding_df = padding_df.with_columns(
265+
[pl.col(col).cast(df.schema[col]).alias(col) for col in group_by_columns]
266+
)
267+
263268
padding_df = padding_df.join(
264269
(
265270
df.unique(group_by_columns).select(

0 commit comments

Comments
 (0)