-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathtransformer.py
969 lines (824 loc) · 34.3 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
# Copyright (c) 2021 EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
from .norms import get_norm
from megatron import mpu
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.activations import get_activation
from megatron.model.utils import exists, get_fusion_type
from megatron.model.positional_embeddings import (
RotaryEmbedding,
apply_rotary_pos_emb_torch,
apply_rotary_pos_emb,
AliBi,
)
from megatron.model.fused_bias_dropout import (
get_bias_dropout_add,
bias_dropout_add_fused_train,
bias_dropout_add_fused_inference,
)
from megatron.model.utils import configure_sparse_attention
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmasked-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmasked-attention-scores, attention-mask)
"""
class ParallelMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(
self, neox_args, init_method, output_layer_init_method, parallel_output=False
):
super().__init__()
self.activation_func = get_activation(neox_args)
self.activation_type = neox_args.activation
self.bias_gelu_fusion = neox_args.bias_gelu_fusion
# auto scale so geglu has equal parameters
ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4
ff_dim = (
int(ff_mult * neox_args.hidden_size) * 2
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)
self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim_in,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if (
self.activation_type == "gelu" and self.bias_gelu_fusion
) or self.activation_type == "geglu":
intermediate_parallel = self.activation_func(
intermediate_parallel, bias_parallel
)
else:
intermediate_parallel = self.activation_func(
intermediate_parallel + bias_parallel
)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class LLaMAParallelMLP(nn.Module):
"""LLaMA's MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Note: multiple_of is used to compute the hidden dimension of the MLP
"""
def __init__(
self,
neox_args,
init_method,
output_layer_init_method,
parallel_output=False,
multiple_of=256,
):
super().__init__()
self.activation_func = get_activation(neox_args)
self.activation_type = neox_args.activation
self.multiple_of = multiple_of
ff_dim = int(2 * neox_args.hidden_size * 4 / 3)
ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
bias=False,
)
self.w3 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
bias=False,
)
self.w2 = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=False,
)
def forward(self, hidden_states):
w1_out, _ = self.w1(hidden_states)
w3_out, _ = self.w3(hidden_states)
return self.w2(self.activation_func(w1_out) * w3_out)
class ParallelLinear(nn.Module):
"""
A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size
"""
def __init__(
self,
neox_args,
parallel_output=True,
init_method=nn.init.xavier_normal_,
is_last_layer=False,
):
super().__init__()
parallelism = neox_args.output_layer_parallelism
if parallelism == "column":
self.final_linear = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.padded_vocab_size,
bias=False,
init_method=init_method,
gather_output=not parallel_output,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
)
# else:
# print(
# 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
# )
# exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )
def forward(self, hidden_states):
return self.final_linear(hidden_states)
class ParallelSelfAttention(nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=None,
rotary=False,
use_cache=False,
parallel_output=False,
):
super().__init__()
self.fp16 = neox_args.precision == "fp16"
self.bf16 = neox_args.precision == "bfloat16"
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling
self.use_cache = use_cache
self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = layer_number
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)
self.hidden_size_per_attention_head = mpu.divide(
neox_args.hidden_size, neox_args.num_attention_heads
)
self.num_attention_heads_per_partition = mpu.divide(
neox_args.num_attention_heads, world_size
)
self.pos_emb = neox_args.pos_emb
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = max(1, self.layer_number)
self.norm_factor *= coeff
# TODO
#right now there's no way to correctly set use_mup here, possible options:
#- refactor model init (hard)
#- do this via another config argument, e.g. "mup_norm_factor" (probably easy)
#- ignore, this never changed anything in my experiments
#
#if neox_args.use_mup:
# self.norm_factor = self.hidden_size_per_attention_head
self.rpe = rpe
if self.pos_emb == "alibi":
self.alibi_embed = AliBi(
neox_args.num_attention_heads,
neox_args.model_parallel_size,
mpu.get_model_parallel_rank(),
)
# TODO: this arg shouldn't need to be passed in - get from neox_args
if rotary:
if neox_args.rotary_pct == 1:
self.rotary_ndims = None
else:
assert neox_args.rotary_pct < 1
self.rotary_ndims = int(
self.hidden_size_per_attention_head * neox_args.rotary_pct
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else self.hidden_size_per_attention_head
)
self.rotary_emb = RotaryEmbedding(
dim,
base=neox_args.rotary_emb_base,
max_seq_len=neox_args.seq_length,
precision=neox_args.params_dtype,
)
else:
self.rotary_emb = None
self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.sparse = self.attention_type not in ("global", "flash")
if self.sparse:
self.sparse_attn = configure_sparse_attention(
neox_args,
self.attention_type,
self.num_attention_heads_per_partition,
mpu=mpu,
)
else:
if self.use_flash_attention:
from megatron.model.flash_attention import (
# flash_attn_unpadded_qkvpacked_func_cuda,
# flash_attn_unpadded_kvpacked_func_cuda,
# Change of function names going from flash attention 1 -> flash attention 2
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_unpadded_unpacked_func_triton,
)
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
self.flash_qkv_fn = flash_attn_varlen_qkvpacked_func
self.flash_kv_fn = flash_attn_varlen_kvpacked_func
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.dropout_p = neox_args.attention_dropout
self.attention_dropout = nn.Dropout(self.dropout_p)
# Output.
self.dense = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=neox_args.use_bias_in_attn_linear,
)
def attention(
self, query_layer, key_layer, value_layer, layer_past, attention_mask
):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
output_size[2], output_size[0] * output_size[1], -1
)
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocating result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if self.use_cache:
with torch.no_grad():
attention_mask = attention_mask[
..., : attention_scores.size(3), : attention_scores.size(3)
]
# ===========================
# Attention probs and dropout
# ===========================
if exists(self.rpe):
rpe = self.rpe(query_layer.size(0), key_layer.size(0))
attention_scores += rpe # [1, np, sq, sk]
if self.pos_emb == "alibi":
attention_scores = self.alibi_embed(attention_scores)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
return context_layer
def flash_attention(self, query_layer, key_layer, value_layer):
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
if self.pos_emb != "alibi":
# [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
batch_size = output_size[0]
max_seqlen_q = output_size[2]
max_seqlen_k = output_size[3]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device,
)
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * max_seqlen_k,
step=max_seqlen_k,
dtype=torch.int32,
device=key_layer.device,
)
if not self.training:
# [sq, b, np, hn] -> [b * sq, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0] * output_size[2], output_size[1], -1
)
# Combined k/v into [b * sk, 2, np, hn].
kv = torch.cat([key_layer, value_layer], dim=1)
output = self.flash_kv_fn(
query_layer,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
)
else:
# [sq, b, np, hn] -> [b * sq, 1, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0] * output_size[2], 1, output_size[1], -1
)
# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.cat([query_layer, key_layer, value_layer], dim=1)
output = self.flash_qkv_fn(
qkv,
cu_seqlens_q,
max_seqlen_q,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
)
# [b * sq, np, hn] -> [b, sq, np, hn]
matmul_result = output.view(
output_size[0], output_size[2], output.shape[1], output.shape[2]
)
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)
else:
# [sq, b, np, hn] -> [b, sq, np, hn]
sq = query_layer.size(0)
b = query_layer.size(1)
sk = key_layer.size(0)
query_layer = query_layer.transpose(0, 1)
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype)
bias = bias.unsqueeze(0).tile((b, 1, 1, 1))
matmul_result = self.flash_triton_fn(
query_layer, key_layer, value_layer, bias=bias, causal=True
)
matmul_result = matmul_result.transpose(1, 2)
return matmul_result
def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
# shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn]
query_layer, key_layer, value_layer = map(
lambda t: t.permute(1, 2, 0, 3).contiguous(),
(query_layer, key_layer, value_layer),
)
# output shape [b, np(heads), sq, hn]
attn_mask = attention_mask.to(query_layer.dtype) * -10000
if exists(self.rpe):
rpe = self.rpe(query_layer.size(0), key_layer.size(0))
else:
rpe = None
return self.sparse_attn(
query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe
)
def forward(self, hidden_states, attention_mask, layer_past=None):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(
mixed_x_layer, 3
)
if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
query_rot, query_pass = (
query_layer[..., : self.rotary_ndims],
query_layer[..., self.rotary_ndims :],
)
key_rot, key_pass = (
key_layer[..., : self.rotary_ndims],
key_layer[..., self.rotary_ndims :],
)
else:
# full rotary
query_rot, key_rot = query_layer, key_layer
apply_rotary_fn = (
apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
)
seq_len = key_layer.shape[0]
offset = 0
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(
query_rot, key_rot, cos, sin, offset=offset
)
if exists(self.rotary_ndims):
query_layer = torch.cat((query_layer, query_pass), dim=-1)
key_layer = torch.cat((key_layer, key_pass), dim=-1)
# ==================================
# Cache key and value for inference
# ==================================
if exists(layer_past) and layer_past.numel() > 0:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
value_layer = torch.cat(
(past_value.type_as(value_layer), value_layer), dim=0
)
if self.use_cache:
present = torch.stack((key_layer, value_layer))
if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
else:
context_layer = self.sparse_attention(
query_layer, key_layer, value_layer, attention_mask
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
if self.use_cache:
output = [output, present]
return output, bias
class ParallelTransformerLayer(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(
self,
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=None,
rotary=False,
use_cache=False,
):
super().__init__()
self.layer_number = layer_number
norm, eps = get_norm(neox_args)
# Layernorm on the input data.
self.input_layernorm = norm(neox_args.hidden_size, eps=eps)
self.use_cache = use_cache
self.hidden_dropout = neox_args.hidden_dropout
self.bias_dropout_fusion = neox_args.bias_dropout_fusion
self.gpt_j_residual = neox_args.gpt_j_residual
self.gpt_j_tied = neox_args.gpt_j_tied
self.mlp_type = neox_args.mlp_type
if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
# Self attention.
self.attention = ParallelSelfAttention(
neox_args=neox_args,
attention_mask_func=attention_mask_func,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
rpe=rpe,
use_cache=self.use_cache,
rotary=rotary,
parallel_output=self.gpt_j_residual,
)
# Layernorm on the output of the attention layer.
# If GPT-J residuals are used, this is surpurfulous but leaving it in
# leads to cleaner code
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)
# MLP
if neox_args.mlp_type == "regular":
self.mlp = ParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
elif neox_args.mlp_type == "llama":
self.mlp = LLaMAParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
else:
raise KeyError(neox_args.mlp_type)
self.layer_past = None # used to cache k/v pairs in inference
def _get_bias_dropout(self):
if self.bias_dropout_fusion:
fn = (
bias_dropout_add_fused_train
if self.training
else bias_dropout_add_fused_inference
)
else:
fn = get_bias_dropout_add(self.training)
return fn
def forward(self, x, attention_mask, layer_past=None):
layer_past = layer_past if layer_past is not None else self.layer_past
bias_dropout_fn = self._get_bias_dropout()
# x: [b, s, h]
if self.gpt_j_residual:
# pseudocode:
# x = x + attn(ln(x)) + mlp(ln(x))
# this means we can avoid doing the allreduce in the attn / mlp outputs
# to save communication time (we can do a single allreduce after we add mlp / attn outputs).
# due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but
# we preserve the functionality for backwards compatibility
residual = x
# applies the correct normalization depending on if the norms are tied
if self.gpt_j_tied:
x = self.input_layernorm(x)
x1, x2 = x, x
else:
x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x)
# attention operator
attention_output, attention_bias = self.attention(
x1, attention_mask, layer_past=layer_past
)
if self.use_cache:
attention_output, presents = attention_output
self.layer_past = presents
with torch.enable_grad():
attention_output = bias_dropout_fn(
attention_output,
bias=attention_bias.expand_as(attention_output),
residual=None,
prob=self.hidden_dropout,
)
# mlp operator
mlp_output, mlp_bias = self.mlp(x2)
with torch.enable_grad():
output = bias_dropout_fn(
mlp_output,
bias=mlp_bias.expand_as(mlp_output),
residual=attention_output,
prob=self.hidden_dropout,
)
# output = (x + attn(ln(x)) + mlp(ln(x))
output = residual + self.reduce(output)
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
residual = x
# x = x + attn(ln1(x))
attention_output, attention_bias = self.attention(
self.input_layernorm(x), attention_mask, layer_past=layer_past
)
if self.use_cache:
attention_output, presents = attention_output
self.layer_past = presents
with torch.enable_grad():
if attention_bias is not None:
# Use special bias_dropout_fn if we have a bias term from the above attention layer
attention_output = bias_dropout_fn(
attention_output,
bias=attention_bias.expand_as(residual),
residual=residual,
prob=self.hidden_dropout,
)
else:
# Otherwise just apply dropout + residual
attention_output = (
torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
+ residual
)
# output = x + mlp(ln2(x))
mlp_output, mlp_bias = self.mlp(
self.post_attention_layernorm(attention_output)
)
with torch.enable_grad():
if self.mlp_type == "llama":
# No dropout either
assert mlp_bias is None
output = mlp_output + attention_output
else:
output = bias_dropout_fn(
mlp_output,
bias=mlp_bias.expand_as(attention_output),
residual=attention_output,
prob=self.hidden_dropout,
)
return output
class ParallelTransformerLayerPipe(ParallelTransformerLayer):
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline."""
def forward(self, args):
assert (
len(args) == 2
), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask"
hidden_states, attention_mask = args
# we are returning just [hidden_states, mask]
return super().forward(hidden_states, attention_mask), attention_mask
class ParallelLinearPipe(ParallelLinear):
"""Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model"""
def forward(self, args):
assert isinstance(
args, torch.Tensor
), "ParallelLinearPipe expects a single argument - hidden_states"
hidden_state = args
logits, bias = super().forward(hidden_state)
return logits
class NormPipe(nn.Module):
"""Just a helper class to pass presents through to the output when doing inference with a Pipe Parallel model"""
def __init__(self, norm_class, hidden_size, eps):
super().__init__()
self.norm = norm_class(hidden_size, eps=eps)
def forward(self, args):
assert not isinstance(
args, tuple
), "NormPipe should only receive a single tensor as input"
return self.norm(args)
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return mpu.gather_from_model_parallel_region(logits_parallel)