Skip to content

Commit 63205da

Browse files
author
Joe Cummings
committed
Merge remote-tracking branch 'upstream/main' into multi-node-support
2 parents b56b6be + e6b9064 commit 63205da

File tree

8 files changed

+790
-4
lines changed

8 files changed

+790
-4
lines changed

docs/source/tune_cli.rst

+71-2
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ with a short description of each.
1717
.. code-block:: bash
1818
1919
$ tune --help
20-
usage: tune [-h] {download,ls,cp,run,validate} ...
20+
usage: tune [-h] {download,ls,cp,run,validate,cat} ...
2121
2222
Welcome to the torchtune CLI!
2323
2424
options:
2525
-h, --help show this help message and exit
2626
2727
subcommands:
28-
{download,ls,cp,run,validate}
28+
{download,ls,cp,run,validate,cat}
2929
download Download a model from the Hugging Face Hub.
3030
ls List all built-in recipes and configs
3131
...
@@ -233,3 +233,72 @@ The ``tune validate <config>`` command will validate that your config is formatt
233233
# If you've copied over a built-in config and want to validate custom changes
234234
$ tune validate my_configs/llama3/8B_full.yaml
235235
Config is well-formed!
236+
237+
.. _tune_cat_cli_label:
238+
239+
Inspect a config
240+
---------------------
241+
242+
The ``tune cat <config>`` command pretty prints a configuration file, making it easy to use ``tune run`` with confidence. This command is useful for inspecting the structure and contents of a config file before running a recipe, ensuring that all parameters are correctly set.
243+
244+
You can also use the ``--sort`` option to print the config in sorted order, which can help in quickly locating specific keys.
245+
246+
.. list-table::
247+
:widths: 30 60
248+
249+
* - \--sort
250+
- Print the config in sorted order.
251+
252+
**Workflow Example**
253+
254+
1. **List all available configs:**
255+
256+
Use the ``tune ls`` command to list all the built-in recipes and configs within torchtune.
257+
258+
.. code-block:: bash
259+
260+
$ tune ls
261+
RECIPE CONFIG
262+
full_finetune_single_device llama2/7B_full_low_memory
263+
code_llama2/7B_full_low_memory
264+
llama3/8B_full_single_device
265+
mistral/7B_full_low_memory
266+
phi3/mini_full_low_memory
267+
full_finetune_distributed llama2/7B_full
268+
llama2/13B_full
269+
llama3/8B_full
270+
llama3/70B_full
271+
...
272+
273+
2. **Inspect the contents of a config:**
274+
275+
Use the ``tune cat`` command to pretty print the contents of a specific config. This helps you understand the structure and parameters of the config.
276+
277+
.. code-block:: bash
278+
279+
$ tune cat llama2/7B_full
280+
output_dir: /tmp/torchtune/llama2_7B/full
281+
tokenizer:
282+
_component_: torchtune.models.llama2.llama2_tokenizer
283+
path: /tmp/Llama-2-7b-hf/tokenizer.model
284+
max_seq_len: null
285+
...
286+
287+
You can also print the config in sorted order:
288+
289+
.. code-block:: bash
290+
291+
$ tune cat llama2/7B_full --sort
292+
293+
3. **Run a recipe with parameter override:**
294+
295+
After inspecting the config, you can use the ``tune run`` command to run a recipe with the config. You can also override specific parameters directly from the command line. For example, to override the `output_dir` parameter:
296+
297+
.. code-block:: bash
298+
299+
$ tune run full_finetune_distributed --config llama2/7B_full output_dir=./
300+
301+
Learn more about config overrides :ref:`here <cli_override>`.
302+
303+
.. note::
304+
You can find all the cat-able configs via the ``tune ls`` command.

tests/torchtune/_cli/test_cat.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import runpy
8+
import sys
9+
10+
import pytest
11+
from tests.common import TUNE_PATH
12+
13+
14+
class TestTuneCatCommand:
15+
"""This class tests the `tune cat` command."""
16+
17+
def test_cat_valid_config(self, capsys, monkeypatch):
18+
testargs = "tune cat llama2/7B_full".split()
19+
monkeypatch.setattr(sys, "argv", testargs)
20+
runpy.run_path(TUNE_PATH, run_name="__main__")
21+
22+
captured = capsys.readouterr()
23+
output = captured.out.rstrip("\n")
24+
25+
# Check for key sections that should be in the YAML output
26+
assert "output_dir:" in output
27+
assert "tokenizer:" in output
28+
assert "model:" in output
29+
30+
def test_cat_recipe_name_shows_error(self, capsys, monkeypatch):
31+
testargs = "tune cat full_finetune_single_device".split()
32+
monkeypatch.setattr(sys, "argv", testargs)
33+
runpy.run_path(TUNE_PATH, run_name="__main__")
34+
35+
captured = capsys.readouterr()
36+
output = captured.out.rstrip("\n")
37+
38+
assert "is a recipe, not a config" in output
39+
40+
def test_cat_non_existent_config(self, capsys, monkeypatch):
41+
testargs = "tune cat non_existent_config".split()
42+
monkeypatch.setattr(sys, "argv", testargs)
43+
44+
with pytest.raises(SystemExit):
45+
runpy.run_path(TUNE_PATH, run_name="__main__")
46+
47+
captured = capsys.readouterr()
48+
err = captured.err.rstrip("\n")
49+
50+
assert (
51+
"Invalid config format: 'non_existent_config'. Must be YAML (.yaml/.yml)"
52+
in err
53+
)
54+
55+
def test_cat_invalid_yaml_file(self, capsys, monkeypatch, tmpdir):
56+
invalid_yaml = tmpdir / "invalid.yaml"
57+
invalid_yaml.write_text("invalid: yaml: file", encoding="utf-8")
58+
59+
testargs = f"tune cat {invalid_yaml}".split()
60+
monkeypatch.setattr(sys, "argv", testargs)
61+
62+
with pytest.raises(SystemExit):
63+
runpy.run_path(TUNE_PATH, run_name="__main__")
64+
65+
captured = capsys.readouterr()
66+
err = captured.err.rstrip("\n")
67+
68+
assert "Error parsing YAML file" in err
69+
70+
def test_cat_external_yaml_file(self, capsys, monkeypatch, tmpdir):
71+
valid_yaml = tmpdir / "external.yaml"
72+
valid_yaml.write_text("key: value", encoding="utf-8")
73+
74+
testargs = f"tune cat {valid_yaml}".split()
75+
monkeypatch.setattr(sys, "argv", testargs)
76+
runpy.run_path(TUNE_PATH, run_name="__main__")
77+
78+
captured = capsys.readouterr()
79+
output = captured.out.rstrip("\n")
80+
81+
assert "key: value" in output

tests/torchtune/modules/loss/test_kd_losses.py

+208-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
import pytest
88
import torch
99
from tests.test_utils import assert_expected
10-
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss
10+
from torchtune.modules.loss import (
11+
ForwardKLLoss,
12+
ForwardKLWithChunkedOutputLoss,
13+
ReverseKLLoss,
14+
ReverseKLWithChunkedOutputLoss,
15+
SymmetricKLLoss,
16+
SymmetricKLWithChunkedOutputLoss,
17+
)
1118
from torchtune.training.seed import set_seed
1219

1320

@@ -140,3 +147,203 @@ def test_forward_kl_loss_expected(self):
140147
# assert
141148
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
142149
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
150+
151+
152+
class TestReverseKLWithChunkedOutputLoss:
153+
def test_reverse_kl_loss(self):
154+
# Create a sample input and label
155+
ignore_index = -100
156+
batch_size = 3
157+
num_tokens = 50
158+
vocab_size = 50
159+
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
160+
teacher_logits = torch.randn(
161+
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16
162+
)
163+
labels = torch.randint(
164+
0, vocab_size, (batch_size, num_tokens), dtype=torch.long
165+
)
166+
167+
# add random ignore index to random tokens in the label
168+
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
169+
labels[random_indices < num_tokens // 5] = ignore_index
170+
171+
# chunked RKL
172+
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss(
173+
num_output_chunks=8, ignore_index=ignore_index
174+
)
175+
logits_chunks = logits.chunk(chunked_rkl_loss.num_output_chunks, dim=1)
176+
teacher_logits_chunks = teacher_logits.chunk(
177+
chunked_rkl_loss.num_output_chunks, dim=1
178+
)
179+
chunked_loss = chunked_rkl_loss(logits_chunks, teacher_logits_chunks, labels)
180+
181+
# vanilla RKL
182+
rkl_loss = ReverseKLLoss(ignore_index=ignore_index)
183+
logits = logits.reshape(-1, logits.size(-1))
184+
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
185+
labels = labels.reshape(-1)
186+
standard_loss = rkl_loss(logits, teacher_logits, labels)
187+
188+
# Assert
189+
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)
190+
191+
def test_reverse_kl_loss_expected(self):
192+
student_logits = torch.tensor(
193+
[
194+
[
195+
[1.1250, -0.4102, -0.0879, -2.5000],
196+
[0.2676, 0.3535, 0.8711, -1.4688],
197+
[-0.1084, 1.6641, 0.0084, 0.1196],
198+
[0.5000, -0.6406, -0.2236, -1.5938],
199+
],
200+
[
201+
[-1.5312, -1.9219, 0.0000, -0.5039],
202+
[-1.5391, 1.5312, 0.5820, 0.2695],
203+
[-0.3887, 1.2188, 0.0000, 0.6055],
204+
[0.5000, 1.3828, 0.1309, -1.0312],
205+
],
206+
],
207+
dtype=torch.bfloat16,
208+
)
209+
teacher_logits = torch.tensor(
210+
[
211+
[
212+
[-0.0381, -1.2578, -1.2031, 0.0947],
213+
[-0.7852, 0.4492, 1.5547, 0.0972],
214+
[0.8203, 0.0012, 0.7656, 0.3477],
215+
[-1.5781, 0.4297, 0.5977, 0.3926],
216+
],
217+
[
218+
[1.5156, 0.1641, 2.0781, -0.7734],
219+
[-0.5898, 0.4453, -0.7969, 0.6328],
220+
[0.6289, -0.8359, 0.9258, 0.2109],
221+
[0.0006, 0.5195, 3.2344, -1.5781],
222+
],
223+
],
224+
dtype=torch.bfloat16,
225+
)
226+
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]])
227+
expected_loss = torch.tensor(0.6775, dtype=torch.float32)
228+
229+
# chunked RKL loss
230+
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss(
231+
num_output_chunks=2, ignore_index=-100
232+
)
233+
student_logits_chunks = student_logits.chunk(
234+
chunked_rkl_loss.num_output_chunks, dim=1
235+
)
236+
teacher_logits_chunks = teacher_logits.chunk(
237+
chunked_rkl_loss.num_output_chunks, dim=1
238+
)
239+
chunked_loss = chunked_rkl_loss(
240+
student_logits_chunks, teacher_logits_chunks, labels
241+
)
242+
243+
# vanilla RKL loss
244+
rkl_loss = ReverseKLLoss(ignore_index=-100)
245+
standard_loss = rkl_loss(student_logits, teacher_logits, labels)
246+
247+
# assert
248+
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
249+
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
250+
251+
252+
class TestSymmetricKLWithChunkedOutputLoss:
253+
def test_symmetric_kl_loss(self):
254+
# Create a sample input and label
255+
ignore_index = -100
256+
batch_size = 3
257+
num_tokens = 50
258+
vocab_size = 50
259+
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
260+
teacher_logits = torch.randn(
261+
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16
262+
)
263+
labels = torch.randint(
264+
0, vocab_size, (batch_size, num_tokens), dtype=torch.long
265+
)
266+
267+
# add random ignore index to random tokens in the label
268+
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
269+
labels[random_indices < num_tokens // 5] = ignore_index
270+
271+
# chunked Symmetric KL
272+
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss(
273+
num_output_chunks=8, ignore_index=ignore_index
274+
)
275+
logits_chunks = logits.chunk(chunked_sym_kl_loss.num_output_chunks, dim=1)
276+
teacher_logits_chunks = teacher_logits.chunk(
277+
chunked_sym_kl_loss.num_output_chunks, dim=1
278+
)
279+
chunked_loss = chunked_sym_kl_loss(logits_chunks, teacher_logits_chunks, labels)
280+
281+
# vanilla Symmetric KL
282+
sym_kl_loss = SymmetricKLLoss(ignore_index=ignore_index)
283+
logits = logits.reshape(-1, logits.size(-1))
284+
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
285+
labels = labels.reshape(-1)
286+
standard_loss = sym_kl_loss(logits, teacher_logits, labels)
287+
288+
# Assert
289+
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)
290+
291+
def test_symmetric_kl_loss_expected(self):
292+
student_logits = torch.tensor(
293+
[
294+
[
295+
[1.1250, -0.4102, -0.0879, -2.5000],
296+
[0.2676, 0.3535, 0.8711, -1.4688],
297+
[-0.1084, 1.6641, 0.0084, 0.1196],
298+
[0.5000, -0.6406, -0.2236, -1.5938],
299+
],
300+
[
301+
[-1.5312, -1.9219, 0.0000, -0.5039],
302+
[-1.5391, 1.5312, 0.5820, 0.2695],
303+
[-0.3887, 1.2188, 0.0000, 0.6055],
304+
[0.5000, 1.3828, 0.1309, -1.0312],
305+
],
306+
],
307+
dtype=torch.bfloat16,
308+
)
309+
teacher_logits = torch.tensor(
310+
[
311+
[
312+
[-0.0381, -1.2578, -1.2031, 0.0947],
313+
[-0.7852, 0.4492, 1.5547, 0.0972],
314+
[0.8203, 0.0012, 0.7656, 0.3477],
315+
[-1.5781, 0.4297, 0.5977, 0.3926],
316+
],
317+
[
318+
[1.5156, 0.1641, 2.0781, -0.7734],
319+
[-0.5898, 0.4453, -0.7969, 0.6328],
320+
[0.6289, -0.8359, 0.9258, 0.2109],
321+
[0.0006, 0.5195, 3.2344, -1.5781],
322+
],
323+
],
324+
dtype=torch.bfloat16,
325+
)
326+
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]])
327+
expected_loss = torch.tensor(1.1992, dtype=torch.float32)
328+
329+
# chunked Symmetric KL loss
330+
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss(
331+
num_output_chunks=2, ignore_index=-100
332+
)
333+
student_logits_chunks = student_logits.chunk(
334+
chunked_sym_kl_loss.num_output_chunks, dim=1
335+
)
336+
teacher_logits_chunks = teacher_logits.chunk(
337+
chunked_sym_kl_loss.num_output_chunks, dim=1
338+
)
339+
chunked_loss = chunked_sym_kl_loss(
340+
student_logits_chunks, teacher_logits_chunks, labels
341+
)
342+
343+
# vanilla Symmetric KL loss
344+
sym_kl_loss = SymmetricKLLoss(ignore_index=-100)
345+
standard_loss = sym_kl_loss(student_logits, teacher_logits, labels)
346+
347+
# assert
348+
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
349+
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)

0 commit comments

Comments
 (0)