2222import tensorflow as tf
2323import numpy as np
2424
25+ from tensorflow_probability import distributions as tfpd
2526from tensorflow .contrib .seq2seq import SampleEmbeddingHelper
2627from texar .evals .bleu import sentence_bleu
2728from rouge import Rouge
@@ -95,19 +96,19 @@ def sample(self, time, outputs, state, name=None):
9596 of 'state'([decoded_ids, rnn_state])
9697 """
9798 sample_method_sampler = \
98- tf . distributions .Categorical (probs = self ._lambdas )
99+ tfpd .Categorical (probs = self ._lambdas )
99100 sample_method_id = sample_method_sampler .sample ()
100101
101102 truth_feeding = lambda : tf .cond (
102103 tf .less (time , tf .shape (self ._ground_truth )[1 ]),
103- lambda : tf .to_int32 (self ._ground_truth [:, time ]),
104+ lambda : tf .cast (self ._ground_truth [:, time ], tf . int32 ),
104105 lambda : tf .ones_like (self ._ground_truth [:, 0 ],
105106 dtype = tf .int32 ) * self ._vocab .eos_token_id )
106107
107- self_feeding = lambda : SampleEmbeddingHelper .sample (
108+ self_feeding = lambda : SampleEmbeddingHelper .sample (
108109 self , time , outputs , state , name )
109110
110- reward_feeding = lambda : self ._sample_by_reward (time , state )
111+ reward_feeding = lambda : self ._sample_by_reward (time , state )
111112
112113 sample_ids = tf .cond (
113114 tf .logical_or (tf .equal (time , 0 ), tf .equal (sample_method_id , 1 )),
@@ -207,9 +208,9 @@ def _get_rewards(time, prefix_ids, target_ids, ground_truth_length):
207208
208209 return result
209210
210- sampler = tf . distributions .Categorical (
211+ sampler = tfpd .Categorical (
211212 logits = tf .py_func (_get_rewards , [
212213 time , state [0 ], self ._ground_truth ,
213214 self ._ground_truth_length ], tf .float32 ))
214215 return tf .reshape (
215- sampler .sample (), (tf .shape (self ._ground_truth )[0 ],))
216+ sampler .sample (), (tf .shape (self ._ground_truth )[0 ],))
0 commit comments