Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutStat
return RolloutState.SKIPPED
elif RolloutState.FAILED in group_states:
return RolloutState.FAILED
elif RolloutState.EXPIRED in group_states:
return RolloutState.EXPIRED
elif RolloutState.ABORTED in group_states:
return RolloutState.ABORTED
elif all(state == RolloutState.COMPLETED for state in group_states):
Expand Down Expand Up @@ -391,6 +393,34 @@ def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: R
ray.internal.free(old_obs_refs, local_only=False)
replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids]
self._update_replay_meta_state(replay_meta, new_state)
if new_state == RolloutState.EXPIRED and self.tail_batch_trigger_size <= 0:
self._clear_multimodal_objectrefs(replay_meta)

def _clear_multimodal_objectrefs(self, replay_meta: ReplayMeta):
if replay_meta.action_ref is None:
return

data_item = ray.get(replay_meta.action_ref)
multimodal_info = getattr(data_item, "multimodal_train_info", None)
if not multimodal_info:
return

refs_to_free: List[ObjectRef] = []
changed = False
for key, value in list(multimodal_info.items()):
if isinstance(value, ObjectRef):
refs_to_free.append(value)
multimodal_info[key] = None
changed = True

if not changed:
return

old_action_ref = replay_meta.action_ref
replay_meta.action_ref = ray.put(data_item)
if isinstance(old_action_ref, ObjectRef):
refs_to_free.append(old_action_ref)
free_object_refs(refs_to_free)

def add(self, grouped_dataitem: List[RLDataFlowItem]):
"""Adds a group of data items to the storage.
Expand Down Expand Up @@ -848,6 +878,7 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta):
This is the single source of truth for deleting an action.
"""
action_id = replay_meta.action_id
root_id = replay_meta.root_id

self._release_replay_meta_refs(replay_meta)

Expand All @@ -859,6 +890,12 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta):

self._actions.pop(action_id, None)
self._action2observations.pop(action_id, None)
if root_id in self._root2actions:
self._root2actions[root_id] = [
stored_action_id for stored_action_id in self._root2actions[root_id] if stored_action_id != action_id
]
if not self._root2actions[root_id]:
del self._root2actions[root_id]
del replay_meta

def _clear_meta_for_root(self, replay_meta: ReplayMeta):
Expand Down
Loading