16
16
import json
17
17
import copy
18
18
import glob
19
+ import fnmatch
19
20
from typing import Union
20
21
from collections .abc import Sequence
21
22
from decimal import getcontext , Decimal
31
32
from agml .utils .data import load_public_sources
32
33
from agml .utils .general import NoArgument , resolve_list_value
33
34
from agml .utils .random import inject_random_state
34
- from agml .backend .config import data_save_path , synthetic_data_save_path
35
+ from agml .backend .config import (
36
+ data_save_path , synthetic_data_save_path , SUPER_BASE_DIR
37
+ )
35
38
from agml .backend .experimental import AgMLExperimentalFeatureWrapper
36
39
from agml .backend .tftorch import (
37
40
get_backend , set_backend ,
@@ -101,14 +104,25 @@ class AgMLDataLoader(AgMLSerializable, metaclass = AgMLDataLoaderMeta):
101
104
See the methods for examples on how to use an `AgMLDataLoader` effectively.
102
105
"""
103
106
serializable = frozenset ((
104
- 'info' , 'builder' , 'manager' , 'train_data' ,
105
- 'val_data ' , 'test_data' , 'is_split' , 'meta_properties' ))
107
+ 'info' , 'builder' , 'manager' , 'train_data' , 'train_content' , 'val_data' ,
108
+ 'val_content ' , 'test_data' , 'test_content ' , 'is_split' , 'meta_properties' ))
106
109
107
110
def __new__ (cls , dataset , ** kwargs ):
108
111
# If a single dataset is passed, then we use the base `AgMLDataLoader`.
109
112
# However, if an iterable of datasets is passed, then we need to
110
113
# dispatch to the subclass `AgMLMultiDatasetLoader` for them.
111
114
if isinstance (dataset , (str , DatasetMetadata )):
115
+ if '*' in dataset : # enables wildcard search for datasets
116
+ valid_datasets = fnmatch .filter (load_public_sources ().keys (), dataset )
117
+ if len (valid_datasets ) == 0 :
118
+ raise ValueError (
119
+ f"Wildcard search for dataset '{ dataset } ' yielded no results." )
120
+ if len (valid_datasets ) == 1 :
121
+ log (f"Wildcard search for dataset '{ dataset } ' yielded only "
122
+ f"one result. Returning a regular, single-element data loader." )
123
+ return super (AgMLDataLoader , cls ).__new__ (cls )
124
+ from agml .data .multi_loader import AgMLMultiDatasetLoader
125
+ return AgMLMultiDatasetLoader (valid_datasets , ** kwargs )
112
126
return super (AgMLDataLoader , cls ).__new__ (cls )
113
127
elif isinstance (dataset , Sequence ):
114
128
if len (dataset ) == 1 :
@@ -152,8 +166,11 @@ def __init__(self, dataset, **kwargs):
152
166
# If the dataset is split, then the `AgMLDataLoader`s with the
153
167
# split and reduced data are stored as accessible class properties.
154
168
self ._train_data = None
169
+ self ._train_content = None
155
170
self ._val_data = None
171
+ self ._val_content = None
156
172
self ._test_data = None
173
+ self ._test_content = None
157
174
self ._is_split = False
158
175
159
176
# Set the direct access metadata properties like `num_images` and
@@ -208,7 +225,9 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
208
225
Parameters
209
226
----------
210
227
name : str
211
- A name for the custom dataset (this can be any valid string).
228
+ A name for the custom dataset (this can be any valid string). This
229
+ can also be a path to the dataset (in which case the name will be
230
+ the base directory inferred from the path).
212
231
dataset_path : str, optional
213
232
A custom path to load the dataset from. If this is not passed,
214
233
we will assume that the dataset is at the traditional path:
@@ -231,6 +250,11 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
231
250
f"a string that is not an existing dataset in "
232
251
f"the AgML public data source repository." )
233
252
253
+ # Check if the `name` is itself the path to the dataset.
254
+ if os .path .exists (name ):
255
+ dataset_path = name
256
+ name = os .path .basename (name )
257
+
234
258
# Locate the path to the dataset.
235
259
if dataset_path is None :
236
260
dataset_path = os .path .abspath (os .path .join (data_save_path (), name ))
@@ -624,7 +648,7 @@ def train_data(self):
624
648
if isinstance (self ._train_data , AgMLDataLoader ):
625
649
return self ._train_data
626
650
self ._train_data = self ._generate_split_loader (
627
- self ._train_data , split = 'train' )
651
+ self ._train_content , split = 'train' )
628
652
return self ._train_data
629
653
630
654
@property
@@ -633,7 +657,7 @@ def val_data(self):
633
657
if isinstance (self ._val_data , AgMLDataLoader ):
634
658
return self ._val_data
635
659
self ._val_data = self ._generate_split_loader (
636
- self ._val_data , split = 'val' )
660
+ self ._val_content , split = 'val' )
637
661
return self ._val_data
638
662
639
663
@property
@@ -642,7 +666,7 @@ def test_data(self):
642
666
if isinstance (self ._test_data , AgMLDataLoader ):
643
667
return self ._test_data
644
668
self ._test_data = self ._generate_split_loader (
645
- self ._test_data , split = 'test' )
669
+ self ._test_content , split = 'test' )
646
670
return self ._test_data
647
671
648
672
def eval (self ) -> "AgMLDataLoader" :
@@ -980,8 +1004,7 @@ def __call__(self, contents, name):
980
1004
# Re-map the annotation ID.
981
1005
category_ids = annotations ['category_id' ]
982
1006
category_ids [np .where (category_ids == 0 )[0 ]] = 1 # fix
983
- new_ids = np .array ([self ._map [c ]
984
- for c in category_ids ])
1007
+ new_ids = np .array ([self ._map [c ] for c in category_ids ])
985
1008
annotations ['category_id' ] = new_ids
986
1009
return image , annotations
987
1010
@@ -1175,14 +1198,87 @@ def split(self, train = None, val = None, test = None, shuffle = True):
1175
1198
1176
1199
# Build new `DataBuilder`s and `DataManager`s for the split data.
1177
1200
for split , content in contents .items ():
1178
- setattr (self , f'_{ split } _data ' , content )
1201
+ setattr (self , f'_{ split } _content ' , content )
1179
1202
1180
1203
# Otherwise, raise an error for an invalid type.
1181
1204
else :
1182
1205
raise TypeError (
1183
1206
"Expected either only ints or only floats when generating "
1184
1207
f"a data split, got { [type (i ) for i in arg_dict .values ()]} ." )
1185
1208
1209
+ def save_split (self , name , overwrite = False ):
1210
+ """Saves the current split of data to an internal location.
1211
+
1212
+ This method can be used to save the current split of data to an
1213
+ internal file, such that the same split can be later loaded using
1214
+ the `load_split` method (for reproducibility). This method will only
1215
+ save the actual split of data, not any of the transforms or other
1216
+ parameters which have been applied to the loader.
1217
+
1218
+ Parameters
1219
+ ----------
1220
+ name: str
1221
+ The name of the split to save. This name will be used to identify
1222
+ the split when loading it later.
1223
+ overwrite: bool
1224
+ Whether to overwrite an existing split with the same name.
1225
+ """
1226
+ # Ensure that there exist data splits (train/val/test data).
1227
+ if (
1228
+ self ._train_content is None
1229
+ and self ._val_content is None
1230
+ and self ._test_content is None
1231
+ ):
1232
+ raise NotImplementedError ("Cannot save a split of data when no "
1233
+ "split has been generated." )
1234
+
1235
+ # Get each of the individual splits.
1236
+ splits = {'train' : self ._train_content ,
1237
+ 'val' : self ._val_content ,
1238
+ 'test' : self ._test_content }
1239
+
1240
+ # Save the split to the internal location.
1241
+ split_dir = os .path .join (SUPER_BASE_DIR , 'splits' , self .name )
1242
+ os .makedirs (split_dir , exist_ok = True )
1243
+ if os .path .exists (os .path .join (split_dir , f'{ name } .json' )):
1244
+ if not overwrite :
1245
+ raise FileExistsError (f"A split with the name { name } already exists." )
1246
+ with open (os .path .join (split_dir , f'{ name } .json' ), 'w' ) as f :
1247
+ json .dump (splits , f )
1248
+
1249
+ def load_split (self , name , ** kwargs ):
1250
+ """Loads a previously saved split of data.
1251
+
1252
+ This method can be used to load a previously saved split of data
1253
+ if the split was saved using the `save_split` method. This method
1254
+ will only load the actual split of data, not any of the transforms
1255
+ or other parameters which have been applied to the loader. You can
1256
+ use the traditional split accessors (`train_data`, `val_data`, and
1257
+ `test_data`) to access the loaded data.
1258
+
1259
+ Parameters
1260
+ ----------
1261
+ name: str
1262
+ The name of the split to load. This name will be used to identify
1263
+ the split to load.
1264
+ """
1265
+ if kwargs .get ('manual_split_set' , False ):
1266
+ splits = kwargs ['manual_split_set' ]
1267
+
1268
+ else :
1269
+ # Ensure that the split exists.
1270
+ split_dir = os .path .join (SUPER_BASE_DIR , 'splits' , self .name )
1271
+ if not os .path .exists (os .path .join (split_dir , f'{ name } .json' )):
1272
+ raise FileNotFoundError (f"Could not find a split with the name { name } ." )
1273
+
1274
+ # Load the split from the internal location.
1275
+ with open (os .path .join (split_dir , f'{ name } .json' ), 'r' ) as f :
1276
+ splits = json .load (f )
1277
+
1278
+ # Set the split contents.
1279
+ for split , content in splits .items ():
1280
+ setattr (self , f'_{ split } _content' , content )
1281
+
1186
1282
def batch (self , batch_size = None ):
1187
1283
"""Batches sets of image and annotation data according to a size.
1188
1284
@@ -1611,11 +1707,9 @@ def export_torch(self, **loader_kwargs):
1611
1707
1612
1708
# The `DataLoader` automatically batches objects using its
1613
1709
# own mechanism, so we remove batching from this DataLoader.
1614
- batch_size = loader_kwargs .pop (
1615
- 'batch_size' , obj ._manager ._batch_size )
1710
+ batch_size = loader_kwargs .pop ('batch_size' , obj ._manager ._batch_size )
1616
1711
obj .batch (None )
1617
- shuffle = loader_kwargs .pop (
1618
- 'shuffle' , obj ._manager ._shuffle )
1712
+ shuffle = loader_kwargs .pop ('shuffle' , obj ._manager ._shuffle )
1619
1713
1620
1714
# The `collate_fn` for object detection is different because
1621
1715
# the COCO JSON dictionaries each have different formats. So,
0 commit comments