diff --git a/apps/blackjack/__init__.py b/apps/blackjack/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/apps/blackjack/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/apps/blackjack/blackjack_env.py b/apps/blackjack/blackjack_env.py new file mode 100644 index 000000000..ab1205634 --- /dev/null +++ b/apps/blackjack/blackjack_env.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from dataclasses import dataclass, field +from typing import Any + +from envs.openspiel_env import OpenSpielAction, OpenSpielEnv +from forge.observability.metrics import record_metric, Reduce + + +@dataclass +class EnvStepResult: + """Result from environment step.""" + + observation: dict[str, str] # Next message: {"role": "user", "content": "..."} + reward: float # Reward for this step + done: bool # Episode ended? + metadata: dict[str, Any] = field(default_factory=dict) + + +class BlackjackEnv: + """ + Minimal blackjack environment. + + Responsibilities: + - Manage game state via OpenSpielEnv + - Parse actions from text + - Return next observation message + - Compute rewards + + Does NOT: + - Hold message history (rollout loop does this) + - Tokenize (rollout loop does this) + - Track cumulative tokens (rollout loop does this) + """ + + def __init__(self, server_url: str): + self.server_url = server_url + self.client = OpenSpielEnv(base_url=server_url) + self.client._http.trust_env = False + + # Game state + self.turn_count = 0 + self.has_invalid_action = False + + def reset(self) -> str: + """ + Reset game and return initial user message. + + Returns: + Initial observation text (NOT a dict, just the content string) + """ + self.turn_count = 0 + self.has_invalid_action = False + + # Reset game + result = self.client.reset() + + # Build initial observation + return self._format_observation(result.observation) + + def step(self, action_text: str) -> EnvStepResult: + """ + Execute action and return next observation. + + Args: + action_text: The assistant's text response + + Returns: + EnvStepResult with next observation message, reward, done + """ + + # Parse action + action_name, error_type = self._parse_action(action_text) + + # Track invalid actions + is_invalid = action_name == "INVALID" + if is_invalid: + self.has_invalid_action = True + action_name = "STAND" # Treat invalid as STAND + record_metric("game/invalid_action_rate", 1, Reduce.MEAN) + if error_type == "NO_TAGS": + record_metric("game/missing_answer_tags", 1, Reduce.SUM) + elif error_type == "INVALID_CONTENT": + record_metric("game/invalid_answer_content", 1, Reduce.SUM) + else: + record_metric("game/invalid_action_rate", 0, Reduce.MEAN) + + # Execute in game + action_id = 0 if action_name == "HIT" else 1 + result = self.client.step( + OpenSpielAction(action_id=action_id, game_name="blackjack") + ) + + self.turn_count += 1 + + # Compute reward + if result.done: + reward = self._compute_reward( + result.reward, is_invalid=self.has_invalid_action + ) + # Record game outcome metrics + record_metric("game/games_played", 1, Reduce.SUM) + record_metric("game/average_turns", self.turn_count, Reduce.MEAN) + record_metric("game/win_rate", 1 if result.reward > 0 else 0, Reduce.MEAN) + record_metric("game/env_reward", result.reward, Reduce.MEAN) + else: + reward = 0.0 # No intermediate rewards + + # Build next observation (if game continues) + if result.done: + observation = {"role": "user", "content": ""} # Empty, game ended + else: + obs_text = self._format_observation(result.observation) + observation = {"role": "user", "content": obs_text} + + return EnvStepResult( + observation=observation, + reward=reward, + done=result.done, + metadata={ + "turn_count": self.turn_count, + "has_invalid_action": self.has_invalid_action, + "env_reward": result.reward if result.done else 0.0, + }, + ) + + def _format_observation(self, observation) -> str: + """Format game observation into text.""" + player_total = observation.metadata.get("player_total", "?") + dealer_card = observation.metadata.get("dealer_card", "?") + dealer_str = "Ace" if dealer_card == 1 else str(dealer_card) + + return f"Hand: {player_total}, Dealer: {dealer_str}" + + def _parse_action(self, text: str) -> tuple[str, str]: + """Parse action from assistant text using tags. + + Returns: + (action, error_type): action is "HIT", "STAND", or "INVALID" + error_type is "" for valid, "NO_TAGS" or "INVALID_CONTENT" + """ + import re + + # Try to extract content from tags + match = re.search( + r"\s*(.*?)\s*", text, re.IGNORECASE | re.DOTALL + ) + + if match: + answer = match.group(1).strip().upper() + if answer == "HIT": + return ("HIT", "") + elif answer == "STAND": + return ("STAND", "") + else: + # Has tags but invalid content + return ("INVALID", "INVALID_CONTENT") + else: + # No tags found + return ("INVALID", "NO_TAGS") + + def _compute_reward(self, env_reward: float, is_invalid: bool) -> float: + """Compute final reward.""" + if env_reward > 0: # Win + rwd = 3.0 + else: # Loss or push + rwd = -1.0 + + if is_invalid: + rwd = -10.0 # Penalty for not ending with HIT/STAND + record_metric("game/invalid_action_penalty", 1, Reduce.SUM) + + return rwd + + def close(self): + """Clean up.""" + self.client.close() diff --git a/apps/blackjack/main.py b/apps/blackjack/main.py new file mode 100644 index 000000000..179d8df42 --- /dev/null +++ b/apps/blackjack/main.py @@ -0,0 +1,1098 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: python -m apps.blackjack.main_v2 --config apps/blackjack/qwen3_1_7b.yaml + +import asyncio +import multiprocessing +import os +import signal +import subprocess +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from functools import lru_cache, partial +from typing import Any, Optional + +import requests + +import torch +import torch.nn.functional as F +import torchstore as ts + +from apps.blackjack.blackjack_env import BlackjackEnv, EnvStepResult +from apps.blackjack.token_accumulator import ( + EpisodeData, + TokenAccumulator, + TruncationReason, + ValidationMode, +) +from envs.openspiel_env import OpenSpielAction, OpenSpielEnv +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import TitanTrainer +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data.common import CROSS_ENTROPY_IGNORE_IDX +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse +from forge.util.ops import compute_logprobs, create_shifted_targets +from monarch.actor import endpoint +from omegaconf import DictConfig +from vllm import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer + +# ============================================================================ +# Server Management Functions for OpenSpiel / OpenEnv +# TODO: Written by claude, probably very messy +# ============================================================================ + + +def start_openspiel_server(game_name: str, port: int): + """Start OpenSpiel server in background process.""" + os.environ["OPENSPIEL_GAME"] = game_name + + import uvicorn + from envs.openspiel_env.server.app import app + + print(f"[SERVER] Starting uvicorn for game '{game_name}' on port {port}") + uvicorn.run(app, host="0.0.0.0", port=port, log_level="info", access_log=False) + + +def kill_process_on_port(port: int): + """Kill any process using the specified port.""" + result = subprocess.run( + ["lsof", "-ti", f":{port}"], + capture_output=True, + text=True, + timeout=5, + ) + if result.stdout.strip(): + pids = result.stdout.strip().split("\n") + for pid in pids: + try: + os.kill(int(pid), signal.SIGKILL) + except ProcessLookupError: + pass + time.sleep(0.5) + + +def _wait_for_server_health(port: int, timeout: int = 30) -> bool: + """Wait for server health check to pass.""" + for attempt in range(timeout): + try: + resp = requests.get( + f"http://localhost:{port}/health", + timeout=1, + proxies={"http": None, "https": None}, + ) + if resp.status_code == 200: + return True + except Exception: + pass + time.sleep(1) + return False + + +def start_servers( + num_servers: int, base_port: int, game_name: str +) -> tuple[list, list]: + """Start OpenSpiel servers and wait for them to be ready. + + Args: + num_servers: Number of servers to start + base_port: Base port (will use base_port, base_port+1, ...) + game_name: Name of the game (e.g., "blackjack") + + Returns: + (server_processes, server_ports) + + Raises: + RuntimeError: If any server fails to start + """ + server_processes = [] + server_ports = [] + + # Start all servers + for i in range(num_servers): + port = base_port + i + server_ports.append(port) + + kill_process_on_port(port) # Clean up existing + + proc = multiprocessing.Process( + target=start_openspiel_server, args=(game_name, port) + ) + proc.start() + server_processes.append(proc) + + # Wait for health checks + time.sleep(1) # Give servers time to start + for i, port in enumerate(server_ports): + if not _wait_for_server_health(port, timeout=30): + # Cleanup and fail + for proc in server_processes: + proc.terminate() + raise RuntimeError(f"Server on port {port} failed to start") + + print(f"✓ Started {num_servers} OpenSpiel server(s)") + return server_processes, server_ports + + +def shutdown_servers(server_processes: list): + """Shutdown all OpenSpiel servers gracefully.""" + for proc in server_processes: + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): + proc.kill() + proc.join(timeout=1) + + +# ============================================================================ +# debugging +# ============================================================================ + + +def print_episode_debug(episode, tokenizer, rollout_count: int): + """Print detailed episode debug info using TokenAccumulator's visualization. + + Creates a temporary TokenAccumulator and populates it with episode data + to reuse the colorized token stream display. + """ + print(f"\n[ROLLOUT {rollout_count}] Episode Debug") + print( + f"Reward: {episode.reward:.2f}, Tokens: {len(episode.all_token_ids)}, " + f"Trainable: {episode.response_mask.sum().item()}, Truncated: {episode.is_truncated}" + ) + + # Create a minimal TokenAccumulator just for visualization + # We need to provide the required init params, but we'll override internals + dummy_messages = [{"role": "system", "content": ""}] + acc = TokenAccumulator( + tokenizer=tokenizer, + messages=dummy_messages, + max_len=len(episode.all_token_ids), + eos_id=tokenizer.eos_token_id, + thinking=False, + validation=ValidationMode.OFF, + ) + + # Replace internal state with episode data + acc._tokens = episode.all_token_ids.tolist() + acc._mask = episode.response_mask.tolist() + acc._logprobs = [0.0] * len(episode.all_token_ids) # Dummy logprobs + acc.messages = episode.message_log if episode.message_log else [] + + # Use TokenAccumulator's existing show_messages method + acc.show_messages(max_chars=2000) + + +# ============================================================================ +# Episode +# ============================================================================ + + +@dataclass +class Episode: + """Episode data for GRPO training (new structure).""" + + episode_id: str + all_token_ids: torch.Tensor # [seq_len] + response_mask: torch.Tensor # [seq_len] + loss_mask: torch.Tensor # [seq_len] + reward: float + + task_name: str = "blackjack" + policy_version: int = 0 + is_truncated: bool = False + advantage: float | None = None + logprobs: torch.Tensor | None = None # [seq_len] + ref_logprobs: torch.Tensor | None = None # [seq_len] + metadata: dict[str, Any] = field(default_factory=dict) + message_log: list[dict[str, str]] | None = None + + +# ============================================================================ +# Rollout Functions (from v5) +# ============================================================================ + + +async def do_single_rollout( + env: BlackjackEnv, + policy, + tokenizer, + max_seq_len: int, + max_turns: int, + messages: list[dict], + game_id: str | None = None, +) -> Episode: + """ + Play one game and return one Episode. + + Uses TokenAccumulator for efficient multi-turn token management with BASE anchor pattern. + + Args: + env: BlackjackEnv instance + policy: Policy for generation + tokenizer: Tokenizer with apply_chat_template + max_seq_len: Maximum tokens for full conversation + max_turns: Maximum game turns + messages: Initial messages (e.g., [{"role": "system", "content": "..."}]) + game_id: Optional game ID + + Returns: + Episode with accumulated tokens, masks, and logprobs + """ + + if game_id is None: + game_id = str(uuid.uuid4()) + + # Initialize TokenAccumulator with BASE anchor pattern + accumulator = TokenAccumulator( + tokenizer=tokenizer, + messages=messages, + max_len=max_seq_len, + eos_id=tokenizer.eos_token_id, + validation=ValidationMode.OFF, + thinking=False, + ) + + try: + # ============ Reset environment ============ + initial_obs = env.reset() + accumulator.add_user(initial_obs) + + # ============ Multi-turn loop ============ + final_reward = 0.0 + turn_num = 0 + game_done = False + policy_version = 0 + + while not game_done and turn_num < max_turns: + remaining_budget = accumulator.budget + + if remaining_budget <= 0: + break + + # ============ Generate ============ + prompt = accumulator.format_prompt() + sampling_params = SamplingParams(max_tokens=remaining_budget) + responses = await policy.generate.route( + prompt, sampling_params=sampling_params + ) + response = responses[0] + + policy_version = response.generator_version + + # ============ Add assistant response ============ + response_logprobs = response.logprobs + response_text = response.text + response_token_ids_list = list(response.token_ids) + + # success means not truncated. We drop the entire response if truncated. + success = accumulator.add_assistant( + text=response_text, + token_ids=response_token_ids_list, + logprobs=response_logprobs, + ) + + # If generation truncated, break + if not success: + break + + # ============ Step environment ============ + result = env.step(action_text=response.text) + final_reward = result.reward + game_done = result.done + turn_num += 1 + + # ============ Add environment observation ============ + if not result.done: + obs_text = result.observation["content"] + success = accumulator.add_user(obs_text) + + # If env obs would exceed budget, break + if not success: + break + + # ============ Get episode data ============ + episode_data = accumulator.get_data() + + # Record metrics + if episode_data.truncation_reason: + record_metric( + f"episode/truncated_{episode_data.truncation_reason}", + 1, + Reduce.SUM, + ) + record_metric("episode/total_tokens", len(episode_data.token_ids), Reduce.MEAN) + record_metric("episode/turns", turn_num, Reduce.MEAN) + + # ============ Create episode ============ + # Create loss_mask by shifting response_mask + loss_mask_tensor = torch.roll( + episode_data.response_mask, shifts=-1, dims=0 + ).float() + loss_mask_tensor[-1] = 0.0 # Last position should not train + + return Episode( + episode_id=game_id, + task_name="blackjack", + policy_version=policy_version, + is_truncated=episode_data.is_truncated, + all_token_ids=episode_data.token_ids, + response_mask=episode_data.response_mask, + loss_mask=loss_mask_tensor, + reward=final_reward, + logprobs=episode_data.logprobs, + message_log=accumulator.messages.copy(), + metadata={ + "truncation_reason": episode_data.truncation_reason, + "num_turns": turn_num, + "num_trainable_tokens": episode_data.response_mask.sum().item(), + **(result.metadata if "result" in locals() else {}), + }, + ) + + finally: + env.close() + + +async def do_group_rollout( + envs: list[BlackjackEnv], + policy, + tokenizer, + max_seq_len: int, + max_turns: int, + messages: list[dict], +) -> list[Episode]: + """ + Rollout multiple games in parallel. + + Args: + envs: List of N BlackjackEnv instances + policy: Policy for generation + tokenizer: Tokenizer for chat template + max_seq_len: Episode-level token budget + max_turns: Max turns per game + messages: Initial messages for all games (e.g., [{"role": "system", ...}]) + + Returns: + List of N Episodes + """ + tasks = [ + do_single_rollout( + env=envs[i], + policy=policy, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + max_turns=max_turns, + messages=messages, + game_id=f"game_{i}_{uuid.uuid4().hex[:8]}", + ) + for i in range(len(envs)) + ] + + episodes = await asyncio.gather(*tasks) + return list(episodes) + + +# ============================================================================ +# Helper Actors (from main.py) +# ============================================================================ + + +@dataclass +class ComputeAdvantages(ForgeActor): + """Compute advantages for a group of episodes.""" + + @endpoint + async def compute(self, group: list[Episode]) -> list[float]: + """Compute advantages using reward standardization.""" + rewards = torch.tensor([[e.reward for e in group]]) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() + + +# ============================================================================ +# Training Functions (from main.py) +# ============================================================================ + + +def collate( + batches: list[list[Episode]], + pad_id: int, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collates a list of batches (groups) into inputs and targets. + + Args: + batches: List of groups, where each group is a list of Episodes + pad_id: Padding token ID from tokenizer + + Returns: + (inputs, targets) for training + """ + inputs = [] + targets = [] + + for batch in batches: + # Stack all tensors (pad to max length in batch) + all_tokens = [e.all_token_ids for e in batch] + all_tokens = torch.nn.utils.rnn.pad_sequence( + all_tokens, batch_first=True, padding_value=pad_id + ) + + loss_masks = [e.loss_mask for e in batch] + loss_masks = torch.nn.utils.rnn.pad_sequence( + loss_masks, batch_first=True, padding_value=0.0 + ) + + ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = torch.nn.utils.rnn.pad_sequence( + ref_logprobs, batch_first=True, padding_value=0.0 + ) + + advantages = torch.tensor([e.advantage for e in batch]).unsqueeze(-1) # [b, 1] + + # Create input and target dicts + input = {"tokens": all_tokens} + target = { + "input_ids": all_tokens, # For torch.roll in loss + "loss_mask": loss_masks, # Trainable positions + "ref_logprobs": ref_logprobs, + "advantages": advantages, + } + + inputs.append(input) + targets.append(target) + + return inputs, targets + + +# TODO: delete extensive debugging +# TODO: make KL clipping optional +def simple_grpo_loss( + logits: torch.Tensor, # [b, seq_len, vocab] + input_ids: torch.Tensor, # [b, seq_len] + loss_mask: torch.Tensor, # [b, seq_len] float + ref_logprobs: torch.Tensor, # [b, seq_len] + advantages: torch.Tensor, # [b, 1] + beta: float = 0.1, +) -> torch.Tensor: + """ + GRPO loss with KL clipping + + Args: + logits: Model logits [b, seq_len, vocab_size] + input_ids: Input token IDs [b, seq_len] + loss_mask: Loss mask [b, seq_len] - 1.0 for trainable positions + ref_logprobs: Reference logprobs [b, seq_len] + advantages: Advantages [b, 1] + beta: KL penalty coefficient + + Returns: + Loss scalar + """ + # Create targets using utility function + targets = create_shifted_targets(input_ids, loss_mask) # [b, seq_len] + + # Compute policy logprobs (ignore_index automatically zeros masked positions) + logprobs = compute_logprobs( + logits, targets, ignore_index=CROSS_ENTROPY_IGNORE_IDX + ) # [b, seq_len] - masked positions already 0.0! + + # ======================================================================== + # LOGGING: Input validation + # ======================================================================== + record_metric("loss_debug/batch_size", float(input_ids.shape[0]), Reduce.MEAN) + record_metric("loss_debug/seq_len", float(input_ids.shape[1]), Reduce.MEAN) + record_metric( + "loss_debug/num_trainable_tokens", loss_mask.sum().item(), Reduce.MEAN + ) + record_metric("loss_debug/targets_min", targets.float().min().item(), Reduce.MEAN) + record_metric("loss_debug/targets_max", targets.float().max().item(), Reduce.MEAN) + + # ======================================================================== + # LOGGING: Logprobs statistics + # ======================================================================== + # Mask logprobs for stats (only look at trainable positions) + masked_logprobs = logprobs * loss_mask + masked_ref_logprobs = ref_logprobs * loss_mask + num_trainable = loss_mask.sum().clamp(min=1.0) + + record_metric( + "loss_debug/logprobs_mean", + (masked_logprobs.sum() / num_trainable).item(), + Reduce.MEAN, + ) + record_metric( + "loss_debug/logprobs_min", + logprobs[loss_mask.bool()].min().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/logprobs_max", + logprobs[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/logprobs_std", + logprobs[loss_mask.bool()].std().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + + record_metric( + "loss_debug/ref_logprobs_mean", + (masked_ref_logprobs.sum() / num_trainable).item(), + Reduce.MEAN, + ) + record_metric( + "loss_debug/ref_logprobs_min", + ref_logprobs[loss_mask.bool()].min().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/ref_logprobs_max", + ref_logprobs[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/ref_logprobs_std", + ref_logprobs[loss_mask.bool()].std().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + + # Logprob difference + logprob_diff = ref_logprobs - logprobs + masked_logprob_diff = logprob_diff * loss_mask + record_metric( + "loss_debug/logprob_diff_mean", + (masked_logprob_diff.sum() / num_trainable).item(), + Reduce.MEAN, + ) + record_metric( + "loss_debug/logprob_diff_min", + logprob_diff[loss_mask.bool()].min().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/logprob_diff_max", + logprob_diff[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + + # KL divergence (masked positions are 0.0, so they don't contribute) + # Following VERL's approach: clip log difference before exp for numerical stability + # See: verl/trainer/ppo/core_algos.py kl_penalty_forward() + logprob_diff_clipped = torch.clamp(logprob_diff, min=-20.0, max=20.0) + kl = torch.exp(logprob_diff_clipped) - logprob_diff_clipped - 1 + # Clip final KL to prevent extreme values + kl = torch.clamp(kl, min=-10.0, max=10.0) + + # ======================================================================== + # LOGGING: KL divergence statistics + # ======================================================================== + masked_kl = kl * loss_mask + record_metric( + "loss_debug/kl_mean", (masked_kl.sum() / num_trainable).item(), Reduce.MEAN + ) + record_metric( + "loss_debug/kl_min", + kl[loss_mask.bool()].min().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/kl_max", + kl[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/kl_std", + kl[loss_mask.bool()].std().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/beta_times_kl_mean", + (beta * masked_kl.sum() / num_trainable).item(), + Reduce.MEAN, + ) + + # ======================================================================== + # LOGGING: Advantages statistics + # ======================================================================== + record_metric("loss_debug/advantages_mean", advantages.mean().item(), Reduce.MEAN) + record_metric("loss_debug/advantages_min", advantages.min().item(), Reduce.MEAN) + record_metric("loss_debug/advantages_max", advantages.max().item(), Reduce.MEAN) + record_metric("loss_debug/advantages_std", advantages.std().item(), Reduce.MEAN) + + # Policy loss + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - beta * kl) # [b, seq_len] + + # ======================================================================== + # LOGGING: Per-token loss statistics + # ======================================================================== + masked_policy_loss = per_token_policy_loss * loss_mask + masked_per_token_loss = per_token_loss * loss_mask + + record_metric( + "loss_debug/policy_loss_mean", + (masked_policy_loss.sum() / num_trainable).item(), + Reduce.MEAN, + ) + record_metric( + "loss_debug/policy_loss_min", + ( + per_token_policy_loss[loss_mask.bool()].min().item() + if num_trainable > 0 + else 0.0 + ), + Reduce.MEAN, + ) + record_metric( + "loss_debug/policy_loss_max", + ( + per_token_policy_loss[loss_mask.bool()].max().item() + if num_trainable > 0 + else 0.0 + ), + Reduce.MEAN, + ) + + record_metric( + "loss_debug/per_token_loss_mean", + (masked_per_token_loss.sum() / num_trainable).item(), + Reduce.MEAN, + ) + record_metric( + "loss_debug/per_token_loss_min", + per_token_loss[loss_mask.bool()].min().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + record_metric( + "loss_debug/per_token_loss_max", + per_token_loss[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0, + Reduce.MEAN, + ) + + # Masked average (per sample, then batch average) + loss = ( + (per_token_loss * loss_mask).sum(dim=1) / loss_mask.sum(dim=1).clamp(min=1.0) + ).mean() + + # ======================================================================== + # LOGGING: Final loss + # ======================================================================== + record_metric("loss_debug/final_loss", loss.item(), Reduce.MEAN) + + # ======================================================================== + # EMERGENCY DUMP: If any value is huge, save tensors to file + # ======================================================================== + huge_threshold = 1000.0 + all_stats = [ + ("logprobs_mean", (masked_logprobs.sum() / num_trainable).item()), + ("ref_logprobs_mean", (masked_ref_logprobs.sum() / num_trainable).item()), + ("kl_mean", (masked_kl.sum() / num_trainable).item()), + ("kl_max", kl[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0), + ("advantages_mean", advantages.mean().item()), + ("advantages_max", advantages.max().item()), + ("policy_loss_mean", (masked_policy_loss.sum() / num_trainable).item()), + ( + "policy_loss_max", + ( + per_token_policy_loss[loss_mask.bool()].max().item() + if num_trainable > 0 + else 0.0 + ), + ), + ("per_token_loss_mean", (masked_per_token_loss.sum() / num_trainable).item()), + ( + "per_token_loss_max", + per_token_loss[loss_mask.bool()].max().item() if num_trainable > 0 else 0.0, + ), + ("final_loss", loss.item()), + ] + + # for name, value in all_stats: + # if abs(value) > huge_threshold: + # # Save all tensors to file for debugging + # import datetime + + # timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + # dump_file = f"/tmp/grpo_loss_debug_{timestamp}.pt" + # torch.save( + # { + # "logits": logits.cpu(), + # "input_ids": input_ids.cpu(), + # "targets": targets.cpu(), + # "loss_mask": loss_mask.cpu(), + # "logprobs": logprobs.cpu(), + # "ref_logprobs": ref_logprobs.cpu(), + # "advantages": advantages.cpu(), + # "kl": kl.cpu(), + # "per_token_policy_loss": per_token_policy_loss.cpu(), + # "per_token_loss": per_token_loss.cpu(), + # "loss": loss.cpu(), + # "beta": beta, + # "trigger_stat": name, + # "trigger_value": value, + # }, + # dump_file, + # ) + # print(f"\n{'='*80}") + # print(f"⚠️ HUGE VALUE DETECTED: {name} = {value:.2f}") + # print(f"Dumped all tensors to: {dump_file}") + # print(f"{'='*80}\n") + # break # Only dump once + + return loss + + +async def drop_weights(version: int): + """Drop old weights from torchstore.""" + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") + + +# ============================================================================ +# Main Training Loop +# ============================================================================ + + +async def main(cfg: DictConfig): + """Main GRPO training loop with rollout and training processes.""" + + # ---- Start OpenSpiel Servers ---- # + server_processes, server_ports = start_servers( + num_servers=cfg.get("rollout_threads", 1), + base_port=cfg.blackjack_env.server_port, + game_name=cfg.blackjack_env.game_name, + ) + + # ---- Global setups ---- # + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.metric_logging + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # ---- Setup tokenizers ---- # + # Create N tokenizers for N rollout threads (one per thread, no sharing) + num_rollout_threads = cfg.rollout_threads + tokenizers = [ + get_tokenizer(cfg.blackjack_env.model) for _ in range(num_rollout_threads) + ] + pad_id = ( + tokenizers[0].pad_token_id + if tokenizers[0].pad_token_id is not None + else tokenizers[0].eos_token_id + ) + + # Create collate function with pad_id + collate_fn = partial(collate, pad_id=pad_id) + + # ---- Setup services ---- # + ( + policy, + trainer, + replay_buffer, + compute_advantages, + ref_model, + ) = await asyncio.gather( + Generator.options(**cfg.services.policy).as_service(**cfg.policy), + TitanTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, loss=simple_grpo_loss + ), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( + **cfg.replay_buffer, collate=collate_fn + ), + ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), + ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), + ) + + max_steps = cfg.trainer.training.steps or -1 + + print("All services initialized successfully!") + shutdown_event = asyncio.Event() + + # Initialize torchstore + trainer_num_procs = cfg.actors.trainer["procs"] + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + print("Torchstore successfully initialized with local rank strategy") + + # ---- Core RL loops ---- # + async def continuous_rollouts(thread_id: int, tokenizer): + """Main GRPO rollout loop using new architecture.""" + rollout_count = 0 + + # Config - use dedicated server for this thread + server_url = f"http://localhost:{server_ports[thread_id]}" + max_seq_len = cfg.blackjack_env.max_seq_len + max_turns = cfg.blackjack_env.max_turns + group_size = cfg.group_size + + print(f"[Thread {thread_id}] Using server at {server_url}") + + # Initial messages + initial_messages = [ + { + "role": "system", + "content": """You are an expert Blackjack player. + +GOAL: Get a hand total closer to 21 than the dealer without going over 21 (busting). + +RULES: +- Card values: Ace=1 or 11, Face cards (J,Q,K)=10, Number cards=face value +- If you go over 21, you bust and lose immediately +- The dealer plays after you and must hit until reaching 17+ + +ACTIONS: +- HIT: Take another card (increases your hand total) +- STAND: Keep your current hand and end your turn + +WIN CONDITIONS: +- Your hand is closer to 21 than the dealer's final hand +- Dealer busts (goes over 21) and you don't +- You get exactly 21 + +IMPORTANT: You MUST output your action in the following format: +HIT or STAND""", + } + ] + + while not shutdown_event.is_set(): + t = Tracer("main_perf/continuous_rollouts") + t.start() + + # ============ Step 1: Rollout group ============ + # TODO: currently done serially + episodes = [] + for i in range(group_size): + env = BlackjackEnv(server_url=server_url) + game_id = f"game_{i}_{uuid.uuid4().hex[:8]}" + + episode = await do_single_rollout( + env=env, + policy=policy, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + max_turns=max_turns, + messages=initial_messages, + game_id=game_id, + ) + episodes.append(episode) + + t.step("play_games") + + # Print episode details every 10 rollouts + if episodes and rollout_count % 10 == 0: + print_episode_debug(episodes[0], tokenizer, rollout_count) + + # ============ Step 2: Filter groups (constant rewards) ============ + rewards = [e.reward for e in episodes] + if len(set(rewards)) == 1: + print( + f"[ROLLOUT {rollout_count}] ⚠️ DROPPED GROUP - All {len(episodes)} episodes have same reward: {rewards[0]}" + ) + record_metric("groups/rate_dropped", 1, Reduce.MEAN) + rollout_count += 1 + t.stop() + continue + record_metric("groups/rate_dropped", 0, Reduce.MEAN) + + # ============ Step 3: Compute ref_model ============ + max_len = max(len(e.all_token_ids) for e in episodes) + + # Pad input_ids and loss_masks + padded_input_ids, padded_loss_masks = [], [] + for i, e in enumerate(episodes): + pad_len = max_len - len(e.all_token_ids) + + padded_input_ids.append( + F.pad(e.all_token_ids, (0, pad_len), value=pad_id) + ) + padded_loss_masks.append(F.pad(e.loss_mask, (0, pad_len), value=0.0)) + + input_ids = torch.stack(padded_input_ids) # [batch, max_len] + loss_mask_batch = torch.stack(padded_loss_masks) # [batch, max_len] + + # Call ref_model with loss_mask - returns [batch, max_len] + ref_logprobs_padded = await ref_model.forward.route( + input_ids, return_logprobs=True, loss_mask=loss_mask_batch + ) + + t.step("reference_model_calculate_logprobs") + + # Assign ref_logprobs to episodes (unpad to original length) + for i, episode in enumerate(episodes): + seq_len = len(episode.all_token_ids) + episode.ref_logprobs = ref_logprobs_padded[i, :seq_len] # [seq_len] + + del ref_logprobs_padded, input_ids, loss_mask_batch + + # ============ Step 4: Compute advantages ============ + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages): + episode.advantage = advantage + + # ============ Step 5: Episode-level acceptance ============ + accepted = [] + for episode in episodes: + if episode.is_truncated and not cfg.accept_truncated: + record_metric("buffer/rate_rejected_truncated", 1, Reduce.MEAN) + else: + record_metric("buffer/rate_rejected_truncated", 0, Reduce.MEAN) + accepted.append(episode) + + # ============ Step 6: Add to buffer ============ + for episode in accepted: + await replay_buffer.add.call_one(episode) + + record_metric("buffer/episodes_accepted", len(accepted), Reduce.SUM) + record_metric( + "buffer/episode_acceptance_rate", + len(accepted) / len(episodes) if episodes else 0, + Reduce.MEAN, + ) + + rollout_count += 1 + record_metric( + "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM + ) + t.stop() + + async def continuous_training(): + """Training loop.""" + training_step = 0 + restart_tracer = True + + while max_steps == -1 or training_step < max_steps: + if restart_tracer: + t = Tracer("main_perf/continuous_training") + t.start() + restart_tracer = False + + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) + if batch is None: + await asyncio.sleep(0.1) + else: + t.step("waiting_for_buffer") + print(f"[TRAINING] Step {training_step}: Starting training") + + inputs, targets = batch + await trainer.train_step.call(inputs, targets) + training_step += 1 + t.step("train_step") + + await trainer.push_weights.call(training_step) + t.step("push_weights") + + await policy.update_weights.fanout(training_step) + t.step("update_weights") + + if training_step >= 2: + await drop_weights(training_step - 1) + t.step("drop_weights") + + t.stop() + restart_tracer = True + + # Flush metrics every training step + await mlogger.flush.call_one(training_step) + + print( + f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." + ) + + print(f"Starting GRPO with {num_rollout_threads} rollout threads") + rollout_tasks = [ + asyncio.create_task(continuous_rollouts(thread_id=i, tokenizer=tokenizers[i])) + for i in range(num_rollout_threads) + ] + training_task = asyncio.create_task(continuous_training()) + + try: + await training_task + except KeyboardInterrupt: + print("Training interrupted by user") + finally: + print("Shutting down... (this may take a few seconds)") + shutdown_event.set() + + # Cancel rollout tasks + try: + await asyncio.wait_for( + asyncio.gather(*rollout_tasks, return_exceptions=True), + timeout=5, + ) + except asyncio.TimeoutError: + print("Timeout waiting for rollouts; forcing cancellation...") + for t in rollout_tasks: + t.cancel() + await asyncio.gather(*rollout_tasks, return_exceptions=True) + + # Cancel training task + training_task.cancel() + try: + await asyncio.wait_for(training_task, timeout=2) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + # Shutdown forge actors/services + print("Shutting down Forge actors...") + try: + await asyncio.wait_for(shutdown(), timeout=10) + print("✓ Forge actors shut down") + except asyncio.TimeoutError: + print("⚠ Forge shutdown timed out after 10s, forcing exit...") + + # Shutdown OpenSpiel servers + shutdown_servers(server_processes) + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() # @parse grabs the cfg from CLI diff --git a/apps/blackjack/openenv_patch/README.md b/apps/blackjack/openenv_patch/README.md new file mode 100644 index 000000000..a444afbc1 --- /dev/null +++ b/apps/blackjack/openenv_patch/README.md @@ -0,0 +1,65 @@ +# Blackjack RL Training + +## Setup + +```bash +# Clone and install OpenEnv +git clone git@github.com:meta-pytorch/OpenEnv.git +cd OpenEnv +pip install -e . + +# Apply blackjack modifications +python ../forge/apps/blackjack/openenv_patch/apply_patch.py + +# Run training +cd ../forge +python -m apps.blackjack.main --config apps/blackjack/qwen3_1_7b.yaml +``` + +## What gets changed in OpenEnv + +### 1. Enable metadata passthrough (`src/core/env_server/http_server.py`) + +```python +# Before: +obs_dict.pop("metadata", None) # Remove metadata from observation + +# After: +# obs_dict.pop("metadata", None) # Remove metadata from observation +``` + +### 2. Extract blackjack game state (`src/envs/openspiel_env/server/openspiel_environment.py`) + +```python +# Add this after line 252 (before creating OpenSpielObservation): + +# Extract game-specific metadata for blackjack +metadata = {} +if self.game_name == "blackjack" and not time_step.last(): + try: + state = self._ospiel_env.get_state + if hasattr(state, "get_best_player_total"): + metadata["player_total"] = state.get_best_player_total( + self.agent_player + ) + if hasattr(state, "dealers_visible_card"): + dealer_card_idx = state.dealers_visible_card() + rank = dealer_card_idx % 13 + if rank == 0: + dealer_value = 1 # Ace + elif rank <= 9: + dealer_value = rank + 1 # 2-10 + else: + dealer_value = 10 # Jack, Queen, King + metadata["dealer_card"] = dealer_value + except Exception: + pass + +# Then update OpenSpielObservation creation: +obs = OpenSpielObservation( + ..., + metadata=metadata, # Add this line +) +``` + +This allows observations like `"Hand: 17, Dealer: Ace"` instead of raw state vectors. diff --git a/apps/blackjack/openenv_patch/apply_patch.py b/apps/blackjack/openenv_patch/apply_patch.py new file mode 100755 index 000000000..17fe51661 --- /dev/null +++ b/apps/blackjack/openenv_patch/apply_patch.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Apply OpenEnv modifications for blackjack training.""" + +import subprocess +import sys +from pathlib import Path + + +def main(): + # Get script directory + script_dir = Path(__file__).parent + patch_file = script_dir / "openenv_blackjack.patch" + + if not patch_file.exists(): + print(f"Error: Patch file not found at {patch_file}") + sys.exit(1) + + # Apply patch + try: + subprocess.run( + ["git", "apply", str(patch_file)], + check=True, + capture_output=True, + text=True, + ) + print("✓ Patch applied successfully") + except subprocess.CalledProcessError as e: + print(f"Error applying patch: {e.stderr}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/apps/blackjack/openenv_patch/openenv_blackjack.patch b/apps/blackjack/openenv_patch/openenv_blackjack.patch new file mode 100644 index 000000000..3826ba474 --- /dev/null +++ b/apps/blackjack/openenv_patch/openenv_blackjack.patch @@ -0,0 +1,160 @@ +diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py +index d18873f..31b99df 100644 +--- a/src/core/env_server/http_server.py ++++ b/src/core/env_server/http_server.py +@@ -17,9 +17,11 @@ import os + from dataclasses import asdict + from typing import Any, Dict, Type + ++from fastapi import Body, FastAPI ++ + from .interfaces import Environment + from .types import Action, Observation +-from fastapi import Body, FastAPI ++ + + class HTTPEnvServer: + """ +@@ -107,7 +109,6 @@ class HTTPEnvServer: + """Health check endpoint.""" + return {"status": "healthy"} + +- + def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: + """ + Convert JSON dict to Action instance. +@@ -150,7 +151,7 @@ class HTTPEnvServer: + # Extract reward and done (these are part of StepResult on client side) + reward = obs_dict.pop("reward", None) + done = obs_dict.pop("done", False) +- obs_dict.pop("metadata", None) # Remove metadata from observation ++ # obs_dict.pop("metadata", None) # Remove metadata from observation + + # Return in HTTPEnvClient expected format + return { +@@ -159,6 +160,7 @@ class HTTPEnvServer: + "done": done, + } + ++ + def create_app( + env: Environment, + action_cls: Type[Action], +@@ -167,33 +169,36 @@ def create_app( + ) -> Any: + """ + Create a FastAPI application with or without web interface. +- ++ + This function creates a FastAPI app with the web interface enabled by default, + including README integration for better user experience. +- ++ + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading +- ++ + Returns: + FastAPI application instance with or without web interface and README integration + """ + # Check if web interface should be enabled + # This can be controlled via environment variable or build argument +- enable_web = ( +- os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes") ++ enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( ++ "true", ++ "1", ++ "yes", + ) + + if enable_web: + # Import web interface only when needed + from .web_interface import create_web_interface_app ++ + return create_web_interface_app(env, action_cls, observation_cls, env_name) + else: + # Use standard FastAPI app without web interface + return create_fastapi_app(env, action_cls, observation_cls) +- ++ + + def create_fastapi_app( + env: Environment, +diff --git a/src/envs/openspiel_env/server/openspiel_environment.py b/src/envs/openspiel_env/server/openspiel_environment.py +index 481aefb..580ec81 100644 +--- a/src/envs/openspiel_env/server/openspiel_environment.py ++++ b/src/envs/openspiel_env/server/openspiel_environment.py +@@ -21,8 +21,8 @@ from .opponent_policies import get_opponent_policy, OpponentPolicy + + # Import OpenSpiel + try: +- from open_spiel.python import rl_environment + import pyspiel ++ from open_spiel.python import rl_environment + except ImportError as e: + raise ImportError( + "OpenSpiel is not installed. " +@@ -73,9 +73,7 @@ class OpenSpielEnvironment(Environment): + + # Create OpenSpiel environment + try: +- self._ospiel_env = rl_environment.Environment( +- game_name, **self.game_params +- ) ++ self._ospiel_env = rl_environment.Environment(game_name, **self.game_params) + except Exception as e: + raise ValueError( + f"Failed to create OpenSpiel game '{game_name}': {e}" +@@ -252,15 +250,48 @@ class OpenSpielEnvironment(Environment): + if time_step.rewards is not None: + reward = float(time_step.rewards[self.agent_player]) + ++ # Extract game-specific metadata for blackjack ++ metadata = {} ++ if self.game_name == "blackjack" and not time_step.last(): ++ # Get underlying OpenSpiel state to access blackjack-specific methods ++ try: ++ state = self._ospiel_env.get_state # Property, not method - no () ++ if hasattr(state, "get_best_player_total"): ++ metadata["player_total"] = state.get_best_player_total( ++ self.agent_player ++ ) ++ if hasattr(state, "dealers_visible_card"): ++ dealer_card_idx = state.dealers_visible_card() ++ # Convert card index (0-51) to blackjack value (1-10) ++ # This matches the C++ CardValue() logic in blackjack.cc ++ # Cards are indexed from 0 to kDeckSize-1 (52 cards total) ++ # Rank = card_idx % 13, where 0=Ace, 1-9=2-10, 10=J, 11=Q, 12=K ++ rank = dealer_card_idx % 13 ++ if rank == 0: ++ dealer_value = 1 # Ace ++ elif rank <= 9: ++ dealer_value = rank + 1 # 2-10 ++ else: ++ dealer_value = 10 # Jack, Queen, King ++ metadata["dealer_card"] = dealer_value ++ except Exception: ++ # If extraction fails, continue without metadata ++ pass ++ + # Create observation + obs = OpenSpielObservation( +- info_state=info_state.tolist() if hasattr(info_state, "tolist") else list(info_state), ++ info_state=( ++ info_state.tolist() ++ if hasattr(info_state, "tolist") ++ else list(info_state) ++ ), + legal_actions=legal_actions, + game_phase=game_phase, + current_player_id=current_player_id, + opponent_last_action=self._last_opponent_action, + done=time_step.last(), + reward=reward, ++ metadata=metadata, + ) + + return obs diff --git a/apps/blackjack/qwen3_1_7b.yaml b/apps/blackjack/qwen3_1_7b.yaml new file mode 100644 index 000000000..57231e1f6 --- /dev/null +++ b/apps/blackjack/qwen3_1_7b.yaml @@ -0,0 +1,153 @@ +# BlackJack GRPO Training Configuration +# >>> python -m apps.blackjack.main --config apps/blackjack/qwen3_1_7b.yaml +# +# The OpenSpiel server will be started automatically by the training script. + +# Global configuration +group_size: 16 # Number of parallel games per rollout +local_batch_size: 16 # Per-device batch size +max_seq_len: 2048 # Maximum tokens for full conversation (including all turns) +model: "Qwen/Qwen3-1.7B" +off_by_n: 1 # Off-policy tolerance +accept_truncated: true # Accept truncated episodes in replay buffer + +# Main loop configuration +rollout_threads: 1 # Number of parallel rollout threads + +# Observability configuration +metric_logging: + wandb: + project: "blackjack-grpo" + group: "blackjack_exp_${oc.env:USER}" + logging_mode: global_reduce + console: + logging_mode: global_reduce + +# OpenSpiel environment configuration +blackjack_env: + game_name: "blackjack" # OpenSpiel game to run (blackjack, catch, tic_tac_toe, etc.) + server_url: "http://localhost:9000" + server_port: 9000 + model: ${model} + max_seq_len: ${max_seq_len} # Maximum tokens for full conversation (including all turns) + max_turns: 10 # Maximum number of turns per game + +# Policy configuration +policy: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams + n: 1 # Generate 1 response per game state (not group_size, since we play full games) + max_tokens: ${max_seq_len} # changed dinamically on generate call + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${local_batch_size} + seq_len: ${max_seq_len} + max_norm: 1.0 + steps: 1000 # Tutorial: 1000 steps (increase for production) + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${local_batch_size} + max_policy_age: ${off_by_n} + dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree + +# Reference model configuration +ref_model: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + +# All resource allocations +services: + policy: + procs: ${policy.engine_args.tensor_parallel_size} + num_replicas: 1 + mesh_name: policy + with_gpus: true + ref_model: + procs: 1 + num_replicas: 1 + mesh_name: ref_model + with_gpus: true + reward_actor: + procs: 1 + num_replicas: 1 + mesh_name: reward_actor + with_gpus: false + +actors: + blackjack_env: + procs: 1 + with_gpus: false + mesh_name: blackjack_env + trainer: + procs: 1 + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer + compute_advantages: + procs: 1 + with_gpus: false + mesh_name: compute_advantages diff --git a/apps/blackjack/token_accumulator.py b/apps/blackjack/token_accumulator.py new file mode 100644 index 000000000..249a84a68 --- /dev/null +++ b/apps/blackjack/token_accumulator.py @@ -0,0 +1,621 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +import torch + + +class ValidationMode(Enum): + """Validation strictness.""" + + STRICT = "strict" # Raise on failures + WARN = "warn" # Print warnings + OFF = "off" # No validation + + +class TruncationReason(Enum): + """Truncation reason.""" + + USER_TOO_LONG = "user_too_long" + ASSISTANT_TOO_LONG = "assistant_too_long" + TOOL_TOO_LONG = "tool_too_long" + MAX_NUM_TURNS = "max_num_turns" + + +@dataclass +class EpisodeData: + """ + Episode data as tensors, ready for training. + + All tensors have shape (T,) where T is sequence length. + """ + + token_ids: torch.Tensor # dtype=long + response_mask: torch.Tensor # dtype=bool + logprobs: torch.Tensor # dtype=float + is_truncated: bool + truncation_reason: Optional[str] = None + + +class TokenAccumulator: + """ + Accumulate tokens for multi-turn RL episodes using vLLM tokens directly. + + ## Why Delta Tokenization? + + vLLM only returns assistant response tokens. We need the full conversation with + chat template tokens for training. We can't re-tokenize because it's expensive + and error-prone. + + **What we get from vLLM:** + ``` + response_tokens = [791, 19, 374, 220, 2] # ["The", "answer", "is", "4", ""] + ``` + + **What we need for training:** + ``` + [1, 2, 3] # ["You", "are", "helpful"] (not trainable) + [10, 11, 12, 13] # ["What", "is", "2+2", "?"] (not trainable) + [150, 123] # ["<|im_start|>", "assistant"] (not trainable) + [791, 19, 374, 220, 2] # ["The", "answer", "is", "4", eos] (TRAINABLE!) + [151] # ["<|im_end|>"] (not trainable, Qwen only) + ``` + + **Solution:** Use an anchor conversation [system, empty_user] that never changes. + Tokenize new messages against it and extract deltas. For assistant responses, + add generation prompt prefix and any model-specific suffix. + + ## Truncation Behavior + + - **add_user**: If truncated, adds partial message (truncated to fit budget) + - **add_assistant**: If truncated, DROPS entire response (nothing added) + - Once truncated, all subsequent adds will fail (return False) + + ## Usage + + ```python + acc = TokenAccumulator(tok, [{"role": "system", "content": "Help"}], 2048, eos_id=2) + + # Add messages + acc.add_user("What is 2+2?") + prompt = acc.format_prompt() + response = vllm_generate(prompt) + acc.add_assistant(response.text, response.token_ids, response.logprobs) + + # Show what will be trained on + acc.show_messages() + + # Get episode data as tensors + episode = acc.get_data() + # episode.token_ids: torch.Tensor (long) + # episode.response_mask: torch.Tensor (bool, True = trainable) + # episode.logprobs: torch.Tensor (float) + ``` + + Args: + tokenizer: HuggingFace tokenizer with apply_chat_template + messages: Initial messages (must include system message) + max_len: Maximum sequence length + eos_id: End-of-sequence token ID + thinking: Enable tags for Qwen models + validation: Validation mode (STRICT, WARN, OFF) + """ + + def __init__( + self, + tokenizer, + messages: list[dict], + max_len: int, + eos_id: int, + thinking: bool = True, + validation: ValidationMode = ValidationMode.STRICT, + ) -> None: + self._validate_init(tokenizer, messages, max_len, eos_id) + + self.tokenizer = tokenizer + self.max_len = max_len + self.eos_id = eos_id + self.thinking = thinking + self.validation = validation + + # State + self.messages: list[dict] = [] + self._tokens: list[int] = [] + self._mask: list[bool] = [] + self._logprobs: list[float] = [] + self.truncated: bool = False + self.truncation_reason: Optional[TruncationReason] = None + + # Track message boundaries for efficient validation + # Each entry: (end_idx, role, should_end_with_eos) + self._message_ends: list[tuple[int, str, bool]] = [] + + # Setup + self._setup_anchor(messages) + self._init_messages(messages) + + def __repr__(self) -> str: + status = f", truncated" if self.truncated else "" + return f"TokenAccumulator({len(self._tokens)}/{self.max_len}{status})" + + @property + def budget(self) -> int: + """Remaining token budget.""" + return max(0, self.max_len - len(self._tokens) - self.gen_prompt_len) + + def add_user(self, content: str) -> bool: + """ + Add user message. If truncated, adds partial message (truncated to fit). + + Returns: + True if not truncated, False if truncated + """ + if not isinstance(content, str): + raise TypeError(f"content must be str, got {type(content)}") + + msg = {"role": "user", "content": content} + + # Tokenize [system, user] and extract delta + full = self.tokenizer.apply_chat_template( + [self.anchor[0], msg], + add_generation_prompt=False, + tokenize=True, + enable_thinking=self.thinking, + ) + # Extract user tokens by slicing off system prefix + tokens = full[self.sys_len :] + + if not tokens: + return True + + # Check budget + budget = self.budget + if budget <= 0: + self._mark_truncated(TruncationReason.USER_TOO_LONG) + return False + + # Truncate if needed (still adds partial) + was_truncated = len(tokens) > budget + if was_truncated: + tokens = tokens[:budget] + self._mark_truncated(TruncationReason.USER_TOO_LONG) + + self.messages.append(msg) + self._add_tokens(tokens, trainable=False, role="user", ends_with_eos=False) + + return not was_truncated + + def add_assistant( + self, text: str, token_ids: list[int], logprobs: Optional[list[float]] = None + ) -> bool: + """ + Add assistant response from vLLM. If truncated, DROPS entire response (nothing added). + + Args: + text: Response text (for message log) + token_ids: Token IDs from vLLM (must end with EOS) + logprobs: Log probabilities (optional) + + Returns: + False if truncated/invalid (response dropped), True if added successfully + """ + # Type validation + if not isinstance(text, str): + raise TypeError(f"text must be str, got {type(text)}") + if not isinstance(token_ids, list): + raise TypeError(f"token_ids must be list, got {type(token_ids)}") + + # Must have tokens and end with EOS + if not token_ids: + return self._mark_truncated(TruncationReason.ASSISTANT_TOO_LONG) + if token_ids[-1] != self.eos_id: + return self._mark_truncated(TruncationReason.ASSISTANT_TOO_LONG) + + # Check budget: generation_prompt + response + suffix + total_len = self.gen_prompt_len + len(token_ids) + len(self.suffix) + if total_len > self.budget: + return self._mark_truncated(TruncationReason.ASSISTANT_TOO_LONG) + + # Validate logprobs if provided + if logprobs is not None: + if not isinstance(logprobs, list): + raise TypeError(f"logprobs must be list or None") + if len(logprobs) != len(token_ids): + raise ValueError( + f"logprobs length mismatch: {len(logprobs)} != {len(token_ids)}" + ) + + self.messages.append({"role": "assistant", "content": text}) + + # Generation prompt (not trainable) + self._add_tokens( + self.gen_prompt_tokens, + trainable=False, + logprobs=[0.0] * len(self.gen_prompt_tokens), + role="assistant_prompt", + ends_with_eos=False, + ) + + # Response tokens (trainable) + self._add_tokens( + token_ids, + trainable=True, + logprobs=logprobs, + role="assistant", + ends_with_eos=True, + ) + + # Suffix if needed (not trainable) + if self.suffix: + self._add_tokens( + self.suffix, + trainable=False, + logprobs=[0.0] * len(self.suffix), + role="assistant_suffix", + ends_with_eos=False, + ) + + return True + + def format_prompt(self) -> str: + """Format conversation for vLLM generation.""" + return self.tokenizer.apply_chat_template( + self.messages, + add_generation_prompt=True, + tokenize=False, + enable_thinking=self.thinking, + ) + + def get_data(self) -> EpisodeData: + """ + Convert to tensors, validate, and return episode data. + + Returns: + EpisodeData with torch tensors + + Raises: + AssertionError/ValueError: If validation fails in STRICT mode + """ + # Convert to tensors + token_ids = torch.tensor(self._tokens, dtype=torch.long) + response_mask = torch.tensor(self._mask, dtype=torch.bool) + logprobs = torch.tensor(self._logprobs, dtype=torch.float) + + # Validate on tensors + if self.validation != ValidationMode.OFF: + self._validate(token_ids, response_mask, logprobs) + + return EpisodeData( + token_ids=token_ids, + response_mask=response_mask, + logprobs=logprobs, + is_truncated=self.truncated, + truncation_reason=( + self.truncation_reason.value if self.truncation_reason else None + ), + ) + + def show_messages(self, max_chars: int = 5000) -> None: + """ + Show token stream with trainability highlighted. + + Uses colored text runs for readability (similar to tinker-cookbook's format_colorized). + Groups consecutive tokens with same trainability and decodes together for proper + multi-byte character handling. + + Args: + max_chars: Maximum characters to show in decoded output (default: 5000) + """ + print("=" * 80) + print(f"TokenAccumulator: {len(self._tokens)}/{self.max_len} tokens") + trainable_count = sum(self._mask) + trainable_pct = 100 * trainable_count / len(self._tokens) if self._tokens else 0 + print( + f"Trainable: {trainable_count}/{len(self._tokens)} ({trainable_pct:.1f}%)" + ) + print("=" * 80) + + if not self._tokens: + print("(no tokens)") + print("=" * 80) + return + + # Show messages list + print("\nMessages:") + for i, msg in enumerate(self.messages): + role = msg["role"] + content = msg["content"] + preview = content[:100] + "..." if len(content) > 100 else content + print(f" [{i}] {role:10s} {preview!r}") + + # Show colorized token stream + print("\nToken stream:") + self._show_colorized_token_stream(max_chars) + + print("=" * 80) + + def _show_colorized_token_stream(self, max_chars: int) -> None: + """ + Show full token stream with color coding by trainability. + + Groups consecutive tokens with same trainability into "runs" and decodes + them together. This handles multi-byte characters correctly. + """ + chunks = [] + current_ids = [] + current_trainable = None + total_chars = 0 + + def flush_run(): + nonlocal total_chars + if not current_ids: + return + + # Decode entire run at once + decoded = self.tokenizer.decode(current_ids) + + # Check if we've exceeded max_chars + if total_chars >= max_chars: + return + + # Truncate if needed + if total_chars + len(decoded) > max_chars: + remaining = max_chars - total_chars + decoded = decoded[:remaining] + "..." + + total_chars += len(decoded) + + # Color based on trainability + if current_trainable: + color_code = "\033[92m" # Green for trainable + symbol = "✓" + else: + color_code = "\033[90m" # Gray for not trainable + symbol = "·" + + # Escape special characters for display + decoded_repr = repr(decoded)[1:-1] # Remove outer quotes + chunks.append(f"{color_code}{symbol} {decoded_repr}\033[0m") + + # Group tokens into runs + for i in range(len(self._tokens)): + trainable = self._mask[i] + + # Flush when trainability changes + if trainable != current_trainable and current_ids: + flush_run() + current_ids = [] + + current_ids.append(self._tokens[i]) + current_trainable = trainable + + # Flush final run + flush_run() + + # Print runs + if chunks: + print(" " + " ".join(chunks)) + + if total_chars >= max_chars: + print(f"\n (output truncated at {max_chars} chars)") + + def _show_colorized_tokens(self, start_idx: int, end_idx: int) -> None: + """ + DEPRECATED: Old method, kept for compatibility. + Use _show_colorized_token_stream instead. + """ + pass + + # Internal helpers + def _validate_init( + self, tokenizer, messages: list[dict], max_len: int, eos_id: int + ) -> None: + """Validate initialization parameters.""" + if not hasattr(tokenizer, "apply_chat_template"): + raise ValueError("Tokenizer must have apply_chat_template method") + if not messages: + raise ValueError("Must provide at least a system message") + if not isinstance(messages, list): + raise TypeError(f"messages must be list, got {type(messages)}") + for i, msg in enumerate(messages): + if not isinstance(msg, dict): + raise TypeError(f"Message {i} must be dict") + if "role" not in msg or "content" not in msg: + raise ValueError(f"Message {i} missing 'role' or 'content'") + if not isinstance(max_len, int) or max_len <= 0: + raise ValueError(f"max_len must be positive int, got {max_len}") + if not isinstance(eos_id, int): + raise TypeError(f"eos_id must be int, got {type(eos_id)}") + + def _setup_anchor(self, msgs: list[dict]) -> None: + """ + Setup anchor for delta tokenization and compute suffix. + + The suffix is anything after EOS in the chat template. We create a test + conversation with EOS and extract any tokens that follow it. + """ + sys = ( + msgs[0] + if msgs[0]["role"] == "system" + else {"role": "system", "content": ""} + ) + self.anchor = [sys, {"role": "user", "content": ""}] + + # Compute generation prompt + without = self.tokenizer.apply_chat_template( + self.anchor, + add_generation_prompt=False, + tokenize=True, + enable_thinking=self.thinking, + ) + with_gen = self.tokenizer.apply_chat_template( + self.anchor, + add_generation_prompt=True, + tokenize=True, + enable_thinking=self.thinking, + ) + self.gen_prompt_tokens = with_gen[len(without) :] + self.gen_prompt_len = len(self.gen_prompt_tokens) + + # Compute system length + sys_tokens = self.tokenizer.apply_chat_template( + [sys], + add_generation_prompt=False, + tokenize=True, + enable_thinking=self.thinking, + ) + self.sys_len = len(sys_tokens) + + # Compute suffix by tokenizing a test conversation + test_conv = [ + sys, + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "response"}, + ] + test_tokens = self.tokenizer.apply_chat_template( + test_conv, + add_generation_prompt=False, + tokenize=True, + enable_thinking=self.thinking, + ) + + # Find last EOS + eos_idx = -1 + for i in range(len(test_tokens) - 1, -1, -1): + if test_tokens[i] == self.eos_id: + eos_idx = i + break + + # Extract suffix (everything after EOS, or empty if nothing) + if eos_idx >= 0 and eos_idx < len(test_tokens) - 1: + self.suffix = test_tokens[eos_idx + 1 :] + else: + self.suffix = [] + + def _init_messages(self, msgs: list[dict]) -> None: + """Initialize with starting messages.""" + if not msgs: + return + + tokens = self.tokenizer.apply_chat_template( + msgs, + add_generation_prompt=False, + tokenize=True, + enable_thinking=self.thinking, + ) + + if len(tokens) > self.max_len: + self._mark_truncated(TruncationReason.USER_TOO_LONG) + tokens = tokens[: self.max_len] + + self.messages = msgs.copy() + self._add_tokens(tokens, trainable=False, role="initial", ends_with_eos=False) + + def _add_tokens( + self, + tokens: list[int], + trainable: bool, + logprobs: Optional[list[float]] = None, + role: str = "", + ends_with_eos: bool = False, + ) -> None: + """Add tokens to parallel arrays and track message boundary.""" + if not tokens: + return + + self._tokens.extend(tokens) + self._mask.extend([trainable] * len(tokens)) + self._logprobs.extend(logprobs if logprobs else [0.0] * len(tokens)) + + # Track message end for validation + end_idx = len(self._tokens) - 1 + self._message_ends.append((end_idx, role, ends_with_eos)) + + def _mark_truncated(self, reason: TruncationReason) -> bool: + """Mark as truncated.""" + self.truncated = True + self.truncation_reason = reason + return False + + def _validate( + self, + token_ids: torch.Tensor, + response_mask: torch.Tensor, + logprobs: torch.Tensor, + ) -> None: + """ + Run validation checks on tensors. + + Args: + token_ids: Token IDs tensor (shape: T) + response_mask: Response mask tensor (shape: T) + logprobs: Log probabilities tensor (shape: T) + """ + # Check 1: Shapes match + if not (token_ids.shape == response_mask.shape == logprobs.shape): + raise AssertionError( + f"Shape mismatch: token_ids={token_ids.shape}, " + f"mask={response_mask.shape}, logprobs={logprobs.shape}" + ) + + # Check 2: Budget not exceeded + if len(token_ids) > self.max_len: + raise ValueError(f"Budget overflow: {len(token_ids)} > {self.max_len}") + + # Check 3: Message boundaries are correct + for end_idx, role, should_end_with_eos in self._message_ends: + if should_end_with_eos: + # Token at end_idx should be eos_id + if token_ids[end_idx].item() != self.eos_id: + msg = f"{role} at {end_idx} has token {token_ids[end_idx].item()}, expected EOS {self.eos_id}" + if self.validation == ValidationMode.STRICT: + raise ValueError(msg) + print(f"WARNING: {msg}") + + # For assistant: end_idx should be trainable + if role == "assistant" and not response_mask[end_idx].item(): + msg = f"Assistant EOS at {end_idx} is not trainable" + if self.validation == ValidationMode.STRICT: + raise ValueError(msg) + print(f"WARNING: {msg}") + + # Token after EOS should not be trainable + if end_idx + 1 < len(token_ids) and response_mask[end_idx + 1].item(): + msg = ( + f"Token after EOS at {end_idx+1} is trainable (should be False)" + ) + if self.validation == ValidationMode.STRICT: + raise ValueError(msg) + print(f"WARNING: {msg}") + + # Check 4: Prefix consistency (incremental == full tokenization) + # DISABLED: Qwen always adds think tags to LAST assistant message only, + # but in incremental accumulation every assistant response IS the last one + # at the time we add it. This causes mismatches: + # - thinking=True: missing 4 tokens (last gets think tags in full tokenization) + # - thinking=False: extra 4 tokens (first doesn't get think tags in full tokenization) + # This is expected behavior for Qwen and not a bug. + # + # with self._lock: + # full_tokens = self.tokenizer.apply_chat_template( + # self.messages, add_generation_prompt=False, tokenize=True, enable_thinking=self.thinking + # ) + # + # accumulated_len = len(token_ids) + # expected_len = len(full_tokens) + # + # if accumulated_len != expected_len: + # msg = ( + # f"Prefix consistency failed: " + # f"accumulated={accumulated_len} tokens, " + # f"expected={expected_len}" + # ) + # if self.validation == ValidationMode.STRICT: + # raise AssertionError(msg) + # print(f"WARNING: {msg}") diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 2f9983b56..306a50f26 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -15,9 +15,10 @@ import torch from forge.controller import ForgeActor +from forge.data.common import CROSS_ENTROPY_IGNORE_IDX from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer -from forge.util.ops import compute_logprobs +from forge.util.ops import compute_logprobs, create_shifted_targets from monarch.actor import current_rank, current_size, endpoint from torch.distributed.tensor import DTensor @@ -126,21 +127,23 @@ async def setup(self): @endpoint async def forward( - self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool + self, + input_ids: torch.Tensor, + return_logprobs: bool, + loss_mask: torch.Tensor = None, ) -> torch.Tensor: """ Args: - input_ids (torch.Tensor): input token ids with shape [group_size, req + res length]. - max_req_tokens (int): maximum request length. - return_logprobs (bool): whether to return log probabilities instead of raw logits. - - return_logprobs flag significantly impacts the amount of data transferred to the caller: - - When False: Returns logits with shape [group_size, req + res_length, vocab_size]. - This includes the full vocabulary distribution for each token position. - - - When True: Returns log probabilities with shape [group_size, req_length]. - This only includes probabilities for the request tokens, significantly reducing memory - usage and transfer overhead. + input_ids: Input token ids [batch, seq_len] + return_logprobs: Whether to return logprobs + return_logprobs flag significantly impacts the amount of data transferred to the caller: + - When False: Returns logits with shape [group_size, req + res_length, vocab_size]. + This includes the full vocabulary distribution for each token position. + + - When True: Returns log probabilities with shape [group_size, req_length]. + This only includes probabilities for the request tokens, significantly reducing memory + usage and transfer overhead. + loss_mask: Optional mask for which positions to compute logprobs [batch, seq_len] """ # Record reference model metrics record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) @@ -188,7 +191,14 @@ async def forward( t.stop() return logits else: - logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) + # Create targets using utility function (loss_mask=None means all trainable) + targets = create_shifted_targets(input_ids, loss_mask) + + # Compute logprobs using updated compute_logprobs + logprobs = compute_logprobs( + logits, targets, ignore_index=CROSS_ENTROPY_IGNORE_IDX + ) + t.step("compute_logprobs") t.stop() return logprobs diff --git a/src/forge/data/common.py b/src/forge/data/common.py new file mode 100644 index 000000000..472faf34c --- /dev/null +++ b/src/forge/data/common.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# PyTorch cross_entropy default ignore index for masking positions +# Positions with this value in targets will be ignored during loss computation +CROSS_ENTROPY_IGNORE_IDX = -100 diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index f7152f065..4720f9c5b 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -7,91 +7,73 @@ import torch import torch.nn.functional as F +from forge.data.common import CROSS_ENTROPY_IGNORE_IDX + def compute_logprobs( logits: torch.Tensor, - input_ids: torch.Tensor, + targets: torch.Tensor, temperature: float = 1.0, - align: bool = True, + ignore_index: int = CROSS_ENTROPY_IGNORE_IDX, ) -> torch.Tensor: """ - Computes the log probabilities of the input tokens given the model logits and temperature. - Always converts inputs to fp32 for numerical stability. - - This function handles two common usage patterns: - - **Pattern 1: Pre-aligned logits (align=False)** - Use when logits are already aligned with input_ids, typically when you: - - Pass input_ids to the model: model(input_ids) -> logits - - The model outputs logits[i] that predict target_ids[i] - - logits.shape[1] == input_ids.shape[1] - - Example: - >>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input - >>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction) - >>> logits = model(input_ids) # Shape: [1, 4, vocab_size] - >>> # logits already aligned: logits[:, i] predicts target_ids[:, i] - >>> logprobs = compute_logprobs(logits, target_ids, align=False) - - **Pattern 2: Full-sequence logits needing alignment (align=True, default)** - Use when you have logits for the full sequence but only want log probs for a subset - (e.g., just the response tokens, not the prompt). The function will: - - Slice logits to match the length of input_ids - - Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids - - Example: - >>> # Full sequence passed to model: [prompt + response] - >>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response - >>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size] - >>> # Only want log probs for response tokens - >>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response - >>> logprobs = compute_logprobs(logits, response_tokens, align=True) - >>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6] - - The alignment logic ensures that when you have a full sequence but only want log - probabilities for the response portion, you don't need to re-run the model. This - is a key optimization in RL training where the prompt remains constant. + Computes the log probabilities of target tokens given the model logits. Args: - logits (`torch.Tensor`): - The model output logits of shape `(batch_size, sequence_length, vocab_size)`. - input_ids (`torch.Tensor`): - The target token ids of shape `(batch_size, target_sequence_length)`. - These are the tokens for which you want to compute log probabilities. - temperature (`float`, *optional*, defaults to 1.0): - The temperature value for scaling logits before computing log probabilities. - Higher values make the distribution more uniform, lower values more peaked. - align (`bool`, *optional*, defaults to True): - If True (default), align logits with input_ids by slicing to extract the - relevant positions from a longer sequence (Pattern 2). - If False, assume logits are already aligned with input_ids (Pattern 1). + logits: Model logits [batch, seq_len, vocab] + targets: Target token IDs [batch, seq_len] + temperature: Temperature for scaling + ignore_index: Positions with this value in targets are masked (get 0.0 logprob) Returns: - torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`. - Each element [b, i] is the log probability of input_ids[b, i] given the - corresponding logits. - - Note: - This function uses cross_entropy instead of log_softmax + gather for better - numerical stability, especially important for fp16/bf16 training. + logprobs: [batch, seq_len] - Positions with ignore_index automatically get 0.0 """ - # Align logits with input_ids if requested - if align: - # Ignore the last token from logits because it predicts the next token (-1) - # And align logits with the input tokens length. - logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) - scaled_logits = logits / temperature - - # Cast up to fp32 for numerical stability scaled_logits_fp32 = scaled_logits.float() - # get per-token log probs batch_size, seq_len, vocab_size = scaled_logits_fp32.shape logprobs = -F.cross_entropy( scaled_logits_fp32.reshape(-1, vocab_size), - input_ids.reshape(-1).long(), + targets.reshape(-1).long(), reduction="none", + ignore_index=ignore_index, ) return logprobs.reshape(batch_size, seq_len) + + +def create_shifted_targets( + input_ids: torch.Tensor, + loss_mask: torch.Tensor | None = None, + ignore_index: int = CROSS_ENTROPY_IGNORE_IDX, +) -> torch.Tensor: + """ + Create next-token prediction targets using torch.roll. + Maintains same shape as input_ids. + + Args: + input_ids: [batch, seq_len] or [seq_len] - Input token IDs + loss_mask: [batch, seq_len] or [seq_len] - Trainable positions (bool or float) + If None, all positions are trainable + ignore_index: Value for masked positions (default: -100) + + Returns: + targets: Same shape as input_ids + targets[i] = input_ids[i+1] where trainable, else ignore_index + """ + if input_ids.dim() == 1: + # 1D case + targets = torch.roll(input_ids, shifts=-1, dims=0) + targets[-1] = ignore_index # Last position wraps, mask it + else: + # 2D case (batched) + targets = torch.roll(input_ids, shifts=-1, dims=-1) + targets[:, -1] = ignore_index # Last position wraps, mask it + + if loss_mask is not None: + loss_mask = loss_mask.to(input_ids.device) + targets = torch.where( + loss_mask.bool(), targets, torch.full_like(targets, ignore_index) + ) + + return targets