1010# ====================================
1111# ====== DataFlow 数据流 ==============
1212# ====================================
13+ from xtuner .v1 .utils .cache import CacheObj
1314from xtuner .v1 .utils .logger import get_logger
1415
1516
1617if TYPE_CHECKING :
17- import ray
18-
19- RayObjectRef = ray .ObjectRef
18+ from ray import ObjectRef as RayObjectRef
2019else :
2120 RayObjectRef : TypeAlias = Any
2221
@@ -60,12 +59,12 @@ class Status(Enum):
6059
6160class MultimodalInfo (TypedDict ):
6261 # 使用TypedDict给出pixel_values的类型提示
63- pixel_values : NotRequired [torch .Tensor | RayObjectRef | None ] # type: ignore[valid-type]
62+ pixel_values : NotRequired [torch .Tensor | RayObjectRef | None ]
6463 image_grid_thw : NotRequired [torch .Tensor ]
6564 position_ids : NotRequired [torch .Tensor ]
6665
6766
68- class RolloutState (BaseModel ):
67+ class RolloutState (CacheObj , BaseModel ):
6968 model_config = ConfigDict (extra = "forbid" , arbitrary_types_allowed = True )
7069
7170 # --- 数据 ---
@@ -88,22 +87,22 @@ class RolloutState(BaseModel):
8887 response : str | None = None
8988 response_ids : list [int ] | None = None
9089 logprobs : list [float ] | None = None
91- routed_experts : list [int ] | RayObjectRef | None = None # type: ignore[valid-type]
90+ routed_experts : list [int ] | RayObjectRef | None = None
9291 finish_reason : str | None = None
9392
94- @field_serializer (' routed_experts' )
93+ @field_serializer (" routed_experts" )
9594 def _serialize_routed_experts (self , value : list [int ] | RayObjectRef | None ) -> list [int ] | None :
9695 """Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。"""
9796 if value is None :
9897 return None
9998 try :
10099 import ray
100+
101101 if isinstance (value , ray .ObjectRef ):
102102 return None
103103 except ImportError :
104104 pass
105- if type (value ).__name__ == 'ObjectRef' and 'ray' in getattr (
106- type (value ), '__module__' , '' ):
105+ if type (value ).__name__ == "ObjectRef" and "ray" in getattr (type (value ), "__module__" , "" ):
107106 return None
108107 return value # list[int]
109108
0 commit comments