Skip to content
Open
6 changes: 5 additions & 1 deletion applications/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
parser.add_argument(
'--num-classes', action='store', default=1000, type=int,
help='number of ImageNet classes (default: 1000)', metavar='NUM')
parser.add_argument(
'--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
lbann.contrib.args.add_optimizer_arguments(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -64,7 +67,8 @@
opt = lbann.contrib.args.create_optimizer(args)

# Setup data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
data_path=args.data_path)

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size)
Expand Down
49 changes: 29 additions & 20 deletions applications/vision/data/imagenet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import lbann
import lbann.contrib.launcher

def make_data_reader(num_classes=1000, small_testing=False):
def make_data_reader(num_classes=1000, small_testing=False, data_path=None):

# Load Protobuf message from file
current_dir = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -18,27 +18,36 @@ def make_data_reader(num_classes=1000, small_testing=False):
google.protobuf.text_format.Merge(f.read(), message)
message = message.data_reader

# Paths to ImageNet data
# Note: Paths are only known for some compute centers
compute_center = lbann.contrib.launcher.compute_center()
if compute_center == 'lc':
from lbann.contrib.lc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train',
num_classes=num_classes)
train_label_file = imagenet_labels(data_set='train',
num_classes=num_classes)
test_data_dir = imagenet_dir(data_set='val',
num_classes=num_classes)
test_label_file = imagenet_labels(data_set='val',

if data_path is not None:
print("Setting up data reader")
train_data_dir = os.path.join(data_path, 'train')
test_data_dir = os.path.join(data_path, 'val')
train_label_file = os.path.join(data_path, 'labels/train.txt')
test_label_file = os.path.join(data_path, 'labels/val.txt')

elif lbann.contrib.launcher.compute_center() in ['lc', 'nersc']:
# Paths to ImageNet data
# Note: Paths are only known for some compute centers
compute_center = lbann.contrib.launcher.compute_center()
if compute_center == 'lc':
from lbann.contrib.lc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train',
num_classes=num_classes)
elif compute_center == 'nersc':
from lbann.contrib.nersc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train')
train_label_file = imagenet_labels(data_set='train')
test_data_dir = imagenet_dir(data_set='val')
test_label_file = imagenet_labels(data_set='val')
train_label_file = imagenet_labels(data_set='train',
num_classes=num_classes)
test_data_dir = imagenet_dir(data_set='val',
num_classes=num_classes)
test_label_file = imagenet_labels(data_set='val',
num_classes=num_classes)
elif compute_center == 'nersc':
from lbann.contrib.nersc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train')
train_label_file = imagenet_labels(data_set='train')
test_data_dir = imagenet_dir(data_set='val')
test_label_file = imagenet_labels(data_set='val')
else:
raise RuntimeError(f'ImageNet data paths are unknown for current compute center ({compute_center})')
raise RuntimeError(f'ImageNet data paths are unknown for current compute center ({compute_center}). Set "--data-path" to the location of your dataset.')

# Check that data paths are accessible
if not os.path.isdir(train_data_dir):
Expand Down
8 changes: 6 additions & 2 deletions applications/vision/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ def get_args():
parser.add_argument("--print-matrix-summary", dest="print_matrix_summary",
action="store_const",
const=True, default=False)
parser.add_argument('--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
args = parser.parse_args()
return args

Expand All @@ -438,7 +440,7 @@ def set_up_experiment(args,
labels):
algo = lbann.BatchedIterativeOptimizer("sgd", epoch_count=args.num_epochs)


# Set up objective function
cross_entropy = lbann.CrossEntropy([probs, labels])
layers = list(lbann.traverse_layer_graph(input_))
Expand Down Expand Up @@ -472,7 +474,9 @@ def set_up_experiment(args,
callbacks=callbacks)

# Set up data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, small_testing=True)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
small_testing=True,
data_path=args.data_path)

percentage = 0.001 * 2 * (args.mini_batch_size / 16) * 2

Expand Down
6 changes: 5 additions & 1 deletion applications/vision/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
parser.add_argument(
'--random-seed', action='store', default=0, type=int,
help='random seed for LBANN RNGs', metavar='NUM')
parser.add_argument(
'--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
lbann.contrib.args.add_optimizer_arguments(parser, default_learning_rate=0.1)
args = parser.parse_args()

Expand Down Expand Up @@ -145,7 +148,8 @@
opt = lbann.contrib.args.create_optimizer(args)

# Setup data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
data_path=args.data_path)

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size, random_seed=args.random_seed)
Expand Down
13 changes: 13 additions & 0 deletions docs/data_ingestion.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
.. role:: bash(code)
:language: bash
.. role:: python(code)
:language: python

Data Ingestion
==============

Expand Down Expand Up @@ -27,6 +32,14 @@ Legacy Data Readers
Some of the legacy data readers are the ``MNIST``, ``ImageNet``, and
``CIFAR10`` data readers.

.. note:: The imagenet data reader uses a path that may not be known
to all compute centers. If the dataset is not found
:python:`--data-path` may be set to the top level of the data
set in :code:`resnet.py`, :code:`alexnet.py`, and
:code:`densenet.py`. The data set is must contain
:code:`labels/train.txt`, :code:`labels/val.txt`,
:code:`train/`, and :code:`val/`.


"New" Data Readers
-------------------
Expand Down