Skip to content

Commit c0574e4

Browse files
Merge pull request #28 from complex-reasoning/codex/pr23-safer-fixes
[codex] fix data indexing and clip metrics
2 parents d509152 + 6cbe5f8 commit c0574e4

2 files changed

Lines changed: 40 additions & 21 deletions

File tree

process-data.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
import numpy as np
33
import os # Import os for path manipulation
44

5+
6+
def build_extra_info(value: object, index: int) -> dict[str, object]:
7+
if isinstance(value, dict):
8+
extra_info = dict(value)
9+
else:
10+
extra_info = {}
11+
extra_info["index"] = index
12+
return extra_info
13+
514
# --- Configuration ---
615
# Define the directory containing the input file
716
data_directory = 'data'
@@ -30,9 +39,7 @@
3039
dummy_df.to_parquet(input_parquet_path)
3140

3241

33-
# Read the Parquet file into a pandas DataFrame
34-
# We don't need the original index, so we can reset it immediately if needed,
35-
# but setting df.index directly below overwrites it anyway.
42+
# Read the Parquet file into a pandas DataFrame.
3643
print(f"Reading Parquet file from: {input_parquet_path}")
3744
df = pd.read_parquet(input_parquet_path)
3845
print("Original DataFrame info:")
@@ -45,23 +52,31 @@
4552
num_rows = len(df)
4653
print(f"\nDataFrame has {num_rows} rows.")
4754

48-
# Create a new sequential index starting from 1 up to the number of rows
49-
# Name the new index 'extra_info' as requested
50-
print("Generating new sequential index named 'extra_info' from 1...")
51-
new_index = pd.RangeIndex(start=1, stop=num_rows + 1, step=1, name='extra_info')
55+
# RLHFDataset reads row_dict["extra_info"]["index"], so store the repeat
56+
# index inside the extra_info column rather than as a pandas index.
57+
print("Generating 0-based extra_info.index values...")
58+
if "extra_info" in df.columns:
59+
existing_extra_info = df["extra_info"].tolist()
60+
else:
61+
existing_extra_info = [None] * num_rows
5262

53-
# Set the new index for the DataFrame, replacing the old one
54-
df.index = new_index
55-
print("New index assigned.")
63+
df["extra_info"] = [
64+
build_extra_info(value=value, index=index)
65+
for index, value in enumerate(existing_extra_info)
66+
]
67+
df = df.reset_index(drop=True)
68+
print("extra_info.index assigned.")
5669

5770
# Write the modified DataFrame back to a new Parquet file
58-
# index=True ensures the new index ('extra_info') is written to the file
5971
print(f"Writing modified DataFrame to: {output_parquet_path}")
60-
df.to_parquet(output_parquet_path, index=True)
72+
df.to_parquet(output_parquet_path, index=False)
6173

6274
print("\n--- Success ---")
6375
print(f"Successfully processed '{input_parquet_path}'.")
64-
print(f"Created new index named 'extra_info' from 1 to {num_rows}.")
76+
if num_rows:
77+
print(f"Created 0-based extra_info.index values from 0 to {num_rows - 1}.")
78+
else:
79+
print("Created empty extra_info.index values.")
6580
print(f"Output saved to '{output_parquet_path}'.")
6681

6782
# Display the first few rows with the new index to verify

verl/trainer/ppo/core_algos.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,12 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
135135
id2score[index[i]].append(scores[i])
136136
for idx in id2score:
137137
if len(id2score[idx]) == 1:
138-
id2mean[idx] = torch.tensor(0.0)
139-
id2std[idx] = torch.tensor(1.0)
138+
id2mean[idx] = scores.new_tensor(0.0)
139+
id2std[idx] = scores.new_tensor(1.0)
140140
elif len(id2score[idx]) > 1:
141-
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
142-
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
141+
scores_tensor = torch.stack(id2score[idx])
142+
id2mean[idx] = scores_tensor.mean()
143+
id2std[idx] = scores_tensor.std()
143144
else:
144145
raise ValueError(f"no score in prompt index: {idx}")
145146
for i in range(bsz):
@@ -522,7 +523,7 @@ def compute_policy_loss_reinforce(old_log_prob,
522523
ppo_kl: (float)
523524
the estimated KL divergence between the latest updating policy and the old sampling policy
524525
pg_clipfrac_lower: (float)
525-
the fraction of policy gradient loss being clipped when the advantage is negative
526+
the fraction of policy gradient loss being clipped at the lower bound
526527
"""
527528

528529
negative_approx_kl = log_prob - old_log_prob
@@ -567,9 +568,12 @@ def compute_policy_loss_reinforce(old_log_prob,
567568
else:
568569
A = (advantages * w_ + kl_term).detach()
569570
pg_losses = -A * log_prob
570-
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses, pg_losses).float(), response_mask)
571-
pg_clipfrac_lower = verl_F.masked_mean(
572-
torch.gt(pg_losses, pg_losses) * (advantages < 0).float(), response_mask)
571+
# This branch uses hard-clipped importance weights in A, so report how
572+
# often w falls outside the clamp bounds.
573+
lower_clipped = w < (1 - clip_ratio_low)
574+
upper_clipped = w > (1 + clip_ratio_high)
575+
pg_clipfrac = verl_F.masked_mean((lower_clipped | upper_clipped).float(), response_mask)
576+
pg_clipfrac_lower = verl_F.masked_mean(lower_clipped.float(), response_mask)
573577

574578
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
575579

0 commit comments

Comments
 (0)