Skip to content

Commit e24594b

Browse files
amogkambarakmich
authored andcommitted
fix example (#10964)
1 parent 86dd29e commit e24594b

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
# yapf: disable
99

1010
# __torch_operator_start__
11+
import torch
12+
import torch.nn as nn
13+
from torch.utils.data import DataLoader
14+
1115
from ray.util.sgd.torch import TrainingOperator
16+
from ray.util.sgd.torch.examples.train_example import LinearDataset
1217

1318
class MyTrainingOperator(TrainingOperator):
1419
def setup(self, config):
@@ -44,10 +49,29 @@ def setup(self, config):
4449
self.model, self.optimizer, self.criterion, self.scheduler = \
4550
self.register(models=model, optimizers=optimizer,
4651
criterion=criterion,
47-
scheduler=scheduler)
52+
schedulers=scheduler)
4853
self.register_data(train_loader=train_loader, validation_loader=val_loader)
4954
# __torch_operator_end__
5055

56+
# __torch_ray_start__
57+
import ray
58+
59+
ray.init()
60+
# or ray.init(address="auto") to connect to a running cluster.
61+
# __torch_ray_end__
62+
63+
# __torch_trainer_start__
64+
from ray.util.sgd import TorchTrainer
65+
66+
trainer = TorchTrainer(
67+
training_operator_cls=MyTrainingOperator,
68+
scheduler_step_freq="epoch", # if scheduler is used
69+
config={"lr": 0.001, "batch_size": 64})
70+
71+
# __torch_trainer_end__
72+
73+
trainer.shutdown()
74+
5175
# __torch_model_start__
5276
import torch.nn as nn
5377

@@ -144,13 +168,6 @@ def scheduler_creator(optimizer, config):
144168

145169
# __torch_scheduler_end__
146170

147-
# __torch_ray_start__
148-
import ray
149-
150-
ray.init()
151-
# or ray.init(address="auto") to connect to a running cluster.
152-
# __torch_ray_end__
153-
154171
# __backwards_compat_start__
155172
from ray.util.sgd import TorchTrainer
156173

@@ -167,15 +184,3 @@ def scheduler_creator(optimizer, config):
167184
# __backwards_compat_end__
168185

169186
trainer.shutdown()
170-
171-
# __torch_trainer_start__
172-
from ray.util.sgd import TorchTrainer
173-
174-
trainer = TorchTrainer(
175-
training_operator_cls=MyTrainingOperator,
176-
scheduler_step_freq="epoch", # if scheduler is used
177-
config={"lr": 0.001, "batch_size": 64})
178-
179-
# __torch_trainer_end__
180-
181-
trainer.shutdown()

0 commit comments

Comments
 (0)