-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
46 lines (33 loc) · 1.38 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
from attention_blocks import *
def build_model(d_model):
N = 6 # Number of stacked encoder/decoder layers
#TODO: Add placeholder or Input layer
x = tf.placeholder(tf.float32,shape=(10,13,512))
#TODO: Add positional encoding
# Encoder (stacked N times)
for i in range(N):
multihead_out = multihead_attention(x,x,x)
resid1_out = residual_connection(x,multihead_out)
norm_out = layer_norm(resid1_out)
ppff_out = positionwise_feedforward(norm_out)
resid2_out = residual_connection(norm_out,ppff_out)
x = layer_norm(resid2_out)
# TODO: Add output (output so far) layer
y = tf.placeholder(tf.float32,shape=(10,2,512))
# Decoder (stacked N times)
for i in range(N):
multihead_out = multihead_attention(y,y,y)
resid1_out = residual_connection(y,multihead_out)
norm1_out = layer_norm(resid1_out)
enc_dec_attn = multihead_attention(norm1_out,x,x) #Use x (from encoder) as key's and values
resid2_out = residual_connection(norm1_out,enc_dec_attn)
norm2_out = layer_norm(resid2_out)
ppff_out = positionwise_feedforward(norm2_out)
resid3_out = residual_connection(norm2_out,ppff_out)
y = layer_norm(resid3_out)
y_transform = tf.layers.dense(y,d_model)
def main():
build_model(512)
if __name__ == "__main__":
main()