|
7 | 7 | import pytest
|
8 | 8 | import torch
|
9 | 9 | from tests.test_utils import assert_expected
|
10 |
| -from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss |
| 10 | +from torchtune.modules.loss import ( |
| 11 | + ForwardKLLoss, |
| 12 | + ForwardKLWithChunkedOutputLoss, |
| 13 | + ReverseKLLoss, |
| 14 | + ReverseKLWithChunkedOutputLoss, |
| 15 | + SymmetricKLLoss, |
| 16 | + SymmetricKLWithChunkedOutputLoss, |
| 17 | +) |
11 | 18 | from torchtune.training.seed import set_seed
|
12 | 19 |
|
13 | 20 |
|
@@ -140,3 +147,203 @@ def test_forward_kl_loss_expected(self):
|
140 | 147 | # assert
|
141 | 148 | assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
|
142 | 149 | assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
|
| 150 | + |
| 151 | + |
| 152 | +class TestReverseKLWithChunkedOutputLoss: |
| 153 | + def test_reverse_kl_loss(self): |
| 154 | + # Create a sample input and label |
| 155 | + ignore_index = -100 |
| 156 | + batch_size = 3 |
| 157 | + num_tokens = 50 |
| 158 | + vocab_size = 50 |
| 159 | + logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) |
| 160 | + teacher_logits = torch.randn( |
| 161 | + batch_size, num_tokens, vocab_size, dtype=torch.bfloat16 |
| 162 | + ) |
| 163 | + labels = torch.randint( |
| 164 | + 0, vocab_size, (batch_size, num_tokens), dtype=torch.long |
| 165 | + ) |
| 166 | + |
| 167 | + # add random ignore index to random tokens in the label |
| 168 | + random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) |
| 169 | + labels[random_indices < num_tokens // 5] = ignore_index |
| 170 | + |
| 171 | + # chunked RKL |
| 172 | + chunked_rkl_loss = ReverseKLWithChunkedOutputLoss( |
| 173 | + num_output_chunks=8, ignore_index=ignore_index |
| 174 | + ) |
| 175 | + logits_chunks = logits.chunk(chunked_rkl_loss.num_output_chunks, dim=1) |
| 176 | + teacher_logits_chunks = teacher_logits.chunk( |
| 177 | + chunked_rkl_loss.num_output_chunks, dim=1 |
| 178 | + ) |
| 179 | + chunked_loss = chunked_rkl_loss(logits_chunks, teacher_logits_chunks, labels) |
| 180 | + |
| 181 | + # vanilla RKL |
| 182 | + rkl_loss = ReverseKLLoss(ignore_index=ignore_index) |
| 183 | + logits = logits.reshape(-1, logits.size(-1)) |
| 184 | + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) |
| 185 | + labels = labels.reshape(-1) |
| 186 | + standard_loss = rkl_loss(logits, teacher_logits, labels) |
| 187 | + |
| 188 | + # Assert |
| 189 | + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) |
| 190 | + |
| 191 | + def test_reverse_kl_loss_expected(self): |
| 192 | + student_logits = torch.tensor( |
| 193 | + [ |
| 194 | + [ |
| 195 | + [1.1250, -0.4102, -0.0879, -2.5000], |
| 196 | + [0.2676, 0.3535, 0.8711, -1.4688], |
| 197 | + [-0.1084, 1.6641, 0.0084, 0.1196], |
| 198 | + [0.5000, -0.6406, -0.2236, -1.5938], |
| 199 | + ], |
| 200 | + [ |
| 201 | + [-1.5312, -1.9219, 0.0000, -0.5039], |
| 202 | + [-1.5391, 1.5312, 0.5820, 0.2695], |
| 203 | + [-0.3887, 1.2188, 0.0000, 0.6055], |
| 204 | + [0.5000, 1.3828, 0.1309, -1.0312], |
| 205 | + ], |
| 206 | + ], |
| 207 | + dtype=torch.bfloat16, |
| 208 | + ) |
| 209 | + teacher_logits = torch.tensor( |
| 210 | + [ |
| 211 | + [ |
| 212 | + [-0.0381, -1.2578, -1.2031, 0.0947], |
| 213 | + [-0.7852, 0.4492, 1.5547, 0.0972], |
| 214 | + [0.8203, 0.0012, 0.7656, 0.3477], |
| 215 | + [-1.5781, 0.4297, 0.5977, 0.3926], |
| 216 | + ], |
| 217 | + [ |
| 218 | + [1.5156, 0.1641, 2.0781, -0.7734], |
| 219 | + [-0.5898, 0.4453, -0.7969, 0.6328], |
| 220 | + [0.6289, -0.8359, 0.9258, 0.2109], |
| 221 | + [0.0006, 0.5195, 3.2344, -1.5781], |
| 222 | + ], |
| 223 | + ], |
| 224 | + dtype=torch.bfloat16, |
| 225 | + ) |
| 226 | + labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]]) |
| 227 | + expected_loss = torch.tensor(0.6775, dtype=torch.float32) |
| 228 | + |
| 229 | + # chunked RKL loss |
| 230 | + chunked_rkl_loss = ReverseKLWithChunkedOutputLoss( |
| 231 | + num_output_chunks=2, ignore_index=-100 |
| 232 | + ) |
| 233 | + student_logits_chunks = student_logits.chunk( |
| 234 | + chunked_rkl_loss.num_output_chunks, dim=1 |
| 235 | + ) |
| 236 | + teacher_logits_chunks = teacher_logits.chunk( |
| 237 | + chunked_rkl_loss.num_output_chunks, dim=1 |
| 238 | + ) |
| 239 | + chunked_loss = chunked_rkl_loss( |
| 240 | + student_logits_chunks, teacher_logits_chunks, labels |
| 241 | + ) |
| 242 | + |
| 243 | + # vanilla RKL loss |
| 244 | + rkl_loss = ReverseKLLoss(ignore_index=-100) |
| 245 | + standard_loss = rkl_loss(student_logits, teacher_logits, labels) |
| 246 | + |
| 247 | + # assert |
| 248 | + assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) |
| 249 | + assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) |
| 250 | + |
| 251 | + |
| 252 | +class TestSymmetricKLWithChunkedOutputLoss: |
| 253 | + def test_symmetric_kl_loss(self): |
| 254 | + # Create a sample input and label |
| 255 | + ignore_index = -100 |
| 256 | + batch_size = 3 |
| 257 | + num_tokens = 50 |
| 258 | + vocab_size = 50 |
| 259 | + logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) |
| 260 | + teacher_logits = torch.randn( |
| 261 | + batch_size, num_tokens, vocab_size, dtype=torch.bfloat16 |
| 262 | + ) |
| 263 | + labels = torch.randint( |
| 264 | + 0, vocab_size, (batch_size, num_tokens), dtype=torch.long |
| 265 | + ) |
| 266 | + |
| 267 | + # add random ignore index to random tokens in the label |
| 268 | + random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) |
| 269 | + labels[random_indices < num_tokens // 5] = ignore_index |
| 270 | + |
| 271 | + # chunked Symmetric KL |
| 272 | + chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss( |
| 273 | + num_output_chunks=8, ignore_index=ignore_index |
| 274 | + ) |
| 275 | + logits_chunks = logits.chunk(chunked_sym_kl_loss.num_output_chunks, dim=1) |
| 276 | + teacher_logits_chunks = teacher_logits.chunk( |
| 277 | + chunked_sym_kl_loss.num_output_chunks, dim=1 |
| 278 | + ) |
| 279 | + chunked_loss = chunked_sym_kl_loss(logits_chunks, teacher_logits_chunks, labels) |
| 280 | + |
| 281 | + # vanilla Symmetric KL |
| 282 | + sym_kl_loss = SymmetricKLLoss(ignore_index=ignore_index) |
| 283 | + logits = logits.reshape(-1, logits.size(-1)) |
| 284 | + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) |
| 285 | + labels = labels.reshape(-1) |
| 286 | + standard_loss = sym_kl_loss(logits, teacher_logits, labels) |
| 287 | + |
| 288 | + # Assert |
| 289 | + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) |
| 290 | + |
| 291 | + def test_symmetric_kl_loss_expected(self): |
| 292 | + student_logits = torch.tensor( |
| 293 | + [ |
| 294 | + [ |
| 295 | + [1.1250, -0.4102, -0.0879, -2.5000], |
| 296 | + [0.2676, 0.3535, 0.8711, -1.4688], |
| 297 | + [-0.1084, 1.6641, 0.0084, 0.1196], |
| 298 | + [0.5000, -0.6406, -0.2236, -1.5938], |
| 299 | + ], |
| 300 | + [ |
| 301 | + [-1.5312, -1.9219, 0.0000, -0.5039], |
| 302 | + [-1.5391, 1.5312, 0.5820, 0.2695], |
| 303 | + [-0.3887, 1.2188, 0.0000, 0.6055], |
| 304 | + [0.5000, 1.3828, 0.1309, -1.0312], |
| 305 | + ], |
| 306 | + ], |
| 307 | + dtype=torch.bfloat16, |
| 308 | + ) |
| 309 | + teacher_logits = torch.tensor( |
| 310 | + [ |
| 311 | + [ |
| 312 | + [-0.0381, -1.2578, -1.2031, 0.0947], |
| 313 | + [-0.7852, 0.4492, 1.5547, 0.0972], |
| 314 | + [0.8203, 0.0012, 0.7656, 0.3477], |
| 315 | + [-1.5781, 0.4297, 0.5977, 0.3926], |
| 316 | + ], |
| 317 | + [ |
| 318 | + [1.5156, 0.1641, 2.0781, -0.7734], |
| 319 | + [-0.5898, 0.4453, -0.7969, 0.6328], |
| 320 | + [0.6289, -0.8359, 0.9258, 0.2109], |
| 321 | + [0.0006, 0.5195, 3.2344, -1.5781], |
| 322 | + ], |
| 323 | + ], |
| 324 | + dtype=torch.bfloat16, |
| 325 | + ) |
| 326 | + labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]]) |
| 327 | + expected_loss = torch.tensor(1.1992, dtype=torch.float32) |
| 328 | + |
| 329 | + # chunked Symmetric KL loss |
| 330 | + chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss( |
| 331 | + num_output_chunks=2, ignore_index=-100 |
| 332 | + ) |
| 333 | + student_logits_chunks = student_logits.chunk( |
| 334 | + chunked_sym_kl_loss.num_output_chunks, dim=1 |
| 335 | + ) |
| 336 | + teacher_logits_chunks = teacher_logits.chunk( |
| 337 | + chunked_sym_kl_loss.num_output_chunks, dim=1 |
| 338 | + ) |
| 339 | + chunked_loss = chunked_sym_kl_loss( |
| 340 | + student_logits_chunks, teacher_logits_chunks, labels |
| 341 | + ) |
| 342 | + |
| 343 | + # vanilla Symmetric KL loss |
| 344 | + sym_kl_loss = SymmetricKLLoss(ignore_index=-100) |
| 345 | + standard_loss = sym_kl_loss(student_logits, teacher_logits, labels) |
| 346 | + |
| 347 | + # assert |
| 348 | + assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) |
| 349 | + assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) |
0 commit comments