8
8
# yapf: disable
9
9
10
10
# __torch_operator_start__
11
+ import torch
12
+ import torch .nn as nn
13
+ from torch .utils .data import DataLoader
14
+
11
15
from ray .util .sgd .torch import TrainingOperator
16
+ from ray .util .sgd .torch .examples .train_example import LinearDataset
12
17
13
18
class MyTrainingOperator (TrainingOperator ):
14
19
def setup (self , config ):
@@ -44,10 +49,29 @@ def setup(self, config):
44
49
self .model , self .optimizer , self .criterion , self .scheduler = \
45
50
self .register (models = model , optimizers = optimizer ,
46
51
criterion = criterion ,
47
- scheduler = scheduler )
52
+ schedulers = scheduler )
48
53
self .register_data (train_loader = train_loader , validation_loader = val_loader )
49
54
# __torch_operator_end__
50
55
56
+ # __torch_ray_start__
57
+ import ray
58
+
59
+ ray .init ()
60
+ # or ray.init(address="auto") to connect to a running cluster.
61
+ # __torch_ray_end__
62
+
63
+ # __torch_trainer_start__
64
+ from ray .util .sgd import TorchTrainer
65
+
66
+ trainer = TorchTrainer (
67
+ training_operator_cls = MyTrainingOperator ,
68
+ scheduler_step_freq = "epoch" , # if scheduler is used
69
+ config = {"lr" : 0.001 , "batch_size" : 64 })
70
+
71
+ # __torch_trainer_end__
72
+
73
+ trainer .shutdown ()
74
+
51
75
# __torch_model_start__
52
76
import torch .nn as nn
53
77
@@ -144,13 +168,6 @@ def scheduler_creator(optimizer, config):
144
168
145
169
# __torch_scheduler_end__
146
170
147
- # __torch_ray_start__
148
- import ray
149
-
150
- ray .init ()
151
- # or ray.init(address="auto") to connect to a running cluster.
152
- # __torch_ray_end__
153
-
154
171
# __backwards_compat_start__
155
172
from ray .util .sgd import TorchTrainer
156
173
@@ -167,15 +184,3 @@ def scheduler_creator(optimizer, config):
167
184
# __backwards_compat_end__
168
185
169
186
trainer .shutdown ()
170
-
171
- # __torch_trainer_start__
172
- from ray .util .sgd import TorchTrainer
173
-
174
- trainer = TorchTrainer (
175
- training_operator_cls = MyTrainingOperator ,
176
- scheduler_step_freq = "epoch" , # if scheduler is used
177
- config = {"lr" : 0.001 , "batch_size" : 64 })
178
-
179
- # __torch_trainer_end__
180
-
181
- trainer .shutdown ()
0 commit comments