Skip to content

Conversation

@SolitaryThinker
Copy link

@SolitaryThinker SolitaryThinker commented Dec 9, 2024

First attempt at logit extraction. Does not include the softmax normalization. 15-30% slowdown compared to unmodified triton kernel.

Note: this PR currently contains the version that casts to bf16 before writing out logits

@SolitaryThinker
Copy link
Author

SolitaryThinker commented Dec 9, 2024

BF16
Raw benchmark numbers over 1k iterations of kernel

Format: batch_size, seq_len, block_size, head_size, expose_time(ms), unmodified_time(ms), slowdown(%)
1, 128, 16, 128, 0.12, 0.10, 17.82
1, 512, 16, 128, 0.11, 0.10, 16.11
1, 1024, 16, 128, 0.26, 0.21, 22.69
1, 2048, 16, 128, 0.76, 0.64, 19.21
1, 4096, 16, 128, 2.32, 1.99, 16.25
4, 128, 16, 128, 0.11, 0.10, 14.86
4, 512, 16, 128, 0.24, 0.18, 31.01
4, 1024, 16, 128, 0.69, 0.57, 22.17
4, 2048, 16, 128, 2.47, 2.09, 18.57
4, 4096, 16, 128, 9.00, 7.72, 16.59
8, 128, 16, 128, 0.11, 0.10, 15.53
8, 512, 16, 128, 0.40, 0.31, 28.70
8, 1024, 16, 128, 1.29, 1.08, 19.40
8, 2048, 16, 128, 4.66, 3.91, 19.08
8, 4096, 16, 128, 17.57, 15.12, 16.18
16, 128, 16, 128, 0.17, 0.14, 14.64
16, 512, 16, 128, 0.74, 0.59, 24.85
16, 1024, 16, 128, 2.45, 2.05, 19.45
16, 2048, 16, 128, 9.16, 7.71, 18.74
16, 4096, 16, 128, 34.91, 30.07, 16.09
32, 128, 16, 128, 0.32, 0.28, 14.10
32, 512, 16, 128, 1.39, 1.12, 24.71
32, 1024, 16, 128, 4.81, 4.02, 19.43
32, 2048, 16, 128, 18.20, 15.40, 18.13

@SolitaryThinker
Copy link
Author

float32
Seems like the perf is slightly better if we don't cast to bf16

Format: batch_size, seq_len, block_size, head_size, expose_time, prefix_time, slowdown
1, 128, 16, 128, 0.11, 0.09, 17.42
1, 512, 16, 128, 0.11, 0.09, 17.41
1, 1024, 16, 128, 0.25, 0.21, 18.61
1, 2048, 16, 128, 0.71, 0.64, 12.35
1, 4096, 16, 128, 2.23, 2.00, 11.65
4, 128, 16, 128, 0.11, 0.10, 16.93
4, 512, 16, 128, 0.22, 0.18, 18.21
4, 1024, 16, 128, 0.65, 0.57, 14.95
4, 2048, 16, 128, 2.35, 2.09, 12.52
4, 4096, 16, 128, 8.67, 7.72, 12.31
8, 128, 16, 128, 0.11, 0.10, 16.04
8, 512, 16, 128, 0.36, 0.31, 16.64
8, 1024, 16, 128, 1.23, 1.08, 13.28
8, 2048, 16, 128, 4.41, 3.91, 12.66
8, 4096, 16, 128, 16.98, 15.12, 12.29
16, 128, 16, 128, 0.16, 0.14, 13.14
16, 512, 16, 128, 0.68, 0.59, 14.74
16, 1024, 16, 128, 2.32, 2.05, 13.28
16, 2048, 16, 128, 8.70, 7.72, 12.76
16, 4096, 16, 128, 33.72, 30.06, 12.16
32, 128, 16, 128, 0.32, 0.28, 12.68
32, 512, 16, 128, 1.28, 1.12, 15.08
32, 1024, 16, 128, 4.59, 4.02, 14.07
32, 2048, 16, 128, 17.32, 15.40, 12.45

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant