Skip to content

Commit 97b85ef

Browse files
committed
add deepspeed support, fine-tuned results
1 parent 86e4d09 commit 97b85ef

File tree

3 files changed

+157
-59
lines changed

3 files changed

+157
-59
lines changed

diffu_grpo/diffu_grpo_train.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from accelerate import PartialState
99
from data_utils import (
1010
get_apps_questions,
11+
get_countdown_questions,
1112
get_dapo17_data,
1213
get_gsm8k_questions,
1314
get_math500_questions,
15+
get_sudoku_questions,
1416
)
1517
from diffu_grpo_config import DiffuGRPOConfig
1618

@@ -23,7 +25,9 @@
2325
boxed_and_answer_tags_format_reward,
2426
codefence_reward_func,
2527
correctness_reward_func_math,
28+
countdown_reward_func,
2629
soft_format_reward_func,
30+
sudoku_reward_func,
2731
)
2832
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
2933
from transformers.trainer_callback import TrainerCallback
@@ -32,7 +36,6 @@
3236

3337
from model.llada.configuration_llada import LLaDAConfig
3438
from model.llada.lladou import LLaDOUModelLM
35-
from model.path_utils import lladou_config_dir
3639
from utils import set_random_seed
3740

3841
logging.set_verbosity_info()
@@ -101,6 +104,11 @@ def main(grpo_config, model_config):
101104
"reward_funcs": [apps_reward_func, codefence_reward_func],
102105
"reward_weights": [3.0, 1.0],
103106
},
107+
"countdown": {
108+
"loader": lambda: get_countdown_questions("train", prompt_mode=prompt_mode),
109+
"reward_funcs": [countdown_reward_func],
110+
"reward_weights": [1.0],
111+
},
104112
"dapo17": {
105113
"loader": lambda: get_dapo17_data("train", prompt_mode=prompt_mode),
106114
"reward_funcs": [
@@ -109,10 +117,15 @@ def main(grpo_config, model_config):
109117
],
110118
"reward_weights": [1.0, 1.0],
111119
},
120+
"sudoku": {
121+
"loader": get_sudoku_questions,
122+
"reward_funcs": [sudoku_reward_func],
123+
"reward_weights": [1.0],
124+
},
112125
}
113126

