|
1 | 1 | import logging
|
2 | 2 | from enum import Enum, auto
|
3 |
| -from typing import Any, Callable, Dict, List, Optional |
| 3 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | from torch._decomp import register_decomposition
|
@@ -435,6 +435,137 @@ def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
|
435 | 435 | return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])
|
436 | 436 |
|
437 | 437 |
|
| 438 | +@register_torch_trt_decomposition(aten.view.default, registry=TORCH_TRT_DECOMPOSITIONS) |
| 439 | +def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tensor: |
| 440 | + return aten._reshape_copy.default(x, size) |
| 441 | + |
| 442 | + |
| 443 | +@register_torch_trt_decomposition( |
| 444 | + aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS |
| 445 | +) |
| 446 | +def scaled_dot_product_attention_decomposition( |
| 447 | + query: torch.Tensor, |
| 448 | + key: torch.Tensor, |
| 449 | + value: torch.Tensor, |
| 450 | + attn_mask: Optional[torch.Tensor] = None, |
| 451 | + dropout_p: float = 0.0, |
| 452 | + is_causal: bool = False, |
| 453 | + *, |
| 454 | + scale: Optional[float] = None, |
| 455 | + enable_gqa: bool = False, |
| 456 | +) -> torch.Tensor: |
| 457 | + L, S = query.size(-2), key.size(-2) |
| 458 | + device = query.device |
| 459 | + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device) |
| 460 | + |
| 461 | + if is_causal: |
| 462 | + assert attn_mask is None, "attn_mask must be None when is_causal=True" |
| 463 | + temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0) |
| 464 | + attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) |
| 465 | + |
| 466 | + if attn_mask is not None: |
| 467 | + if attn_mask.dtype == torch.bool: |
| 468 | + attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) |
| 469 | + else: |
| 470 | + attn_bias = attn_mask + attn_bias |
| 471 | + |
| 472 | + if enable_gqa: |
| 473 | + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) |
| 474 | + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) |
| 475 | + |
| 476 | + attn_weight = query @ key.transpose(-2, -1) |
| 477 | + |
| 478 | + if scale is None: |
| 479 | + scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) |
| 480 | + attn_weight = attn_weight / scale |
| 481 | + else: |
| 482 | + attn_weight = attn_weight * scale |
| 483 | + |
| 484 | + attn_weight = attn_weight + attn_bias |
| 485 | + attn_weight = torch.softmax(attn_weight, dim=-1) |
| 486 | + return attn_weight @ value |
| 487 | + |
| 488 | + |
| 489 | +@register_torch_trt_decomposition( |
| 490 | + aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS |
| 491 | +) |
| 492 | +def scaled_dot_product_flash_attention_decomposition( |
| 493 | + query: torch.Tensor, |
| 494 | + key: torch.Tensor, |
| 495 | + value: torch.Tensor, |
| 496 | + dropout_p: float = 0.0, |
| 497 | + is_causal: bool = False, |
| 498 | + return_debug_mask: bool = False, |
| 499 | + *, |
| 500 | + scale: Optional[float] = None, |
| 501 | +) -> Tuple[ |
| 502 | + torch.Tensor, |
| 503 | + torch.Tensor, |
| 504 | + torch.Tensor, |
| 505 | + torch.Tensor, |
| 506 | + torch.SymInt, |
| 507 | + torch.SymInt, |
| 508 | + torch.Tensor, |
| 509 | + torch.Tensor, |
| 510 | + torch.Tensor, |
| 511 | +]: |
| 512 | + attn = scaled_dot_product_attention_decomposition( |
| 513 | + query, key, value, None, dropout_p, is_causal, scale=scale |
| 514 | + ) |
| 515 | + return attn, None, None, None, 0, 0, None, None, None |
| 516 | + |
| 517 | + |
| 518 | +@register_torch_trt_decomposition( |
| 519 | + aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS |
| 520 | +) |
| 521 | +def scaled_dot_product_efficient_attention_decomposition( |
| 522 | + query: torch.Tensor, |
| 523 | + key: torch.Tensor, |
| 524 | + value: torch.Tensor, |
| 525 | + attn_bias: Optional[torch.Tensor], |
| 526 | + compute_log_sumexp: bool, |
| 527 | + dropout_p: float = 0.0, |
| 528 | + is_causal: bool = False, |
| 529 | + *, |
| 530 | + scale: Optional[float] = None, |
| 531 | +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 532 | + attn = scaled_dot_product_attention_decomposition( |
| 533 | + query, key, value, attn_bias, dropout_p, is_causal, scale=scale |
| 534 | + ) |
| 535 | + return attn, None, None, None |
| 536 | + |
| 537 | + |
| 538 | +@register_torch_trt_decomposition( |
| 539 | + aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS |
| 540 | +) |
| 541 | +def scaled_dot_product_cudnn_attention_decomposition( |
| 542 | + query: torch.Tensor, |
| 543 | + key: torch.Tensor, |
| 544 | + value: torch.Tensor, |
| 545 | + attn_bias: Optional[torch.Tensor], |
| 546 | + compute_log_sumexp: bool, |
| 547 | + dropout_p: float = 0.0, |
| 548 | + is_causal: bool = False, |
| 549 | + return_debug_mask: bool = False, |
| 550 | + *, |
| 551 | + scale: Optional[float] = None, |
| 552 | +) -> Tuple[ |
| 553 | + torch.Tensor, |
| 554 | + torch.Tensor, |
| 555 | + torch.Tensor, |
| 556 | + torch.Tensor, |
| 557 | + torch.SymInt, |
| 558 | + torch.SymInt, |
| 559 | + torch.Tensor, |
| 560 | + torch.Tensor, |
| 561 | + torch.Tensor, |
| 562 | +]: |
| 563 | + attn = scaled_dot_product_attention_decomposition( |
| 564 | + query, key, value, attn_bias, dropout_p, is_causal, scale=scale |
| 565 | + ) |
| 566 | + return attn, None, None, None, 0, 0, None, None, None |
| 567 | + |
| 568 | + |
438 | 569 | def get_decompositions(
|
439 | 570 | enable_experimental_decompositions: bool = False,
|
440 | 571 | ) -> Dict[OpOverload, Callable[[Any], Any]]:
|
|
0 commit comments