Skip to content

Commit b121235

Browse files
tianshubThe tunix Authors
authored andcommitted
Code update
PiperOrigin-RevId: 874934348
1 parent e1d4836 commit b121235

File tree

9 files changed

+55
-36
lines changed

9 files changed

+55
-36
lines changed

tests/rl/common_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ def test_compute_per_token_logps(self):
123123
completion_tokens = jnp.array(
124124
[[10, 11, -1, 12], [10, 11, 12, 13], [10, 11, 12, -1]]
125125
)
126+
graphdef, state = nnx.split(model)
126127
per_token_logps = common.compute_per_token_logps(
127-
model,
128+
graphdef,
129+
state,
128130
prompt_tokens,
129131
completion_tokens,
130132
pad_id=0,
@@ -142,7 +144,8 @@ def test_compute_per_token_logps(self):
142144
rtol=1e-2,
143145
)
144146
_, logits = common.compute_per_token_logps(
145-
model,
147+
graphdef,
148+
state,
146149
prompt_tokens,
147150
completion_tokens,
148151
pad_id=0,

tunix/generate/sampler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def __init__(
202202
self._compiled_decode_fn = jax.jit(self._decode_fn)
203203
self._compiled_prefill_fn = jax.jit(self._prefill_fn)
204204

205+
def model_def_and_state(self) -> tuple[graph.NodeDef, statelib.State]:
206+
"""Returns the transformer graphdef and state."""
207+
return self._transformer_graphdef, self._flattened_transformer_state
208+
205209
@property
206210
def transformer(self) -> nnx.Module:
207211
return nnx.merge(

tunix/rl/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Common RL helper classes and functions."""
1515

16+
from functools import partial # pylint: disable=g-importing-member
1617
from typing import Any, Iterable
1718

1819
import flax
@@ -177,7 +178,7 @@ def get_per_token_logps(
177178

178179
# TODO(abheesht): This is computed 4 times - twice in `compute_per_token_logps`
179180
# and twice in `compute_score`. We can factor this out and compute it just once.
180-
@nnx.jit(static_argnames=("pad_id", "eos_id"))
181+
@partial(jax.jit, static_argnames=("pad_id", "eos_id"))
181182
def process_ids(
182183
prompt_tokens: jax.Array,
183184
completion_tokens: jax.Array,
@@ -202,9 +203,13 @@ def process_ids(
202203
return prompt_completion_ids, positions, attn_mask
203204

204205

205-
@nnx.jit(static_argnames=("pad_id", "eos_id", "stop_gradient", "return_logits"))
206+
@partial(
207+
jax.jit,
208+
static_argnames=("pad_id", "eos_id", "stop_gradient", "return_logits"),
209+
)
206210
def compute_per_token_logps(
207-
model: nnx.Module,
211+
graphdef,
212+
state,
208213
prompt_tokens: jax.Array,
209214
completion_tokens: jax.Array,
210215
pad_id: int,
@@ -214,6 +219,7 @@ def compute_per_token_logps(
214219
return_logits: bool = False,
215220
) -> jax.Array | tuple[jax.Array, jax.Array]:
216221
"""Computes the per-token log probabilities."""
222+
model = nnx.merge(graphdef, state)
217223
input_tokens, positions, attn_mask = process_ids(
218224
prompt_tokens, completion_tokens, pad_id, eos_id, completion_mask
219225
)

tunix/rl/experimental/agentic_grpo_learner.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from typing import Any, Dict, List, Sequence, Type, TypeVar
3434

3535
from absl import logging
36+
from flax import nnx
3637
import jax
3738
import jax.numpy as jnp
3839
import numpy as np
@@ -345,7 +346,7 @@ def _process_results(
345346
completion_tokens=completion_ids,
346347
pad_id=pad_value,
347348
eos_id=eos_value,
348-
micro_batch_size=1,
349+
micro_batch_size=None,
349350
)
350351
else:
351352
ref_per_token_logps = None
@@ -390,7 +391,7 @@ def _process_results(
390391
rewards=rewards, num_generations=self.algo_config.num_generations
391392
)
392393

393-
policy_versions = jnp.array(policy_versions_list, dtype=jnp.int32)
394+
policy_versions = np.array(policy_versions_list, dtype=np.int32)
394395

395396
# Log completion lengths.
396397
agg_completion_mask = completion_mask.sum(axis=-1)
@@ -439,10 +440,7 @@ def _process_results(
439440
old_per_token_logps=old_per_token_logps,
440441
policy_version=policy_versions,
441442
)
442-
return [
443-
rl_utils.get_batch_slice(combined_batch, slice(i, i + 1))
444-
for i in range(self.algo_config.num_generations)
445-
]
443+
return [combined_batch]
446444

447445

448446
@function_registry.register_policy_loss_fn("agentic_grpo")
@@ -486,10 +484,11 @@ def grpo_loss_fn(
486484
train_example.completion_mask,
487485
)
488486

489-
# TODO(yangmu): trace this part as "actor_inference_and_training".
490-
# with perf_tracer.span("...", list(completion_ids.devices())):
487+
# TODO(tsbao): split can be avoided with updated peft_trainer model handling.
488+
graphdef, state = nnx.split(model)
491489
per_token_logps = common.compute_per_token_logps(
492-
model,
490+
graphdef,
491+
state,
493492
prompt_tokens=train_example.prompt_ids,
494493
completion_tokens=completion_ids,
495494
pad_id=pad_id,

tunix/rl/experimental/agentic_rl_learner.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858

5959
@flax.struct.dataclass(frozen=True)
6060
class TrainExample(common.TrainExample):
61-
policy_version: jax.Array | None = None
61+
policy_version: np.ndarray | None = None
6262

6363

6464
@dataclasses.dataclass(slots=True, kw_only=True)
@@ -705,7 +705,7 @@ def train(
705705

706706
# 2. Consume training examples and train.
707707
train_data_gen = self._data_consumer_batch_generator(
708-
train_data_queue, train_micro_batch_size * self._num_generations()
708+
train_data_queue, train_micro_batch_size
709709
)
710710
micro_batches_since_last_sync = 0
711711
micro_batches_per_full_batch = full_batch_size // train_micro_batch_size
@@ -720,13 +720,14 @@ def train(
720720
break
721721
self._iter_steps += 1
722722

723+
# TODO(tsbao): Re-enable this once off-policy filtering is needed.
723724
# Filter out examples that are too old (off-policy).
724-
filtered_train_micro_batch = self._filter_outdated_offpolicy_examples(
725-
train_micro_batch
726-
)
727-
if not filtered_train_micro_batch:
728-
continue
729-
train_micro_batch = filtered_train_micro_batch
725+
# filtered_train_micro_batch = self._filter_outdated_offpolicy_examples(
726+
# train_micro_batch
727+
# )
728+
# if not filtered_train_micro_batch:
729+
# continue
730+
# train_micro_batch = filtered_train_micro_batch
730731

731732
merged_train_micro_batch = jax.tree.map(
732733
lambda *xs: jnp.concatenate(xs, axis=0), *train_micro_batch
@@ -770,7 +771,7 @@ async def _eval_runner_async(current_eval_orchestrator):
770771
)
771772
if hasattr(self.rl_cluster, "critic_trainer"):
772773
self.rl_cluster.update_critic(
773-
train_micro_batch, current_eval_dataset, skip_jit
774+
[merged_train_micro_batch], current_eval_dataset, skip_jit
774775
)
775776

776777
# --- Weight Sync Logic ---

tunix/rl/grpo/grpo_learner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Iterable, List, Sequence, TypeVar
2121

2222
import flax
23+
from flax import nnx
2324
import jax
2425
import jax.numpy as jnp
2526
import numpy as np
@@ -455,8 +456,10 @@ def grpo_loss_fn(
455456

456457
# TODO(yangmu): trace this part as "actor_inference_and_training".
457458
# with perf_tracer.span("...", list(completion_ids.devices())):
459+
graphdef, state = nnx.split(model)
458460
per_token_logps = common.compute_per_token_logps(
459-
model,
461+
graphdef,
462+
state,
460463
prompt_tokens=train_example.prompt_ids,
461464
completion_tokens=completion_ids,
462465
pad_id=pad_id,

tunix/rl/inference/inference_worker.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from flax import nnx
1818
import jax
19-
import jaxtyping
2019
from tunix.rl import common
2120

2221

@@ -31,6 +30,9 @@ def __init__(self, models: dict[str, nnx.Module]):
3130
" reference and reward."
3231
)
3332
self._models = models
33+
self._model_states = {}
34+
for k, m in models.items():
35+
self._model_states[k] = nnx.split(m)
3436
# TODO(tsbao): support multiple reward models.
3537

3638
def get_rewards(
@@ -55,11 +57,12 @@ def get_ref_per_token_logps(
5557
eos_id: int,
5658
completion_mask: jax.Array | None = None,
5759
) -> jax.Array:
58-
ref_model = self._models.get("reference")
59-
if ref_model is None:
60+
graphdef, state = self._model_states.get("reference")
61+
if graphdef is None:
6062
raise ValueError("Reference model is not available.")
6163
return common.compute_per_token_logps(
62-
ref_model,
64+
graphdef,
65+
state,
6366
prompt_tokens=prompt_tokens,
6467
completion_tokens=completion_tokens,
6568
pad_id=pad_id,
@@ -77,7 +80,8 @@ def get_values(
7780
eos_id: int,
7881
completion_mask: jax.Array | None = None,
7982
) -> jax.Array:
80-
critic_model = self._models.get("critic")
83+
graphdef, state = self._model_states.get("critic")
84+
critic_model = nnx.merge(graphdef, state)
8185
if critic_model is None:
8286
raise ValueError("Critic model is not available.")
8387
return common.compute_score(
@@ -93,8 +97,3 @@ def get_model(self, role: str) -> nnx.Module:
9397
if role not in self._models:
9498
raise ValueError(f"Model role {role} is not available.")
9599
return self._models[role]
96-
97-
def update_model(self, role: str, params: jaxtyping.PyTree):
98-
if role not in self._models:
99-
raise ValueError(f"Model role {role} is not available.")
100-
nnx.update(self._models[role], params)

tunix/rl/ppo/ppo_learner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,10 @@ def ppo_policy_loss_fn(
586586
use_dual_clip_ppo = epsilon_c is not None
587587

588588
# Get log probs.
589+
graphdef, state = nnx.split(model)
589590
per_token_logps, logits = common.compute_per_token_logps(
590-
model,
591+
graphdef,
592+
state,
591593
prompt_tokens=prompt_ids,
592594
completion_tokens=completion_ids,
593595
pad_id=pad_id,

tunix/rl/rollout/vanilla_rollout.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ def get_per_token_logps(
7878
completion_mask: jax.Array | None = None,
7979
) -> jax.Array:
8080
"""Returns per-token log probabilities from the rollout policy."""
81+
graphdef, state = self._sampler.model_def_and_state()
8182
return common.compute_per_token_logps(
82-
self.model(),
83+
graphdef,
84+
state,
8385
prompt_tokens=prompt_tokens,
8486
completion_tokens=completion_tokens,
8587
pad_id=self.pad_id(),

0 commit comments

Comments
 (0)