114127
if thinking_mode:
115-
thinking_datasets = ("gsm8k", "math500", "dapo17", "apps")
128+
thinking_datasets = ("gsm8k", "math500", "dapo17", "countdown", "apps")
116129
for key in thinking_datasets:
117130
if key in dataset_registry:
118131
dataset_registry[key]["reward_funcs"] = [
@@ -148,7 +161,13 @@ def main(grpo_config, model_config):
148161
# Shuffle dataset with fixed seed for reproducibility
149162
dataset = dataset.shuffle(seed=grpo_config.seed)
150163

151-
train_set = dataset
164+
# Split dataset if needed
165+
if grpo_config.dataset in ["countdown", "sudoku"]:
166+
train_set = dataset.select(
167+
range(0, len(dataset) - 500)
168+
) # Leave last 500 for evaluation
169+
else:
170+
train_set = dataset
152171

153172
# 4 bit quantization configuration
154173
if model_config.load_in_4bit:
@@ -170,6 +189,8 @@ def main(grpo_config, model_config):
170189
state.wait_for_everyone()
171190
local_dir = snapshot_download(grpo_config.model_path)
172191

192+
repo_root = Path(__file__).resolve().parent.parent
193+
173194
if grpo_config.use_official_model:
174195
model = AutoModel.from_pretrained(
175196
local_dir,
@@ -191,7 +212,9 @@ def main(grpo_config, model_config):
191212
# quantization_config=bnb_config,
192213
# )
193214
# elif "lladou" in grpo_config.model_path.lower():
194-
lladou_config = LLaDAConfig.from_pretrained(lladou_config_dir())
215+
lladou_config = LLaDAConfig.from_pretrained(
216+
repo_root / "model/llada/lladou_config"
217+
)
195218
assert lladou_config.flash_attention
196219
model = LLaDOUModelLM.from_pretrained(
197220
local_dir,
@@ -239,6 +262,13 @@ def main(grpo_config, model_config):
239262
processing_class=tokenizer,
240263
)
241264

265+
# Propagate teacher device preference (if any) so accel_reward can honor it
266+
# when loading the verifier/teacher model.
267+
if grpo_config.teacher_device is not None:
268+
os.environ["DIFFUGRPO_TEACHER_DEVICE"] = grpo_config.teacher_device
269+
else:
270+
os.environ.pop("DIFFUGRPO_TEACHER_DEVICE", None)
271+
242272
local_log_path = grpo_config.local_log_path or os.path.join(
243273
grpo_config.output_dir, "local_training_logs.jsonl"
244274
)

diffu_grpo/diffu_grpo_trainer.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, **kwargs):
5252
super().__init__(**kwargs)
5353
self.beta = beta
5454

55-
self.model_wrapped = self.model_wrapped.to(torch.bfloat16)
55+
# self.model_wrapped = self.model_wrapped.to(torch.bfloat16)
5656

5757
def _generate_and_score_completions(
5858
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@@ -153,7 +153,7 @@ def _generate_and_score_completions(
153153
prompt_ids,
154154
prompt_mask,
155155
)
156-
156+
# breakpoint()
157157
# Rollout
158158
with (
159159
profiling_context(self, "transformers.generate"),
@@ -189,10 +189,9 @@ def _generate_and_score_completions(
189189
use_scheduler=self.args.use_scheduler,
190190
)
191191
logger.info("Rollout completed")
192-
if (
193-
self.args.torch_empty_cache_steps is not None
194-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
195-
):
192+
193+
# let deepspeed manage cuda cache
194+
if self.accelerator.distributed_type != DistributedType.DEEPSPEED:
196195
torch.cuda.empty_cache()
197196

198197
# Compute prompt length and extract completion ids
@@ -243,22 +242,12 @@ def _generate_and_score_completions(
243242
example["student_logprob"] = logprob
244243

245244
with torch.no_grad():
246-
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
247-
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
248-
# **samples** may come from an earlier version of the model. In that case, we need to track old_per_token_logps
249-
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
250-
# old_per_token_logps to None.
251-
# This will only run when self._step % generate_every == 0 or self._buffered_inputs is None
252-
# generate_every = (
253-
# self.args.steps_per_generation * self.num_iterations
254-
# ) # generation frequency
245+
# In the diffusion setting we already have per-token log-probs for the rollout trajectory (`sequence_logp`)
246+
# computed in `generate`. We always reuse them as `old_per_token_logps` so that we can explicitly
247+
# measure and correct any on/off-policy mismatch during replay.
255248
old_per_token_logps = sequence_logp.clone().detach()
256-
# if self.args.gradient_accumulation_steps % generate_every != 0:
257-
# old_per_token_logps = sequence_logp.clone().detach()
258-
# else:
259-
# old_per_token_logps = None
260249

261-
# Compute the per-token log probabilities for the reference model
250+
# Compute the per-token log probabilities for the reference model when KL regularization is enabled.
262251
if self.beta != 0.0:
263252
ref_per_token_logps = old_per_token_logps.clone().detach()
264253
else:
@@ -452,16 +441,10 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
452441
inputs["completion_mask"],
453442
)
454443
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
455-
# Replay must mirror the rollout mask: during generation only the prompt tokens
456-
# were marked as valid, so keep zeros on the completion portion.
457-
prompt_only_mask = torch.ones_like(prompt_ids, dtype=prompt_mask.dtype)
458-
attention_mask = torch.cat(
459-
[
460-
prompt_only_mask,
461-
torch.zeros_like(completion_ids, dtype=prompt_mask.dtype),
462-
],
463-
dim=1,
464-
)
444+
# Replay must mirror the rollout mask used during generation:
445+
# prompt padding is preserved, completion tokens are treated as valid (non-padding) tokens.
446+
attention_mask = torch.ones_like(input_ids, dtype=prompt_mask.dtype)
447+
attention_mask[:, :prompt_len] = prompt_mask
465448
sampling_traj = inputs["sampling_traj"]
466449
x0_hist = inputs["x0_hist"]
467450
all_advantages = inputs["advantages"]
@@ -483,17 +466,13 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
483466

484467
all_traj_len = self.accelerator.gather(
485468
torch.tensor(traj_len, device=input_ids_batch.device)
486-
)
469+
)
487470
max_traj_len = all_traj_len.max().item()
488471

489472
mask_id = self.args.mask_id
490473
cur_input = input_ids_batch.clone()
491474
cur_input[:, prompt_len:] = mask_id
492-
if (
493-
self.args.torch_empty_cache_steps is not None
494-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
495-
):
496-
torch.cuda.empty_cache()
475+
# torch.cuda.empty_cache()
497476
for step in tqdm(range(max_traj_len), desc="Computing per-token logps"):
498477
# logger.info(f"Step {step} of {traj_len}")
499478
# running the model in batches per step
@@ -546,14 +525,19 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
546525
cur_logp = torch.zeros_like(
547526
unmasking_prob[batch], dtype=torch.float32
548527
).unsqueeze(0)
528+
EPS = 1e-6
529+
clamped_prob = torch.clamp(unmasking_prob[batch], min=EPS, max=1.0 - EPS)
549530
if len(cur_traj[batch][step]) > 0:
531+
# Use log1p for log(1-p) when p is small
550532
cur_logp[:, keep_mask_index_mask] = torch.log1p(
551-
-unmasking_prob[batch, keep_mask_index_mask]
533+
-clamped_prob[keep_mask_index_mask]
552534
)
535+
# Use log for log(p), now safe due to clamping
553536
cur_logp[:, unmasking_index_mask] = (
554-
torch.log(unmasking_prob[batch, unmasking_index_mask])
537+
torch.log(clamped_prob[unmasking_index_mask])
555538
+ x0_logp[batch, unmasking_index_mask]
556539
)
540+
557541
if (
558542
torch.isnan(cur_logp).sum() > 0
559543
or not torch.isfinite(cur_logp).all()
@@ -620,17 +604,15 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
620604
# Two-sided clipping
621605
if self.args.delta is not None:
622606
coef_1 = torch.clamp(coef_1, max=self.args.delta)
623-
advantages = torch.where(
624-
advantages < self.args.advantage_min_clip,
625-
torch.zeros_like(
626-
advantages
627-
), # ignores advantages below a threshold
628-
advantages,
607+
advantages = torch.clamp(
608+
advantages, min=self.args.advantage_min_clip
629609
)
630610

631611
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
632612
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
633613
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
614+
# if entropy_mask is not None:
615+
# per_token_loss = per_token_loss * entropy_mask
634616

635617
if self.beta != 0.0:
636618
per_token_loss = per_token_loss + self.beta * per_token_kl
@@ -640,21 +622,28 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
640622
/ per_token_loss.size(0)
641623
/ self.max_completion_length
642624
)
643-
loss = loss / self.current_gradient_accumulation_steps
625+
644626
if loss.grad_fn is None:
645627
# this means that no token is unmasked, this can happen because generated completion rollout is splitted into smaller batches
646628
# raise ValueError("No gradient found")
647629
loss = logits.exp().sum() * 0.0 + unmasking_prob.sum() * 0.0
648630
loss_list.append(loss.item())
649631
# print(f"Loss: {loss}")
650-
# Backward pass
651632
if bad_flag or loss.isnan():
652633
accel_break(bad_process_index)
653-
# logger.info(f"[Rank {self.accelerator.process_index}]Loss: {loss}")
654-
self.backward(loss, num_items_in_batch)
634+
635+
# Backward pass: accumulate gradients over diffusion steps but only let DeepSpeed
636+
# take an optimizer step on the final (chunk, step) pair.
637+
force_deepspeed_step = False
638+
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
639+
is_last_chunk = start + batch_size == input_ids.size(0)
640+
is_last_step = step == max_traj_len - 1
641+
force_deepspeed_step = is_last_chunk and is_last_step
642+
self.backward(loss, num_items_in_batch, force_deepspeed_step=force_deepspeed_step)
655643
return_loss += loss.detach()
656644

657645
del cur_input
646+
# torch.cuda.empty_cache() # to reduce memory usage but will make things super slow
658647
cur_input = next_input
659648

660649
# Log the metrics
@@ -758,11 +747,31 @@ def compute_loss(
758747
else:
759748
return self._compute_loss(model, inputs, num_items_in_batch)
760749

761-
def backward(self, loss: torch.Tensor, num_items_in_batch):
750+
def backward(self, loss: torch.Tensor, num_items_in_batch, force_deepspeed_step=False):
751+
if (force_deepspeed_step and self.accelerator.distributed_type != DistributedType.DEEPSPEED):
752+
raise ValueError("force_deepspeed_step should only be true during DeepSpeed runs")
753+
762754
kwargs = {}
755+
756+
757+
# since we don't want deepspeed to step the optimizer
758+
# every single time we unmask a chunk, we force the gradient sync
759+
# flag to be false here until we're ready to step (after the full rollout)
760+
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
761+
orig_sync = getattr(self.accelerator, "sync_gradients", True)
762+
self.accelerator.sync_gradients = force_deepspeed_step
763+
self.accelerator.backward(loss, **kwargs)
764+
self.accelerator.sync_gradients = orig_sync
765+
return # exit early so that we don't clear any live gradients from the cache
766+
767+
if (
768+
self.args.torch_empty_cache_steps is not None
769+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
770+
):
771+
torch.cuda.empty_cache()
763772

764773
if self.args.n_gpu > 1:
765-
loss = loss.mean() # mean() to average on multi-gpu parallel training
774+
loss = loss.mean() # mean() to average on multi-gpu parallel training (non-deepspeed)
766775

767776
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
768777
if (
@@ -771,9 +780,4 @@ def backward(self, loss: torch.Tensor, num_items_in_batch):
771780
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
772781
loss = loss / self.current_gradient_accumulation_steps
773782

774-
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
775-
# https://github.com/huggingface/transformers/pull/35808
776-
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
777-
kwargs["scale_wrt_gas"] = False
778-
779783
self.accelerator.backward(loss, **kwargs)

diffu_grpo/run_grpo_gsm_wei.sbatch

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash
2+
#SBATCH --partition=h100
3+
#SBATCH --gres=gpu:8
4+
#SBATCH --time=96:00:00
5+
#SBATCH --job-name=gsm8k
6+
#SBATCH --output=logs/gsm8k_%j.out
7+
#SBATCH --error=logs/gsm8k_%j.err
8+
9+
# Set environment variables
10+
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
11+
export TOKENIZERS_PARALLELISM=false
12+
export PYTHONUNBUFFERED=1
13+
14+
# Create logs directory if it doesn't exist
15+
mkdir -p logs
16+
17+
# The running script is in a different setup as in the paper and requires different hyperparameters.
18+
TEMPERATURE=1.0
19+
BLOCK_LENGTH=32
20+
MAX_COMPLETION_LENGTH=256
21+
NUM_GENERATIONS=96
22+
PER_DEVICE_TRAIN_BATCH_SIZE=12
23+
LEARNING_RATE=1e-6
24+
ADVANTAGE_MIN_CLIP=0.0
25+
BETA=0.0
26+
EFFICIENCY_REWARD_WEIGHT=0.2
27+
28+
RUN_NAME="gsm8k"
29+
MODEL_DIR="YOUR_PATH"
30+
31+
OUTPUT_DIR="./checkpoints/${RUN_NAME}"
32+
33+
accelerate launch \
34+
--config_file accelerate.yaml \
35+
--num_processes 8 \
36+
--main_process_port 12446 \
37+
diffu_grpo_train.py \
38+
--config sbatch_scripts/train.yaml \
39+
--model_path "${MODEL_DIR}" \
40+
--dataset "gsm8k" \
41+
--run_name "${RUN_NAME}" \
42+
--output_dir "${OUTPUT_DIR}" \
43+
--temperature ${TEMPERATURE} \
44+
--num_iterations 1 \
45+
--max_steps 10000 \
46+
--gen_step_efficiency_reward_weight ${EFFICIENCY_REWARD_WEIGHT} \
47+
--beta ${BETA} \
48+
--normalize true \
49+
--scale 30.0 \
50+
--use_scheduler false \
51+
--max_prompt_length 256 \
52+
--max_completion_length ${MAX_COMPLETION_LENGTH} \
53+
--block_length ${BLOCK_LENGTH} \
54+
--learning_rate ${LEARNING_RATE} \
55+
--warmup_steps ${WARMUP_STEPS} \
56+
--warmup_learning_rate ${WARMUP_LEARNING_RATE} \
57+
--advantage_min_clip ${ADVANTAGE_MIN_CLIP} \
58+
--freeze_unmasking_head false \
59+
--num_generations ${NUM_GENERATIONS} \
60+
--per_device_train_batch_size ${PER_DEVICE_TRAIN_BATCH_SIZE} \
61+
--gradient_accumulation_steps 1 \
62+
--rollout_mode training \
63+
--scale_reward none \
64+
--prompt_mode "non-thinking"

0 commit comments

Comments
 (0)