|
16 | 16 | import numpy as np |
17 | 17 | import optuna |
18 | 18 | from optuna.integration import SkorchPruningCallback |
19 | | -import pandas as pd |
20 | 19 | import skorch |
21 | 20 | import torch |
22 | 21 | import torch.nn as nn |
23 | 22 | import torch.nn.functional as F |
| 23 | +from torchvision import datasets |
24 | 24 |
|
25 | | -from sklearn.datasets import fetch_openml |
26 | 25 | from sklearn.metrics import accuracy_score |
27 | 26 | from sklearn.model_selection import train_test_split |
28 | 27 |
|
29 | 28 |
|
30 | 29 | SUBSET_RATIO = 0.4 |
31 | 30 |
|
32 | | -mnist = fetch_openml("mnist_784", cache=False) |
| 31 | +mnist_train = datasets.MNIST(root="./data", train=True, download=True) |
| 32 | +mnist_test = datasets.MNIST(root="./data", train=False, download=True) |
| 33 | + |
| 34 | +X = np.concatenate([mnist_train.data.numpy(), mnist_test.data.numpy()], axis=0) |
| 35 | +y = np.concatenate([mnist_train.targets.numpy(), mnist_test.targets.numpy()], axis=0) |
| 36 | + |
| 37 | +X = X.reshape(X.shape[0], -1).astype(np.float32) / 255.0 |
33 | 38 |
|
34 | | -X = pd.DataFrame(mnist.data) |
35 | | -y = mnist.target.astype("int64") |
36 | 39 | indices = np.random.permutation(len(X)) |
37 | 40 | N = int(len(X) * SUBSET_RATIO) |
38 | | -X = X.iloc[indices][:N].astype(np.float32) |
| 41 | +X = X[indices][:N] |
39 | 42 | y = y[indices][:N] |
40 | 43 |
|
41 | | -X /= 255.0 |
42 | | - |
43 | 44 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42) |
44 | 45 | device = "cuda" if torch.cuda.is_available() else "cpu" |
45 | 46 |
|
@@ -79,9 +80,9 @@ def objective(trial: optuna.Trial) -> float: |
79 | 80 | callbacks=[SkorchPruningCallback(trial, "valid_acc")], |
80 | 81 | ) |
81 | 82 |
|
82 | | - net.fit(X_train.to_numpy().astype(np.float32), y_train) |
| 83 | + net.fit(X_train, y_train) |
83 | 84 |
|
84 | | - return accuracy_score(y_test.to_numpy(), net.predict(X_test.to_numpy().astype(np.float32))) |
| 85 | + return accuracy_score(y_test, net.predict(X_test)) |
85 | 86 |
|
86 | 87 |
|
87 | 88 | if __name__ == "__main__": |
|
0 commit comments