-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathmodeling.py
More file actions
1923 lines (1702 loc) · 67.8 KB
/
modeling.py
File metadata and controls
1923 lines (1702 loc) · 67.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DalleBart model. """
import math
from functools import partial
from typing import Any, Dict, Optional, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from flax.core.frozen_dict import unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.linear import PrecisionLike
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import custom_jvp, lax
from jax.random import PRNGKey
from transformers.generation_flax_utils import FlaxSampleOutput
from transformers.modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
)
from transformers.modeling_flax_utils import ACT2FN
from transformers.models.bart.modeling_flax_bart import (
FlaxBartAttention,
FlaxBartForConditionalGeneration,
FlaxBartForConditionalGenerationModule,
FlaxBartModule,
)
from transformers.utils import logging
from .configuration import DalleBartConfig
from .utils import PretrainedFromWandbMixin
logger = logging.get_logger(__name__)
remat = nn_partitioning.remat
def smelu(beta: Any = 1.0):
"""
Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
https://arxiv.org/abs/2202.06499
"""
@custom_jvp
@jax.jit
def _smelu(x: Any) -> Any:
x = jnp.where(x <= -beta, 0.0, x)
return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
_smelu.defjvps(
lambda g, ans, x: lax.select(
x == -beta,
lax.full_like(g, 0),
lax.select(x == beta, lax.full_like(g, 1), g),
)
)
return _smelu
ACT2FN.update({"smelu": smelu()})
# deepnet initialization
def deepnet_init(gain=1):
init = jax.nn.initializers.glorot_normal()
def _init(*args, **kwargs):
return gain * init(*args, **kwargs)
return _init
# deepnet gain
deepnet_gain = {
"encoder": {
"alpha": lambda config: 0.81
* (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
"beta": lambda config: 0.87
* (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
},
"decoder": {
"alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
"beta": lambda config: (12 * config.decoder_layers) ** -0.25,
},
}
class RMSNorm(nn.Module):
"""
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
Adapted from flax.linen.LayerNorm
"""
epsilon: float = 1e-6
dtype: Any = jnp.float32
param_dtype: Any = jnp.float32
use_scale: bool = True
scale_init: Any = jax.nn.initializers.ones
@nn.compact
def __call__(self, x):
reduction_axes = (-1,)
feature_axes = (-1,)
rms_sq = self._compute_rms_sq(x, reduction_axes)
return self._normalize(
self,
x,
rms_sq,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_scale,
self.scale_init,
)
def _compute_rms_sq(self, x, axes):
x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
rms_sq = jnp.mean(jax.lax.square(x), axes)
return rms_sq
def _normalize(
self,
mdl,
x,
rms_sq,
reduction_axes,
feature_axes,
dtype,
param_dtype,
epsilon,
use_scale,
scale_init,
):
reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
stats_shape = list(x.shape)
for axis in reduction_axes:
stats_shape[axis] = 1
rms_sq = rms_sq.reshape(stats_shape)
feature_shape = [1] * x.ndim
reduced_feature_shape = []
for ax in feature_axes:
feature_shape[ax] = x.shape[ax]
reduced_feature_shape.append(x.shape[ax])
mul = lax.rsqrt(rms_sq + epsilon)
if use_scale:
scale = mdl.param(
"scale", scale_init, reduced_feature_shape, param_dtype
).reshape(feature_shape)
mul *= scale
y = mul * x
return jnp.asarray(y, dtype)
def norm(type, *args, **kwargs):
if type == "rmsnorm":
return RMSNorm(*args, **kwargs)
elif type == "layernorm":
return nn.LayerNorm(*args, **kwargs)
else:
raise ValueError(f"Unknown norm type {type}")
def dot_product_attention_weights(
query: Any,
key: Any,
bias: Optional[Any] = None,
mask: Optional[Any] = None,
embed_pos: Optional[Any] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Any = jnp.float32,
precision: PrecisionLike = None,
sinkhorn_iters: int = 1,
is_encoder: bool = False,
):
"""
Computes dot-product attention weights given query and key.
mask is included into the bias.
Adapted from flax.linen.attention.dot_product_attention_weights"
"""
assert query.ndim == key.ndim, "q, k must have same rank."
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias
# add relative position
if embed_pos is not None:
attn_weights = attn_weights + embed_pos
# normalize the attention weights
if not is_encoder or sinkhorn_iters == 1:
# sinkhorn does not work for causal (leaks info of future tokens into past)
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
else:
# adapted from https://github.com/lucidrains/sinkhorn-transformer
for i in range(sinkhorn_iters):
# when causal, some attn_weights have been set to -inf through bias
if i % 2 == 0:
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
else:
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
if mask is not None:
attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
attn_weights = jnp.exp(attn_weights).astype(dtype)
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
else:
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
keep_prob, dtype=dtype
)
attn_weights = attn_weights * multiplier
return attn_weights
class FlaxBartAttention(FlaxBartAttention):
"""
Edits:
- causal mask is used only in decoder and considers image_length
- scale attention heads per NormFormer paper
"""
is_encoder: bool = False
q_length: int = None
k_length: int = None
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
)
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
self.config
)
self.q_proj = dense(
kernel_init=deepnet_init()
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std)
)
self.k_proj = dense(
kernel_init=deepnet_init()
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std)
)
self.v_proj = dense(
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std)
)
self.out_proj = dense(
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std)
)
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.config.use_head_scale:
self.head_scale = self.param(
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
)
if self.config.use_cosine_attention:
self.tau = self.param(
"tau",
jax.nn.initializers.constant(self.config.tau_init),
(1, self.num_heads, 1, 1),
)
if self.config.use_swin_position_embeddings:
self.rel_bias = nn.Embed(
self.q_length,
self.k_length * self.num_heads,
embedding_init=deepnet_init()
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std),
)
if self.causal:
# used only in decoder
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
)
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states)
value_states = self.v_proj(key_value_states)
else:
# self_attention
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask,
(0, 0, mask_shift, 0),
(1, 1, query_length, max_decoder_length),
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(
causal_mask, (batch_size,) + causal_mask.shape[1:]
)
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(
jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
if self.config.use_cosine_attention:
# normalize q and k
query_states = query_states / (
jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
)
key_states = key_states / (
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
)
# relative position embeddings
if self.config.use_swin_position_embeddings:
position_ids = jnp.arange(self.q_length)
embed_pos = self.rel_bias(position_ids)
embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
else:
embed_pos = None
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
mask=attention_mask,
embed_pos=embed_pos,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
sinkhorn_iters=self.config.sinkhorn_iters,
is_encoder=self.is_encoder,
)
if self.config.use_cosine_attention:
# divide by tau
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
if self.config.use_head_scale:
# per Normformer
attn_output = attn_output * self.head_scale
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class GLU(nn.Module):
"""From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
config: DalleBartConfig
ffn_dim: int
embed_dim: int
dtype: jnp.dtype = jnp.float32
is_encoder: bool = False
@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
self.config
)
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
x = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(x)
w = nn.Dense(
self.ffn_dim,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std),
)(x)
w = ACT2FN[self.config.activation_function](w)
v = nn.Dense(
self.ffn_dim,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std),
)(x)
x = w * v
if self.config.ln_positions in ["normformer"]:
x = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(x)
x = nn.Dropout(rate=self.config.activation_dropout)(
x, deterministic=deterministic
)
x = nn.Dense(
self.embed_dim,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std),
)(x)
if self.config.ln_positions in ["swinv2", "cogview"]:
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
return x
class FFN(nn.Module):
"""Simple FFN layer"""
config: DalleBartConfig
ffn_dim: int
embed_dim: int
dtype: jnp.dtype = jnp.float32
is_encoder: bool = False
@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
self.config
)
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
x = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(x)
x = nn.Dense(
self.ffn_dim,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std),
)(x)
x = ACT2FN[self.config.activation_function](x)
if self.config.ln_positions in ["normformer"]:
x = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(x)
x = nn.Dropout(rate=self.config.activation_dropout)(
x, deterministic=deterministic
)
x = nn.Dense(
self.embed_dim,
dtype=self.dtype,
use_bias=self.config.use_bias,
kernel_init=deepnet_init(gain)
if self.config.use_deepnet_scaling
else jax.nn.initializers.normal(self.config.init_std),
)(x)
if self.config.ln_positions in ["swinv2", "cogview"]:
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
return x
class FlaxBartEncoderLayer(nn.Module):
"""
Edits:
- no bias
- use custom FlaxBartAttention
"""
config: DalleBartConfig
dtype: jnp.dtype = jnp.float32
add_norm: bool = False
use_scale: bool = True
@nn.compact
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
if self.config.use_scan:
hidden_states = hidden_states[0]
res_gain = (
deepnet_gain["encoder"]["alpha"](self.config)
if self.config.use_deepnet_scaling
else 1
)
embed_dim = self.config.d_model
residual = hidden_states
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
hidden_states = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(hidden_states)
hidden_states, attn_weights = FlaxBartAttention(
config=self.config,
embed_dim=embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
bias=self.config.use_bias,
dtype=self.dtype,
is_encoder=True,
q_length=self.config.max_text_length,
k_length=self.config.max_text_length,
)(hidden_states=hidden_states, attention_mask=attention_mask)
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
hidden_states
)
hidden_states = nn.Dropout(rate=self.config.dropout)(
hidden_states, deterministic=deterministic
)
hidden_states = residual * res_gain + hidden_states
if self.config.ln_positions in ["postln"]:
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
hidden_states
)
residual = hidden_states
ff_block = (
GLU(
config=self.config,
ffn_dim=self.config.encoder_ffn_dim,
embed_dim=embed_dim,
dtype=self.dtype,
is_encoder=True,
)
if self.config.use_glu
else FFN(
config=self.config,
ffn_dim=self.config.encoder_ffn_dim,
embed_dim=embed_dim,
dtype=self.dtype,
is_encoder=True,
)
)
hidden_states = ff_block(hidden_states, deterministic=deterministic)
hidden_states = residual * res_gain + hidden_states
if self.add_norm:
use_scale = self.use_scale or self.config.force_ln_scale
hidden_states = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=use_scale,
)(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if self.config.use_scan:
outputs = (outputs, None)
return outputs
class FlaxBartDecoderLayer(nn.Module):
"""
Edits:
- no bias
- use custom FlaxBartAttention
"""
config: DalleBartConfig
dtype: jnp.dtype = jnp.float32
add_norm: bool = False
use_scale: bool = True
@nn.compact
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
if self.config.use_scan:
hidden_states = hidden_states[0]
res_gain = (
deepnet_gain["decoder"]["alpha"](self.config)
if self.config.use_deepnet_scaling
else 1
)
embed_dim = self.config.d_model
residual = hidden_states
# Self Attention
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
hidden_states = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(hidden_states)
hidden_states, attn_weights = FlaxBartAttention(
config=self.config,
embed_dim=embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
bias=self.config.use_bias,
dtype=self.dtype,
is_encoder=False,
q_length=self.config.image_length,
k_length=self.config.image_length,
)(
hidden_states=hidden_states,
attention_mask=attention_mask,
init_cache=init_cache,
)
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
hidden_states
)
hidden_states = nn.Dropout(rate=self.config.dropout)(
hidden_states, deterministic=deterministic
)
hidden_states = residual * res_gain + hidden_states
if self.config.ln_positions in ["postln"]:
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
hidden_states
)
# Cross Attention
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
hidden_states = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=self.config.force_ln_scale,
)(hidden_states)
hidden_states, cross_attn_weights = FlaxBartAttention(
config=self.config,
embed_dim=embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
bias=self.config.use_bias,
dtype=self.dtype,
is_encoder=False,
q_length=self.config.image_length,
k_length=self.config.max_text_length,
)(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
hidden_states = norm(
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
)(hidden_states)
hidden_states = nn.Dropout(rate=self.config.dropout)(
hidden_states, deterministic=deterministic
)
hidden_states = residual * res_gain + hidden_states
if self.config.ln_positions in ["postln"]:
hidden_states = norm(
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
)(hidden_states)
# Feed forward
residual = hidden_states
ff_block = (
GLU(
config=self.config,
ffn_dim=self.config.decoder_ffn_dim,
embed_dim=embed_dim,
dtype=self.dtype,
is_encoder=False,
)
if self.config.use_glu
else FFN(
config=self.config,
ffn_dim=self.config.decoder_ffn_dim,
embed_dim=embed_dim,
dtype=self.dtype,
is_encoder=False,
)
)
hidden_states = ff_block(hidden_states, deterministic=deterministic)
hidden_states = residual * res_gain + hidden_states
if self.add_norm:
use_scale = self.use_scale or self.config.force_ln_scale
hidden_states = norm(
self.config.ln_type,
dtype=self.dtype,
epsilon=1e-05,
use_scale=use_scale,
)(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights, cross_attn_weights)
if self.config.use_scan:
outputs = (outputs, None)
return outputs
class FlaxBartEncoderLayerCollection(nn.Module):
config: DalleBartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
"""
Edits:
- use custom FlaxBartEncoderLayer
- allow Gradient Checkpointing (nn.remat)
"""
@nn.compact
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
n_layers = self.config.encoder_layers
layer = (
remat(
FlaxBartEncoderLayer,
static_argnums=(2, 3),
prevent_cse=not self.config.use_scan,
)
if self.config.gradient_checkpointing
else FlaxBartEncoderLayer
)
if self.config.use_scan:
# all blocks are the same so we use nn.scan
assert not output_attentions, "cannot scan with output_attentions"
assert not output_hidden_states, "cannot scan with output_hidden_states"
hidden_states = (hidden_states,)
# we use a scale on all norms (even last layer) to allow scanning
hidden_states, _ = nn.scan(
layer,
variable_axes={"params": 0, "cache": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
length=n_layers,
)(
self.config,
dtype=self.dtype,
add_norm=self.config.ln_positions == "postln",
name="FlaxBartEncoderLayers",
)(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = hidden_states[0]
else:
for i in range(n_layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# final layernorm on the output of the last layer
# or every 6 layers for Swin v2
add_norm = self.config.ln_positions == "postln" or (
self.config.ln_positions == "swinv2"
and ((i + 1) % 6 == 0)
and (i != n_layers - 1)
)
# we don't need to scale the norm for the last layer
use_scale = i != n_layers - 1
layer_outputs = layer(
self.config,
dtype=self.dtype,
add_norm=add_norm,
use_scale=use_scale,
name=f"FlaxBartEncoderLayer_{i}",
)(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# add hidden states from the last layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [
hidden_states,
all_hidden_states,
all_self_attns,
]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class FlaxBartDecoderLayerCollection(nn.Module):
config: DalleBartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
"""
Edits:
- use custom FlaxBartDecoderLayer
- allow Gradient Checkpointing (nn.remat)
"""
@nn.compact
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = (
() if (output_attentions and encoder_hidden_states is not None) else None
)
n_layers = self.config.decoder_layers
layer = (
remat(
FlaxBartDecoderLayer,
static_argnums=(4, 5, 6),
prevent_cse=not self.config.use_scan,
)
if self.config.gradient_checkpointing
else FlaxBartDecoderLayer
)
if self.config.use_scan:
# all blocks are the same so we use nn.scan
assert not output_attentions, "cannot scan with output_attentions"
assert not output_hidden_states, "cannot scan with output_hidden_states"
hidden_states = (hidden_states,)
# we use a scale on all norms (even last layer) to allow scanning
hidden_states, _ = nn.scan(
layer,
variable_axes={"params": 0, "cache": 0},
split_rngs={"params": True, "dropout": True},