Skip to content

Commit 641f0fa

Browse files
authored
Support transfer ObjectRef from JAVA application to PYTHON deployment (ray-project#45729)
## Why are these changes needed? In the current implementation, it's not possible to transfer an ObjectRef from a Java application to a Python deployment. For more details, please refer to this issue: ray-project#45676 **Root Cause** In the original approach, all arguments were bundled together and passed as a list parameter, like so: {1, 2, "parameter3"}. If one of these parameters was an ObjectRef, it would lead to serialization issues, preventing cross-language transfer: {1, 2, "parameter3", objectRef4}. **Proposed Solution** In the new approach, each argument is passed independently. The ObjectRef is implemented as PassByReference, which avoids the previous serialization issues and enables the transfer of ObjectRef: {1, 2, "parameter3", objectRef4}. This change enhances the interoperability between Java applications and Python deployments in Ray.
1 parent 82ad4a8 commit 641f0fa

File tree

4 files changed

+56
-13
lines changed

4 files changed

+56
-13
lines changed

java/serve/src/main/java/io/ray/serve/router/ReplicaSet.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,17 @@
66
import io.ray.api.ObjectRef;
77
import io.ray.api.PyActorHandle;
88
import io.ray.api.Ray;
9+
import io.ray.api.call.PyActorTaskCaller;
910
import io.ray.api.function.PyActorMethod;
1011
import io.ray.serve.common.Constants;
1112
import io.ray.serve.exception.RayServeException;
1213
import io.ray.serve.generated.ActorNameList;
1314
import io.ray.serve.replica.RayServeWrappedReplica;
1415
import io.ray.serve.util.CollectionUtil;
15-
import java.util.ArrayList;
16-
import java.util.HashSet;
17-
import java.util.List;
18-
import java.util.Map;
19-
import java.util.Optional;
20-
import java.util.Set;
16+
import java.util.*;
2117
import java.util.concurrent.ConcurrentHashMap;
2218
import java.util.concurrent.TimeUnit;
19+
import java.util.stream.Stream;
2320
import org.apache.commons.lang3.RandomUtils;
2421
import org.slf4j.Logger;
2522
import org.slf4j.LoggerFactory;
@@ -108,12 +105,15 @@ private ObjectRef<Object> tryAssignReplica(Query query) {
108105
handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries
109106
LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica);
110107
if (replica instanceof PyActorHandle) {
111-
return ((PyActorHandle) replica)
112-
.task(
113-
PyActorMethod.of("handle_request_from_java"),
114-
query.getMetadata().toByteArray(),
115-
query.getArgs())
116-
.remote();
108+
Object[] args =
109+
Stream.concat(
110+
Stream.of(query.getMetadata().toByteArray()),
111+
Arrays.stream((Object[]) query.getArgs()))
112+
.toArray();
113+
PyActorTaskCaller<Object> pyCaller =
114+
new PyActorTaskCaller<>(
115+
(PyActorHandle) replica, PyActorMethod.of("handle_request_from_java"), args);
116+
return pyCaller.remote();
117117
} else {
118118
return ((ActorHandle<RayServeWrappedReplica>) replica)
119119
.task(

java/serve/src/test/java/io/ray/serve/deployment/CrossLanguageDeploymentTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.ray.serve.deployment;
22

3+
import io.ray.api.ObjectRef;
4+
import io.ray.api.Ray;
35
import io.ray.serve.BaseServeTest;
46
import io.ray.serve.api.Serve;
57
import io.ray.serve.generated.DeploymentLanguage;
@@ -58,6 +60,21 @@ public void createPyClassTest() {
5860
Assert.assertEquals(handle.method("increase").remote("6").result(), "34");
5961
}
6062

63+
@Test
64+
public void createPyClassWithObjectRefTest() {
65+
Application deployment =
66+
Serve.deployment()
67+
.setLanguage(DeploymentLanguage.PYTHON)
68+
.setName("createPyClassWithObjectRefTest")
69+
.setDeploymentDef(PYTHON_MODULE + ".Counter")
70+
.setNumReplicas(1)
71+
.bind("28");
72+
73+
DeploymentHandle handle = Serve.run(deployment).get();
74+
ObjectRef<Integer> numRef = Ray.put(10);
75+
Assert.assertEquals(handle.method("increase").remote(numRef).result(), "38");
76+
}
77+
6178
@Test
6279
public void createPyMethodTest() {
6380
Application deployment =
@@ -71,6 +88,20 @@ public void createPyMethodTest() {
7188
Assert.assertEquals(handle.method("__call__").remote("6").result(), "6");
7289
}
7390

91+
@Test
92+
public void createPyMethodWithObjectRefTest() {
93+
Application deployment =
94+
Serve.deployment()
95+
.setLanguage(DeploymentLanguage.PYTHON)
96+
.setName("createPyMethodWithObjectRefTest")
97+
.setDeploymentDef(PYTHON_MODULE + ".echo_server")
98+
.setNumReplicas(1)
99+
.bind();
100+
DeploymentHandle handle = Serve.run(deployment).get();
101+
ObjectRef<String> numRef = Ray.put("10");
102+
Assert.assertEquals(handle.method("__call__").remote(numRef).result(), "10");
103+
}
104+
74105
@Test
75106
public void userConfigTest() throws InterruptedException {
76107
Application deployment =

java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ public void testCallingPythonFunction() {
139139
Assert.assertEquals(((Number) r7array[0]).intValue(), input7array[0]);
140140
Assert.assertEquals(((Number) r7array[1]).intValue(), input7array[1]);
141141
}
142+
// objectRef
143+
{
144+
ObjectRef<Integer> input = Ray.put(1);
145+
ObjectRef<Integer> res =
146+
Ray.task(PyFunction.of(PYTHON_MODULE, "py_return_input", Integer.class), input).remote();
147+
Assert.assertEquals(res.get(), input.get());
148+
}
142149
// Unsupported types, all Java specific types, e.g. List / Map...
143150
{
144151
Assert.expectThrows(
@@ -173,6 +180,11 @@ public void testCallingPythonActor() {
173180
ObjectRef<byte[]> res =
174181
actor.task(PyActorMethod.of("increase", byte[].class), "1".getBytes()).remote();
175182
Assert.assertEquals(res.get(), "2".getBytes());
183+
184+
ObjectRef<String> numRef = Ray.put("2");
185+
ObjectRef<byte[]> res2 =
186+
actor.task(PyActorMethod.of("increase", byte[].class), numRef).remote();
187+
Assert.assertEquals(res2.get(), "4".getBytes());
176188
}
177189

178190
@Test

python/ray/serve/_private/replica.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ async def handle_request_from_java(
560560
)
561561
with self._wrap_user_method_call(request_metadata):
562562
return await self._user_callable_wrapper.call_user_method(
563-
request_metadata, request_args[0], request_kwargs
563+
request_metadata, request_args, request_kwargs
564564
)
565565

566566
async def is_allocated(self) -> str:

0 commit comments

Comments
 (0)