-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathdiscriminator.py
67 lines (51 loc) · 2.19 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import importlib
import tensorflow as tf
import texar as tx
from utils import model_utils
class Discriminator(object):
def __init__(self, gpt2_config):
vocab_size = gpt2_config.vocab_size
self.word_embedder = tx.modules.WordEmbedder(
vocab_size=vocab_size,
hparams=gpt2_config.embed)
self.pos_embedder = tx.modules.PositionEmbedder(
position_size=gpt2_config.position_size,
hparams=gpt2_config.pos_embed)
# Ties output layer with input word embedding
output_layer = tf.transpose(self.word_embedder.embedding, (1, 0))
self.decoder = tx.modules.TransformerDecoder(
vocab_size=vocab_size,
output_layer=output_layer,
hparams=gpt2_config.decoder)
def init_model(self, sess, ckpt_path):
tf.logging.info('Discriminator, restore from {}'.format(ckpt_path))
model_utils.init_gpt2_checkpoint(sess, ckpt_path)
print("\nFinished loading\n")
def compute_loss(self, soft_ids, length):
batch_size = tf.shape(soft_ids)[0]
seq_len = tf.fill([batch_size], tf.shape(soft_ids)[1])
pos_embeds = self.pos_embedder(sequence_length=seq_len)
#pos_embed = tf.stop_gradient(pos_embeds)
input_embeds = self.word_embedder(
soft_ids=soft_ids, stop_gradient=False) + pos_embeds
#return tf.reduce_sum(input_embeds)
outputs = self.decoder(
inputs=input_embeds, decoding_strategy='train_greedy')
#return tf.reduce_sum(soft_ids[:,:,1000])
#return tf.reduce_mean(outputs.logits)
#return tf.reduce_mean(
# tf.nn.softmax_cross_entropy_with_logits_v2(
# labels=soft_ids[:, 1:],
# logits=tf.stop_gradient(outputs.logits[:, :-1, :]))
# )
loss = tx.losses.sequence_softmax_cross_entropy(
labels=soft_ids[:, 1:],
logits=outputs.logits[:, :-1, :],
sequence_length=length-1,
average_across_timesteps=True,
sum_over_timesteps=False,
average_across_batch=True,
sum_over_batch=False,
stop_gradient_to_label=False) #TODO
return loss