Skip to content

Commit 1458b61

Browse files
committed
run pre-commit
1 parent a5cea3e commit 1458b61

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

pytorch/skorch_simple.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515

1616
import numpy as np
1717
import optuna
18+
from optuna.integration import SkorchPruningCallback
1819
import skorch
1920
import torch
2021
import torch.nn as nn
2122
import torch.nn.functional as F
22-
from optuna.integration import SkorchPruningCallback
23+
from torchvision import datasets
24+
2325
from sklearn.metrics import accuracy_score
2426
from sklearn.model_selection import train_test_split
25-
from torchvision import datasets
27+
2628

2729
SUBSET_RATIO = 0.4
2830

@@ -39,9 +41,7 @@
3941
X = X[indices][:N]
4042
y = y[indices][:N]
4143

42-
X_train, X_test, y_train, y_test = train_test_split(
43-
X, y, test_size=0.25, random_state=42
44-
)
44+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
4545
device = "cuda" if torch.cuda.is_available() else "cpu"
4646

4747

@@ -96,9 +96,7 @@ def objective(trial: optuna.Trial) -> float:
9696
)
9797
args = parser.parse_args()
9898

99-
pruner = (
100-
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
101-
)
99+
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
102100

103101
study = optuna.create_study(direction="maximize", pruner=pruner)
104102
study.optimize(objective, n_trials=100, timeout=600)

0 commit comments

Comments
 (0)