Skip to content

Commit 73e646f

Browse files
committed
Merge branch 'public-dev' into public-main
2 parents 0c380be + 906240f commit 73e646f

File tree

11 files changed

+269
-48
lines changed

11 files changed

+269
-48
lines changed

agml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = '0.5.1'
15+
__version__ = '0.5.2'
1616
__all__ = ['data', 'synthetic', 'backend', 'viz', 'io']
1717

1818

agml/data/loader.py

Lines changed: 108 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import copy
1818
import glob
19+
import fnmatch
1920
from typing import Union
2021
from collections.abc import Sequence
2122
from decimal import getcontext, Decimal
@@ -31,7 +32,9 @@
3132
from agml.utils.data import load_public_sources
3233
from agml.utils.general import NoArgument, resolve_list_value
3334
from agml.utils.random import inject_random_state
34-
from agml.backend.config import data_save_path, synthetic_data_save_path
35+
from agml.backend.config import (
36+
data_save_path, synthetic_data_save_path, SUPER_BASE_DIR
37+
)
3538
from agml.backend.experimental import AgMLExperimentalFeatureWrapper
3639
from agml.backend.tftorch import (
3740
get_backend, set_backend,
@@ -101,14 +104,25 @@ class AgMLDataLoader(AgMLSerializable, metaclass = AgMLDataLoaderMeta):
101104
See the methods for examples on how to use an `AgMLDataLoader` effectively.
102105
"""
103106
serializable = frozenset((
104-
'info', 'builder', 'manager', 'train_data',
105-
'val_data', 'test_data', 'is_split', 'meta_properties'))
107+
'info', 'builder', 'manager', 'train_data', 'train_content', 'val_data',
108+
'val_content', 'test_data', 'test_content', 'is_split', 'meta_properties'))
106109

107110
def __new__(cls, dataset, **kwargs):
108111
# If a single dataset is passed, then we use the base `AgMLDataLoader`.
109112
# However, if an iterable of datasets is passed, then we need to
110113
# dispatch to the subclass `AgMLMultiDatasetLoader` for them.
111114
if isinstance(dataset, (str, DatasetMetadata)):
115+
if '*' in dataset: # enables wildcard search for datasets
116+
valid_datasets = fnmatch.filter(load_public_sources().keys(), dataset)
117+
if len(valid_datasets) == 0:
118+
raise ValueError(
119+
f"Wildcard search for dataset '{dataset}' yielded no results.")
120+
if len(valid_datasets) == 1:
121+
log(f"Wildcard search for dataset '{dataset}' yielded only "
122+
f"one result. Returning a regular, single-element data loader.")
123+
return super(AgMLDataLoader, cls).__new__(cls)
124+
from agml.data.multi_loader import AgMLMultiDatasetLoader
125+
return AgMLMultiDatasetLoader(valid_datasets, **kwargs)
112126
return super(AgMLDataLoader, cls).__new__(cls)
113127
elif isinstance(dataset, Sequence):
114128
if len(dataset) == 1:
@@ -152,8 +166,11 @@ def __init__(self, dataset, **kwargs):
152166
# If the dataset is split, then the `AgMLDataLoader`s with the
153167
# split and reduced data are stored as accessible class properties.
154168
self._train_data = None
169+
self._train_content = None
155170
self._val_data = None
171+
self._val_content = None
156172
self._test_data = None
173+
self._test_content = None
157174
self._is_split = False
158175

159176
# Set the direct access metadata properties like `num_images` and
@@ -208,7 +225,9 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
208225
Parameters
209226
----------
210227
name : str
211-
A name for the custom dataset (this can be any valid string).
228+
A name for the custom dataset (this can be any valid string). This
229+
can also be a path to the dataset (in which case the name will be
230+
the base directory inferred from the path).
212231
dataset_path : str, optional
213232
A custom path to load the dataset from. If this is not passed,
214233
we will assume that the dataset is at the traditional path:
@@ -231,6 +250,11 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
231250
f"a string that is not an existing dataset in "
232251
f"the AgML public data source repository.")
233252

253+
# Check if the `name` is itself the path to the dataset.
254+
if os.path.exists(name):
255+
dataset_path = name
256+
name = os.path.basename(name)
257+
234258
# Locate the path to the dataset.
235259
if dataset_path is None:
236260
dataset_path = os.path.abspath(os.path.join(data_save_path(), name))
@@ -624,7 +648,7 @@ def train_data(self):
624648
if isinstance(self._train_data, AgMLDataLoader):
625649
return self._train_data
626650
self._train_data = self._generate_split_loader(
627-
self._train_data, split = 'train')
651+
self._train_content, split = 'train')
628652
return self._train_data
629653

630654
@property
@@ -633,7 +657,7 @@ def val_data(self):
633657
if isinstance(self._val_data, AgMLDataLoader):
634658
return self._val_data
635659
self._val_data = self._generate_split_loader(
636-
self._val_data, split = 'val')
660+
self._val_content, split = 'val')
637661
return self._val_data
638662

639663
@property
@@ -642,7 +666,7 @@ def test_data(self):
642666
if isinstance(self._test_data, AgMLDataLoader):
643667
return self._test_data
644668
self._test_data = self._generate_split_loader(
645-
self._test_data, split = 'test')
669+
self._test_content, split = 'test')
646670
return self._test_data
647671

648672
def eval(self) -> "AgMLDataLoader":
@@ -980,8 +1004,7 @@ def __call__(self, contents, name):
9801004
# Re-map the annotation ID.
9811005
category_ids = annotations['category_id']
9821006
category_ids[np.where(category_ids == 0)[0]] = 1 # fix
983-
new_ids = np.array([self._map[c]
984-
for c in category_ids])
1007+
new_ids = np.array([self._map[c] for c in category_ids])
9851008
annotations['category_id'] = new_ids
9861009
return image, annotations
9871010

@@ -1175,14 +1198,87 @@ def split(self, train = None, val = None, test = None, shuffle = True):
11751198

11761199
# Build new `DataBuilder`s and `DataManager`s for the split data.
11771200
for split, content in contents.items():
1178-
setattr(self, f'_{split}_data', content)
1201+
setattr(self, f'_{split}_content', content)
11791202

11801203
# Otherwise, raise an error for an invalid type.
11811204
else:
11821205
raise TypeError(
11831206
"Expected either only ints or only floats when generating "
11841207
f"a data split, got {[type(i) for i in arg_dict.values()]}.")
11851208

1209+
def save_split(self, name, overwrite = False):
1210+
"""Saves the current split of data to an internal location.
1211+
1212+
This method can be used to save the current split of data to an
1213+
internal file, such that the same split can be later loaded using
1214+
the `load_split` method (for reproducibility). This method will only
1215+
save the actual split of data, not any of the transforms or other
1216+
parameters which have been applied to the loader.
1217+
1218+
Parameters
1219+
----------
1220+
name: str
1221+
The name of the split to save. This name will be used to identify
1222+
the split when loading it later.
1223+
overwrite: bool
1224+
Whether to overwrite an existing split with the same name.
1225+
"""
1226+
# Ensure that there exist data splits (train/val/test data).
1227+
if (
1228+
self._train_content is None
1229+
and self._val_content is None
1230+
and self._test_content is None
1231+
):
1232+
raise NotImplementedError("Cannot save a split of data when no "
1233+
"split has been generated.")
1234+
1235+
# Get each of the individual splits.
1236+
splits = {'train': self._train_content,
1237+
'val': self._val_content,
1238+
'test': self._test_content}
1239+
1240+
# Save the split to the internal location.
1241+
split_dir = os.path.join(SUPER_BASE_DIR, 'splits', self.name)
1242+
os.makedirs(split_dir, exist_ok = True)
1243+
if os.path.exists(os.path.join(split_dir, f'{name}.json')):
1244+
if not overwrite:
1245+
raise FileExistsError(f"A split with the name {name} already exists.")
1246+
with open(os.path.join(split_dir, f'{name}.json'), 'w') as f:
1247+
json.dump(splits, f)
1248+
1249+
def load_split(self, name, **kwargs):
1250+
"""Loads a previously saved split of data.
1251+
1252+
This method can be used to load a previously saved split of data
1253+
if the split was saved using the `save_split` method. This method
1254+
will only load the actual split of data, not any of the transforms
1255+
or other parameters which have been applied to the loader. You can
1256+
use the traditional split accessors (`train_data`, `val_data`, and
1257+
`test_data`) to access the loaded data.
1258+
1259+
Parameters
1260+
----------
1261+
name: str
1262+
The name of the split to load. This name will be used to identify
1263+
the split to load.
1264+
"""
1265+
if kwargs.get('manual_split_set', False):
1266+
splits = kwargs['manual_split_set']
1267+
1268+
else:
1269+
# Ensure that the split exists.
1270+
split_dir = os.path.join(SUPER_BASE_DIR, 'splits', self.name)
1271+
if not os.path.exists(os.path.join(split_dir, f'{name}.json')):
1272+
raise FileNotFoundError(f"Could not find a split with the name {name}.")
1273+
1274+
# Load the split from the internal location.
1275+
with open(os.path.join(split_dir, f'{name}.json'), 'r') as f:
1276+
splits = json.load(f)
1277+
1278+
# Set the split contents.
1279+
for split, content in splits.items():
1280+
setattr(self, f'_{split}_content', content)
1281+
11861282
def batch(self, batch_size = None):
11871283
"""Batches sets of image and annotation data according to a size.
11881284
@@ -1611,11 +1707,9 @@ def export_torch(self, **loader_kwargs):
16111707

16121708
# The `DataLoader` automatically batches objects using its
16131709
# own mechanism, so we remove batching from this DataLoader.
1614-
batch_size = loader_kwargs.pop(
1615-
'batch_size', obj._manager._batch_size)
1710+
batch_size = loader_kwargs.pop('batch_size', obj._manager._batch_size)
16161711
obj.batch(None)
1617-
shuffle = loader_kwargs.pop(
1618-
'shuffle', obj._manager._shuffle)
1712+
shuffle = loader_kwargs.pop('shuffle', obj._manager._shuffle)
16191713

16201714
# The `collate_fn` for object detection is different because
16211715
# the COCO JSON dictionaries each have different formats. So,

0 commit comments

Comments
 (0)