Skip to content

Commit 0dd7d8c

Browse files
richardliawDmitriGekhtman
authored andcommitted
[tune] Support object refs in with_parameters
1 parent e9177be commit 0dd7d8c

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

python/ray/tune/registry.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ def get(self, k):
168168

169169
def flush(self):
170170
for k, v in self.to_flush.items():
171-
self.references[k] = ray.put(v)
171+
if isinstance(v, ray.ObjectRef):
172+
self.references[k] = v
173+
else:
174+
self.references[k] = ray.put(v)
172175
self.to_flush.clear()
173176

174177

python/ray/tune/tests/test_api.py

+21
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,27 @@ def step(self):
12561256
dumped = cp.dumps(trainable)
12571257
assert sys.getsizeof(dumped) < 100 * 1024
12581258

1259+
def testWithParameters3(self):
1260+
class Data:
1261+
def __init__(self):
1262+
import numpy as np
1263+
self.data = np.random.rand((2 * 1024 * 1024))
1264+
1265+
class TestTrainable(Trainable):
1266+
def setup(self, config, data):
1267+
self.data = data.data
1268+
1269+
def step(self):
1270+
return dict(metric=len(self.data), done=True)
1271+
1272+
new_data = Data()
1273+
ref = ray.put(new_data)
1274+
trainable = tune.with_parameters(TestTrainable, data=ref)
1275+
# ray.cloudpickle will crash for some reason
1276+
import cloudpickle as cp
1277+
dumped = cp.dumps(trainable)
1278+
assert sys.getsizeof(dumped) < 100 * 1024
1279+
12591280

12601281
class SerializabilityTest(unittest.TestCase):
12611282
@classmethod

0 commit comments

Comments
 (0)