A PyTorch-based image classification project for training neural networks on MNIST and Fashion-MNIST datasets, with support for custom datasets.
This project implements several neural network architectures to classify handwritten digits (MNIST) and fashion items (Fashion-MNIST). It includes functionality to train models, evaluate them on standard test sets, and test on custom user-provided images.
.
├── main.py # Main training and evaluation pipeline
├── models.py # Neural network architecture definitions
├── datasets.py # Dataset loading and preprocessing
├── custom_dataset.py # Custom dataset class and image preprocessing
├── data/
│ ├── MNIST/ # MNIST dataset (auto-downloaded)
│ ├── FashionMNIST/ # Fashion-MNIST dataset (auto-downloaded)
│ ├── myNums/ # Custom handwritten digit images
│ └── myFashion/ # Custom fashion item images
└── examples/ # Example outputs and results
-
Multiple Model Architectures: Three progressively larger neural networks
Base_Model: Lightweight model (51,210 parameters)Long_Model: Medium model (3,156,490 parameters)Giant_Model: Large model (536,928,778 parameters)
-
Dataset Support:
- MNIST handwritten digits
- Fashion-MNIST clothing items
- Custom user-provided images via
CustomDataset
-
Image Preprocessing:
mnistifyfunction converts any image to MNIST format (28×28 grayscale, inverted, autocontrasted) -
Training Features:
- Automatic device detection (CPU/CUDA)
- AdamW optimizer with weight decay
- Cross-entropy loss
- Accuracy and loss tracking per epoch
-
Visualization:
plot_resultsfunction generates accuracy graphs over training epochs
pip install torch torchvision pillow matplotlibpython main.pyThis will:
- Train each model (
Base_Model,Long_Model,Giant_Model) on each dataset (Fashion-MNIST, MNIST) for 20 epochs - Evaluate on built-in test set (validation) and custom images
- Display accuracy plots for each model × dataset combination
Requirements:
- Place images in
data/myNums/ordata/myFashion/ - Name files with the label as the first character (e.g.,
3_handwritten.pngfor digit 3) - Images must have a white or transparent background
- Supported format: PNG files
- The
mnistifyfunction will automatically preprocess them to MNIST format
Trains and tests a model on given datasets.
Parameters:
model: Model class to instantiatetrain_d: Training datasettest_d1: Built-in test dataset (validation)test_d2: Custom test dataset (optional)device: "cpu" or "cuda" (default: "cpu")epochs: Number of training epochs (default: 5)
Returns: (train_results, test_results, test2_results)
Trains the model for one epoch.
Returns: Training accuracy for the epoch
Evaluates model on test dataset.
Parameters:
verbose: IfTrue, saves classified images to{dataset.path}/classified/{predicted_label}/sample_{i}.png
Returns: Test accuracy
PyTorch Dataset for loading and preprocessing custom images.
Parameters:
path: Directory containing PNG imagestransform: Optional torchvision transformimg_process: Preprocessing function (default:mnistify)
Image Naming Convention: First character of filename is the label (e.g., 7_shoe.png → label 7)
Converts any PIL image to MNIST format:
- Converts to 28×28 grayscale
- Inverts colors (white background → black background)
- Applies autocontrast
- Crops to bounding box
- Resamples using Lanczos interpolation
All models inherit from Base_Model and use a flatten layer followed by fully connected layers with ReLU activation. Each model has a name attribute for identification.
Base_Model(name='Base'): 28×28 → 64 → 32 → 10Long_Model(name='Large'): 28×28 → 1024 → 1024 → 1024 → 10Giant_Model(name='Giant'): 28×28 → 16384 → 16384 → 8192 → 512 → 10
All models are also available as a list via models (exported from models.py).
from models import Base_Model
from datasets import number_train, number_test
from main import deep_learning
train_acc, test_acc, _ = deep_learning(
model=Base_Model,
train_d=number_train,
test_d1=number_test,
device="cuda",
epochs=10
)from custom_dataset import mnistify
from PIL import Image
with Image.open("my_digit.png") as img:
processed = mnistify(img)
processed.save("my_digit_mnist.png")from main import test
from datasets import my_number_test
from torch.utils.data import DataLoader
# This will save classified images to data/myNums/classified/{label}/
test(
DataLoader(my_number_test, batch_size=32),
model,
loss_fn,
device="cpu",
verbose=True
)From datasets.py:
fashion_train/fashion_test: FashionMNISTnumber_train/number_test: MNISTmy_fashion_test: Custom fashion imagesmy_number_test: Custom digit images
Datasets are grouped into Dataset_Bundle objects for easy iteration:
from datasets import sets # list of Dataset_Bundle objects
for dataset in sets:
print(dataset.name) # e.g. 'Fashion', 'Number'
print(dataset.train) # training dataset
print(dataset.validation) # built-in test dataset
print(dataset.test) # custom test datasetAvailable bundles in sets:
- Fashion:
fashion_train,fashion_test,my_fashion_test - Number:
number_train,number_test,my_number_test
GNU General Public License v3.0 - see LICENSE
- Images are automatically downloaded on first run
- Training uses AdamW optimizer (lr=1e-3, weight_decay=1e-3)
- Batch size is set to 128
- Models automatically move to available accelerator (CUDA if available)
- Special thanks to claude sonnet for writing this file in a minute