Skip to content

AvrahamTsaban/torch

Repository files navigation

PyTorch MNIST & Fashion-MNIST Classifier

A PyTorch-based image classification project for training neural networks on MNIST and Fashion-MNIST datasets, with support for custom datasets.

Project Overview

This project implements several neural network architectures to classify handwritten digits (MNIST) and fashion items (Fashion-MNIST). It includes functionality to train models, evaluate them on standard test sets, and test on custom user-provided images.

Project Structure

.
├── main.py              # Main training and evaluation pipeline
├── models.py            # Neural network architecture definitions
├── datasets.py          # Dataset loading and preprocessing
├── custom_dataset.py    # Custom dataset class and image preprocessing
├── data/
│   ├── MNIST/          # MNIST dataset (auto-downloaded)
│   ├── FashionMNIST/   # Fashion-MNIST dataset (auto-downloaded)
│   ├── myNums/         # Custom handwritten digit images
│   └── myFashion/      # Custom fashion item images
└── examples/           # Example outputs and results

Features

  • Multiple Model Architectures: Three progressively larger neural networks

  • Dataset Support:

    • MNIST handwritten digits
    • Fashion-MNIST clothing items
    • Custom user-provided images via CustomDataset
  • Image Preprocessing: mnistify function converts any image to MNIST format (28×28 grayscale, inverted, autocontrasted)

  • Training Features:

    • Automatic device detection (CPU/CUDA)
    • AdamW optimizer with weight decay
    • Cross-entropy loss
    • Accuracy and loss tracking per epoch
  • Visualization: plot_results function generates accuracy graphs over training epochs

Quick Start

Installation

pip install torch torchvision pillow matplotlib

Running the Project

python main.py

This will:

  1. Train each model (Base_Model, Long_Model, Giant_Model) on each dataset (Fashion-MNIST, MNIST) for 20 epochs
  2. Evaluate on built-in test set (validation) and custom images
  3. Display accuracy plots for each model × dataset combination

Using Custom Images

Requirements:

  1. Place images in data/myNums/ or data/myFashion/
  2. Name files with the label as the first character (e.g., 3_handwritten.png for digit 3)
  3. Images must have a white or transparent background
  4. Supported format: PNG files
  5. The mnistify function will automatically preprocess them to MNIST format

API Reference

Main Functions

Trains and tests a model on given datasets.

Parameters:

  • model: Model class to instantiate
  • train_d: Training dataset
  • test_d1: Built-in test dataset (validation)
  • test_d2: Custom test dataset (optional)
  • device: "cpu" or "cuda" (default: "cpu")
  • epochs: Number of training epochs (default: 5)

Returns: (train_results, test_results, test2_results)

Trains the model for one epoch.

Returns: Training accuracy for the epoch

Evaluates model on test dataset.

Parameters:

  • verbose: If True, saves classified images to {dataset.path}/classified/{predicted_label}/sample_{i}.png

Returns: Test accuracy

Custom Dataset

PyTorch Dataset for loading and preprocessing custom images.

Parameters:

  • path: Directory containing PNG images
  • transform: Optional torchvision transform
  • img_process: Preprocessing function (default: mnistify)

Image Naming Convention: First character of filename is the label (e.g., 7_shoe.png → label 7)

Converts any PIL image to MNIST format:

  • Converts to 28×28 grayscale
  • Inverts colors (white background → black background)
  • Applies autocontrast
  • Crops to bounding box
  • Resamples using Lanczos interpolation

Models

All models inherit from Base_Model and use a flatten layer followed by fully connected layers with ReLU activation. Each model has a name attribute for identification.

  • Base_Model (name='Base'): 28×28 → 64 → 32 → 10
  • Long_Model (name='Large'): 28×28 → 1024 → 1024 → 1024 → 10
  • Giant_Model (name='Giant'): 28×28 → 16384 → 16384 → 8192 → 512 → 10

All models are also available as a list via models (exported from models.py).

Example Usage

Training a Custom Model

from models import Base_Model
from datasets import number_train, number_test
from main import deep_learning

train_acc, test_acc, _ = deep_learning(
    model=Base_Model,
    train_d=number_train,
    test_d1=number_test,
    device="cuda",
    epochs=10
)

Processing Custom Images

from custom_dataset import mnistify
from PIL import Image

with Image.open("my_digit.png") as img:
    processed = mnistify(img)
    processed.save("my_digit_mnist.png")

Using the Verbose Test Mode

from main import test
from datasets import my_number_test
from torch.utils.data import DataLoader

# This will save classified images to data/myNums/classified/{label}/
test(
    DataLoader(my_number_test, batch_size=32),
    model,
    loss_fn,
    device="cpu",
    verbose=True
)

Available Datasets

From datasets.py:

Dataset Bundles

Datasets are grouped into Dataset_Bundle objects for easy iteration:

from datasets import sets  # list of Dataset_Bundle objects

for dataset in sets:
    print(dataset.name)        # e.g. 'Fashion', 'Number'
    print(dataset.train)       # training dataset
    print(dataset.validation)  # built-in test dataset
    print(dataset.test)        # custom test dataset

Available bundles in sets:

  • Fashion: fashion_train, fashion_test, my_fashion_test
  • Number: number_train, number_test, my_number_test

License

GNU General Public License v3.0 - see LICENSE

Notes

  • Images are automatically downloaded on first run
  • Training uses AdamW optimizer (lr=1e-3, weight_decay=1e-3)
  • Batch size is set to 128
  • Models automatically move to available accelerator (CUDA if available)
  • Special thanks to claude sonnet for writing this file in a minute

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages