Skip to content

Commit fed9fd7

Browse files
authored
feat: add FFT trainer worker class (#113)
* refactor: move trainer and extract losses * refactor: split base and lora trainer workers * feat: add fft trainer worker
1 parent d581e52 commit fed9fd7

7 files changed

Lines changed: 669 additions & 320 deletions

File tree

src/server/clock_cycle.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
from opentelemetry import context as otel_context
1111
from opentelemetry import propagate, trace
1212
from store import get_store
13-
from trainer import Datum, LoraConfig, TrainerEngine
13+
from training.lora_trainer_worker import LoraConfig, LoraTrainingWorker
14+
from training.trainer_worker import Datum
1415

1516
tracer = trace.get_tracer(__name__)
1617

17-
engine = TrainerEngine()
18+
engine = LoraTrainingWorker()
1819

1920

20-
def _parse_datum(raw: dict) -> Datum:
21+
def parse_datum(raw: dict) -> Datum:
2122
"""Convert wire-format datum (with chunks) to our flat Datum type."""
2223
chunks = raw.get("model_input", {}).get("chunks", [])
2324
tokens: list[int] = []
@@ -98,7 +99,7 @@ async def clock_cycle_loop() -> None:
9899
loss_fn = r["loss_fn"]
99100
loss_config = r.get("loss_config")
100101

101-
typed_data = [_parse_datum(item) for item in raw_data]
102+
typed_data = [parse_datum(item) for item in raw_data]
102103

103104
result = await asyncio.to_thread(engine.forward_backward, typed_data, loss_fn, loss_config, m_id)
104105
result["type"] = "forward_backward"

src/server/training/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Training engine package."""
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Full fine-tuning trainer worker lifecycle.
2+
3+
import json
4+
import math
5+
import os
6+
import time
7+
from datetime import datetime
8+
from typing import Any
9+
10+
import torch
11+
from pydantic import BaseModel
12+
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
13+
14+
from training.trainer_worker import BaseTrainerWorker, Datum
15+
16+
ENABLE_GRADIENT_CHECKPOINTING = os.getenv("ENABLE_GRADIENT_CHECKPOINTING", "1") == "1"
17+
18+
19+
class FFTConfig(BaseModel):
20+
seed: int | None = None
21+
22+
23+
def trainable_model_parameters(model: PreTrainedModel) -> list[torch.nn.Parameter]:
24+
params = [param for param in model.parameters() if param.requires_grad]
25+
if not params:
26+
raise ValueError("No trainable parameters found for full fine-tuning model")
27+
return params
28+
29+
30+
class FFTTrainingWorker(BaseTrainerWorker):
31+
def __init__(self):
32+
super().__init__()
33+
self.model: PreTrainedModel | None = None
34+
self.base_model_name: str | None = None
35+
self.trainable_params: list[torch.nn.Parameter] = []
36+
self.optimizer: torch.optim.Optimizer | None = None
37+
38+
def load_base_model(self, base_model_name: str) -> None:
39+
"""Load one full model for one fine-tuning job process."""
40+
if self.model is not None and self.base_model_name == base_model_name:
41+
print(f"Full fine-tuning model {base_model_name} already loaded.")
42+
return
43+
44+
print(f"Loading full fine-tuning model {base_model_name} to {self.device}...")
45+
self.base_model_name = base_model_name
46+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
47+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
48+
49+
self.model = AutoModelForCausalLM.from_pretrained(base_model_name, dtype=dtype, device_map=self.device)
50+
self.prepare_model_for_training()
51+
print("Successfully loaded full fine-tuning model.")
52+
53+
def create_model(self, model_id: str | None = None, config: FFTConfig | None = None) -> None:
54+
"""Prepare the loaded model for full fine-tuning."""
55+
if config is not None and config.seed is not None:
56+
torch.manual_seed(config.seed)
57+
self.prepare_model_for_training()
58+
59+
def prepare_model_for_training(self) -> None:
60+
assert self.model is not None, "Model is not loaded. Call load_base_model first."
61+
62+
for param in self.model.parameters():
63+
param.requires_grad_(True)
64+
self.trainable_params = trainable_model_parameters(self.model)
65+
66+
if ENABLE_GRADIENT_CHECKPOINTING:
67+
try:
68+
self.model.gradient_checkpointing_enable()
69+
self.model.enable_input_require_grads()
70+
print("Gradient checkpointing and input require grads enabled on full fine-tuning model.")
71+
except Exception as e:
72+
print(f"Failed to enable gradient checkpointing: {e}")
73+
74+
self.model.train()
75+
76+
def save_model(self, alias: str | None = None) -> dict[str, Any]:
77+
assert self.model is not None, "Model must be loaded first."
78+
79+
tmp_dir = os.getenv("OPEN_RL_TMP_DIR", "/tmp/open-rl")
80+
name = alias or "fft-model"
81+
save_path = name if os.path.isabs(name) else os.path.join(tmp_dir, "fft", name)
82+
os.makedirs(save_path, exist_ok=True)
83+
84+
self.model.save_pretrained(save_path)
85+
if self.tokenizer is not None:
86+
self.tokenizer.save_pretrained(save_path)
87+
88+
metadata = {
89+
"base_model": self.base_model_name,
90+
"created_at": datetime.now().isoformat(),
91+
"kind": "weights",
92+
"model_id": alias,
93+
"timestamp": time.time(),
94+
}
95+
with open(os.path.join(save_path, "metadata.json"), "w") as f:
96+
json.dump(metadata, f)
97+
98+
print(f"Saved full fine-tuning model to {save_path}")
99+
return {"path": save_path}
100+
101+
def save_state(self, model_id: str, state_path: str, include_optimizer: bool = False, kind: str = "state") -> dict[str, Any]:
102+
assert self.model is not None, "Model must be loaded first."
103+
104+
os.makedirs(state_path, exist_ok=True)
105+
self.model.save_pretrained(state_path)
106+
if self.tokenizer is not None:
107+
self.tokenizer.save_pretrained(state_path)
108+
109+
if include_optimizer and self.optimizer is not None:
110+
torch.save(self.optimizer.state_dict(), os.path.join(state_path, "optimizer.pt"))
111+
112+
metadata = {
113+
"base_model": self.base_model_name,
114+
"created_at": datetime.now().isoformat(),
115+
"kind": kind,
116+
"has_optimizer": include_optimizer and self.optimizer is not None,
117+
"model_id": model_id,
118+
"timestamp": time.time(),
119+
}
120+
with open(os.path.join(state_path, "metadata.json"), "w") as f:
121+
json.dump(metadata, f)
122+
123+
print(f"Saved full fine-tuning state to {state_path}")
124+
return {"path": state_path}
125+
126+
def load_from_state(self, model_id: str, state_path: str, restore_optimizer: bool = False) -> dict[str, Any]:
127+
metadata_path = os.path.join(state_path, "metadata.json")
128+
if not os.path.exists(metadata_path):
129+
raise FileNotFoundError(f"No metadata.json found at {state_path}")
130+
131+
with open(metadata_path) as f:
132+
metadata = json.load(f)
133+
134+
base_model = metadata.get("base_model")
135+
if not base_model:
136+
raise ValueError(f"metadata.json at {state_path} missing base_model")
137+
138+
self.base_model_name = base_model
139+
self.tokenizer = AutoTokenizer.from_pretrained(state_path)
140+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
141+
self.model = AutoModelForCausalLM.from_pretrained(state_path, dtype=dtype, device_map=self.device)
142+
self.prepare_model_for_training()
143+
144+
if restore_optimizer and metadata.get("has_optimizer"):
145+
optimizer_path = os.path.join(state_path, "optimizer.pt")
146+
if os.path.exists(optimizer_path):
147+
self.optimizer = torch.optim.AdamW(self.trainable_params, lr=1e-4)
148+
self.optimizer.load_state_dict(torch.load(optimizer_path, map_location=self.device))
149+
print(f"Restored optimizer state from {optimizer_path}")
150+
151+
print(f"Loaded full fine-tuning state from {state_path}")
152+
return {"model_id": model_id, "is_lora": False, "base_model": base_model}
153+
154+
def forward_backward(self, data: list[Datum], loss_fn: str, loss_config: dict | None = None, model_id: str | None = None) -> dict[str, Any]:
155+
assert self.model is not None, "Model must be loaded first."
156+
return super().forward_backward(self.model, data, loss_fn, loss_config)
157+
158+
def optim_step(self, adam_params: dict[str, Any], model_id: str | None = None) -> dict[str, Any]:
159+
assert self.model is not None, "Model must be loaded first."
160+
if not self.trainable_params:
161+
self.trainable_params = trainable_model_parameters(self.model)
162+
163+
if self.optimizer is None:
164+
lr = adam_params.get("learning_rate", 1e-4)
165+
beta1 = adam_params.get("beta1", 0.9)
166+
beta2 = adam_params.get("beta2", 0.95)
167+
eps = adam_params.get("eps", 1e-12)
168+
weight_decay = adam_params.get("weight_decay", 0.0)
169+
170+
print(f"Initializing AdamW optimizer for full fine-tuning model with lr={lr}")
171+
self.optimizer = torch.optim.AdamW(
172+
self.trainable_params,
173+
lr=lr,
174+
betas=(beta1, beta2),
175+
eps=eps,
176+
weight_decay=weight_decay,
177+
)
178+
179+
learning_rate = adam_params.get("learning_rate")
180+
if learning_rate is not None:
181+
for param_group in self.optimizer.param_groups:
182+
param_group["lr"] = learning_rate
183+
184+
max_grad_norm = adam_params.get("grad_clip_norm") or math.inf
185+
if max_grad_norm <= 0.0:
186+
max_grad_norm = math.inf
187+
188+
total_norm = torch.nn.utils.clip_grad_norm_(
189+
self.trainable_params,
190+
max_grad_norm,
191+
)
192+
193+
self.optimizer.step()
194+
self.optimizer.zero_grad()
195+
196+
return {
197+
"metrics": {
198+
"grad_norm:mean": self.sanitize_float(total_norm.item()),
199+
},
200+
}
201+
202+
def generate(
203+
self,
204+
prompt_tokens: list[int],
205+
max_tokens: int,
206+
num_samples: int = 1,
207+
temperature: float = 0.0,
208+
model_id: str | None = None,
209+
include_prompt_logprobs: bool = False,
210+
) -> dict[str, Any]:
211+
return super().generate(self.model, prompt_tokens, max_tokens, num_samples, temperature, include_prompt_logprobs)

0 commit comments

Comments
 (0)