Skip to content

Commit b0d2cb2

Browse files
authored
[core] change GcsClient.get_all_node_info return type to proto type. (ray-project#46057)
Signed-off-by: Ruiyang Wang <rywang014@gmail.com>
1 parent ab85dd2 commit b0d2cb2

File tree

7 files changed

+32
-30
lines changed

7 files changed

+32
-30
lines changed

python/ray/_private/gcs_aio_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import Dict, List, Optional
33
from concurrent.futures import ThreadPoolExecutor
4-
from ray._raylet import GcsClient
4+
from ray._raylet import GcsClient, JobID
55
from ray.core.generated import (
66
gcs_pb2,
77
)
@@ -146,7 +146,7 @@ async def internal_kv_keys(
146146

147147
async def get_all_job_info(
148148
self, timeout: Optional[float] = None
149-
) -> Dict[bytes, gcs_pb2.JobTableData]:
149+
) -> Dict[JobID, gcs_pb2.JobTableData]:
150150
"""
151151
Return dict key: bytes of job_id; value: JobTableData pb message.
152152
"""

python/ray/_private/usage/usage_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def get_total_num_nodes_to_report(gcs_client, timeout=None) -> Optional[int]:
556556
result = gcs_client.get_all_node_info(timeout=timeout)
557557
total_num_nodes = 0
558558
for node_id, node_info in result.items():
559-
if node_info["state"] == gcs_pb2.GcsNodeInfo.GcsNodeState.ALIVE:
559+
if node_info.state == gcs_pb2.GcsNodeInfo.GcsNodeState.ALIVE:
560560
total_num_nodes += 1
561561
return total_num_nodes
562562
except Exception as e:

python/ray/_raylet.pyx

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ from ray.util.scheduling_strategies import (
194194
import ray._private.ray_constants as ray_constants
195195
import ray.cloudpickle as ray_pickle
196196
from ray.core.generated.common_pb2 import ActorDiedErrorContext
197-
from ray.core.generated.gcs_pb2 import JobTableData
197+
from ray.core.generated.gcs_pb2 import JobTableData, GcsNodeInfo
198198
from ray.core.generated.gcs_service_pb2 import GetAllResourceUsageReply
199199
from ray._private.async_compat import (
200200
sync_to_async,
@@ -2880,27 +2880,26 @@ cdef class GcsClient:
28802880
c_uri, expiration_s, timeout_ms))
28812881

28822882
@_auto_reconnect
2883-
def get_all_node_info(self, timeout=None):
2883+
def get_all_node_info(self, timeout=None) -> Dict[NodeID, GcsNodeInfo]:
28842884
cdef:
28852885
int64_t timeout_ms = round(1000 * timeout) if timeout else -1
2886-
CGcsNodeInfo node_info
2887-
c_vector[CGcsNodeInfo] node_infos
2886+
CGcsNodeInfo c_node_info
2887+
c_vector[CGcsNodeInfo] c_node_infos
2888+
c_vector[c_string] serialized_node_infos
28882889
with nogil:
2889-
check_status(self.inner.get().GetAllNodeInfo(timeout_ms, node_infos))
2890+
check_status(self.inner.get().GetAllNodeInfo(timeout_ms, c_node_infos))
2891+
for c_node_info in c_node_infos:
2892+
serialized_node_infos.push_back(c_node_info.SerializeAsString())
28902893

28912894
result = {}
2892-
for node_info in node_infos:
2893-
c_resources = PythonGetResourcesTotal(node_info)
2894-
result[node_info.node_id()] = {
2895-
"node_name": node_info.node_name(),
2896-
"state": node_info.state(),
2897-
"labels": PythonGetNodeLabels(node_info),
2898-
"resources": {key.decode(): value for key, value in c_resources}
2899-
}
2895+
for serialized in serialized_node_infos:
2896+
node_info = GcsNodeInfo()
2897+
node_info.ParseFromString(serialized)
2898+
result[NodeID.from_binary(node_info.node_id)] = node_info
29002899
return result
29012900

29022901
@_auto_reconnect
2903-
def get_all_job_info(self, timeout=None) -> Dict[bytes, JobTableData]:
2902+
def get_all_job_info(self, timeout=None) -> Dict[JobID, JobTableData]:
29042903
# Ideally we should use json_format.MessageToDict(job_info),
29052904
# but `job_info` is a cpp pb message not a python one.
29062905
# Manually converting each and every protobuf field is out of question,
@@ -2918,7 +2917,7 @@ cdef class GcsClient:
29182917
for serialized in serialized_job_infos:
29192918
job_info = JobTableData()
29202919
job_info.ParseFromString(serialized)
2921-
result[job_info.job_id] = job_info
2920+
result[JobID.from_binary(job_info.job_id)] = job_info
29222921
return result
29232922

29242923
@_auto_reconnect

python/ray/includes/common.pxd

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
506506

507507
cdef cppclass CJobConfig "ray::rpc::JobConfig":
508508
c_string ray_namespace() const
509-
const c_string &SerializeAsString()
509+
const c_string &SerializeAsString() const
510510

511511
cdef cppclass CNodeDeathInfo "ray::rpc::NodeDeathInfo":
512512
int reason() const
@@ -526,6 +526,7 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
526526
int runtime_env_agent_port() const
527527
CNodeDeathInfo death_info() const
528528
void ParseFromString(const c_string &serialized)
529+
const c_string& SerializeAsString() const
529530

530531
cdef enum CGcsNodeState "ray::rpc::GcsNodeInfo_GcsNodeState":
531532
ALIVE "ray::rpc::GcsNodeInfo_GcsNodeState_ALIVE",
@@ -534,7 +535,7 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
534535
c_string job_id() const
535536
c_bool is_dead() const
536537
CJobConfig config() const
537-
const c_string &SerializeAsString()
538+
const c_string &SerializeAsString() const
538539

539540
cdef cppclass CPythonFunction "ray::rpc::PythonFunction":
540541
void set_key(const c_string &key)
@@ -570,7 +571,7 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
570571
cdef cppclass CActorTableData "ray::rpc::ActorTableData":
571572
CAddress address() const
572573
void ParseFromString(const c_string &serialized)
573-
const c_string &SerializeAsString()
574+
const c_string &SerializeAsString() const
574575

575576
cdef extern from "ray/common/task/task_spec.h" nogil:
576577
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":

python/ray/includes/unique_ids.pxi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ cdef class JobID(BaseID):
230230
check_id(id, CJobID.Size())
231231
self.data = CJobID.FromBinary(<c_string>id)
232232

233+
@classmethod
234+
def from_binary(cls, id_bytes):
235+
if not isinstance(id_bytes, bytes):
236+
raise TypeError("Expect bytes, got " + str(type(id_bytes)))
237+
return cls(id_bytes)
238+
233239
cdef CJobID native(self):
234240
return <CJobID>self.data
235241

python/ray/serve/_private/cluster_node_info_cache.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,21 @@ def update(self):
2626
"""
2727
nodes = self._gcs_client.get_all_node_info(timeout=RAY_GCS_RPC_TIMEOUT_S)
2828
alive_nodes = [
29-
(ray.NodeID.from_binary(node_id).hex(), node["node_name"].decode("utf-8"))
29+
(node_id.hex(), node.node_name)
3030
for (node_id, node) in nodes.items()
31-
if node["state"] == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE
31+
if node.state == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE
3232
]
3333

3434
# Sort on NodeID to ensure the ordering is deterministic across the cluster.
3535
sorted(alive_nodes)
3636
self._cached_alive_nodes = alive_nodes
3737
self._cached_node_labels = {
38-
ray.NodeID.from_binary(node_id).hex(): {
39-
label_name.decode("utf-8"): label_value.decode("utf-8")
40-
for label_name, label_value in node["labels"].items()
41-
}
42-
for (node_id, node) in nodes.items()
38+
node_id.hex(): dict(node.labels) for (node_id, node) in nodes.items()
4339
}
4440

4541
# Node resources
4642
self._cached_total_resources_per_node = {
47-
ray.NodeID.from_binary(node_id).hex(): node["resources"]
43+
node_id.hex(): dict(node.resources_total)
4844
for (node_id, node) in nodes.items()
4945
}
5046

python/ray/tests/test_state_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3584,7 +3584,7 @@ def f(signal):
35843584
ray.get(signal.wait.remote())
35853585

35863586
client = ray.worker.global_worker.gcs_client
3587-
job_id = ray.worker.global_worker.current_job_id.binary()
3587+
job_id = ray.worker.global_worker.current_job_id
35883588
all_job_info = client.get_all_job_info()
35893589
assert len(all_job_info) == 1
35903590
assert job_id in all_job_info

0 commit comments

Comments
 (0)