Skip to content

Commit a5cea3e

Browse files
committed
enable download
1 parent e0165c5 commit a5cea3e

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

pytorch/skorch_simple.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,19 @@
1515

1616
import numpy as np
1717
import optuna
18-
from optuna.integration import SkorchPruningCallback
1918
import skorch
2019
import torch
2120
import torch.nn as nn
2221
import torch.nn.functional as F
23-
from torchvision import datasets
24-
22+
from optuna.integration import SkorchPruningCallback
2523
from sklearn.metrics import accuracy_score
2624
from sklearn.model_selection import train_test_split
27-
25+
from torchvision import datasets
2826

2927
SUBSET_RATIO = 0.4
3028

31-
mnist_train = datasets.MNIST(root="./data", train=True, download=False)
32-
mnist_test = datasets.MNIST(root="./data", train=False, download=False)
29+
mnist_train = datasets.MNIST(root="./data", train=True, download=True)
30+
mnist_test = datasets.MNIST(root="./data", train=False, download=True)
3331

3432
X = np.concatenate([mnist_train.data.numpy(), mnist_test.data.numpy()], axis=0)
3533
y = np.concatenate([mnist_train.targets.numpy(), mnist_test.targets.numpy()], axis=0)
@@ -41,7 +39,9 @@
4139
X = X[indices][:N]
4240
y = y[indices][:N]
4341

44-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
42+
X_train, X_test, y_train, y_test = train_test_split(
43+
X, y, test_size=0.25, random_state=42
44+
)
4545
device = "cuda" if torch.cuda.is_available() else "cpu"
4646

4747

@@ -96,7 +96,9 @@ def objective(trial: optuna.Trial) -> float:
9696
)
9797
args = parser.parse_args()
9898

99-
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
99+
pruner = (
100+
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
101+
)
100102

101103
study = optuna.create_study(direction="maximize", pruner=pruner)
102104
study.optimize(objective, n_trials=100, timeout=600)

0 commit comments

Comments
 (0)