Skip to content

Commit 2bc3cb7

Browse files
committed
[Refactor] Refactor GRPO as a separate class
ghstack-source-id: eecc32c Pull-Request: #3205
1 parent 7b85c71 commit 2bc3cb7

File tree

6 files changed

+768
-154
lines changed

6 files changed

+768
-154
lines changed

test/llm/test_objectives.py

Lines changed: 198 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from tensordict import lazy_stack, TensorDict
1515
from torchrl.data import History, LazyStackStorage, ReplayBuffer
1616
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
17-
from torchrl.modules.llm import Text, TransformersWrapper, vLLMWrapper
18-
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Tokens
19-
from torchrl.objectives.llm.grpo import MCAdvantage
17+
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
18+
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens
19+
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
2020
from torchrl.objectives.llm.sft import SFTLoss
2121

2222
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -53,7 +53,7 @@ def make_silly_trajectory(n_steps=None):
5353
rewards = [torch.randn(n_tokens, 1)]
5454
prompt = np.random.choice(prompts)
5555
td = TensorDict(
56-
text=Text(prompt=prompt),
56+
query=prompt, # MCAdvantage expects "query" key, not "text"
5757
next=TensorDict(
5858
reward=rewards, done=torch.zeros(1, dtype=torch.bool)
5959
),
@@ -83,8 +83,158 @@ def make_silly_trajectory(n_steps=None):
8383
assert "advantage" in s.keys()
8484

8585

86-
def test_grpo():
87-
...
86+
# Mock infrastructure moved to conftest.py
87+
88+
89+
def _mock_data_grpo(
90+
vocab_size: int, device: torch.device | str = "cpu"
91+
) -> TensorDict:
92+
from transformers import AutoTokenizer
93+
94+
device = torch.device(device)
95+
96+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
97+
prompt = History(
98+
role=["system", "user"],
99+
content=["You are a useful assistant.", "What is 2+2?"],
100+
batch_size=(2,),
101+
device=device,
102+
)
103+
response = History(
104+
role=["assistant"],
105+
content=["2 + 2 = 4."],
106+
batch_size=(1,),
107+
device=device,
108+
)
109+
full_history = prompt.extend(response, inplace=False)
110+
history = ChatHistory(
111+
prompt=prompt,
112+
response=response,
113+
full=full_history,
114+
device=device,
115+
)
116+
batch_size = 1
117+
118+
# Expand history to match batch size before getting tokens
119+
history = history.expand((batch_size,))
120+
next_history = ChatHistory(
121+
prompt=full_history,
122+
device=device,
123+
)
124+
next_history = next_history.expand((batch_size,))
125+
126+
# Now get tokens from the expanded history objects
127+
tokens_full = history.to_tokens(tokenizer)
128+
next_tokens = next_history.to_tokens(tokenizer)
129+
130+
# Get the actual sequence length from the tokens
131+
# tokens_full has structure with "full" key containing the actual tokens
132+
# We need to get the padded version to know the actual length
133+
tokens_input_ids = tokens_full.get(
134+
"full", as_padded_tensor=True, padding_side="left", padding_value=0
135+
)
136+
seq_len = tokens_input_ids.shape[-1]
137+
138+
# Create tensors with proper shapes
139+
reward = torch.randn(batch_size, seq_len, 1, device=device)
140+
done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device)
141+
advantage = torch.randn(batch_size, seq_len, 1, device=device)
142+
log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device)
143+
144+
# Create attention mask (all ones for non-padded tokens)
145+
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
146+
147+
# Import Masks to create proper mask structure
148+
from tensordict import MetaData
149+
from torchrl.modules.llm.policies.common import Masks
150+
151+
masks = Masks(
152+
all_attention_mask=attention_mask,
153+
all_assistant_mask=None, # Will be computed by the wrapper
154+
padded=MetaData(True),
155+
device=device,
156+
)
157+
158+
data = TensorDict(
159+
{
160+
"advantage": advantage,
161+
"history": history,
162+
"tokens": tokens_full % vocab_size,
163+
"masks": masks,
164+
"next": {
165+
"history": next_history,
166+
"tokens": next_tokens % vocab_size,
167+
"reward": reward,
168+
"done": done,
169+
},
170+
"log_probs": log_probs,
171+
},
172+
batch_size=(batch_size,),
173+
)
174+
return data
175+
176+
177+
class TestLosses:
178+
def test_grpo(self, mock_transformer_model):
179+
"""Test GRPO loss computation with mock models."""
180+
vocab_size = 1024
181+
device = torch.device("cpu")
182+
183+
# Create mock model and wrap it
184+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
185+
actor_network = TransformersWrapper(
186+
model,
187+
generate=False,
188+
pad_output=True,
189+
input_mode="history",
190+
)
191+
192+
# Create loss module
193+
loss_fn = GRPOLoss(actor_network)
194+
195+
# Create fake data
196+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
197+
198+
# Compute loss
199+
loss_vals = loss_fn(data)
200+
201+
# Assertions: Check output type and structure
202+
from torchrl.objectives.llm.grpo import GRPOLossOutput
203+
204+
assert isinstance(
205+
loss_vals, GRPOLossOutput
206+
), f"Expected GRPOLossOutput, got {type(loss_vals)}"
207+
208+
# Check that all expected keys are present
209+
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
210+
assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction"
211+
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
212+
assert hasattr(loss_vals, "ESS"), "Missing ESS"
213+
assert hasattr(loss_vals, "entropy"), "Missing entropy"
214+
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"
215+
216+
# Check tensor shapes (all losses should be scalars after reduction)
217+
assert (
218+
loss_vals.loss_objective.shape == ()
219+
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
220+
assert (
221+
loss_vals.clip_fraction.shape == ()
222+
), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}"
223+
assert (
224+
loss_vals.kl_approx.shape == ()
225+
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
226+
assert (
227+
loss_vals.ESS.shape == ()
228+
), f"ESS should be scalar, got {loss_vals.ESS.shape}"
229+
230+
# Check that losses are finite
231+
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
232+
assert torch.isfinite(loss_vals.ESS), "ESS is not finite"
233+
234+
# Check that clip_fraction is in valid range [0, 1]
235+
assert (
236+
0 <= loss_vals.clip_fraction <= 1
237+
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
88238

89239

90240
class TestSFT:
@@ -203,7 +353,7 @@ def test_sft(
203353
assistant_only=True,
204354
tokenizer_kwargs={"chat_template_name": "qwen"},
205355
tokenizer=tokenizer,
206-
log_probs_key=("ref_log_prob", "full"),
356+
log_probs_full_key=("ref_log_probs", "full"),
207357
)
208358
with torch.no_grad():
209359
# Compute ref log-probs
@@ -247,7 +397,7 @@ def test_sft_assistant_only(self, data):
247397
assistant_only=True,
248398
tokenizer_kwargs={"chat_template_name": "qwen"},
249399
tokenizer=tokenizer,
250-
log_probs_key=("ref_log_prob", "full"),
400+
log_probs_full_key=("ref_log_probs", "full"),
251401
)
252402
td = transform(data)
253403
assert td is data
@@ -262,10 +412,12 @@ def test_sft_assistant_only(self, data):
262412
loss(td)
263413

264414

415+
@pytest.mark.slow
416+
@pytest.mark.integration
265417
class TestGRPOLossIntegration:
266-
"""Test GRPOLoss integration with the new distribution methods."""
418+
"""Integration tests for GRPOLoss with real models (vLLM + transformers)."""
267419

268-
@pytest.fixture(scope="module")
420+
@pytest.fixture(scope="class")
269421
def transformers_instance(self):
270422
"""Create transformers model and tokenizer for testing."""
271423
if not _has_transformers:
@@ -277,7 +429,7 @@ def transformers_instance(self):
277429
tokenizer.pad_token = tokenizer.eos_token
278430
return model, tokenizer
279431

280-
@pytest.fixture(scope="module")
432+
@pytest.fixture(scope="class")
281433
def vllm_instance(self):
282434
"""Create vLLM model and tokenizer for testing."""
283435
if not _has_vllm:
@@ -297,102 +449,52 @@ def vllm_instance(self):
297449
except Exception as e:
298450
pytest.skip(f"Failed to load vLLM model: {e}")
299451

300-
@pytest.fixture(scope="module")
301-
def sample_tokens(self, vllm_instance):
302-
"""Create sample tokens for testing."""
303-
model, tokenizer = vllm_instance
304-
text = [
305-
"Are you happy? Say yes or no.",
306-
"Explain the difference between a cat and a dog. Be very detailed.",
307-
]
308-
tokenized = tokenizer(
309-
text, return_tensors="pt", padding=True, padding_side="left"
310-
)
311-
return tokenized["input_ids"], tokenized["attention_mask"]
312-
313-
@pytest.fixture(scope="module")
314-
def sample_text(self):
315-
"""Create sample text for testing."""
316-
return [
317-
"Are you happy? Say yes or no.",
318-
"Explain the difference between a cat and a dog. Be very detailed.",
319-
]
320-
321-
@pytest.fixture(scope="module")
322-
def sample_history(self):
323-
"""Create sample conversation history for testing."""
324-
chats = [
325-
[
326-
{"role": "system", "content": "You are a helpful assistant."},
327-
{"role": "user", "content": "Are you happy? Say yes or no."},
328-
],
329-
[
330-
{
331-
"role": "system",
332-
"content": "You are a very helpful assistant, but more handsome.",
333-
},
334-
{
335-
"role": "user",
336-
"content": "Explain the difference between a cat and a dog. Be very detailed.",
337-
},
338-
],
339-
]
340-
return History.from_chats(chats)
341-
342-
@pytest.fixture(scope="module")
343-
def sample_history_assistant(self):
344-
"""Create sample conversation history for testing."""
345-
chats = [
346-
[
347-
{"role": "system", "content": "You are a helpful assistant."},
348-
{"role": "user", "content": "Are you happy? Say yes or no."},
349-
{"role": "assistant", "content": "Yes."},
350-
],
351-
[
352-
{
353-
"role": "system",
354-
"content": "You are a very helpful assistant, but more handsome.",
355-
},
356-
{
357-
"role": "user",
358-
"content": "Explain the difference between a cat and a dog. Be very detailed.",
359-
},
360-
{
361-
"role": "assistant",
362-
"content": "A cat is a small animal that meows, while a dog is a larger animal that barks.",
363-
},
364-
],
365-
]
366-
return History.from_chats(chats)
367-
368452
@pytest.mark.skipif(not _has_vllm, reason="vllm not available")
369453
@pytest.mark.parametrize("masking_strategy", ["sft", "rlhf"])
370-
def test_grpo_loss_with_transformers(
454+
def test_grpo_loss_with_real_models(
371455
self,
372456
vllm_instance,
373457
transformers_instance,
374-
sample_history,
375-
sample_tokens,
376458
masking_strategy,
377459
):
378-
"""Test GRPOLoss with vLLM wrapper and different masking strategies."""
460+
"""Test GRPOLoss with vLLM generation and transformers loss computation."""
379461
from torchrl.objectives.llm.grpo import GRPOLoss
380462

381463
model, tokenizer = transformers_instance
382464
vllm_model, vllm_tokenizer = vllm_instance
383465

384-
# Use tokens input mode for SFT, history for RLHF/generic
466+
# Create sample input based on masking strategy
385467
if masking_strategy == "sft":
386-
input_mode = "tokens"
387-
input_ids, attention_mask = sample_tokens
468+
# Use tokens input mode for SFT
469+
text = [
470+
"Are you happy? Say yes or no.",
471+
"What is 2+2?",
472+
]
473+
tokenized = tokenizer(
474+
text, return_tensors="pt", padding=True, padding_side="left"
475+
)
388476
input_data = {
389-
"tokens": Tokens(prompt=input_ids),
390-
"masks": Masks(all_attention_mask=attention_mask),
477+
"tokens": Tokens(prompt=tokenized["input_ids"]),
478+
"masks": Masks(all_attention_mask=tokenized["attention_mask"]),
391479
}
480+
input_mode = "tokens"
392481
else:
393-
input_mode = "history"
482+
# Use history input mode for RLHF
483+
chats = [
484+
[
485+
{"role": "system", "content": "You are a helpful assistant."},
486+
{"role": "user", "content": "Are you happy? Say yes or no."},
487+
],
488+
[
489+
{"role": "system", "content": "You are a helpful assistant."},
490+
{"role": "user", "content": "What is 2+2?"},
491+
],
492+
]
493+
sample_history = History.from_chats(chats)
394494
input_data = {"history": ChatHistory(prompt=sample_history)}
495+
input_mode = "history"
395496

497+
# Generate responses with vLLM
396498
wrapper_gen = vLLMWrapper(
397499
vllm_model,
398500
tokenizer=vllm_tokenizer,
@@ -403,12 +505,11 @@ def test_grpo_loss_with_transformers(
403505
generate_kwargs={"max_tokens": 10},
404506
)
405507

406-
# Create test data with advantage and correct batch size
407508
td = TensorDict(input_data, batch_size=(2,)).to_lazystack(0)
408509
td = wrapper_gen(td)
409-
# use a shape that can be broadcast
410510
td["advantage"] = torch.randn(2, 1, 1)
411511

512+
# Compute loss with transformers
412513
wrapper = TransformersWrapper(
413514
model,
414515
tokenizer=tokenizer,
@@ -418,23 +519,13 @@ def test_grpo_loss_with_transformers(
418519
pad_output=True,
419520
)
420521

421-
# Create GRPOLoss with specified masking strategy
422-
loss_fn = GRPOLoss(
423-
actor_network=wrapper,
424-
masking_strategy=masking_strategy,
425-
)
522+
loss_fn = GRPOLoss(actor_network=wrapper, masking_strategy=masking_strategy)
426523

427-
# This should work without shape mismatch errors
428-
try:
429-
result = loss_fn(td)
430-
assert result is not None
431-
except ValueError as e:
432-
if "Shape mismatch" in str(e):
433-
# This is expected if the advantage shape doesn't match the log-prob shape
434-
# due to different masking strategies
435-
assert masking_strategy in str(e)
436-
else:
437-
raise
524+
# Should successfully compute loss
525+
result = loss_fn(td)
526+
assert result is not None
527+
assert hasattr(result, "loss_objective")
528+
assert torch.isfinite(result.loss_objective)
438529

439530

440531
if __name__ == "__main__":

0 commit comments

Comments
 (0)