@@ -44,12 +44,11 @@ def dataset_docstring_header(fn):
44
44
"""
45
45
Returns docstring for a dataset based on function arguments.
46
46
47
- Assumes function signature of form (root='.data', split=<some tuple of strings>, offset=0, **kwargs)
47
+ Assumes function signature of form (root='.data', split=<some tuple of strings>, **kwargs)
48
48
"""
49
49
argspec = inspect .getfullargspec (fn )
50
50
if not (argspec .args [0 ] == "root" and
51
- argspec .args [1 ] == "split" and
52
- argspec .args [2 ] == "offset" ):
51
+ argspec .args [1 ] == "split" ):
53
52
raise ValueError ("Internal Error: Given function {} did not adhere to standard signature." .format (fn ))
54
53
default_split = argspec .defaults [1 ]
55
54
@@ -68,8 +67,6 @@ def dataset_docstring_header(fn):
68
67
By default, all three datasets are generated. Users
69
68
could also choose any subset of them, for example {} or just 'train'.
70
69
Default: {}
71
- offset: the number of the starting line.
72
- Default: 0
73
70
""" .format (fn .__name__ , "/" .join (default_split ), str (example_subset ), str (default_split ))
74
71
75
72
if isinstance (default_split , str ):
@@ -81,9 +78,7 @@ def dataset_docstring_header(fn):
81
78
root: Directory where the datasets are saved.
82
79
Default: ".data"
83
80
split: Only {default_split} is available.
84
- Default: {default_split}
85
- offset: the number of the starting line.
86
- Default: 0""" .format (fn .__name__ , default_split = default_split )
81
+ Default: {default_split}""" .format (fn .__name__ , default_split = default_split )
87
82
88
83
raise ValueError ("default_split type expected to be of string or tuple but got {}" .format (type (default_split )))
89
84
@@ -116,9 +111,7 @@ def wrap_split_argument(fn):
116
111
argspec = inspect .getfullargspec (fn )
117
112
if not (argspec .args [0 ] == "root" and
118
113
argspec .args [1 ] == "split" and
119
- argspec .args [2 ] == "offset" and
120
114
argspec .defaults [0 ] == ".data" and
121
- argspec .defaults [2 ] == 0 and
122
115
argspec .varargs is None and
123
116
argspec .varkw is None and
124
117
len (argspec .kwonlyargs ) == 0 and
@@ -133,16 +126,15 @@ def wrap_split_argument(fn):
133
126
# keyword arguments with default values only, so only a dictionary of default
134
127
# values is needed to support that behavior for new_fn as well.
135
128
fn_kwargs_dict = {}
136
- for arg , default in zip (argspec .args [3 :], argspec .defaults [3 :]):
129
+ for arg , default in zip (argspec .args [2 :], argspec .defaults [2 :]):
137
130
fn_kwargs_dict [arg ] = default
138
131
139
132
@functools .wraps (fn )
140
- def new_fn (root = '.data' , split = argspec .defaults [1 ], offset = 0 , ** kwargs ):
133
+ def new_fn (root = '.data' , split = argspec .defaults [1 ], ** kwargs ):
141
134
for arg in fn_kwargs_dict :
142
135
if arg not in kwargs :
143
136
kwargs [arg ] = fn_kwargs_dict [arg ]
144
137
kwargs ["root" ] = root
145
- kwargs ["offset" ] = offset
146
138
kwargs ["split" ] = check_default_set (split , argspec .defaults [1 ], fn .__name__ )
147
139
result = fn (** kwargs )
148
140
return wrap_datasets (tuple (result ), split )
@@ -154,32 +146,28 @@ class RawTextIterableDataset(torch.utils.data.IterableDataset):
154
146
"""Defines an abstraction for raw text iterable datasets.
155
147
"""
156
148
157
- def __init__ (self , name , full_num_lines , iterator , offset = 0 ):
149
+ def __init__ (self , name , full_num_lines , iterator ):
158
150
"""Initiate text-classification dataset.
159
151
"""
160
152
super (RawTextIterableDataset , self ).__init__ ()
161
153
self .name = name
162
154
self .full_num_lines = full_num_lines
163
155
self ._iterator = iterator
164
- self .start = offset
165
- if offset < 0 :
166
- raise ValueError ("Given offset must be non-negative, got {} instead." .format (offset ))
167
- self .num_lines = full_num_lines - offset
156
+ self .num_lines = full_num_lines
157
+ self .current_pos = None
168
158
169
159
def __iter__ (self ):
170
- for i , item in enumerate (self ._iterator ):
171
- if i < self .start :
172
- continue
173
- if self .num_lines and i >= (self .start + self .num_lines ):
174
- break
175
- yield item
160
+ return self
176
161
177
162
def __next__ (self ):
163
+ if self .current_pos == self .num_lines - 1 :
164
+ raise StopIteration
178
165
item = next (self ._iterator )
166
+ if self .current_pos is None :
167
+ self .current_pos = 0
168
+ else :
169
+ self .current_pos += 1
179
170
return item
180
171
181
172
def __len__ (self ):
182
173
return self .num_lines
183
-
184
- def get_iterator (self ):
185
- return self ._iterator
0 commit comments