1- import os
2- import time
3- from datetime import datetime
4- import random
5- import torch
6- import numpy as np
71import uuid
8- import hashlib
9- from torch import nn , optim
10- from torchvision import datasets , transforms
11- from torch .utils .data import DataLoader , random_split
122
13- from seml import Experiment
14- from seml .settings import SETTINGS
3+ import torch
154from model import SimpleNN
16-
17- from seml .utils import ASHA # Import asha class
18-
5+ from seml import Experiment
196from seml .database import get_mongodb_config
7+ from seml .utils import ASHA # Import asha class
8+ from torch import nn , optim
9+ from torch .utils .data import DataLoader , random_split
10+ from torchvision import datasets , transforms
11+
2012# def seed_everything(job_id):
2113# # Combine job_id and entropy to make a unique seed
2214# entropy = f"{job_id}-{time.time()}-{os.urandom(8)}".encode("utf-8")
3123
3224experiment = Experiment ()
3325
26+
3427@experiment .config
3528def default_config ():
3629 num_stages = 10
37- dataset = ' mnist'
30+ dataset = " mnist"
3831 hidden_units = [64 ]
3932 dropout = 0.3
4033 learning_rate = 1e-3
@@ -43,23 +36,30 @@ def default_config():
4336 seed = 42
4437 asha_collection_name = "unknown_experiment"
4538 samples = 5
46- asha = {
47- "eta" : 3 ,
48- "min_r" : 1 ,
49- "max_resource" : 20 ,
50- "progression" : "increase"
51- }
39+ asha = {"eta" : 3 , "min_r" : 1 , "max_resource" : 20 , "progression" : "increase" }
5240
5341
5442@experiment .automain
55- def main (num_stages , dataset , hidden_units , dropout , learning_rate , base_shared_dir , job_id , asha ,_log , _run ,):
56-
43+ def main (
44+ num_stages ,
45+ dataset ,
46+ hidden_units ,
47+ dropout ,
48+ learning_rate ,
49+ base_shared_dir ,
50+ job_id ,
51+ asha ,
52+ _log ,
53+ _run ,
54+ ):
5755 mongodb_configurations = get_mongodb_config ()
5856
59- print (f"job parameters, hiddenunits:{ hidden_units } , dropout:{ dropout } , learningrate:{ learning_rate } " )
57+ print (
58+ f"job parameters, hiddenunits:{ hidden_units } , dropout:{ dropout } , learningrate:{ learning_rate } "
59+ )
6060 if job_id is None :
6161 job_id = str (uuid .uuid4 ())
62- #job_id = str(_run._id)
62+ # job_id = str(_run._id)
6363
6464 asha_collection_name = _run .config .get ("asha_collection_name" , "unknown_experiment" )
6565 print ("Run info:" , _run .experiment_info )
@@ -70,11 +70,12 @@ def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_
7070 model .to (device )
7171
7272 # Prepare dataset and loaders
73- transform = transforms .Compose ([
74- transforms .ToTensor (),
75- transforms .Normalize ((0.1307 ,), (0.3081 ,))
76- ])
77- full_dataset = datasets .MNIST (root = "./data" , train = True , download = True , transform = transform )
73+ transform = transforms .Compose (
74+ [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
75+ )
76+ full_dataset = datasets .MNIST (
77+ root = "./data" , train = True , download = True , transform = transform
78+ )
7879 train_size = int (0.8 * len (full_dataset ))
7980 val_size = len (full_dataset ) - train_size
8081 train_dataset , val_dataset = random_split (full_dataset , [train_size , val_size ])
@@ -126,16 +127,19 @@ def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_
126127 metric = correct / total
127128 print (f"[Epoch { stage } ] Validation Accuracy: { metric :.4f} " )
128129
129- if stage < (num_stages - 1 ):
130+ if stage < (num_stages - 1 ):
130131 should_stop = tracker .store_stage_metric (stage , metric )
131132 if should_stop :
132133 print ("We should end this process here" )
133- print (f"job parameters, hiddenunits:{ hidden_units } , dropout:{ dropout } , learningrate:{ learning_rate } " )
134+ print (
135+ f"job parameters, hiddenunits:{ hidden_units } , dropout:{ dropout } , learningrate:{ learning_rate } "
136+ )
134137 break
135138 else :
136139 print ("job finished" )
137- print (f"job parameters, hiddenunits:{ hidden_units } , dropout:{ dropout } , learningrate:{ learning_rate } " )
138-
140+ print (
141+ f"job parameters, hiddenunits:{ hidden_units } , dropout:{ dropout } , learningrate:{ learning_rate } "
142+ )
139143
140144 return {
141145 "asha_collection_name" : asha_collection_name ,
@@ -148,6 +152,5 @@ def main(num_stages, dataset, hidden_units, dropout, learning_rate, base_shared_
148152 "num_stages" : num_stages ,
149153 "dataset" : dataset ,
150154 "device" : str (device ),
151- "final_stage" : len (tracker .metric_history )
155+ "final_stage" : len (tracker .metric_history ),
152156 }
153-
0 commit comments