Skip to content

Commit 2795707

Browse files
authored
Merge pull request #97 from delta-rs/62-fix-unit-tests-in-mnistrs-cifar10rs-and-imagenet_v2rs
resolves #62 add mnist and cifar10 tests
2 parents 19f27b4 + 40e94f3 commit 2795707

File tree

4 files changed

+251
-193
lines changed

4 files changed

+251
-193
lines changed

Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ resolver = "2"
99

1010
[workspace.dependencies]
1111
tokio = { version = "1.32.0", features = ["full"] }
12-
ndarray = "0.15"

delta/src/dataset/image/cifar10.rs

+105-76
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@
2727
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2828
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
30-
use crate::common::{Tensor};
30+
use crate::common::Tensor;
31+
use crate::dataset::base::{Dataset, ImageDatasetOps};
32+
use crate::get_workspace_dir;
3133
use flate2::read::GzDecoder;
3234
use log::debug;
35+
use ndarray::{IxDyn, Shape};
3336
use std::collections::HashSet;
3437
use std::fs;
3538
use std::fs::File;
3639
use std::future::Future;
3740
use std::io::Read;
3841
use std::path::Path;
3942
use std::pin::Pin;
40-
use ndarray::{IxDyn, Shape};
4143
use tar::Archive;
42-
use crate::dataset::base::{Dataset, ImageDatasetOps};
43-
use crate::get_workspace_dir;
4444

4545
/// A struct representing the CIFAR10 dataset.
4646
pub struct Cifar10Dataset {
@@ -155,7 +155,11 @@ impl Cifar10Dataset {
155155

156156
for &file in files {
157157
let (img, lbl) = Self::parse_file(
158-
&format!("{}/.cache/dataset/cifar10/cifar-10-batches-bin/{}", env!("CARGO_MANIFEST_DIR"), file),
158+
&format!(
159+
"{}/.cache/dataset/cifar10/{}",
160+
get_workspace_dir().display(),
161+
file
162+
),
159163
total_examples / files.len(),
160164
);
161165
images.extend(img);
@@ -172,7 +176,10 @@ impl Cifar10Dataset {
172176
3,
173177
])),
174178
),
175-
Tensor::new(labels, Shape::from(IxDyn(&[total_examples, Self::CIFAR10_NUM_CLASSES]))),
179+
Tensor::new(
180+
labels,
181+
Shape::from(IxDyn(&[total_examples, Self::CIFAR10_NUM_CLASSES])),
182+
),
176183
)
177184
}
178185

@@ -397,77 +404,99 @@ impl ImageDatasetOps for Cifar10Dataset {
397404
Self {
398405
train: self.train.clone(),
399406
test: self.test.clone(),
400-
val: self.val.clone()
407+
val: self.val.clone(),
401408
}
402409
}
403410
}
404411

