Skip to content

Commit 55bbf69

Browse files
authored
Merge pull request #338 from sotagg/fix/skorch-mnist
Fix skorch example: Replace unavailable OpenML MNIST
2 parents f192f3c + 1458b61 commit 55bbf69

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

.github/workflows/pytorch.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ jobs:
7979
env:
8080
OMP_NUM_THREADS: 1
8181
- name: Run skorch example
82-
# TODO(c-bata): Remove the following if statement after fixing the example.
83-
# See https://github.com/optuna/optuna-examples/issues/336 for details.
84-
if: github.event_name != 'schedule'
8582
run: |
8683
python pytorch/skorch_simple.py
8784
env:

pytorch/skorch_simple.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,31 @@
1616
import numpy as np
1717
import optuna
1818
from optuna.integration import SkorchPruningCallback
19-
import pandas as pd
2019
import skorch
2120
import torch
2221
import torch.nn as nn
2322
import torch.nn.functional as F
23+
from torchvision import datasets
2424

25-
from sklearn.datasets import fetch_openml
2625
from sklearn.metrics import accuracy_score
2726
from sklearn.model_selection import train_test_split
2827

2928

3029
SUBSET_RATIO = 0.4
3130

32-
mnist = fetch_openml("mnist_784", cache=False)
31+
mnist_train = datasets.MNIST(root="./data", train=True, download=True)
32+
mnist_test = datasets.MNIST(root="./data", train=False, download=True)
33+
34+
X = np.concatenate([mnist_train.data.numpy(), mnist_test.data.numpy()], axis=0)
35+
y = np.concatenate([mnist_train.targets.numpy(), mnist_test.targets.numpy()], axis=0)
36+
37+
X = X.reshape(X.shape[0], -1).astype(np.float32) / 255.0
3338

34-
X = pd.DataFrame(mnist.data)
35-
y = mnist.target.astype("int64")
3639
indices = np.random.permutation(len(X))
3740
N = int(len(X) * SUBSET_RATIO)
38-
X = X.iloc[indices][:N].astype(np.float32)
41+
X = X[indices][:N]
3942
y = y[indices][:N]
4043

41-
X /= 255.0
42-
4344
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
4445
device = "cuda" if torch.cuda.is_available() else "cpu"
4546

@@ -79,9 +80,9 @@ def objective(trial: optuna.Trial) -> float:
7980
callbacks=[SkorchPruningCallback(trial, "valid_acc")],
8081
)
8182

82-
net.fit(X_train.to_numpy().astype(np.float32), y_train)
83+
net.fit(X_train, y_train)
8384

84-
return accuracy_score(y_test.to_numpy(), net.predict(X_test.to_numpy().astype(np.float32)))
85+
return accuracy_score(y_test, net.predict(X_test))
8586

8687

8788
if __name__ == "__main__":

0 commit comments

Comments
 (0)