Skip to content

Commit 7c8fc7e

Browse files
Merge pull request #6 from UnravelSports/feat/kloppy-polars
⚽ Polars implementation
2 parents 663a024 + 77ab8c2 commit 7c8fc7e

31 files changed

+2934
-323
lines changed

examples/1_kloppy_gnn_train.ipynb

Lines changed: 94 additions & 48 deletions
Large diffs are not rendered by default.

examples/2_big_data_bowl_guide.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@
218218
"name": "python",
219219
"nbconvert_exporter": "python",
220220
"pygments_lexer": "ipython3",
221-
"version": "3.11.9"
221+
"version": "3.11.11"
222222
}
223223
},
224224
"nbformat": 4,

examples/deprecated/1_kloppy_gnn_train.ipynb

Lines changed: 794 additions & 0 deletions
Large diffs are not rendered by default.

examples/graphs_faq.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ In section 6.1 we can see what this looks like in Python.
5151
| `max_ball_acceleration` | float | The maximum speed of the ball in yards per second squared. Used for normalizing node features. | 10.0 | 🏈 |
5252
| `attacking_non_qb_node_value` | float | Value for the node feature when player is NOT the QB, but is on the attacking team | 0.1 | 🏈 |
5353
| `chunk_size` | int | Set to determine size of conversions from Polars to Graphs. Preferred setting depends on available computing power | 2_000 | 🏈 |
54-
| `ball_carrier_threshold` | float | The distance threshold to determine the ball carrier in meters. If no ball carrier within ball_carrier_threshold, we skip the frame. | 25.0 ||
55-
| `boundary_correction` | float | A correction factor for boundary calculations, used to correct out of bounds as a percentage (Used as 1+boundary_correction, i.e., 0.05). Not setting this might lead to players outside the pitch markings to have values that fall slightly outside of our normalization range. When we set boundary_correction, any players outside the pitch will be moved to be on the closest line. | None ||
56-
| `infer_ball_ownership` | bool | Infers 'attacking_team' if no 'ball_owning_team' exist (in Kloppy TrackingDataset) by finding the player closest to the ball using ball xyz, uses 'ball_carrier_threshold' as a cut-off. | True ||
57-
| `infer_goalkeepers` | bool | Set True if no GK label is provided, set False for incomplete (broadcast tracking) data that might not have a GK in every frame. | True ||
5854
| `non_potential_receiver_node_value` | float | Value for the node feature when player is NOT a potential receiver of a pass (when on opposing team or in possession of the ball). Should be between 0 and 1 including. | 0.1 ||
5955

6056

@@ -64,7 +60,7 @@ In section 6.1 we can see what this looks like in Python.
6460
#### C. What features does each Graph have?
6561

6662
<details>
67-
<summary> <b><i> 🌀 ⚽ Expand for a full list of Soccer features </b></i></summary>
63+
<summary> <b><i> 🌀 ⚽ Expand for a full list of Soccer features (note: `SoccerGraphConverter`, `SoccerGraphConverterPolars` has slightly different features) </b></i></summary>
6864

6965
| Variable | Datatype | Index | Features |
7066
|----------|-----------------------------------|-------|---------------------------------------------------------------------------------------------------------------------------------|

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
numpy==1.26.4
22
spektral==1.2.0
3-
kloppy==3.15.0
3+
kloppy==3.16.0
44
tensorflow>=2.14.0; platform_machine != 'arm64' or platform_system != 'Darwin'
55
tensorflow-macos>=2.14.0; platform_machine == 'arm64' and platform_system == 'Darwin'
66
keras==2.14.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def read_version():
3232
python_requires="~=3.11",
3333
install_requires=[
3434
"spektral==1.2.0",
35-
"kloppy==3.15.0",
35+
"kloppy==3.16.0",
3636
"tensorflow>=2.14.0;platform_machine != 'arm64' or platform_system != 'Darwin'",
3737
"tensorflow-macos>=2.14.0;platform_machine == 'arm64' and platform_system == 'Darwin'",
3838
"keras==2.14.0",

tests/test_bigdb.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
AmericanFootballGraphConverter,
2121
AmericanFootballPitchDimensions,
2222
)
23+
from unravel.american_football.graphs.dataset import Constant
2324
from unravel.utils import (
24-
add_graph_id_column,
25-
add_dummy_label_column,
2625
flatten_to_reshaped_array,
2726
make_sparse,
2827
CustomSpektralDataset,
@@ -53,10 +52,8 @@ def dataset(self, coordinates: str, players: str, plays: str):
5352
plays_file_path=plays,
5453
)
5554
bdb_dataset.load()
56-
bdb_dataset.add_graph_ids(by=["gameId", "playId"], column_name="graph_id")
57-
bdb_dataset.add_dummy_labels(
58-
by=["gameId", "playId", "frameId"], column_name="label"
59-
)
55+
bdb_dataset.add_graph_ids(by=["gameId", "playId"])
56+
bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"])
6057
return bdb_dataset
6158

