forked from thinking-machines-lab/tinker-cookbook
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllms.txt
More file actions
174 lines (118 loc) · 9.57 KB
/
llms.txt
File metadata and controls
174 lines (118 loc) · 9.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# TINKER DOCUMENTATION
This file contains the core Tinker documentation (index, quickstart, and losses).
## File: index.mdx
# Tinker: a training API for researchers and developers
Tinker lets you focus on what matters in LLM fine-tuning – your data and algorithms – while we handle the heavy lifting of distributed training.
You write a simple loop that runs on your CPU-only machine, including the data or environment and the loss function. We figure out how to make the training work on a bunch of GPUs, doing the exact computation you specified, efficiently. To change the model you're working with, you only need to change a single string in your code.
Tinker gives you full control over the training loop and all the algorithmic details. It's not a magic black box that makes fine-tuning "easy". It's a clean abstraction that shields you from the complexity of distributed training while preserving your control.
Here's how the division of responsibilities works in practice:
| **You focus on** | **You write** | **We handle** |
|---|---|---|
| 📊 **Datasets and RL environments**<br />Your custom training data | 💻 **Simple Python script**<br />Runs on your CPU | ⚡ **Efficient distributed training of large models**<br />Llama, Qwen, and more |
| 🎯 **Training logic**<br />Your loss functions, training loop, and evals | 🔧 **API calls**<br />`forward_backward()`<br />`optim_step()`<br />`sample()` | 🛡️ **Reliability**<br />Hardware failures handled transparently |
## Features
What the Tinker service currently supports:
- Tinker lets you fine-tune open-weight models like the Qwen and Llama series, including large mixture-of-experts models like Qwen3-235B-A22B.
- Tinker implements low-rank adaptation (LoRA) fine-tuning, not full fine-tuning. However, we believe that LoRA gives the same performance as full fine-tuning for many important use cases, especially in RL (see [LoRA Without Regret](https://thinkingmachines.ai/blog/lora/)).
- You can download the weights of your trained model to use outside of Tinker, for example with your inference provider of choice.
## A quick look at functionality
Tinker's main functionality is contained in a few key functions:
- `forward_backward`: feed in your data and loss function, and we'll compute and accumulate the gradients for you.
- `optim_step`: update your model using the accumulated gradients
- `sample`: Generate outputs from your trained model
- other functions for saving and loading weights and optimizer state
## What's next?
Some features we expect to support in the future:
- Image input for applicable models
- Full fine-tuning
---
## File: losses.mdx
# Loss functions in Tinker
For most use cases, you can use the Tinker API's built-in loss functions by passing in a string identifier to `forward_backward`, which supports cross entropy and policy gradient objectives. When you need more control, `forward_backward_custom` enables arbitrary differentiable loss functions at the cost of an additional forward pass; we explain both approaches in this doc.
When you call `forward_backward`, you specify a loss function using a string that selects from a predetermined set of options, comprising the most common losses used for language model training.
- **Input:** `forward_backward` expects a certain set of input tensors, passed in via `datum.loss_fn_inputs`, which is a dict mapping `str` to either a numpy or torch tensor
- **Output:** `forward_backward` returns a `ForwardBackwardOutput`, which has a set of output tensors in `fwd_bwd_result.loss_fn_outputs`
For an example of using `forward_backward`, see `rl/train.py` in the Cookbook:
```python
async def forward_backward(
training_client: tinker.TrainingClient,
batch_d: List[tinker.Datum],
) -> List[torch.Tensor]:
"""Accumulate gradients on a minibatch of data"""
fwd_bwd_future = await training_client.forward_backward_async(
list(map(remove_mask, batch_d)), loss_fn="importance_sampling"
)
fwd_bwd_result = await fwd_bwd_future.result_async()
# Extract training logprobs from loss_fn_outputs
training_logprobs_D: list[torch.Tensor] = []
for output in fwd_bwd_result.loss_fn_outputs:
training_logprobs = output["logprobs"].to_torch()
training_logprobs_D.append(training_logprobs)
return training_logprobs_D
```
## Basic loss functions
Currently, the Tinker API supports `cross_entropy` (for supervised learning), `importance_sampling` (for RL), and `ppo` (for RL).
All tensors below have shape `(N,)` where `N` is `model_input.length`. They can be provided as `numpy.ndarray` or `torch.Tensor`, and the return values will use the same tensor type.
### Supervised learning: `cross_entropy`
For SL, we implement the standard cross-entropy loss (i.e., negative-log-likelihood), which optimizes the policy $p_\theta$ to maximize the log-probability of the tokens $x$:
$$
\mathcal{L(\theta)} = -\mathbb{E}_x[\log p_\theta(x)]
$$
In practice, this looks like `-(weights * logp(target_tokens)).sum()`, where `weights` is either 0 or 1, typically generated from `renderers.build_supervised_example` (i.e., to specify the desired assistant turns to train on).
- **Input tensors:**
- `target_tokens: array[(N,), int]` - Target token IDs
- `weights: array[(N,), float]` - Token-level loss weights (typically from the renderer)
- **Output tensors:**
- `logprobs: array[(N,), float]` - Log probabilities of predicted tokens
- **Output diagnostics:**
- `loss:sum` (scalar) - Sum of weighted cross-entropy losses
### Policy gradient: `importance_sampling`
For RL, we implement a common variant of the policy gradient objective, used in practical settings where the *learner policy* $p$ may differ from the *sampling policy* $q$, which is common due to e.g. [non-determinism](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/). To remove the bias caused by this difference, we can use a modified "importance sampling" objective:
$$
\nabla \mathbb{E}_{x\sim p_\theta}\bigl[r(x) \bigr] = \mathbb{E}_{x\sim q}\Bigl[r(x) \cdot \frac{\nabla p_\theta(x)}{q(x)}\Bigr]
$$
which yields the correct expected reward in expectation. This is implemented as `(exp(new_logprobs - logprobs) * advantages).sum()`, where `advantages` may additionally subtract a baseline from the rewards. Note this only works in the bandit setting, which is common in both RLHF and RLVR setups.
- **Input tensors:**
- `target_tokens: array[(N,), int]` - Target token IDs (from the sampler $q$)
- `logprobs: array[(N,), float]` - Reference log probabilities $q$ for the target tokens
- `advantages: array[(N,), float]` - Advantage values for RL
- **Output tensors:**
- `logprobs: array[(N,), float]` - Log probabilities $p$ for the target tokens
- **Output diagnostics:**
- `loss:sum` (scalar) - Sum of importance-weighted policy gradient losses
**Addendum:** Let's consider naively applying the policy gradient objective when $q \neq p$:
$$
\begin{align*}
\mathbb{E}_{x\sim q}\bigr[ r(x) \cdot \nabla \log p_\theta(x) \bigl] &= \sum_x q(x) r(x) \cdot \nabla \log p_\theta(x) \\
&= \sum_x q(x) (r(x) - \bar{r}) \nabla \log p_\theta(x) + \sum_x q(x) \bar{r} \cdot \nabla \log p_\theta(x) \\
&= \mathbb{E}_{x\sim q}\bigl[(r(x) - \bar{r}) \nabla \log p_\theta(x)\bigr] - \bar{r} \cdot \nabla KL(q \Vert p)
\end{align*}
$$
where $\bar{r} = \sum_x q(x) r(x)$, effectively an average-reward baseline.
- The first expectation term resembles a pseudo-policy gradient, increasing the log-likelihood of tokens $x$ which achieve higher-than-average rewards. (It is not an actual policy gradient, because $q \neq p$.)
- The second KL term is effectively a bias term which can destablize RL optimization. This bias increases as either the divergence $KL(q \Vert p)$ grows, or as the average reward $\bar{r}$ shifts.
## Flexible loss functions: `forward_backward_custom`
For use-cases outside of the above, we've provided the more flexible (but slower) methods `forward_backward_custom` and `forward_backward_custom_async` to compute a more general class of loss functions.
### Usage
Here's a simple example of a custom loss function:
```python
def logprob_squared_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
loss = (logprobs ** 2).sum()
return loss, {"logprob_squared_loss": loss.item()}
```
You can call this loss function with `forward_backward_custom` like:
```python
loss, metrics = training_client.forward_backward_custom(data, logprob_squared_loss)
```
You can also define loss functions which operate on multiple sequences at a time. For example, (although practically useless), a loss function that computes the variance across the sequences can be implemented as:
```python
def variance_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
flat_logprobs = torch.cat(logprobs)
variance = torch.var(flat_logprobs)
return variance, {"variance_loss": variance.item()}
```
A more practical use case would be to compute a Bradley-Terry loss on pairwise comparison data -- a classic approach in RL from human feedback, as introduced/popularized by [Learning to Summarize](https://arxiv.org/abs/2009.01325). Similarly, we can also implement [Direct Preference Optimization](https://arxiv.org/abs/2305.18290), which also computes a loss involving pairs of sequences; see the [DPO guide](/preferences/dpo-guide) for more details.
If you're using a custom loss function that you think is generally useful, please let us know, and we'll add it to the list of built-in loss functions.
We detail the `async` version of methods in the [Async and Futures](./async) of these docs.
### How `forward_backward_custom` works
---