1414from tensordict import lazy_stack , TensorDict
1515from torchrl .data import History , LazyStackStorage , ReplayBuffer
1616from torchrl .envs .llm .transforms .kl import RetrieveLogProb
17- from torchrl .modules .llm import Text , TransformersWrapper , vLLMWrapper
18- from torchrl .modules .llm .policies .common import ChatHistory , Masks , Tokens
19- from torchrl .objectives .llm .grpo import MCAdvantage
17+ from torchrl .modules .llm import TransformersWrapper , vLLMWrapper
18+ from torchrl .modules .llm .policies .common import ChatHistory , Masks , Text , Tokens
19+ from torchrl .objectives .llm .grpo import GRPOLoss , MCAdvantage
2020from torchrl .objectives .llm .sft import SFTLoss
2121
2222_has_transformers = importlib .util .find_spec ("transformers" ) is not None
@@ -53,7 +53,7 @@ def make_silly_trajectory(n_steps=None):
5353 rewards = [torch .randn (n_tokens , 1 )]
5454 prompt = np .random .choice (prompts )
5555 td = TensorDict (
56- text = Text ( prompt = prompt ),
56+ query = prompt , # MCAdvantage expects "query" key, not "text"
5757 next = TensorDict (
5858 reward = rewards , done = torch .zeros (1 , dtype = torch .bool )
5959 ),
@@ -83,8 +83,158 @@ def make_silly_trajectory(n_steps=None):
8383 assert "advantage" in s .keys ()
8484
8585
86- def test_grpo ():
87- ...
86+ # Mock infrastructure moved to conftest.py
87+
88+
89+ def _mock_data_grpo (
90+ vocab_size : int , device : torch .device | str = "cpu"
91+ ) -> TensorDict :
92+ from transformers import AutoTokenizer
93+
94+ device = torch .device (device )
95+
96+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-0.5B" )
97+ prompt = History (
98+ role = ["system" , "user" ],
99+ content = ["You are a useful assistant." , "What is 2+2?" ],
100+ batch_size = (2 ,),
101+ device = device ,
102+ )
103+ response = History (
104+ role = ["assistant" ],
105+ content = ["2 + 2 = 4." ],
106+ batch_size = (1 ,),
107+ device = device ,
108+ )
109+ full_history = prompt .extend (response , inplace = False )
110+ history = ChatHistory (
111+ prompt = prompt ,
112+ response = response ,
113+ full = full_history ,
114+ device = device ,
115+ )
116+ batch_size = 1
117+
118+ # Expand history to match batch size before getting tokens
119+ history = history .expand ((batch_size ,))
120+ next_history = ChatHistory (
121+ prompt = full_history ,
122+ device = device ,
123+ )
124+ next_history = next_history .expand ((batch_size ,))
125+
126+ # Now get tokens from the expanded history objects
127+ tokens_full = history .to_tokens (tokenizer )
128+ next_tokens = next_history .to_tokens (tokenizer )
129+
130+ # Get the actual sequence length from the tokens
131+ # tokens_full has structure with "full" key containing the actual tokens
132+ # We need to get the padded version to know the actual length
133+ tokens_input_ids = tokens_full .get (
134+ "full" , as_padded_tensor = True , padding_side = "left" , padding_value = 0
135+ )
136+ seq_len = tokens_input_ids .shape [- 1 ]
137+
138+ # Create tensors with proper shapes
139+ reward = torch .randn (batch_size , seq_len , 1 , device = device )
140+ done = torch .zeros (batch_size , seq_len , 1 , dtype = torch .bool , device = device )
141+ advantage = torch .randn (batch_size , seq_len , 1 , device = device )
142+ log_probs = torch .randn_like (tokens_full , dtype = torch .float32 , device = device )
143+
144+ # Create attention mask (all ones for non-padded tokens)
145+ attention_mask = torch .ones (batch_size , seq_len , dtype = torch .bool , device = device )
146+
147+ # Import Masks to create proper mask structure
148+ from tensordict import MetaData
149+ from torchrl .modules .llm .policies .common import Masks
150+
151+ masks = Masks (
152+ all_attention_mask = attention_mask ,
153+ all_assistant_mask = None , # Will be computed by the wrapper
154+ padded = MetaData (True ),
155+ device = device ,
156+ )
157+
158+ data = TensorDict (
159+ {
160+ "advantage" : advantage ,
161+ "history" : history ,
162+ "tokens" : tokens_full % vocab_size ,
163+ "masks" : masks ,
164+ "next" : {
165+ "history" : next_history ,
166+ "tokens" : next_tokens % vocab_size ,
167+ "reward" : reward ,
168+ "done" : done ,
169+ },
170+ "log_probs" : log_probs ,
171+ },
172+ batch_size = (batch_size ,),
173+ )
174+ return data
175+
176+
177+ class TestLosses :
178+ def test_grpo (self , mock_transformer_model ):
179+ """Test GRPO loss computation with mock models."""
180+ vocab_size = 1024
181+ device = torch .device ("cpu" )
182+
183+ # Create mock model and wrap it
184+ model = mock_transformer_model (vocab_size = vocab_size , device = device )
185+ actor_network = TransformersWrapper (
186+ model ,
187+ generate = False ,
188+ pad_output = True ,
189+ input_mode = "history" ,
190+ )
191+
192+ # Create loss module
193+ loss_fn = GRPOLoss (actor_network )
194+
195+ # Create fake data
196+ data = _mock_data_grpo (vocab_size = vocab_size , device = device )
197+
198+ # Compute loss
199+ loss_vals = loss_fn (data )
200+
201+ # Assertions: Check output type and structure
202+ from torchrl .objectives .llm .grpo import GRPOLossOutput
203+
204+ assert isinstance (
205+ loss_vals , GRPOLossOutput
206+ ), f"Expected GRPOLossOutput, got { type (loss_vals )} "
207+
208+ # Check that all expected keys are present
209+ assert hasattr (loss_vals , "loss_objective" ), "Missing loss_objective"
210+ assert hasattr (loss_vals , "clip_fraction" ), "Missing clip_fraction"
211+ assert hasattr (loss_vals , "kl_approx" ), "Missing kl_approx"
212+ assert hasattr (loss_vals , "ESS" ), "Missing ESS"
213+ assert hasattr (loss_vals , "entropy" ), "Missing entropy"
214+ assert hasattr (loss_vals , "loss_entropy" ), "Missing loss_entropy"
215+
216+ # Check tensor shapes (all losses should be scalars after reduction)
217+ assert (
218+ loss_vals .loss_objective .shape == ()
219+ ), f"loss_objective should be scalar, got { loss_vals .loss_objective .shape } "
220+ assert (
221+ loss_vals .clip_fraction .shape == ()
222+ ), f"clip_fraction should be scalar, got { loss_vals .clip_fraction .shape } "
223+ assert (
224+ loss_vals .kl_approx .shape == ()
225+ ), f"kl_approx should be scalar, got { loss_vals .kl_approx .shape } "
226+ assert (
227+ loss_vals .ESS .shape == ()
228+ ), f"ESS should be scalar, got { loss_vals .ESS .shape } "
229+
230+ # Check that losses are finite
231+ assert torch .isfinite (loss_vals .loss_objective ), "loss_objective is not finite"
232+ assert torch .isfinite (loss_vals .ESS ), "ESS is not finite"
233+
234+ # Check that clip_fraction is in valid range [0, 1]
235+ assert (
236+ 0 <= loss_vals .clip_fraction <= 1
237+ ), f"clip_fraction out of range: { loss_vals .clip_fraction } "
88238
89239
90240class TestSFT :
@@ -203,7 +353,7 @@ def test_sft(
203353 assistant_only = True ,
204354 tokenizer_kwargs = {"chat_template_name" : "qwen" },
205355 tokenizer = tokenizer ,
206- log_probs_key = ("ref_log_prob " , "full" ),
356+ log_probs_full_key = ("ref_log_probs " , "full" ),
207357 )
208358 with torch .no_grad ():
209359 # Compute ref log-probs
@@ -247,7 +397,7 @@ def test_sft_assistant_only(self, data):
247397 assistant_only = True ,
248398 tokenizer_kwargs = {"chat_template_name" : "qwen" },
249399 tokenizer = tokenizer ,
250- log_probs_key = ("ref_log_prob " , "full" ),
400+ log_probs_full_key = ("ref_log_probs " , "full" ),
251401 )
252402 td = transform (data )
253403 assert td is data
@@ -262,10 +412,12 @@ def test_sft_assistant_only(self, data):
262412 loss (td )
263413
264414
415+ @pytest .mark .slow
416+ @pytest .mark .integration
265417class TestGRPOLossIntegration :
266- """Test GRPOLoss integration with the new distribution methods ."""
418+ """Integration tests for GRPOLoss with real models (vLLM + transformers) ."""
267419
268- @pytest .fixture (scope = "module " )
420+ @pytest .fixture (scope = "class " )
269421 def transformers_instance (self ):
270422 """Create transformers model and tokenizer for testing."""
271423 if not _has_transformers :
@@ -277,7 +429,7 @@ def transformers_instance(self):
277429 tokenizer .pad_token = tokenizer .eos_token
278430 return model , tokenizer
279431
280- @pytest .fixture (scope = "module " )
432+ @pytest .fixture (scope = "class " )
281433 def vllm_instance (self ):
282434 """Create vLLM model and tokenizer for testing."""
283435 if not _has_vllm :
@@ -297,102 +449,52 @@ def vllm_instance(self):
297449 except Exception as e :
298450 pytest .skip (f"Failed to load vLLM model: { e } " )
299451
300- @pytest .fixture (scope = "module" )
301- def sample_tokens (self , vllm_instance ):
302- """Create sample tokens for testing."""
303- model , tokenizer = vllm_instance
304- text = [
305- "Are you happy? Say yes or no." ,
306- "Explain the difference between a cat and a dog. Be very detailed." ,
307- ]
308- tokenized = tokenizer (
309- text , return_tensors = "pt" , padding = True , padding_side = "left"
310- )
311- return tokenized ["input_ids" ], tokenized ["attention_mask" ]
312-
313- @pytest .fixture (scope = "module" )
314- def sample_text (self ):
315- """Create sample text for testing."""
316- return [
317- "Are you happy? Say yes or no." ,
318- "Explain the difference between a cat and a dog. Be very detailed." ,
319- ]
320-
321- @pytest .fixture (scope = "module" )
322- def sample_history (self ):
323- """Create sample conversation history for testing."""
324- chats = [
325- [
326- {"role" : "system" , "content" : "You are a helpful assistant." },
327- {"role" : "user" , "content" : "Are you happy? Say yes or no." },
328- ],
329- [
330- {
331- "role" : "system" ,
332- "content" : "You are a very helpful assistant, but more handsome." ,
333- },
334- {
335- "role" : "user" ,
336- "content" : "Explain the difference between a cat and a dog. Be very detailed." ,
337- },
338- ],
339- ]
340- return History .from_chats (chats )
341-
342- @pytest .fixture (scope = "module" )
343- def sample_history_assistant (self ):
344- """Create sample conversation history for testing."""
345- chats = [
346- [
347- {"role" : "system" , "content" : "You are a helpful assistant." },
348- {"role" : "user" , "content" : "Are you happy? Say yes or no." },
349- {"role" : "assistant" , "content" : "Yes." },
350- ],
351- [
352- {
353- "role" : "system" ,
354- "content" : "You are a very helpful assistant, but more handsome." ,
355- },
356- {
357- "role" : "user" ,
358- "content" : "Explain the difference between a cat and a dog. Be very detailed." ,
359- },
360- {
361- "role" : "assistant" ,
362- "content" : "A cat is a small animal that meows, while a dog is a larger animal that barks." ,
363- },
364- ],
365- ]
366- return History .from_chats (chats )
367-
368452 @pytest .mark .skipif (not _has_vllm , reason = "vllm not available" )
369453 @pytest .mark .parametrize ("masking_strategy" , ["sft" , "rlhf" ])
370- def test_grpo_loss_with_transformers (
454+ def test_grpo_loss_with_real_models (
371455 self ,
372456 vllm_instance ,
373457 transformers_instance ,
374- sample_history ,
375- sample_tokens ,
376458 masking_strategy ,
377459 ):
378- """Test GRPOLoss with vLLM wrapper and different masking strategies ."""
460+ """Test GRPOLoss with vLLM generation and transformers loss computation ."""
379461 from torchrl .objectives .llm .grpo import GRPOLoss
380462
381463 model , tokenizer = transformers_instance
382464 vllm_model , vllm_tokenizer = vllm_instance
383465
384- # Use tokens input mode for SFT, history for RLHF/generic
466+ # Create sample input based on masking strategy
385467 if masking_strategy == "sft" :
386- input_mode = "tokens"
387- input_ids , attention_mask = sample_tokens
468+ # Use tokens input mode for SFT
469+ text = [
470+ "Are you happy? Say yes or no." ,
471+ "What is 2+2?" ,
472+ ]
473+ tokenized = tokenizer (
474+ text , return_tensors = "pt" , padding = True , padding_side = "left"
475+ )
388476 input_data = {
389- "tokens" : Tokens (prompt = input_ids ),
390- "masks" : Masks (all_attention_mask = attention_mask ),
477+ "tokens" : Tokens (prompt = tokenized [ " input_ids" ] ),
478+ "masks" : Masks (all_attention_mask = tokenized [ " attention_mask" ] ),
391479 }
480+ input_mode = "tokens"
392481 else :
393- input_mode = "history"
482+ # Use history input mode for RLHF
483+ chats = [
484+ [
485+ {"role" : "system" , "content" : "You are a helpful assistant." },
486+ {"role" : "user" , "content" : "Are you happy? Say yes or no." },
487+ ],
488+ [
489+ {"role" : "system" , "content" : "You are a helpful assistant." },
490+ {"role" : "user" , "content" : "What is 2+2?" },
491+ ],
492+ ]
493+ sample_history = History .from_chats (chats )
394494 input_data = {"history" : ChatHistory (prompt = sample_history )}
495+ input_mode = "history"
395496
497+ # Generate responses with vLLM
396498 wrapper_gen = vLLMWrapper (
397499 vllm_model ,
398500 tokenizer = vllm_tokenizer ,
@@ -403,12 +505,11 @@ def test_grpo_loss_with_transformers(
403505 generate_kwargs = {"max_tokens" : 10 },
404506 )
405507
406- # Create test data with advantage and correct batch size
407508 td = TensorDict (input_data , batch_size = (2 ,)).to_lazystack (0 )
408509 td = wrapper_gen (td )
409- # use a shape that can be broadcast
410510 td ["advantage" ] = torch .randn (2 , 1 , 1 )
411511
512+ # Compute loss with transformers
412513 wrapper = TransformersWrapper (
413514 model ,
414515 tokenizer = tokenizer ,
@@ -418,23 +519,13 @@ def test_grpo_loss_with_transformers(
418519 pad_output = True ,
419520 )
420521
421- # Create GRPOLoss with specified masking strategy
422- loss_fn = GRPOLoss (
423- actor_network = wrapper ,
424- masking_strategy = masking_strategy ,
425- )
522+ loss_fn = GRPOLoss (actor_network = wrapper , masking_strategy = masking_strategy )
426523
427- # This should work without shape mismatch errors
428- try :
429- result = loss_fn (td )
430- assert result is not None
431- except ValueError as e :
432- if "Shape mismatch" in str (e ):
433- # This is expected if the advantage shape doesn't match the log-prob shape
434- # due to different masking strategies
435- assert masking_strategy in str (e )
436- else :
437- raise
524+ # Should successfully compute loss
525+ result = loss_fn (td )
526+ assert result is not None
527+ assert hasattr (result , "loss_objective" )
528+ assert torch .isfinite (result .loss_objective )
438529
439530
440531if __name__ == "__main__" :
0 commit comments