Skip to content

Commit 705e543

Browse files
committed
Adjustments for pre-commit hooks
1 parent 225f1a6 commit 705e543

File tree

4 files changed

+96
-98
lines changed

4 files changed

+96
-98
lines changed

examples/asha_example/experiment_1.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fixed:
2121
num_stages: 10
2222
asha:
2323
eta: 3
24-
min_r: 1
24+
min_r: 1
2525
max_r: 20
2626
metric_increases: True #True or False
2727

examples/asha_example/main.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
1-
import os
2-
import time
3-
from datetime import datetime
4-
import random
5-
import torch
6-
import numpy as np
71
import uuid
8-
import hashlib
9-
from torch import nn, optim
10-
from torchvision import datasets, transforms
11-
from torch.utils.data import DataLoader, random_split
122

13-
from seml import Experiment
14-
from seml.settings import SETTINGS
3+
import torch
154
from model import SimpleNN
16-
17-
from seml.utils import ASHA # Import asha class
18-
5+
from seml import Experiment
196
from seml.database import get_mongodb_config
7+
from seml.utils import ASHA # Import asha class
8+
from torch import nn, optim
9+
from torch.utils.data import DataLoader, random_split
10+
from torchvision import datasets, transforms
11+
2012
# def seed_everything(job_id):
2113
# # Combine job_id and entropy to make a unique seed
2214
# entropy = f"{job_id}-{time.time()}-{os.urandom(8)}".encode("utf-8")
@@ -31,10 +23,11 @@
3123

3224
experiment = Experiment()
3325

26+
3427
@experiment.config
3528
def default_config():
3629
num_stages = 10
37-
dataset = 'mnist'
30+
dataset = "mnist"
3831
hidden_units = [64]
3932
dropout = 0.3
4033
learning_rate = 1e-3
@@ -43,23 +36,30 @@ def default_config():
4336
seed = 42
4437
asha_collection_name = "unknown_experiment"
4538
samples = 5
46-
asha = {
47-
"eta" : 3,
48-
"min_r": 1,
49-
"max_resource": 20,
50-
"progression": "increase"
51-
}
39+
asha = {"eta": 3, "min_r": 1, "max_resource": 20, "progression": "increase"}
5240

5341

5442
@experiment.automain
55-
def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_dir, job_id, asha,_log, _run,):
56-
43+
def main(
44+
num_stages,
45+
dataset,
46+
hidden_units,
47+
dropout,
48+
learning_rate,
49+
base_shared_dir,
50+
job_id,
51+
asha,
52+
_log,
53+
_run,
54+
):
5755
mongodb_configurations = get_mongodb_config()
5856

59-
print(f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}")
57+
print(
58+
f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}"
59+
)
6060
if job_id is None:
6161
job_id = str(uuid.uuid4())
62-
#job_id = str(_run._id)
62+
# job_id = str(_run._id)
6363

6464
asha_collection_name = _run.config.get("asha_collection_name", "unknown_experiment")
6565
print("Run info:", _run.experiment_info)
@@ -70,11 +70,12 @@ def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_
7070
model.to(device)
7171

7272
# Prepare dataset and loaders
73-
transform = transforms.Compose([
74-
transforms.ToTensor(),
75-
transforms.Normalize((0.1307,), (0.3081,))
76-
])
77-
full_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
73+
transform = transforms.Compose(
74+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
75+
)
76+
full_dataset = datasets.MNIST(
77+
root="./data", train=True, download=True, transform=transform
78+
)
7879
train_size = int(0.8 * len(full_dataset))
7980
val_size = len(full_dataset) - train_size
8081
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
@@ -126,16 +127,19 @@ def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_
126127
metric = correct / total
127128
print(f"[Epoch {stage}] Validation Accuracy: {metric:.4f}")
128129

129-
if stage < (num_stages-1):
130+
if stage < (num_stages - 1):
130131
should_stop = tracker.store_stage_metric(stage, metric)
131132
if should_stop:
132133
print("We should end this process here")
133-
print(f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}")
134+
print(
135+
f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}"
136+
)
134137
break
135138
else:
136139
print("job finished")
137-
print(f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}")
138-
140+
print(
141+
f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}"
142+
)
139143

140144
return {
141145
"asha_collection_name": asha_collection_name,
@@ -148,6 +152,5 @@ def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_
148152
"num_stages": num_stages,
149153
"dataset": dataset,
150154
"device": str(device),
151-
"final_stage": len(tracker.metric_history)
155+
"final_stage": len(tracker.metric_history),
152156
}
153-

examples/asha_example/model.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,8 @@
11
import torch
22
import torch.nn as nn
3-
import torch.nn.functional as F
4-
# Define your custom PyTorch model
5-
# class SimpleNN(nn.Module):
6-
# def __init__(self, hidden_units=256):
7-
# super().__init__()
8-
# self.fc1 = nn.Linear(28 * 28, hidden_units)
9-
# self.fc2 = nn.Linear(hidden_units, 64)
10-
# self.fc3 = nn.Linear(64, 10)
113

12-
# def forward(self, x):
13-
# x = torch.flatten(x, 1) # flatten all dimensions except batch
14-
# x = F.relu(self.fc1(x))
15-
# x = F.relu(self.fc2(x))
16-
# x = self.fc3(x)
17-
# return x
184

5+
# Define your custom PyTorch model
196
class SimpleNN(nn.Module):
207
def __init__(self, hidden_units=[128, 64], dropout=0.0):
218
super().__init__()

0 commit comments

Comments
 (0)