Skip to content

Commit 2bd24d7

Browse files
Merge pull request #81 from UnravelSports/bug/format
Bug/format
2 parents d9b4fbe + 83ce0b5 commit 2bd24d7

File tree

8 files changed

+142
-181
lines changed

8 files changed

+142
-181
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ sphinx>=8.0.0,<9.0.0
33
sphinx-rtd-theme>=3.0.0
44
myst-parser>=2.0.0
55
sphinx-autosummary-accessors>=2023.4.0
6+
sphinxcontrib-youtube

docs/source/api/classifiers.rst

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,7 @@ PyTorch Geometric
4141
from torch_geometric.loader import DataLoader
4242
4343
# Initialize model
44-
model = PyGLightningCrystalGraphClassifier(
45-
node_features=12,
46-
edge_features=6,
47-
global_features=0,
48-
output_features=1,
49-
learning_rate=0.001,
50-
)
44+
model = PyGLightningCrystalGraphClassifier()
5145
5246
# Train
5347
trainer = pyl.Trainer(max_epochs=50)
@@ -56,29 +50,38 @@ PyTorch Geometric
5650
# Test
5751
trainer.test(model, test_loader)
5852
59-
# Predict
60-
predictions = trainer.predict(model, pred_loader)
61-
6253
Spektral
6354
~~~~~~~~
6455

6556
.. code-block:: python
6657
6758
from unravel.classifiers import CrystalGraphClassifier
6859
69-
# Initialize model
70-
model = CrystalGraphClassifier(
71-
node_features=12,
72-
edge_features=6,
73-
output_features=1,
74-
)
60+
from tensorflow.keras.metrics import AUC, BinaryAccuracy
61+
from tensorflow.keras.losses import BinaryCrossentropy
62+
from tensorflow.keras.optimizers import Adam
63+
from tensorflow.keras.callbacks import EarlyStopping
64+
65+
model = CrystalGraphClassifier()
7566
76-
# Compile
7767
model.compile(
78-
optimizer='adam',
79-
loss='binary_crossentropy',
80-
metrics=['accuracy']
68+
loss=BinaryCrossentropy(), optimizer=Adam(), metrics=[AUC(), BinaryAccuracy()]
8169
)
8270
83-
# Train
84-
model.fit(x=train_data, y=train_labels, epochs=50, validation_data=(val_data, val_labels))
71+
model.fit(
72+
loader_tr.load(),
73+
steps_per_epoch=loader_tr.steps_per_epoch,
74+
epochs=5,
75+
use_multiprocessing=True,
76+
validation_data=loader_va.load(),
77+
callbacks=[EarlyStopping(monitor="loss", patience=5, restore_best_weights=True)],
78+
)
79+
80+
from tensorflow.keras.models import load_model
81+
82+
model_path = "models/my-first-graph-classifier"
83+
model.save(model_path)
84+
loaded_model = load_model(model_path)
85+
86+
loader_te = DisjointLoader(test, epochs=1, shuffle=False, batch_size=batch_size)
87+
results = model.evaluate(loader_te.load())

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"sphinx.ext.intersphinx",
2828
"sphinx.ext.mathjax",
2929
"myst_parser",
30+
"sphinxcontrib.youtube",
3031
]
3132

3233
# Napoleon settings for Google/NumPy style docstrings

docs/source/getting_started/concepts.rst

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,11 @@ Labels
135135
For supervised learning, you need labels for each graph:
136136

