@@ -867,18 +867,18 @@ def dict_pack(example):
867
867
868
868
def _standardize (self , dataset , keys ):
869
869
"""Force dataset structure into a tuple of Tensors."""
870
- shapes = tf .compat . v1 . data .get_output_shapes (dataset )
870
+ shapes = tf .data .get_output_shapes (dataset )
871
871
872
872
if isinstance (shapes , dict ):
873
873
keys = keys or tuple (shapes .keys ())
874
874
dataset = dataset .map (lambda x : tuple (x [k ] for k in keys ))
875
- shapes = tf .compat . v1 . data .get_output_shapes (dataset )
875
+ shapes = tf .data .get_output_shapes (dataset )
876
876
877
877
if not all (isinstance (i , tf .TensorShape ) for i in shapes ):
878
878
# Internally this class expects tuples of Tensors, even for the degenerate
879
879
# case of a single sequence.
880
880
dataset = dataset .map (lambda x : (x ,))
881
- shapes = tf .compat . v1 . data .get_output_shapes (dataset )
881
+ shapes = tf .data .get_output_shapes (dataset )
882
882
883
883
for s in shapes :
884
884
if not s .is_compatible_with (tf .TensorShape ([None ])):
@@ -890,7 +890,7 @@ def _standardize(self, dataset, keys):
890
890
if self ._chop_long_sequences and len (shapes ) != 1 :
891
891
raise ValueError ("chop_long_sequences expects a single sequence dataset." )
892
892
893
- token_types = tf .compat . v1 . data .get_output_types (dataset )
893
+ token_types = tf .data .get_output_types (dataset )
894
894
if len (set (token_types )) > 1 :
895
895
raise ValueError ("Inconsistent dtypes: {}" .format (token_types ))
896
896
0 commit comments