Skip to content

Commit 2ee5d49

Browse files
authored
Use ray backend by default when we're running in a Ray worker. (#1959)
1 parent 12d6c55 commit 2ee5d49

2 files changed

Lines changed: 30 additions & 5 deletions

File tree

lib/zephyr/src/zephyr/backend_factory.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Literal
2323

2424
import humanfriendly
25+
import ray
2526

2627
from zephyr.backends import Backend, BackendConfig, RayBackend, SyncBackend, ThreadPoolBackend
2728

@@ -145,14 +146,22 @@ def flow_backend(
145146
)
146147
if not has_params:
147148
if current is None:
148-
logger.warning("No backend configured in context, using ThreadPoolBackend as default.")
149-
return create_backend("threadpool")
149+
# Default to Ray backend if Ray is already initialized
150+
if ray.is_initialized():
151+
logger.info("Ray is initialized, using RayBackend as default.")
152+
return create_backend("ray")
153+
else:
154+
logger.warning("No backend configured in context, using ThreadPoolBackend as default.")
155+
return create_backend("threadpool")
150156
return current
151157

152158
# Parameters provided: create new backend with merged config
153159
if current is None:
154-
# No current backend, create default threadpool
155-
current = create_backend("threadpool")
160+
# No current backend, create default based on Ray availability
161+
if ray.is_initialized():
162+
current = create_backend("ray")
163+
else:
164+
current = create_backend("threadpool")
156165

157166
# Build override dict, only including non-None values
158167
overrides = {}

lib/zephyr/tests/test_backends.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
"""Tests for backend implementations."""
1616

17-
from zephyr.backends import format_shard_path
17+
import ray
18+
19+
from zephyr.backend_factory import flow_backend
20+
from zephyr.backends import RayBackend, format_shard_path
1821

1922

2023
def test_format_shard_path_basic():
@@ -109,3 +112,16 @@ def test_write_jsonl_no_compression_without_gz_extension(tmp_path):
109112
assert len(lines) == 2
110113
assert json.loads(lines[0]) == {"id": 1, "text": "hello"}
111114
assert json.loads(lines[1]) == {"id": 2, "text": "world"}
115+
116+
117+
def test_flow_backend_defaults_to_ray_when_initialized():
118+
"""Test that flow_backend returns RayBackend when Ray is initialized."""
119+
if ray.is_initialized():
120+
ray.shutdown()
121+
122+
try:
123+
ray.init(ignore_reinit_error=True)
124+
backend = flow_backend()
125+
assert isinstance(backend, RayBackend)
126+
finally:
127+
ray.shutdown()

0 commit comments

Comments
 (0)