6259
@pytest.fixture
@@ -141,8 +138,6 @@ def node_feature_values(self):
141138
@pytest.fixture
142139
def arguments(self):
143140
return dict(
144-
label_col="label",
145-
graph_id_col="graph_id",
146141
max_player_speed=8.0,
147142
max_ball_speed=28.0,
148143
max_player_acceleration=10.0,
@@ -161,8 +156,6 @@ def arguments(self):
161156
@pytest.fixture
162157
def non_default_arguments(self):
163158
return dict(
164-
label_col="label",
165-
graph_id_col="graph_id",
166159
max_player_speed=12.0,
167160
max_ball_speed=24.0,
168161
max_player_acceleration=11.0,
@@ -199,8 +192,8 @@ def test_settings(self, gnnc_non_default, non_default_arguments):
199192
assert settings.pitch_dimensions.y_dim.min == -26.65
200193
assert settings.pitch_dimensions.end_zone == 50.0
201194

202-
assert settings.ball_id == "football"
203-
assert settings.qb_id == "QB"
195+
assert Constant.BALL == "football"
196+
assert Constant.QB == "QB"
204197
assert settings.max_height == 225.0
205198
assert settings.min_height == 150.0
206199
assert settings.max_weight == 200.0

tests/test_kloppy_polars.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from pathlib import Path
2+
from unravel.soccer import SoccerGraphConverterPolars, KloppyPolarsDataset
3+
from unravel.utils import (
4+
dummy_labels,
5+
dummy_graph_ids,
6+
CustomSpektralDataset,
7+
)
8+
9+
from kloppy import skillcorner
10+
from kloppy.domain import Ground, TrackingDataset, Orientation
11+
from typing import List, Dict
12+
13+
from spektral.data import Graph
14+
15+
import pytest
16+
17+
import numpy as np
18+
19+
20+
class TestKloppyPolarsData:
21+
@pytest.fixture
22+
def match_data(self, base_dir: Path) -> str:
23+
return base_dir / "files" / "skillcorner_match_data.json"
24+
25+
@pytest.fixture
26+
def structured_data(self, base_dir: Path) -> str:
27+
return base_dir / "files" / "skillcorner_structured_data.json.gz"
28+
29+
@pytest.fixture()
30+
def kloppy_dataset(self, match_data: str, structured_data: str) -> TrackingDataset:
31+
return skillcorner.load(
32+
raw_data=structured_data,
33+
meta_data=match_data,
34+
coordinates="tracab",
35+
include_empty_frames=False,
36+
limit=500,
37+
)
38+
39+
@pytest.fixture()
40+
def kloppy_polars_dataset(
41+
self, kloppy_dataset: TrackingDataset
42+
) -> KloppyPolarsDataset:
43+
dataset = KloppyPolarsDataset(
44+
kloppy_dataset=kloppy_dataset,
45+
ball_carrier_threshold=25.0,
46+
)
47+
dataset.load()
48+
dataset.add_dummy_labels(by=["game_id", "frame_id"])
49+
dataset.add_graph_ids(by=["game_id", "frame_id"])
50+
return dataset
51+
52+
@pytest.fixture()
53+
def spc_padding(
54+
self, kloppy_polars_dataset: KloppyPolarsDataset
55+
) -> SoccerGraphConverterPolars:
56+
return SoccerGraphConverterPolars(
57+
dataset=kloppy_polars_dataset,
58+
chunk_size=2_0000,
59+
non_potential_receiver_node_value=0.1,
60+
max_player_speed=12.0,
61+
max_player_acceleration=12.0,
62+
max_ball_speed=13.5,
63+
max_ball_acceleration=100,
64+
self_loop_ball=True,
65+
adjacency_matrix_connect_type="ball",
66+
adjacency_matrix_type="split_by_team",
67+
label_type="binary",
68+
defending_team_node_value=0.0,
69+
random_seed=False,
70+
pad=True,
71+
verbose=False,
72+
)
73+
74+
@pytest.fixture()
75+
def soccer_polars_converter(
76+
self, kloppy_polars_dataset: KloppyPolarsDataset
77+
) -> SoccerGraphConverterPolars:
78+
79+
return SoccerGraphConverterPolars(
80+
dataset=kloppy_polars_dataset,
81+
chunk_size=2_0000,
82+
non_potential_receiver_node_value=0.1,
83+
max_player_speed=12.0,
84+
max_player_acceleration=12.0,
85+
max_ball_speed=13.5,
86+
max_ball_acceleration=100,
87+
self_loop_ball=True,
88+
adjacency_matrix_connect_type="ball",
89+
adjacency_matrix_type="split_by_team",
90+
label_type="binary",
91+
defending_team_node_value=0.0,
92+
random_seed=False,
93+
pad=False,
94+
verbose=False,
95+
)
96+
97+
def test_padding(self, spc_padding: SoccerGraphConverterPolars):
98+
spektral_graphs = spc_padding.to_spektral_graphs()
99+
100+
assert 1 == 1
101+
102+
data = spektral_graphs
103+
assert len(data) == 384
104+
assert isinstance(data[0], Graph)
105+
106+
def test_to_spektral_graph(
107+
self, soccer_polars_converter: SoccerGraphConverterPolars
108+
):
109+
"""
110+
Test navigating (next/prev) through events
111+
"""
112+
spektral_graphs = soccer_polars_converter.to_spektral_graphs()
113+
114+
assert 1 == 1
115+
116+
data = spektral_graphs
117+
assert data[0].id == "2417-1529"
118+
assert len(data) == 489
119+
assert isinstance(data[0], Graph)
120+
121+
x = data[0].x
122+
n_players = x.shape[0]
123+
assert x.shape == (n_players, 15)
124+
assert 0.4524340998288571 == pytest.approx(x[0, 0], abs=1e-5)
125+
assert 0.9948105277764999 == pytest.approx(x[0, 4], abs=1e-5)
126+
assert 0.2941671698429814 == pytest.approx(x[8, 2], abs=1e-5)
127+
128+
e = data[0].e
129+
assert e.shape == (129, 6)
130+
assert 0.0 == pytest.approx(e[0, 0], abs=1e-5)
131+
assert 0.5 == pytest.approx(e[0, 4], abs=1e-5)
132+
assert 0.7140882876637022 == pytest.approx(e[8, 2], abs=1e-5)
133+
134+
a = data[0].a
135+
assert a.shape == (n_players, n_players)
136+
assert 1.0 == pytest.approx(a[0, 0], abs=1e-5)
137+
assert 1.0 == pytest.approx(a[0, 4], abs=1e-5)
138+
assert 0.0 == pytest.approx(a[8, 2], abs=1e-5)
139+
140+
dataset = CustomSpektralDataset(graphs=spektral_graphs)
141+
N, F, S, n_out, n = dataset.dimensions()
142+
assert N == 20
143+
assert F == 15
144+
assert S == 6
145+
assert n_out == 1
146+
assert n == 489
147+
148+
train, test, val = dataset.split_test_train_validation(
149+
split_train=4,
150+
split_test=1,
151+
split_validation=1,
152+
by_graph_id=True,
153+
random_seed=42,
154+
)
155+
assert train.n_graphs == 326
156+
assert test.n_graphs == 81
157+
assert val.n_graphs == 82
158+
159+
train, test, val = dataset.split_test_train_validation(
160+
split_train=4,
161+
split_test=1,
162+
split_validation=1,
163+
by_graph_id=False,
164+
random_seed=42,
165+
)
166+
assert train.n_graphs == 326
167+
assert test.n_graphs == 81
168+
assert val.n_graphs == 82
169+
170+
train, test = dataset.split_test_train(
171+
split_train=4, split_test=1, by_graph_id=False, random_seed=42
172+
)
173+
assert train.n_graphs == 391
174+
assert test.n_graphs == 98
175+
176+
train, test = dataset.split_test_train(
177+
split_train=4, split_test=5, by_graph_id=False, random_seed=42
178+
)
179+
assert train.n_graphs == 217
180+
assert test.n_graphs == 272
181+
182+
with pytest.raises(
183+
NotImplementedError,
184+
match="Make sure split_train > split_test >= split_validation, other behaviour is not supported when by_graph_id is True...",
185+
):
186+
dataset.split_test_train(
187+
split_train=4, split_test=5, by_graph_id=True, random_seed=42
188+
)

tests/test_spektral.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,8 @@ def bdb_dataset(self, coordinates: str, players: str, plays: str):
4545
plays_file_path=plays,
4646
)
4747
bdb_dataset.load()
48-
bdb_dataset.add_graph_ids(by=["gameId", "playId"], column_name="graph_id")
49-
bdb_dataset.add_dummy_labels(
50-
by=["gameId", "playId", "frameId"], column_name="label"
51-
)
48+
bdb_dataset.add_graph_ids(by=["gameId", "playId"])
49+
bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"])
5250
return bdb_dataset
5351

5452
@pytest.fixture
@@ -122,8 +120,6 @@ def bdb_converter(
122120
) -> AmericanFootballGraphConverter:
123121
return AmericanFootballGraphConverter(
124122
dataset=bdb_dataset,
125-
label_col="label",
126-
graph_id_col="graph_id",
127123
max_player_speed=8.0,
128124
max_ball_speed=28.0,
129125
max_player_acceleration=10.0,

unravel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.3.0"
22

33
from .soccer import *
44
from .american_football import *

0 commit comments

Comments
 (0)