137137
.. code-block:: python
138-
139-
from unravel.utils import add_dummy_label_column
140-
141138
# Add random binary labels (for demonstration)
142-
dataset.dataset = add_dummy_label_column(dataset.dataset)
139+
dataset.add_dummy_labels()
143140
144141
# Or join real labels from your own data
145-
# dataset.dataset = dataset.dataset.join(your_labels, on="some_key")
142+
dataset.dataset = dataset.dataset.join(your_labels, on="some_key")
146143
147144
Graph IDs
148145
~~~~~~~~~
@@ -153,13 +150,10 @@ Graph IDs group frames that belong to the same "sample":
153150
154151
from unravel.utils import add_graph_id_column
155152
156-
# Each frame is a separate graph
157-
dataset.dataset = add_graph_id_column(dataset.dataset, by=["frame_id"])
158-
159-
# Or group by possession
160-
dataset.dataset = add_graph_id_column(dataset.dataset, by=["possession_id"])
153+
# Each frame (graph) from the same game belongs to a subset
154+
dataset.add_graph_ids(by=["game_id"])
161155
162-
**Important**: Always split data by graph_id to avoid data leakage!
156+
**Important**: Always split data by game_id to avoid data leakage!
163157

164158
Soccer Analytics Models
165159
-----------------------

docs/source/getting_started/quickstart.rst

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ Convert the tracking data to graph structures for training Graph Neural Networks
5757
from unravel.utils import add_dummy_label_column, add_graph_id_column
5858
5959
# Add labels and graph IDs
60-
polars_dataset.dataset = add_dummy_label_column(polars_dataset.dataset)
61-
polars_dataset.dataset = add_graph_id_column(
62-
polars_dataset.dataset,
60+
polars_dataset.add_dummy_labels()
61+
# We group by 'frame_id' instead of 'game_id' here because in this example all
62+
# data comes from the same game.
63+
polars_dataset.add_graph_id_column(
6364
by=["frame_id"]
6465
)
6566
@@ -97,11 +98,7 @@ Split the data and train a model:
9798
test_loader = DataLoader(test, batch_size=32)
9899
99100
# Initialize model
100-
model = PyGLightningCrystalGraphClassifier(
101-
node_features=converter.n_node_features,
102-
edge_features=converter.n_edge_features,
103-
global_features=converter.n_graph_features,
104-
)
101+
model = PyGLightningCrystalGraphClassifier()
105102
106103
# Train
107104
trainer = pyl.Trainer(max_epochs=10)

docs/source/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ The **unravelsports** package aims to aid researchers, analysts and enthusiasts
2323
intermediary steps in the complex process of converting raw sports data into meaningful
2424
information and actionable insights.
2525

26+
.. youtube:: PUXU3SokbW0
27+
2628
Installation
2729
------------
2830

docs/source/tutorials/american_football.rst

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
American Football (NFL)
1+
American Football
22
=======================
33

44
This tutorial covers how to work with NFL tracking data from the Big Data Bowl using the
@@ -93,16 +93,10 @@ For supervised learning, add labels and graph IDs:
9393

9494
.. code-block:: python
9595
96-
from unravel.utils import add_dummy_label_column, add_graph_id_column
97-
98-
# Add labels (use your own labels for real tasks)
99-
bdb_dataset.dataset = add_dummy_label_column(bdb_dataset.dataset)
96+
bdb_dataset.add_dummy_labels()
10097
10198
# Create graph ID for each play
102-
bdb_dataset.dataset = add_graph_id_column(
103-
bdb_dataset.dataset,
104-
by=["gameId", "playId"]
105-
)
99+
bdb_dataset.add_graph_ids(by=["playId", "gameId"])
106100
107101
Step 3: Convert to Graphs
108102
~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -146,11 +140,7 @@ Train a Graph Neural Network:
146140
test_loader = DataLoader(test, batch_size=32)
147141
148142
# Initialize and train model
149-
model = PyGLightningCrystalGraphClassifier(
150-
node_features=converter.n_node_features,
151-
edge_features=converter.n_edge_features,
152-
global_features=converter.n_graph_features,
153-
)
143+
model = PyGLightningCrystalGraphClassifier()
154144
155145
trainer = pyl.Trainer(max_epochs=10)
156146
trainer.fit(model, train_loader, val_loader)

0 commit comments

Comments
 (0)