Description
Prerequisites
Please answer the following questions for yourself before submitting an issue.
- I am using the latest TensorFlow Model Garden release and TensorFlow 2.
- I am reporting the issue to the correct repository. (Model Garden official or research directory)
- I checked to make sure that this issue has not been filed already.
1. The entire URL of the file you are using
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/position_embedding.py
2. Describe the bug
models/official/nlp/modeling/layers/position_embedding.py
Lines 161 to 174 in db50116
Why do we have this in line 167?
inv_timescales = min_timescale * tf.exp(
tf.cast(tf.range(num_timescales), tf.float32) *
-log_timescale_increment)
The problem is that we multiply min_timescale
instead of 1 / min_timescale
to obtain inv_timescales
.
In this way, the final values before sin/cos function we have are:
pos * T_min * [T_r ** 0, T_r ** ( -1/dim_range), T_r ** ( -2/dim_range), ..., T_r ** ( -dim_range/dim_range)]
where T_r = T_max / T_min
and dim_range = num_timescales - 1
.
Notably, the largest timestep corresponds to an effective inverse timescale T_max / (T_min * T_min)
and the smallest timestep uses 1 / T_min
.
3. Steps to reproduce
This is more like a math issue. No code execution is needed.
4. Expected behavior
I think the correct thing to have is:
pos / T_min * [T_r ** 0, T_r ** ( -1/dim_range), T_r ** ( -2/dim_range), ..., T_r ** ( -dim_range/dim_range)]
where the largest timestep corresponds to an inverse timescale T_max
and the smallest timestep uses T_min
.
These two implementations have no difference when T_min=1.0
. The test function at
5. Additional context
This is more like a math issue. No code execution is needed.
6. System information
This is more like a math issue. No code execution is needed.
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device name if the issue happens on a mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory: