@@ -82,31 +82,30 @@ def prepare_maxtext_inputs(input_str, tokenizer_model):
8282 return input_ids , input_segmentation , input_position , completion_segmentation
8383
8484
85-
8685class GrpoTrainerTest (unittest .TestCase ):
8786
8887 def setUp (self ):
8988 super ().setUp ()
9089 command = [
91- "gsutil" ,
92- "cp" ,
93- "-r" ,
94- "gs://maxtext-dataset/hf/llama3.1-tokenizer" ,
95- os .path .join (os .path .dirname (PKG_DIR ), "assets" , "" ),
90+ "gsutil" ,
91+ "cp" ,
92+ "-r" ,
93+ "gs://maxtext-dataset/hf/llama3.1-tokenizer" ,
94+ os .path .join (os .path .dirname (PKG_DIR ), "assets" , "" ),
9695 ]
9796 exit_code = subprocess .call (command , cwd = os .path .dirname (PKG_DIR ))
9897 if exit_code != 0 :
9998 raise ValueError (f"{ command } failed with exit code: { exit_code } " )
10099 self .config = pyconfig .initialize (
101100 [None , "MaxText/experimental/rl/grpo_trainer_test.yml" ],
102101 run_name = "unit_test_grpo_trainer" ,
103- tokenizer_path = os .path .join (os .path .dirname (PKG_DIR ), ' assets' , ' llama3.1-tokenizer' ),
102+ tokenizer_path = os .path .join (os .path .dirname (PKG_DIR ), " assets" , " llama3.1-tokenizer" ),
104103 enable_checkpointing = False ,
105104 )
106105 self .config_inference = pyconfig .initialize (
107106 [None , "MaxText/experimental/rl/grpo_trainer_test.yml" ],
108107 run_name = "unit_test_grpo_trainer_inference" ,
109- tokenizer_path = os .path .join (os .path .dirname (PKG_DIR ), ' assets' , ' llama3.1-tokenizer' ),
108+ tokenizer_path = os .path .join (os .path .dirname (PKG_DIR ), " assets" , " llama3.1-tokenizer" ),
110109 enable_checkpointing = False ,
111110 ici_tensor_parallelism = 4 ,
112111 per_device_batch_size = self .config .per_device_batch_size * self .config .num_generations ,
@@ -115,11 +114,11 @@ def setUp(self):
115114 self .atol = 1e-08
116115 self .rng = jax .random .PRNGKey (self .config .init_weights_seed )
117116 self .tokenizer_model = transformers .AutoTokenizer .from_pretrained (
118- self .config .tokenizer_path ,
119- add_bos_token = self .config .add_bos ,
120- add_eos_token = self .config .add_eos ,
121- legacy = False ,
122- padding_side = "left"
117+ self .config .tokenizer_path ,
118+ add_bos_token = self .config .add_bos ,
119+ add_eos_token = self .config .add_eos ,
120+ legacy = False ,
121+ padding_side = "left" ,
123122 )
124123 self .tokenizer_model .add_special_tokens ({"pad_token" : "<pad>" })
125124
@@ -135,22 +134,23 @@ def test_grpo_trainer_correctness(self):
135134 )
136135 # Obtain per-token logits.
137136 maxtext_per_token_logps , _ = compute_log_probs (
138- maxtext_model ,
139- state .params ,
140- input_ids ,
141- input_position ,
142- input_segmentation ,
143- completion_segmentation ,
144- self .config ,
145- is_train = False ,
146- rngs = self .rng ,
137+ maxtext_model ,
138+ state .params ,
139+ input_ids ,
140+ input_position ,
141+ input_segmentation ,
142+ completion_segmentation ,
143+ self .config ,
144+ is_train = False ,
145+ rngs = self .rng ,
147146 )
148- jax .debug .print ("maxtext_per_token_logps={maxtext_per_token_logps}" ,maxtext_per_token_logps = maxtext_per_token_logps )
149- jax .debug .print ("golden_per_token_logps={golden_per_token_logps}" ,golden_per_token_logps = golden_data ["maxtext_per_token_logps_no_ckpt_loading" ])
150- golden_maxtext_logits = np .array (golden_data ["maxtext_per_token_logps_no_ckpt_loading" ])
151- self .assertTrue (
152- jnp .all (np .array (golden_data ["input_ids" ]) == np .array (input_ids [0 ]))
147+ jax .debug .print ("maxtext_per_token_logps={maxtext_per_token_logps}" , maxtext_per_token_logps = maxtext_per_token_logps )
148+ jax .debug .print (
149+ "golden_per_token_logps={golden_per_token_logps}" ,
150+ golden_per_token_logps = golden_data ["maxtext_per_token_logps_no_ckpt_loading" ],
153151 )
152+ golden_maxtext_logits = np .array (golden_data ["maxtext_per_token_logps_no_ckpt_loading" ])
153+ self .assertTrue (jnp .all (np .array (golden_data ["input_ids" ]) == np .array (input_ids [0 ])))
154154 self .assertTrue (
155155 jax .numpy .allclose (
156156 maxtext_per_token_logps [0 ],
0 commit comments