Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agentlightning/verl/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ actor_rollout_ref:
custom_async_server:
path: pkg://agentlightning.verl.async_server
name: PatchedvLLMServer
trace_agg_mode: transition # transition or trajectory
131 changes: 102 additions & 29 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
llm_proxy: LLMProxy | None = None,
store: LightningStore | None = None,
adapter: TraceTripletAdapter | None = None,
trace_agg_mode: Literal["transition", "trajectory"] = "transition",
):
self.mode = mode

Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
self.pad_token_id = pad_token_id
self.tokenizer = tokenizer
self.reward_fillna_value = reward_fillna_value
self.trace_agg_mode = trace_agg_mode

# Internal State
self.backend_llm_server_addresses: List[str] = []
Expand Down Expand Up @@ -630,49 +632,119 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
reward_list: List[float] = []
data_id_list: List[str] = []
rollout_id_list: List[str] = []
turn_index_list: List[int] = []
turn_index_list: List[int] | List[List[int]] = []
is_drop_list: List[bool] = []
n_trunc_sample_because_of_response = 0

for rollout_id, sample_info in finished_id_to_sample_info.items():
for turn_index, trace in enumerate(sample_info["trace_list"]):
if self.trace_agg_mode == "transition":
for rollout_id, sample_info in finished_id_to_sample_info.items():
for turn_index, trace in enumerate(sample_info["trace_list"]):

reward_list.append(sample_info["reward"])
prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]
reward_list.append(sample_info["reward"])
prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

# Mark samples with prompts exceeding max_prompt_length to be dropped later
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[:max_prompt_length]
is_drop_list.append(True)
else:
is_drop_list.append(False)
# Mark samples with prompts exceeding max_prompt_length to be dropped later
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[:max_prompt_length]
is_drop_list.append(True)
else:
is_drop_list.append(False)

# Truncate responses that exceed max_response_length
if len(response_ids) > max_response_length:
response_ids = response_ids[:max_response_length]
n_trunc_sample_because_of_response += 1
# Truncate responses that exceed max_response_length
if len(response_ids) > max_response_length:
response_ids = response_ids[:max_response_length]
n_trunc_sample_because_of_response += 1

# Pad prompts to the left and responses to the right
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
prompt_ids, max_prompt_length, self.pad_token_id
)
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
response_ids, max_response_length, self.pad_token_id
)
# Pad prompts to the left and responses to the right
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
prompt_ids, max_prompt_length, self.pad_token_id
)
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
response_ids, max_response_length, self.pad_token_id
)

input_ids_list.append(one_input_ids)
input_attention_mask_list.append(one_input_attention_mask)
response_ids_list.append(one_response_ids)
response_attention_mask_list.append(one_response_attention_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
turn_index_list.append(turn_index)
input_ids_list.append(one_input_ids)
input_attention_mask_list.append(one_input_attention_mask)
response_ids_list.append(one_response_ids)
response_attention_mask_list.append(one_response_attention_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
turn_index_list.append(turn_index)

elif self.trace_agg_mode == "trajectory":
response_mask_list: List[List[int]] = []

for rollout_id, sample_info in finished_id_to_sample_info.items():
merged_trace_idx: List[List[int]] = []
current_merged_trace_idx: List[int] = []
current_context: List[int] = []
for turn_index, trace in enumerate(sample_info["trace_list"]):
if (trace["prompt_ids"] + trace["response_ids"])[:len(current_context)] == current_context:
current_context = trace["prompt_ids"] + trace["response_ids"]
current_merged_trace_idx.append(turn_index)
else:
# assert len(current_merged_trace_idx) > 0
merged_trace_idx.append(current_merged_trace_idx)
current_merged_trace_idx = [turn_index]
current_context = trace["prompt_ids"] + trace["response_ids"]
if current_merged_trace_idx not in merged_trace_idx:
merged_trace_idx.append(current_merged_trace_idx)

for current_merged_trace_idx in merged_trace_idx:
prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"]
response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"]
prompt_length = len(prompt_ids)
response_mask = [1] * len(response_ids)
for turn_index in current_merged_trace_idx[1:]:
trace = sample_info["trace_list"][turn_index]
new_prompt_length = len(trace["prompt_ids"]) - len(response_ids) - prompt_length
response_ids += trace["prompt_ids"][-new_prompt_length:]
response_ids += trace["response_ids"]
response_mask += [0] * new_prompt_length
response_mask += [1] * len(trace["response_ids"])

reward_list.append(sample_info["reward"])

# Mark samples with prompts exceeding max_prompt_length to be dropped later
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[:max_prompt_length]
is_drop_list.append(True)
else:
is_drop_list.append(False)

# Truncate responses that exceed max_response_length
if len(response_ids) > max_response_length:
response_ids = response_ids[:max_response_length]
n_trunc_sample_because_of_response += 1

# Pad prompts to the left and responses to the right
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
prompt_ids, max_prompt_length, self.pad_token_id
)
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
response_ids, max_response_length, self.pad_token_id
)
one_response_mask, _ = get_right_padded_ids_and_attention_mask(
response_mask, max_response_length, 0
)

input_ids_list.append(one_input_ids)
input_attention_mask_list.append(one_input_attention_mask)
response_ids_list.append(one_response_ids)
response_attention_mask_list.append(one_response_attention_mask)
response_mask_list.append(one_response_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
turn_index_list.append(current_merged_trace_idx)
else:
raise ValueError(f"Unknown trace_agg_mode: {self.trace_agg_mode}")

n_transition = len(input_ids_list)
batch_input_ids = torch.LongTensor(input_ids_list).to(device)
input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
batch_response_ids = torch.LongTensor(response_ids_list).to(device)
response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)
response_mask = torch.LongTensor(response_mask_list).to(device) if self.trace_agg_mode == "trajectory" else None

# Concatenate prompts and responses to form the full sequence
batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
Expand Down Expand Up @@ -700,6 +772,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
"position_ids": position_ids,
"is_drop_mask": is_drop_mask,
"token_level_scores": token_level_scores.contiguous(),
**({"response_mask": response_mask} if self.trace_agg_mode == "trajectory" else {}),
},
batch_size=n_transition,
)
Expand Down
5 changes: 4 additions & 1 deletion agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def _train_step(self, batch_dict: dict) -> dict:
# uid is used for algorithm like GRPO, should be aligned to data id
batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]

batch.batch["response_mask"] = compute_response_mask(batch)
breakpoint()
if "response_mask" not in batch.batch:
batch.batch["response_mask"] = compute_response_mask(batch)

# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
Expand Down Expand Up @@ -310,6 +312,7 @@ def fit(self):
store=self.store,
llm_proxy=self.llm_proxy,
adapter=self.adapter,
trace_agg_mode=self.config.actor_rollout_ref.rollout.trace_agg_mode,
)
self.agent_mode_daemon.start()

Expand Down
Loading