A web application that generates custom Pokemon based on user-selected types and other parameters using a conditional GAN (Generative Adversarial Network) trained on Pokemon images and a RandomForestRegresser trained on stats.
Deployed at: https://original-pokemon-generator-project.fly.dev/
-
When accessing the website, it may take some time for it to load because it needs to cold start and spin up. Since Fly.io operates with virtual machines and a pay-as-you-go pricing structure, I have configured machines to automatically spin down to save money. They will automatically spin up upon a new request, but it may take a little bit of time (~15 seconds)
-
A commit in this project encompasses a large amount of refactored and added code. This was the result of a very concentrated effort densely packed into 4 days to restructure and optimize the codebase in order to deploy it to a live website. I anticipated a very small amount of changes and configurations, but I got very into the refactoring that I made many, many changes in the process that I didn't commit until the end of my spree. While I typically favor smaller, more frequent commits for easier code review and tracking, I chose to consolidate this particular refactor into a single commit once the entire module was stable and tested to push to a live website.
- Generate custom Pokemon images based on type, height, weight, generation, and legendary status
- View predicted stats based on Pokemon types
- Responsive web interface
- Docker support for easy deployment
- PostgreSQL database for storing generated images
- Python 3.9+: Core programming language
- Flask: Web framework for the API
- SQLAlchemy: ORM for database operations
- PostgreSQL: Database for storing generated images
- PyTorch: Deep learning framework for the GAN model
- scikit-learn: For training stat prediction models
- pandas & numpy: Data processing and manipulation
- HTML/CSS/JavaScript: Frontend interface
- Tailwind CSS: For frontend styling
- Docker & Docker Compose: Containerization
- Gunicorn: WSGI HTTP Server
- Fly.io: Cloud deployment platform
The Pokemon Generator uses a conditional GAN architecture specifically designed to generate 256x256 pixel Pokemon images based on type and other attributes.
The generator takes a random noise vector (z_dim=100) and a condition vector as input and produces a 256x256 RGB image:
-
Condition Processing: The condition vector (containing type, height, weight, generation, and legendary status) is processed through a fully connected layer to create a 40-dimensional embedding.
-
Initial Projection: The noise vector and processed condition are concatenated and projected to a 4×4×1024 feature map using a fully connected layer.
-
Upsampling Layers: Six transposed convolution layers progressively upsample the feature map:
- 4×4×1024 → 8×8×512
- 8×8×512 → 16×16×256
- 16×16×256 → 32×32×128
- 32×32×128 → 64×64×64
- 64×64×64 → 128×128×32
- 128×128×32 → 256×256×3
-
Activation: Each layer uses BatchNorm and ReLU activation, with a final Tanh activation to produce pixel values in the range [-1, 1].
The discriminator evaluates whether an image is real or generated:
-
Condition Processing: The condition vector is processed through a fully connected layer and reshaped to a 256×256×1 feature map.
-
Image Processing: The input image (256×256×3) is concatenated with the condition map (256×256×1) to form a 256×256×4 input.
-
Downsampling Layers: Six convolutional layers progressively downsample the image:
- 256×256×4 → 128×128×32
- 128×128×32 → 64×64×64
- 64×64×64 → 32×32×128
- 32×32×128 → 16×16×256
- 16×16×256 → 8×8×512
- 8×8×512 → 4×4×1024
- 4×4×1024 → 1×1×1
-
Activation: LeakyReLU activation is used throughout, with BatchNorm for normalization and a final Sigmoid activation to output a probability.
The GAN is trained using a standard adversarial process with several optimizations for stability:
-
Data Preparation: Pokemon images are loaded and paired with their type, height, weight, generation, and legendary status.
-
Condition Vector: For each Pokemon, a condition vector is created by:
- One-hot encoding the primary and secondary types
- Normalizing height, weight, and generation
- Adding a binary flag for legendary status
-
Training Loop:
- The discriminator is trained to classify real and fake images
- The generator is trained to produce images that fool the discriminator
- Label smoothing (0.9 for real labels) is used to improve stability
- Adam optimizer with learning rate 2e-4 and betas (0.5, 0.999) is used
- Cosine annealing learning rate scheduling is applied
-
Checkpointing: Model checkpoints are saved after each epoch, allowing training to be resumed if interrupted.
# Start training from scratch
python -m src.training.train --epochs 1000 --batch_size 16 --image_size 256
# Resume training from a checkpoint
python -m src.training.train --epochs 1000 --batch_size 16 --image_size 256 --resumeThe application uses scikit-learn's RandomForestRegressor models to predict Pokemon stats based on their types:
- For each stat (HP, Attack, Defense, Sp. Atk, Sp. Def, Speed), a separate RandomForestRegressor model is trained on type combinations.
- Training data is derived from the average stats of existing Pokemon with each type combination.
- The models predict reasonable stat values for new type combinations, which are then displayed in the UI.
The application uses environment variables for configuration. These must be set in a .env file in the root directory.
Here are all the required environment variables:
FLASK_ENV: Application environment (development, production)FLASK_APP: Main application file (app.py)FLASK_DEBUG: Enable debug mode (1 for enabled, 0 for disabled)SECRET_KEY: Secret key for Flask sessions and securityPORT: Port to run the application on (8080 recommended)
POSTGRES_USER: PostgreSQL usernamePOSTGRES_PASSWORD: PostgreSQL passwordPOSTGRES_DB: PostgreSQL database nameDATABASE_URL: Full PostgreSQL connection string (postgresql://username:password@host:port/dbname)
APP_VERSION: Application version numberMODEL_VERSION: ML model version numberCHECKPOINT_PATH: Path to the model checkpoint filePOKEMON_DATA_PATH: Path to the Pokemon data CSV file
The easiest way to run the application is with Docker Compose:
# Clone the repository
git clone https://github.com/yourusername/pokemon-generator.git
cd pokemon-generator
# Create and configure .env file
# Start the application with Docker Compose
docker-compose up --buildThe application will be available at http://localhost:5001
For local development without Docker, you can use the local SQLite mode:
# Clone the repository
git clone https://github.com/yourusername/pokemon-generator.git
cd pokemon-generator
# Set up a virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt
# Run the application with SQLite database
./run_local.pyThe application will be available at http://localhost:5001
# Flask settings
FLASK_ENV=development
FLASK_APP=app.py
FLASK_DEBUG=1
SECRET_KEY=your_secure_secret_key
PORT=5001
# PostgreSQL database settings
POSTGRES_USER=postgres
POSTGRES_PASSWORD=your_password
POSTGRES_DB=pokemon_dev
DATABASE_URL=postgresql://postgres:your_password@db:5432/pokemon_dev
# Application settings
APP_VERSION=1.2.0
MODEL_VERSION=2.0.0
CHECKPOINT_PATH=models/checkpoint.pth
POKEMON_DATA_PATH=data/Pokemon_stats.csv
If you prefer to run the application without Docker:
-
Ensure you have Python 3.9+ installed
-
Set up a virtual environment:
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate -
Install dependencies:
pip install -r requirements.txt -
Set up the database:
python scripts/db_migrate.py -
Run the application:
python main.py
The application uses PostgreSQL. The schema is created from a .template file using environment variables. Ensure .env file is present and correctly configured.
- GAN Training Stability: Training GANs for high-resolution images (256x256) was challenging and often errored out, so I needed to constantly monitor and tune hyperparameters and regularization.
- Experience: I haven't had much experience with machine learning, much less with neural networks, so it was difficult to start. Documentation and AI helped a lot in learning how to create the GAN architecture and debug.
- Stat Prediction Generalization: Predicting reasonable stats for type combinations relies heavily on the RandomForest models, so I struggled to create accurate predictions. I had to fine-tune a lot.
- Train For More Epochs: My model is not making very accurate images because it is underfitting the data since I trained it only for around 115 epochs. This is largely in part due to the sheer size of the dataset and the significant noise associated with images. Training for more epochs would allow the model to better understand the diversity and complexity of example images to find patterns and create more accurate images.
- Better GAN Architectures: Experiment with more advanced GAN variants (e.g., StyleGAN, BigGAN) for improved image quality.
- User Customization: Gather data on more features/traits to allow users to further customize generated Pokemon or make more specific (e.g., color palette, body shape, accessories).
- Interactive Training: Enable users to provide feedback on generated images to guide future training (reinforcement learning or active learning).
- API Endpoints: Add more RESTful API endpoints for programmatic access to image generation and stat prediction.
- Live Training Visualization: Integrate live charts for loss curves and sample generations during training.
- Pokemon data from PokeAPI
- PyTorch for deep learning framework
- Flask for web framework
- Tailwind for frontend design