405-
// #[cfg(test)]
406-
// mod tests {
407-
// use super::*;
408-
// use serial_test::serial;
409-
// use tokio::runtime::Runtime;
410-
//
411-
// fn setup() {
412-
// let workspace_dir = get_workspace_dir();
413-
// let cache_path = format!("{}/.cache/dataset/cifar10", workspace_dir.display());
414-
// if Path::new(&cache_path).exists() {
415-
// fs::remove_dir_all(&cache_path).expect("Failed to delete cache directory");
416-
// }
417-
// }
418-
//
419-
// #[test]
420-
// #[serial]
421-
// fn test_download_and_extract() {
422-
// setup();
423-
// let rt = Runtime::new().unwrap();
424-
// rt.block_on(async {
425-
// Cifar10Dataset::download_and_extract().await;
426-
// let workspace_dir = get_workspace_dir();
427-
// let cache_path = format!("{}/.cache/dataset/cifar10/cifar-10-binary", workspace_dir.display());
428-
// assert!(Path::new(&cache_path).exists(), "CIFAR-10 dataset should be downloaded and extracted");
429-
// });
430-
// }
431-
//
432-
// #[test]
433-
// #[serial]
434-
// fn test_parse_file() {
435-
// // Ensure the dataset is downloaded before parsing
436-
// test_download_and_extract();
437-
//
438-
// let (images, labels) = Cifar10Dataset::parse_file("path/to/data_batch_1.bin", 10000);
439-
// assert_eq!(images.len(), 10000 * 32 * 32 * 3, "Images should have the correct length");
440-
// assert_eq!(labels.len(), 10000 * 10, "Labels should have the correct length");
441-
// }
442-
//
443-
// #[test]
444-
// #[serial]
445-
// fn test_load_data() {
446-
// // Ensure the dataset is downloaded before loading data
447-
// test_download_and_extract();
448-
//
449-
// let dataset = Cifar10Dataset::load_data(&["data_batch_1.bin"], 10000);
450-
// assert_eq!(dataset.inputs.shape(), &[10000, 32, 32, 3], "Dataset inputs should have the correct shape");
451-
// assert_eq!(dataset.labels.shape(), &[10000, 10], "Dataset labels should have the correct shape");
452-
// }
453-
//
454-
// #[test]
455-
// #[serial]
456-
// fn test_load_train() {
457-
// let rt = Runtime::new().unwrap();
458-
// rt.block_on(async {
459-
// let dataset = Cifar10Dataset::load_train().await;
460-
// assert!(dataset.train.is_some(), "Training dataset should be loaded");
461-
// });
462-
// }
463-
//
464-
// #[test]
465-
// #[serial]
466-
// fn test_load_test() {
467-
// let rt = Runtime::new().unwrap();
468-
// rt.block_on(async {
469-
// let dataset = Cifar10Dataset::load_test().await;
470-
// assert!(dataset.test.is_some(), "Test dataset should be loaded");
471-
// });
472-
// }
473-
// }
412+
#[cfg(test)]
413+
mod tests {
414+
use super::*;
415+
use ndarray::Dimension;
416+
use serial_test::serial;
417+
418+
fn setup() {
419+
let workspace_dir = get_workspace_dir();
420+
let cache_path = format!("{}/.cache/dataset/cifar10", workspace_dir.display());
421+
if Path::new(&cache_path).exists() {
422+
fs::remove_dir_all(&cache_path).expect("Failed to delete cache directory");
423+
}
424+
}
425+
426+
#[tokio::test]
427+
#[serial]
428+
async fn test_download_and_extract() {
429+
setup();
430+
Cifar10Dataset::download_and_extract().await;
431+
let workspace_dir = get_workspace_dir();
432+
let cache_path = format!(
433+
"{}/.cache/dataset/cifar10/data_batch_1.bin",
434+
workspace_dir.display()
435+
);
436+
assert!(
437+
Path::new(&cache_path).exists(),
438+
"CIFAR-10 dataset should be downloaded and extracted"
439+
);
440+
}
441+
442+
#[test]
443+
#[serial]
444+
fn test_parse_file() {
445+
// Ensure the dataset is downloaded before parsing
446+
test_download_and_extract();
447+
let workspace_dir = get_workspace_dir();
448+
let cache_path = format!(
449+
"{}/.cache/dataset/cifar10/data_batch_1.bin",
450+
workspace_dir.display()
451+
);
452+
453+
let (images, labels) = Cifar10Dataset::parse_file(&cache_path, 10000);
454+
assert_eq!(
455+
images.len(),
456+
10000 * 32 * 32 * 3,
457+
"Images should have the correct length"
458+
);
459+
assert_eq!(
460+
labels.len(),
461+
10000 * 10,
462+
"Labels should have the correct length"
463+
);
464+
}
465+
466+
#[test]
467+
#[serial]
468+
fn test_load_data() {
469+
// Ensure the dataset is downloaded before loading data
470+
test_download_and_extract();
471+
472+
let dataset = Cifar10Dataset::load_data(&["data_batch_1.bin"], 10000);
473+
474+
// Compare the shape of inputs
475+
assert_eq!(
476+
dataset.inputs.shape().raw_dim().as_array_view().to_vec(),
477+
&[10000, 32, 32, 3],
478+
"Dataset inputs should have the correct shape"
479+
);
480+
481+
// Compare the shape of labels
482+
assert_eq!(
483+
dataset.labels.shape().raw_dim().as_array_view().to_vec(),
484+
&[10000, 10],
485+
"Dataset labels should have the correct shape"
486+
);
487+
}
488+
489+
#[tokio::test]
490+
#[serial]
491+
async fn test_load_train() {
492+
let dataset = Cifar10Dataset::load_train().await;
493+
assert!(dataset.train.is_some(), "Training dataset should be loaded");
494+
}
495+
496+
#[tokio::test]
497+
#[serial]
498+
async fn test_load_test() {
499+
let dataset = Cifar10Dataset::load_test().await;
500+
assert!(dataset.test.is_some(), "Test dataset should be loaded");
501+
}
502+
}

0 commit comments

Comments
 (0)