Skip to content

Commit 062f9d4

Browse files
authored
Merge pull request #303 from ParagEkbote/fix-skorch-example
Fix Skorch Example
2 parents 5bb6de8 + 07431ff commit 062f9d4

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

pytorch/skorch_simple.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
"""
22
Optuna example that optimizes multi-layer perceptrons using skorch.
33
4-
In this example, we optimize the validation accuracy of hand-written digit recognition using
5-
skorch and MNIST. We optimize the neural network architecture. As it is too time
4+
In this example, we optimize the validation accuracy of hand-written digit recognition
5+
using skorch and MNIST. We optimize the neural network architecture. As it is too time
66
consuming to use the whole MNIST dataset, we here use a small subset of it.
77
8-
You can run this example as follows, pruning can be turned on and off with the `--pruning`
9-
argument.
8+
You can run this example as follows, pruning can be turned on and off with the
9+
`--pruning` argument.
1010
$ python skorch_simple.py [--pruning]
1111
1212
"""
1313

1414
import argparse
15-
import urllib
1615

1716
import numpy as np
1817
import optuna
1918
from optuna.integration import SkorchPruningCallback
19+
import pandas as pd
2020
import skorch
2121
import torch
2222
import torch.nn as nn
@@ -27,22 +27,15 @@
2727
from sklearn.model_selection import train_test_split
2828

2929

30-
# Register a global custom opener to avoid HTTP Error 403: Forbidden when downloading MNIST.
31-
# This is a temporary fix until torchvision v0.9 is released.
32-
opener = urllib.request.build_opener()
33-
opener.addheaders = [("User-agent", "Mozilla/5.0")]
34-
urllib.request.install_opener(opener)
35-
36-
3730
SUBSET_RATIO = 0.4
3831

3932
mnist = fetch_openml("mnist_784", cache=False)
4033

41-
X = mnist.data.astype("float32")
34+
X = pd.DataFrame(mnist.data)
4235
y = mnist.target.astype("int64")
4336
indices = np.random.permutation(len(X))
4437
N = int(len(X) * SUBSET_RATIO)
45-
X = X[indices][:N]
38+
X = X.iloc[indices][:N].astype(np.float32)
4639
y = y[indices][:N]
4740

4841
X /= 255.0
@@ -72,6 +65,8 @@ def __init__(self, trial: optuna.Trial) -> None:
7265
self.model = nn.Sequential(*layers)
7366

7467
def forward(self, x):
68+
if isinstance(x, dict):
69+
x = x["data"]
7570
return F.softmax(self.model(x), dim=-1)
7671

7772

@@ -84,9 +79,9 @@ def objective(trial: optuna.Trial) -> float:
8479
callbacks=[SkorchPruningCallback(trial, "valid_acc")],
8580
)
8681

87-
net.fit(X_train, y_train)
82+
net.fit(X_train.to_numpy().astype(np.float32), y_train)
8883

89-
return accuracy_score(y_test, net.predict(X_test))
84+
return accuracy_score(y_test.to_numpy(), net.predict(X_test.to_numpy().astype(np.float32)))
9085

9186

9287
if __name__ == "__main__":
@@ -110,8 +105,8 @@ def objective(trial: optuna.Trial) -> float:
110105
print("Best trial:")
111106
trial = study.best_trial
112107

113-
print(" Value: {}".format(trial.value))
108+
print("Value: {}".format(trial.value))
114109

115-
print(" Params: ")
110+
print("Params: ")
116111
for key, value in trial.params.items():
117112
print(" {}: {}".format(key, value))

0 commit comments

Comments
 (0)