|
15 | 15 |
|
16 | 16 | import numpy as np |
17 | 17 | import optuna |
18 | | -from optuna.integration import SkorchPruningCallback |
19 | 18 | import skorch |
20 | 19 | import torch |
21 | 20 | import torch.nn as nn |
22 | 21 | import torch.nn.functional as F |
23 | | -from torchvision import datasets |
24 | | - |
| 22 | +from optuna.integration import SkorchPruningCallback |
25 | 23 | from sklearn.metrics import accuracy_score |
26 | 24 | from sklearn.model_selection import train_test_split |
27 | | - |
| 25 | +from torchvision import datasets |
28 | 26 |
|
29 | 27 | SUBSET_RATIO = 0.4 |
30 | 28 |
|
31 | | -mnist_train = datasets.MNIST(root="./data", train=True, download=False) |
32 | | -mnist_test = datasets.MNIST(root="./data", train=False, download=False) |
| 29 | +mnist_train = datasets.MNIST(root="./data", train=True, download=True) |
| 30 | +mnist_test = datasets.MNIST(root="./data", train=False, download=True) |
33 | 31 |
|
34 | 32 | X = np.concatenate([mnist_train.data.numpy(), mnist_test.data.numpy()], axis=0) |
35 | 33 | y = np.concatenate([mnist_train.targets.numpy(), mnist_test.targets.numpy()], axis=0) |
|
41 | 39 | X = X[indices][:N] |
42 | 40 | y = y[indices][:N] |
43 | 41 |
|
44 | | -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42) |
| 42 | +X_train, X_test, y_train, y_test = train_test_split( |
| 43 | + X, y, test_size=0.25, random_state=42 |
| 44 | +) |
45 | 45 | device = "cuda" if torch.cuda.is_available() else "cpu" |
46 | 46 |
|
47 | 47 |
|
@@ -96,7 +96,9 @@ def objective(trial: optuna.Trial) -> float: |
96 | 96 | ) |
97 | 97 | args = parser.parse_args() |
98 | 98 |
|
99 | | - pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() |
| 99 | + pruner = ( |
| 100 | + optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() |
| 101 | + ) |
100 | 102 |
|
101 | 103 | study = optuna.create_study(direction="maximize", pruner=pruner) |
102 | 104 | study.optimize(objective, n_trials=100, timeout=600) |
|
0 commit comments