Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/server/clock_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from opentelemetry import context as otel_context
from opentelemetry import propagate, trace
from store import get_store
from trainer import Datum, LoraConfig, TrainerEngine
from training.lora_trainer_worker import LoraConfig, LoraTrainingWorker
from training.trainer_worker import Datum

tracer = trace.get_tracer(__name__)

engine = TrainerEngine()
engine = LoraTrainingWorker()


def _parse_datum(raw: dict) -> Datum:
def parse_datum(raw: dict) -> Datum:
"""Convert wire-format datum (with chunks) to our flat Datum type."""
chunks = raw.get("model_input", {}).get("chunks", [])
tokens: list[int] = []
Expand Down Expand Up @@ -98,7 +99,7 @@ async def clock_cycle_loop() -> None:
loss_fn = r["loss_fn"]
loss_config = r.get("loss_config")

typed_data = [_parse_datum(item) for item in raw_data]
typed_data = [parse_datum(item) for item in raw_data]

result = await asyncio.to_thread(engine.forward_backward, typed_data, loss_fn, loss_config, m_id)
result["type"] = "forward_backward"
Expand Down
1 change: 1 addition & 0 deletions src/server/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Training engine package."""
211 changes: 211 additions & 0 deletions src/server/training/fft_trainer_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Full fine-tuning trainer worker lifecycle.

import json
import math
import os
import time
from datetime import datetime
from typing import Any

import torch
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

from training.trainer_worker import BaseTrainerWorker, Datum

ENABLE_GRADIENT_CHECKPOINTING = os.getenv("ENABLE_GRADIENT_CHECKPOINTING", "1") == "1"


class FFTConfig(BaseModel):
seed: int | None = None


def trainable_model_parameters(model: PreTrainedModel) -> list[torch.nn.Parameter]:
params = [param for param in model.parameters() if param.requires_grad]
if not params:
raise ValueError("No trainable parameters found for full fine-tuning model")
return params


class FFTTrainingWorker(BaseTrainerWorker):
def __init__(self):
super().__init__()
self.model: PreTrainedModel | None = None
self.base_model_name: str | None = None
self.trainable_params: list[torch.nn.Parameter] = []
self.optimizer: torch.optim.Optimizer | None = None

def load_base_model(self, base_model_name: str) -> None:
"""Load one full model for one fine-tuning job process."""
if self.model is not None and self.base_model_name == base_model_name:
print(f"Full fine-tuning model {base_model_name} already loaded.")
return

print(f"Loading full fine-tuning model {base_model_name} to {self.device}...")
self.base_model_name = base_model_name
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32

self.model = AutoModelForCausalLM.from_pretrained(base_model_name, dtype=dtype, device_map=self.device)
self.prepare_model_for_training()
print("Successfully loaded full fine-tuning model.")

def create_model(self, model_id: str | None = None, config: FFTConfig | None = None) -> None:
"""Prepare the loaded model for full fine-tuning."""
if config is not None and config.seed is not None:
torch.manual_seed(config.seed)
self.prepare_model_for_training()

def prepare_model_for_training(self) -> None:
assert self.model is not None, "Model is not loaded. Call load_base_model first."

for param in self.model.parameters():
param.requires_grad_(True)
self.trainable_params = trainable_model_parameters(self.model)

if ENABLE_GRADIENT_CHECKPOINTING:
try:
self.model.gradient_checkpointing_enable()
self.model.enable_input_require_grads()
print("Gradient checkpointing and input require grads enabled on full fine-tuning model.")
except Exception as e:
print(f"Failed to enable gradient checkpointing: {e}")

self.model.train()

def save_model(self, alias: str | None = None) -> dict[str, Any]:
assert self.model is not None, "Model must be loaded first."

tmp_dir = os.getenv("OPEN_RL_TMP_DIR", "/tmp/open-rl")
name = alias or "fft-model"
save_path = name if os.path.isabs(name) else os.path.join(tmp_dir, "fft", name)
os.makedirs(save_path, exist_ok=True)

self.model.save_pretrained(save_path)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_path)

metadata = {
"base_model": self.base_model_name,
"created_at": datetime.now().isoformat(),
"kind": "weights",
"model_id": alias,
"timestamp": time.time(),
}
with open(os.path.join(save_path, "metadata.json"), "w") as f:
json.dump(metadata, f)

print(f"Saved full fine-tuning model to {save_path}")
return {"path": save_path}

def save_state(self, model_id: str, state_path: str, include_optimizer: bool = False, kind: str = "state") -> dict[str, Any]:
assert self.model is not None, "Model must be loaded first."

os.makedirs(state_path, exist_ok=True)
self.model.save_pretrained(state_path)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(state_path)

if include_optimizer and self.optimizer is not None:
torch.save(self.optimizer.state_dict(), os.path.join(state_path, "optimizer.pt"))

metadata = {
"base_model": self.base_model_name,
"created_at": datetime.now().isoformat(),
"kind": kind,
"has_optimizer": include_optimizer and self.optimizer is not None,
"model_id": model_id,
"timestamp": time.time(),
}
with open(os.path.join(state_path, "metadata.json"), "w") as f:
json.dump(metadata, f)

print(f"Saved full fine-tuning state to {state_path}")
return {"path": state_path}

def load_from_state(self, model_id: str, state_path: str, restore_optimizer: bool = False) -> dict[str, Any]:
metadata_path = os.path.join(state_path, "metadata.json")
if not os.path.exists(metadata_path):
raise FileNotFoundError(f"No metadata.json found at {state_path}")

with open(metadata_path) as f:
metadata = json.load(f)

base_model = metadata.get("base_model")
if not base_model:
raise ValueError(f"metadata.json at {state_path} missing base_model")

self.base_model_name = base_model
self.tokenizer = AutoTokenizer.from_pretrained(state_path)
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(state_path, dtype=dtype, device_map=self.device)
self.prepare_model_for_training()

if restore_optimizer and metadata.get("has_optimizer"):
optimizer_path = os.path.join(state_path, "optimizer.pt")
if os.path.exists(optimizer_path):
self.optimizer = torch.optim.AdamW(self.trainable_params, lr=1e-4)
self.optimizer.load_state_dict(torch.load(optimizer_path, map_location=self.device))
print(f"Restored optimizer state from {optimizer_path}")

print(f"Loaded full fine-tuning state from {state_path}")
return {"model_id": model_id, "is_lora": False, "base_model": base_model}

def forward_backward(self, data: list[Datum], loss_fn: str, loss_config: dict | None = None, model_id: str | None = None) -> dict[str, Any]:
assert self.model is not None, "Model must be loaded first."
return super().forward_backward(self.model, data, loss_fn, loss_config)

def optim_step(self, adam_params: dict[str, Any], model_id: str | None = None) -> dict[str, Any]:
assert self.model is not None, "Model must be loaded first."
if not self.trainable_params:
self.trainable_params = trainable_model_parameters(self.model)

if self.optimizer is None:
lr = adam_params.get("learning_rate", 1e-4)
beta1 = adam_params.get("beta1", 0.9)
beta2 = adam_params.get("beta2", 0.95)
eps = adam_params.get("eps", 1e-12)
weight_decay = adam_params.get("weight_decay", 0.0)

print(f"Initializing AdamW optimizer for full fine-tuning model with lr={lr}")
self.optimizer = torch.optim.AdamW(
self.trainable_params,
lr=lr,
betas=(beta1, beta2),
eps=eps,
weight_decay=weight_decay,
)

learning_rate = adam_params.get("learning_rate")
if learning_rate is not None:
for param_group in self.optimizer.param_groups:
param_group["lr"] = learning_rate

max_grad_norm = adam_params.get("grad_clip_norm") or math.inf
if max_grad_norm <= 0.0:
max_grad_norm = math.inf

total_norm = torch.nn.utils.clip_grad_norm_(
self.trainable_params,
max_grad_norm,
)

self.optimizer.step()
self.optimizer.zero_grad()

return {
"metrics": {
"grad_norm:mean": self.sanitize_float(total_norm.item()),
},
}

def generate(
self,
prompt_tokens: list[int],
max_tokens: int,
num_samples: int = 1,
temperature: float = 0.0,
model_id: str | None = None,
include_prompt_logprobs: bool = False,
) -> dict[str, Any]:
return super().generate(self.model, prompt_tokens, max_tokens, num_samples, temperature, include_prompt_logprobs)
Loading
Loading