Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 77e2ab8

Browse files
sandeepkrishnamurthy-devroywei
authored andcommitted
Add MXNet Backend (#59)
* Adding MXNet backend template. Adding all basic Variable and Tensor operations (#1) * add activation functions * add activation functions * fix some legacy * fix some legacy * cross entropy * cross entropy * fix name scoping introduced in 2.0 * fix name scoping introduced in 2.0 * Add dropout, l2_normalization, random_normal/uniform/binomial (#2) * remove the logic for hacking RNN * remove the logic for hacking RNN * add pooling with utils * add pooling with utils * minor * lint and name scope fix * fix access protected var * fix add neighbor, removed __eq__ in KerasSymbol * fix eval function, unittest for placeholder and variable * add unittests * fix bug * fix bug * fix * add some temporary fixes in mxnet backend. undo change to the pytest.ini * mxnet_backend graph fix, layer support (#3) * add activation functions * fix some legacy * cross entropy * fix name scoping introduced in 2.0 * Add dropout, l2_normalization, random_normal/uniform/binomial (#2) * remove the logic for hacking RNN * add pooling with utils * add activation functions * fix some legacy * cross entropy * fix name scoping introduced in 2.0 * remove the logic for hacking RNN * add pooling with utils * minor * lint and name scope fix * fix access protected var * fix add neighbor, removed __eq__ in KerasSymbol * fix eval function, unittest for placeholder and variable * add unittests * fix bug * fix bug * fix * add some temporary fixes in mxnet backend. undo change to the pytest.ini * Keras function not working is a known issue, add skip in the test * fix random_uniform/constant * fix legacy randomize methods * Fix MXNet backend operator bugs. Enabled Keras backend tests * add bias * Add Amazon copyrights to License (#6) * fix * fix * fix backend for mlp * fix context management, add optimizers * minor change * undo changes on example * fix eval * minor cleanup * fix some property usage * fixing AlphaDroupout, not finished yet * add mx model instantiate * modifies training model construct logic, fix some tests. fix reshape layer. * minor fix * fix bias_add * more fix on Dense and bias_add * In progress commit * fix comment * small fix * remove pytest.skip in conv3d. But it failed with theano backend in my workspace though. * Add conv2d and in_topk operator for mxnet backend (#11) * Skip BatchDot tests for Theano backend. (#12) * BatchDot, Basic Batchnorm, Fix BiasAdd, Fix Conv2D, CodeCleanup (#14) * Fix Conv2d shape issues and enable Conv2D UTs * Remove redundant mxnet only unit tests * Adding batch_dot, remove deconv, code comments and cleanup * Remove buggy conv1d implementation * Fix CR comments. Fix lint check issues * Move mxnet specific code from keras engine to mxnet_backend. (#15) * Move MXNet optimizers from keras optimizers to mxnet backend (#16) * Fix bug in reshape. Minor rename to avoid local conflicts * Bug fixes and enable/skip all Keras tests for mxnet backend (#21) * test results - 374 passed, 235 skipped in 114.44 seconds * fix/skip keras tests - tests/integration_tests, tests/keras/applications * fix/skip keras tests - tests/keras/engine/test_topology * fix/skip keras tests - tests/keras/engine/test_training * fix/skip keras tests - tests/keras/legacy/ * fix/skip keras tests - tests/keras/preprocessing * fix/skip keras tests - tests/keras/utils/ * Fix CR comments * Fix issues in zero_padding. Fix/Enable tests/layers/convolutional_test * Add momentum to batchnorm. Enable/skip tests in layers/core, local, merge, noise, normalization * Skip RNN tests in keras/tests/layers/recurrent_test, wrappers_test * Fix bug in spatial padding, enable/skip tests in loss,optimizers,callback,loss_weighting, model_saving * Fix mxnet backend multi-gpu training (#31) Fixing bug for mxnet backend to use multiple gpus. * Fix performance issue - Batchnormalization, Conv operator (#35) * Fix default axis for batchnorm layer for channels_first data_format * Performance improvement by avoiding kernel transpose in conv operation for channels_first format * Fix model - architecture, weights and both, load and save. (#36) * Prepare initial version of mxnet related documentation in keras (#38) * Skip failing unit tests for unsupported functionality in mxnet backend * Fix pep tests reported by CI * Use pytest module skip, revert kernel_shape logic * remove data_format param from bias_add API * Allow Predict() without compile for mxnet backend and enable tests. contributor - roywei@ * Fix bug - mxnet backend should not override keras config data_format to channels_first. Only warn of low performance * Conv3d() operator implementation for Keras2.0 using MXNet backend (#40) * conv3d implementation for keras2.0 as MXNet backend * conv3d implementation/testing for keras2.0 using MXNet backend * keeping -n option in pytest.ini file * fixed comments given by Sandeep * Add Conv1D support for MXNet backend (#44) * Add Conv1D support for MXNet backend * Fix CR comments * Conv2d transpose (#47) * add conv2d_transpose * conv2d transpose for both channels, enabled test case * add detailed comments and examples, fix style issue * enable test case in topology * Enable performance optimization for conv operators with MXNet backend. Make MXNet default backend with this branch (#48) * Fix conv kernel shape bug for TF backend. (#50) * Add support for keras multi_gpu_model() API with MXNet backend (#49) * Add support for keras multi_gpu_model() API with MXNet backend. Autoset GPU0 context on GPU machine * Fix typo * Add SAME padding mode support for pooling operator. (#51) * Add rnn() operator for MXNet backend with unrolling and masking feature (#46) * Adding rnn() operator in Keras2.0 with MXNet as backend with unroll=True and Masking=True/False and enabled relevant testcases. Also, modified couple of operators. * Modified comments * Added comments to a method * Enable categorical crossentropy testcases and made minor changes * Modified message * nit * Added detail description of handling variable length input in RNN * Skip conv2d_transpose and conv3d_transpose test-case for MXNet backend and minor changes in rnn() * Adamax and NAdam optimizer for MXNet backend (#54) * Add Adamax optimizer for MXNet backend * Fix lr and adamax params * Add Nadam optimizer for mxnet backend * Add Conv3d transpose (#52) * conv3d tranpose, enabled test case * update kernel shape * replace conv2d_transpse conv3d_transpose with convnd_transpose * update value errors with MXNet Backend info, fix typo * add check for conv3d transpose only supports gpu with cudnn * update context check * diable conv3d transpose test * fix typo in comment * Adding MXNet backend template. Adding all basic Variable and Tensor operations (#1) * add activation functions * add activation functions * fix some legacy * fix some legacy * cross entropy * cross entropy * fix name scoping introduced in 2.0 * fix name scoping introduced in 2.0 * Add dropout, l2_normalization, random_normal/uniform/binomial (#2) * remove the logic for hacking RNN * remove the logic for hacking RNN * add pooling with utils * add pooling with utils * minor * lint and name scope fix * fix access protected var * fix add neighbor, removed __eq__ in KerasSymbol * fix eval function, unittest for placeholder and variable * add unittests * fix bug * fix bug * fix * add some temporary fixes in mxnet backend. undo change to the pytest.ini * mxnet_backend graph fix, layer support (#3) * add activation functions * fix some legacy * cross entropy * fix name scoping introduced in 2.0 * Add dropout, l2_normalization, random_normal/uniform/binomial (#2) * remove the logic for hacking RNN * add pooling with utils * add activation functions * fix some legacy * cross entropy * fix name scoping introduced in 2.0 * remove the logic for hacking RNN * add pooling with utils * minor * lint and name scope fix * fix access protected var * fix add neighbor, removed __eq__ in KerasSymbol * fix eval function, unittest for placeholder and variable * add unittests * fix bug * fix bug * fix * add some temporary fixes in mxnet backend. undo change to the pytest.ini * Keras function not working is a known issue, add skip in the test * fix random_uniform/constant * fix legacy randomize methods * Fix MXNet backend operator bugs. Enabled Keras backend tests * add bias * Add Amazon copyrights to License (#6) * fix * fix * fix backend for mlp * fix context management, add optimizers * minor change * undo changes on example * fix eval * minor cleanup * fix some property usage * fixing AlphaDroupout, not finished yet * add mx model instantiate * modifies training model construct logic, fix some tests. fix reshape layer. * minor fix * fix bias_add * more fix on Dense and bias_add * In progress commit * fix comment * small fix * remove pytest.skip in conv3d. But it failed with theano backend in my workspace though. * Add conv2d and in_topk operator for mxnet backend (#11) * Skip BatchDot tests for Theano backend. (#12) * BatchDot, Basic Batchnorm, Fix BiasAdd, Fix Conv2D, CodeCleanup (#14) * Fix Conv2d shape issues and enable Conv2D UTs * Remove redundant mxnet only unit tests * Adding batch_dot, remove deconv, code comments and cleanup * Remove buggy conv1d implementation * Fix CR comments. Fix lint check issues * Move mxnet specific code from keras engine to mxnet_backend. (#15) * Move MXNet optimizers from keras optimizers to mxnet backend (#16) * Fix bug in reshape. Minor rename to avoid local conflicts * Bug fixes and enable/skip all Keras tests for mxnet backend (#21) * test results - 374 passed, 235 skipped in 114.44 seconds * fix/skip keras tests - tests/integration_tests, tests/keras/applications * fix/skip keras tests - tests/keras/engine/test_topology * fix/skip keras tests - tests/keras/engine/test_training * fix/skip keras tests - tests/keras/legacy/ * fix/skip keras tests - tests/keras/preprocessing * fix/skip keras tests - tests/keras/utils/ * Fix CR comments * Fix issues in zero_padding. Fix/Enable tests/layers/convolutional_test * Add momentum to batchnorm. Enable/skip tests in layers/core, local, merge, noise, normalization * Skip RNN tests in keras/tests/layers/recurrent_test, wrappers_test * Fix bug in spatial padding, enable/skip tests in loss,optimizers,callback,loss_weighting, model_saving * Fix mxnet backend multi-gpu training (#31) Fixing bug for mxnet backend to use multiple gpus. * Fix performance issue - Batchnormalization, Conv operator (#35) * Fix default axis for batchnorm layer for channels_first data_format * Performance improvement by avoiding kernel transpose in conv operation for channels_first format * Fix model - architecture, weights and both, load and save. (#36) * Prepare initial version of mxnet related documentation in keras (#38) * Skip failing unit tests for unsupported functionality in mxnet backend * Fix pep tests reported by CI * Use pytest module skip, revert kernel_shape logic * remove data_format param from bias_add API * Allow Predict() without compile for mxnet backend and enable tests. contributor - roywei@ * Fix bug - mxnet backend should not override keras config data_format to channels_first. Only warn of low performance * Conv3d() operator implementation for Keras2.0 using MXNet backend (#40) * conv3d implementation for keras2.0 as MXNet backend * conv3d implementation/testing for keras2.0 using MXNet backend * keeping -n option in pytest.ini file * fixed comments given by Sandeep * Add Conv1D support for MXNet backend (#44) * Add Conv1D support for MXNet backend * Fix CR comments * Conv2d transpose (#47) * add conv2d_transpose * conv2d transpose for both channels, enabled test case * add detailed comments and examples, fix style issue * enable test case in topology * Enable performance optimization for conv operators with MXNet backend. Make MXNet default backend with this branch (#48) * Fix conv kernel shape bug for TF backend. (#50) * Add support for keras multi_gpu_model() API with MXNet backend (#49) * Add support for keras multi_gpu_model() API with MXNet backend. Autoset GPU0 context on GPU machine * Fix typo * Add SAME padding mode support for pooling operator. (#51) * Add rnn() operator for MXNet backend with unrolling and masking feature (#46) * Adding rnn() operator in Keras2.0 with MXNet as backend with unroll=True and Masking=True/False and enabled relevant testcases. Also, modified couple of operators. * Modified comments * Added comments to a method * Enable categorical crossentropy testcases and made minor changes * Modified message * nit * Added detail description of handling variable length input in RNN * Skip conv2d_transpose and conv3d_transpose test-case for MXNet backend and minor changes in rnn() * Adamax and NAdam optimizer for MXNet backend (#54) * Add Adamax optimizer for MXNet backend * Fix lr and adamax params * Add Nadam optimizer for mxnet backend * Add Conv3d transpose (#52) * conv3d tranpose, enabled test case * update kernel shape * replace conv2d_transpse conv3d_transpose with convnd_transpose * update value errors with MXNet Backend info, fix typo * add check for conv3d transpose only supports gpu with cudnn * update context check * diable conv3d transpose test * fix typo in comment * Rebase to latest Keras - April 3, 2018 * Add build badges * Fix multi_gpu API bug for CPU. Fix PEP. (#64) * Fix multi_gpu API bug for CPU. Fix PEP. * fix embedding layer bug (#61) * fix embedding bug * addressed comments, enabled more test cases * add keras test * reduce line length * fix style, add blank lines * Benchmark (#55) * add conv2d_transpose * conv2d transpose for both channels, enabled test case * add detailed comments and examples, fix style issue * add benchmark scripts for resnet and imagenet data * combine scripts * fix args * fix num of gpus * update log * multi_gpu_model only support tf * add benchamrk scripts for synthetic data * update read me and scripts * add mxnet traing result table * update on readme * add cifar10 dataset and enable various resnet layers * fix compile for mxnet multiple gpu * update callbacks * update synthetic data script, add credits * undo new line * update readme, addressed pr comments * update readme * benchmark scripts style fix (#66) * style fix * remove unused import, fix line too long * adrressed pr comments * Added keras util API for conversion of data tensor from channels_last to channels_first using MXNet backend (#65) * Added keras util API for conversion of data tensor from channels_last to channels_first using MXNet backend * Modified comments * Addressed review comments and made the API more generic accross backends * Removed shape check * Modified comments * Added edge cases * moved helper method as nested * Added RNN benchmark scripts (#69) * Added RNN benchmark scripts * Fixed new line in bash script * Removed different backend code and modified comments * Removed spacing * Automated the wikiText2 download script * Added dataset_util functionality to have more flexible code * Added minor comments * modified minor comments * Fixed the multi-gpu context (#68) * Update benchmark result (#70) * update benchmark result * update result * simplify folder structure * add image result * add note * add note
1 parent ef13db0 commit 77e2ab8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+7165
-177
lines changed

.travis.yml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ matrix:
2121
env: KERAS_BACKEND=cntk PYTHONWARNINGS=ignore
2222
- python: 3.6
2323
env: KERAS_BACKEND=cntk PYTHONWARNINGS=ignore
24+
- python: 2.7
25+
env: KERAS_BACKEND=mxnet PYTHONWARNINGS=ignore
26+
- python: 3.6
27+
env: KERAS_BACKEND=mxnet PYTHONWARNINGS=ignore
2428
install:
2529
# code below is taken from http://conda.pydata.org/docs/travis.html
2630
# We do this conditionally because it saves us some downloading if the
@@ -38,7 +42,7 @@ install:
3842
# Useful for debugging any issues with conda
3943
- conda info -a
4044

41-
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytest pandas
45+
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION nose scipy matplotlib pandas pytest h5py
4246
- source activate test-environment
4347
- pip install --only-binary=numpy,scipy numpy nose scipy matplotlib h5py theano
4448
- conda install mkl mkl-service
@@ -57,7 +61,11 @@ install:
5761

5862
# install TensorFlow (CPU version).
5963
- pip install tensorflow
60-
64+
65+
# install Apache MXNet (CPU version).
66+
- pip install mxnet
67+
- pip install --upgrade numpy
68+
6169
# install cntk
6270
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
6371
pip install https://cntk.ai/PythonWheel/CPU-Only/cntk-2.3.1-cp27-cp27mu-linux_x86_64.whl;
@@ -78,6 +86,9 @@ install:
7886
- if [[ "$KERAS_BACKEND" != "cntk" ]]; then
7987
echo ' keras/backend/cntk_backend.py' >> .coveragerc;
8088
fi
89+
- if [[ "$KERAS_BACKEND" != "mxnet" ]]; then
90+
echo ' keras/backend/mxnet_backend.py' >> .coveragerc;
91+
fi
8192

8293
# detect whether core files are changed or not
8394
- export CORE_CHANGED=False;

LICENSE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ All contributions by Microsoft:
1212
Copyright (c) 2017 - 2018, Microsoft, Inc.
1313
All rights reserved.
1414

15+
All contributions by Amazon:
16+
Copyright (c) 2017 Amazon.com, Inc. or its affiliates
17+
All rights reserved.
18+
19+
All contributions by Amazon:
20+
Copyright (c) 2017 Amazon.com, Inc. or its affiliates
21+
All rights reserved.
22+
1523
All other contributions:
1624
Copyright (c) 2015 - 2018, the respective contributors.
1725
All rights reserved.

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
![Keras logo](https://s3.amazonaws.com/keras.io/img/keras-logo-2018-large-1200.png)
44

5-
[![Build Status](https://travis-ci.org/keras-team/keras.svg?branch=master)](https://travis-ci.org/keras-team/keras)
5+
| ubuntu/python-2.7 | ubuntu/python-3.5 |
6+
|---------|---------|
7+
| ![Python3 Build Status](https://codebuild.us-east-1.amazonaws.com/badges?uuid=eyJlbmNyeXB0ZWREYXRhIjoidHBzRFVlMG5SMGFQRTVzMUhxejNIK2dZRU1kb3p2c0JIbTVObDZtdDgxYThYdjRCZlg0RGF1eCsrSUtGQmgwYkFkZzJaT1BrdHpqcVJqcWE2aSt6QmRnPSIsIml2UGFyYW1ldGVyU3BlYyI6IklPMmRORld4TDYrdWNrWDciLCJtYXRlcmlhbFNldFNlcmlhbCI6MX0%3D&branch=master) | ![Python2 Build Status](https://codebuild.us-east-1.amazonaws.com/badges?uuid=eyJlbmNyeXB0ZWREYXRhIjoibHFOTlladW1VK050SFBST1N0UUtNOGdOV24vM25hVUJDQVVKNitvSFpXTFZ4RzlvUXppdHU4RytRR3hLdk1nSDd2VHlTSlZ5ZTlCUC9GdWdscHZRRFBNPSIsIml2UGFyYW1ldGVyU3BlYyI6IjZrQksycy9aWWV5QXh1MkoiLCJtYXRlcmlhbFNldFNlcmlhbCI6MX0%3D&branch=master) |
8+
69
[![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/keras-team/keras/blob/master/LICENSE)
710

811
## You have just found Keras.
912

10-
Keras is a high-level neural networks API, written in Python and capable of running on top of [TensorFlow](https://github.com/tensorflow/tensorflow), [CNTK](https://github.com/Microsoft/cntk), or [Theano](https://github.com/Theano/Theano). It was developed with a focus on enabling fast experimentation. *Being able to go from idea to result with the least possible delay is key to doing good research.*
13+
Keras is a high-level neural networks API, written in Python and capable of running on top of [TensorFlow](https://github.com/tensorflow/tensorflow), [CNTK](https://github.com/Microsoft/cntk), [Apache MXNet](https://github.com/apache/incubator-mxnet/), or [Theano](https://github.com/Theano/Theano). It was developed with a focus on enabling fast experimentation. *Being able to go from idea to result with the least possible delay is key to doing good research.*
1114

1215
Use Keras if you need a deep learning library that:
1316

@@ -117,6 +120,7 @@ Before installing Keras, please install one of its backend engines: TensorFlow,
117120
- [TensorFlow installation instructions](https://www.tensorflow.org/install/).
118121
- [Theano installation instructions](http://deeplearning.net/software/theano/install.html#install).
119122
- [CNTK installation instructions](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-on-your-machine).
123+
- [MXNet installation instructions](http://mxnet.incubator.apache.org/install/index.html).
120124

121125
You may also consider installing the following **optional dependencies**:
122126

@@ -155,7 +159,7 @@ sudo python setup.py install
155159
------------------
156160

157161

158-
## Using a different backend than TensorFlow
162+
## Switching from TensorFlow to CNTK, MXNet or Theano
159163

160164
By default, Keras will use TensorFlow as its tensor manipulation library. [Follow these instructions](https://keras.io/backend/) to configure the Keras backend.
161165

benchmark/README.md

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Keras Benchmarks
2+
3+
## Overview
4+
The benchmark module aims to provide a performance comparison on different Keras backends using various models and
5+
dataset on CPU, 1 GPU and multi-GPU machines.
6+
Currently supported backends: TensorFlow, Apache MXNet
7+
8+
## Setup
9+
To install MXNet backend refer to
10+
[Installation](https://github.com/awslabs/keras-apache-mxnet/wiki/Installation#1-install-keras-with-apache-mxnet-backend)
11+
12+
To switch between different backends refer to
13+
[configure Keras backend](https://github.com/awslabs/keras-apache-mxnet/wiki/Installation#2-configure-keras-backend)
14+
15+
## CNN Benchmarks
16+
We provide benchmark scripts to run on CIFAR-10, ImageNet and Synthetic Dataset(randomly generated)
17+
18+
### CIFAR-10 Dataset
19+
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset has 60000 32x32 color images in 10 classes.
20+
The [training scripts](https://github.com/awslabs/keras-apache-mxnet/blob/master/benchmark/image-classification/benchmark_resnet.py)
21+
will automatically download the dataset, you need to provide dataset name, resnet version
22+
(1 or 2), number of layers (20, 56, or 110), number of GPUs to use.
23+
24+
Example Usage:
25+
26+
`python benchmark_resnet.py --dataset cifar10 --version 1 --layers 56 --gpus 4`
27+
28+
29+
### ImageNet Dataset
30+
First, download ImageNet Dataset from [here](http://image-net.org/download), there are total 1.4 million images
31+
with 1000 classes, each class is in a subfolder. In this script, each image is processed to size 256x256
32+
33+
Since ImageNet Dataset is too large, there are two training mode for data that does not fit into memory:
34+
[`train_on_batch`](https://keras.io/models/sequential/#train_on_batch) and
35+
[`fit_generator`](https://keras.io/models/sequential/#fit_generator),
36+
we recommend train_on_batch since it's more efficient on multi_gpu.
37+
(Refer to [Keras Document](https://keras.io/getting-started/faq/#how-can-i-use-keras-with-datasets-that-dont-fit-in-memory)
38+
and Keras Issue [#9502](https://github.com/keras-team/keras/issues/9502),
39+
[#9204](https://github.com/keras-team/keras/issues/9204), [#9647](https://github.com/keras-team/keras/issues/9647))
40+
41+
Compare to CIFAR-10, you need to provide additional params: training mode and path to imagenet dataset.
42+
43+
Example usage:
44+
45+
`python benchmark_resnet.py --dataset imagenet --mxnet_backend_training_speed.pngversion 1 -layers 56 --gpus 4 --train_mode train_on_batch --data_path home/ubuntu/imagenet/train/`
46+
47+
### Synthetic Dataset
48+
We used benchmark scripts from
49+
[TensorFlow Benchmark](https://github.com/tensorflow/benchmarks/tree/keras-benchmarks/scripts/keras_benchmarks)
50+
official repo, and modified slightly for our use case.
51+
52+
Directly run the shell script to launch the benchmark, provide one of the configurations in config.json and whether
53+
you want to benchmark inference speed (True or False).
54+
55+
Example Usage:
56+
57+
`sh run_<backend-type>_backend.sh gpu_config False`
58+
59+
### CNN Benchmark Results
60+
Here we list the result of MXNet backend training speed on CIFAR-10, ImageNet and Synthetic Data using
61+
ResNet50V1 model, on CPU, 1, 4, 8 GPUs using AWS instances.
62+
Hardware specifications of the instances can be found [here](https://aws.amazon.com/ec2/instance-types/)
63+
64+
For more detailed benchmark results, please refer to [CNN results](https://github.com/awslabs/keras-apache-mxnet/tree/keras2_mxnet_backend/benchmark/benchmark_result/CNN_result.md).
65+
66+
|||
67+
| ------ | ------ |
68+
| Keras Version | 2.1.5 |
69+
| MXNet Version | 1.1.0 |
70+
| Data Format | Channel first |
71+
72+
| Instance | GPU used | Package | CIFAR-10 | ImageNet | Synthetic Data |
73+
| ------ | ------ | ------ | ------ | ------ | ------ |
74+
| C5.18xLarge | 0 | mxnet-mkl | 87 | N/A | 9 |
75+
| P3.8xLarge | 1 | mxnet-cu90 | N/A | 165 | 229 |
76+
| P3.8xLarge | 4 | mxnet-cu90 | 1792 | 538 | 728 |
77+
| P3.16xLarge | 8 | mxnet-cu90 | 1618 | 728 | 963 |
78+
79+
![MXNet backend training speed](https://github.com/roywei/keras/blob/benchmark_result/benchmark/benchmark_result/mxnet_backend_training_speed.png)
80+
81+
Note: X-axis is number of GPUs used, Y-axis is training speed(images/second)
82+
83+
## RNN Benchmarks
84+
85+
We provide benchmark scripts to run on Synthetic(randomly generated), Nietzsche, and WikiText-2 character level Dataset.
86+
87+
Directly run the shell script to launch the benchmark, provide one of the configurations in config.json and whether you want to benchmark inference speed (True or False).
88+
89+
Example Usage:
90+
91+
`sh run_<backend-type>_backend.sh gpu_config False`
92+
93+
### Synthetic Dataset
94+
95+
We used benchmark scripts from [TensorFlow Benchmark](https://github.com/tensorflow/benchmarks/tree/keras-benchmarks/scripts/keras_benchmarks) official repo, and modified slightly for our use case.
96+
97+
### Nietzsche Dataset
98+
99+
We have used an official Keras LSTM example scripts [lstm_text_generation.py](https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py), and modified slightly for our use case.
100+
101+
### WikiText-2 Dataset
102+
103+
We have used an official WikiText-2 character level Dataset from this [link](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset).
104+
105+
The `lstm_text_generation_wikitext2.py` includes a dataset that is hosted on S3 bucket from this [link](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip) (This is a WikiText-2 raw character level data).
106+
107+
### RNN Benchmark Results
108+
109+
Here, we list the result on Synthetic, Nietzsche, and WikiText-2 dataset using Sequential model(LSTM) on Amazon AWS C5.xLarge(CPU) instance and P3.8xLarge(1, 4 GPUs) with MXNet backend. Batch size is 128. For more details about the instance configuration, please refer [P3](https://aws.amazon.com/ec2/instance-types/p3/) and [C5](https://aws.amazon.com/ec2/instance-types/c5/).
110+
111+
| Instance | GPUs | Data Set | Speed/Epoch <br />(Lower is better) |
112+
| ---------- | ---- | ---------- | ----------------------------------- |
113+
| C5.xLarge | 0 | Synthetic | 91 sec - 2ms/step |
114+
| P3.8xLarge | 1 | Synthetic | 13 sec - 264us/step |
115+
| P3.8xLarge | 4 | Synthetic | 12 sec - 241us/step |
116+
| C5.xLarge | 0 | Nietzsche | 352 sec - 2ms/step |
117+
| P3.8xLarge | 1 | Nietzsche | 53 sec - 265us/step |
118+
| P3.8xLarge | 4 | Nietzsche | 47 sec - 236us/step |
119+
| C5.xLarge | 0 | WikiText-2 | 6410 sec - 2ms/step |
120+
| P3.8xLarge | 1 | WikiText-2 | 882 sec - 264us/step |
121+
| P3.8xLarge | 4 | WikiText-2 | 794 sec - 235us/step |
122+
123+
124+
125+
## Credits
126+
127+
Synthetic Data scripts modified from
128+
[TensorFlow Benchmarks](https://github.com/tensorflow/benchmarks/tree/keras-benchmarks)
129+
130+
## Reference
131+
[1] [TensorFlow Benchmarks](https://github.com/tensorflow/benchmarks/tree/keras-benchmarks)

benchmark/__init__.py

Whitespace-only changes.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Detailed CNN Benchmark Results
2+
## CIFAR-10 Dataset
3+
### Configauration
4+
|||
5+
|---|---|
6+
| Data Set | [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) |
7+
| Keras Version | 2.1.5 |
8+
| TensorFlow Version | 1.7.0 |
9+
| MXNet Version | 1.1.0 |
10+
| Training Method | [`fit`](https://keras.io/models/model/#fit) |
11+
| Training Scripts | [Simple CNN Script](https://github.com/awslabs/keras-apache-mxnet/blob/master/examples/CIFAR-10_cnn.py), [ResNet Script](https://github.com/awslabs/keras-apache-mxnet/blob/master/benchmark/image-classification/benchmark_resnet.py) |
12+
13+
### Results
14+
15+
| Instance Type | GPU used | Model | Backend | Package | Batch Size | Data Format | Speed (images/s) |
16+
| ------ | ------ | ------ | ------ | ------ | ------ | ------ | ------ |
17+
| C5.xLarge | 0 | Simple CNN | MXNet | mxnet-mkl | 32 | channel last | 253 |
18+
| C5.xLarge | 0 | Simple CNN | MXNet | mxnet-mkl | 32 | channel first | 223 |
19+
| C5.xLarge | 0 | Simple CNN | TensorFlow | tensorflow | 32 | channel last | 309 |
20+
| C5.xLarge | 0 | Simple CNN | TensorFlow | tensorflow | 32 | channel first | 101 |
21+
| C5.18xLarge | 0 | Simple CNN | MXNet | mxnet-mkl | 32 | channel last | 845 |
22+
| C5.18xLarge | 0 | Simple CNN | MXNet | mxnet-mkl | 32 | channel first | 936 |
23+
| C5.18xLarge | 0 | ReNet50V1 | TensorFlow | tensorflow | 32 | channel last | 59 |
24+
| C5.18xLarge | 0 | ReNet50V1 | TensorFlow | tensorflow | 32 | channel first | 41 |
25+
| C5.18xLarge | 0 | ReNet50V1 | MXNet | mxnet-mkl |32 | channel last | 48 |
26+
| C5.18xLarge | 0 | ReNet50V1 | MXNet | mxnet-mkl | 32 | channel first | 87 |
27+
| P3.8xLarge | 4 | ReNet50V1 | TensorFlow | tensorflow-gpu |128 | channel last | 1020 |
28+
| P3.8xLarge | 4 | ReNet50V1 | MXNet | mxnet-cu90 | 128 | channel first | 1792 |
29+
| P3.8xLarge | 8 | ReNet50V1 | TensorFlow | tensorflow-gpu |256 | channel last | 962 |
30+
| P3.16xLarge | 8 | ReNet50V1 | MXNet | mxnet-cu90 | 256 | channel first | 1618 |
31+
32+
## ImageNet Dataset
33+
34+
### Configuration
35+
|||
36+
|---|---|
37+
| Data Set | [ImageNet](http://image-net.org) |
38+
| Model | ResNet50V1|
39+
| Keras Version | 2.1.3 |
40+
| TensorFlow Version | 1.6.0rc1 |
41+
| MXNet Version | 1.1.0 |
42+
| Training Method | [`train_on_batch`](https://keras.io/models/sequential/#train_on_batch), [`fit_generator`](https://keras.io/models/sequential/#fit_generator) |
43+
| Training Scripts | [ResNet Script](https://github.com/awslabs/keras-apache-mxnet/blob/master/benchmark/image-classification/benchmark_resnet.py) |
44+
45+
### Results
46+
47+
| Instance | GPU used | Backend | Package | Method | Batch Size | Data Format | Speed (images/s) |
48+
| ------ | ------ | ------ | ------ | ------ | ------ | ------ | ------ |
49+
| P3.8xLarge | 1 | TensorFlow | tensorflow-gpu | `train_on_batch` | 32 | channel last | 50 |
50+
| P3.8xLarge | 1 | MXNet | mxnet-cu90 | `train_on_batch` | 32 | channel first | 165 |
51+
| P3.8xLarge | 4 | TensorFlow | tensorflow-gpu | `train_on_batch` | 128 | channel last | 162 |
52+
| P3.8xLarge | 4 | MXNet | mxnet-cu90 | `train_on_batch` | 128 | channel first | 538 |
53+
| P3.16xLarge | 8 | TensorFlow | tensorflow-gpu | `train_on_batch` | 256 | channel last | 212 |
54+
| P3.16xLarge | 8 | MXNet | mxnet-cu90 | `train_on_batch` | 256 | channel first | 728 |
55+
| P3.8xLarge | 1 | TensorFlow | tensorflow-gpu | `fit_generator` | 32 | channel last | 53 |
56+
| P3.8xLarge | 1 | MXNet | mxnet-cu90 | `fit_generator` | 32 | channel first | 73 |
57+
| P3.8xLarge | 4 | TensorFlow | tensorflow-gpu | `fit_generator` | 128 | channel last | 173 |
58+
| P3.8xLarge | 4 | MXNet | mxnet-cu90 | `fit_generator` | 128 | channel first | 197 |
59+
60+
## Synthetic Dataset
61+
62+
### Configuration
63+
|||
64+
|---|---|
65+
| Data Set | Random 256x256 color images, 1000 classes |
66+
| Model | ResNet50V1|
67+
| Keras Version | 2.1.3 |
68+
| TensorFlow Version | 1.6.0rc1 |
69+
| MXNet Version | 1.1.0 |
70+
| Training Method |[`fit`](https://keras.io/models/model/#fit) |
71+
| Training Scripts | [ResNet Script](https://github.com/awslabs/keras-apache-mxnet/tree/keras2_mxnet_backend/benchmark/synthetic) |
72+
73+
### Results
74+
75+
| Instance | GPU used | Backend | Package | Batch Size | Data Format | Speed (images/s) |
76+
| ------ | ------ | ------ | ------ | ------ | ------ | ------ |
77+
| C5.18xLarge | 0 | TensorFlow| tensorflow |32| channel first |4|
78+
| C5.18xLarge | 0 | MXNet | mxnet-mkl | 32 | channel first| 9|
79+
| P3.8xLarge | 1 | TensorFlow | tensorflow-gpu | 32 | channel first | 198|
80+
| P3.8xLarge | 1 | MXNet | mxnet-cu90 | 32 | channel first | 229 |
81+
| P3.8xLarge | 4 | TensorFlow | tensorflow-gpu | 128 | channel first | 448 |
82+
| P3.8xLarge | 4 | MXNet | mxnet-cu90 | 128 | channel first | 728 |
83+
| P3.16xLarge | 8 | TensorFlow | tensorflow-gpu | 256 | channel first | 346 |
84+
| P3.16xLarge | 8 | MXNet | mxnet-cu90 | 256 | channel first | 963 |
85+
| C5.18xLarge | 0 | TensorFlow| tensorflow |32| channel last | 4 |
86+
| C5.18xLarge | 0 | MXNet | mxnet-mkl | 32 | channel last | 3 |
87+
| P3.8xLarge | 1 | TensorFlow | tensorflow-gpu | 32 | channel last | 164|
88+
| P3.8xLarge | 1 | MXNet | mxnet-cu90 | 32 | channel last | 18 |
89+
| P3.8xLarge | 4 | TensorFlow | tensorflow-gpu | 128 | channel last | 409 |
90+
| P3.8xLarge | 4 | MXNet | mxnet-cu90 | 128 | channel last | 73 |
91+
| P3.16xLarge | 8 | TensorFlow | tensorflow-gpu | 256 | channel last | 164 |
92+
| P3.16xLarge | 8 | MXNet | mxnet-cu90 | 256 | channel last | 18 |
9.61 KB
Loading

benchmark/scripts/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)