Open
Description
Hi,
The background is that in Encoder-Decoder model used for translation from "Attention Is All You Need" I desired to mask-out the padding in sentence passed to Encoder's MultiHeadAttention, but I notice that the computed for the mask -neginf
based on the logits eltype might cause some issues and lead to NaN.
Minimal example is provided here https://github.com/mashu/NaNTracker.jl
The result is
caused by: DomainError with Float32[-7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; … ; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7;;; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; … ; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7;;; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; … ; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7;;; … ;;; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; … ; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7;;; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; … ; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7;;; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; -2.38419f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; … ; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7]:
NaN on gradient input for layer: KeyPath(:mha, :out_proj)
Stacktrace:
[1] (::Main.NaNTracker.var"#pb_check#2"{DebugWrapper{…}, Zygote.var"#ad_pullback#58"{…}})(Δ::Array{Float32, 3})
@ Main.NaNTracker ~/NaNTracker.jl/src/NaNTracker.jl:26
[2] ZBack
@ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
[3] #_#334
@ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:129 [inlined]
[4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[5] MultiHeadAttention
@ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:120 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[7] MultiHeadAttention
@ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:120 [inlined]
[8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[9] #_#332
@ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:115 [inlined]
[10] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[11] MultiHeadAttention
@ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:115 [inlined]
[12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[13] #_#5
@ ~/NaNTracker.jl/src/Example.jl:24 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float32, 3, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[15] EncoderOnly
@ ~/NaNTracker.jl/src/Example.jl:22 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float32, 3, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[17] #14
@ ~/NaNTracker.jl/src/Example.jl:47 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
[20] withgradient(f::Function, args::EncoderOnly)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:213
[21] testit()
@ Main ~/NaNTracker.jl/src/Example.jl:46
[22] with_logging(::Function)
@ Main.NaNTracker ~/NaNTracker.jl/src/NaNTracker.jl:36
[23] top-level scope
@ ~/NaNTracker.jl/src/Example.jl:50
I hope it's not an issue with my understanding of how mask should look like, but to be honest documentation in Flux could use a couple of examples for this particular use-case in addition to just make_causal_mask
.
Metadata
Metadata
Assignees
Labels
No labels