1313# limitations under the License.
1414"""Common RL helper classes and functions."""
1515
16+ from functools import partial # pylint: disable=g-importing-member
1617from typing import Any , Iterable
1718
1819import flax
@@ -177,7 +178,7 @@ def get_per_token_logps(
177178
178179# TODO(abheesht): This is computed 4 times - twice in `compute_per_token_logps`
179180# and twice in `compute_score`. We can factor this out and compute it just once.
180- @nnx .jit ( static_argnames = ("pad_id" , "eos_id" ))
181+ @partial ( jax .jit , static_argnames = ("pad_id" , "eos_id" ))
181182def process_ids (
182183 prompt_tokens : jax .Array ,
183184 completion_tokens : jax .Array ,
@@ -202,9 +203,13 @@ def process_ids(
202203 return prompt_completion_ids , positions , attn_mask
203204
204205
205- @nnx .jit (static_argnames = ("pad_id" , "eos_id" , "stop_gradient" , "return_logits" ))
206+ @partial (
207+ jax .jit ,
208+ static_argnames = ("pad_id" , "eos_id" , "stop_gradient" , "return_logits" ),
209+ )
206210def compute_per_token_logps (
207- model : nnx .Module ,
211+ graphdef ,
212+ state ,
208213 prompt_tokens : jax .Array ,
209214 completion_tokens : jax .Array ,
210215 pad_id : int ,
@@ -214,6 +219,7 @@ def compute_per_token_logps(
214219 return_logits : bool = False ,
215220) -> jax .Array | tuple [jax .Array , jax .Array ]:
216221 """Computes the per-token log probabilities."""
222+ model = nnx .merge (graphdef , state )
217223 input_tokens , positions , attn_mask = process_ids (
218224 prompt_tokens , completion_tokens , pad_id , eos_id , completion_mask
219225 )
0 commit comments