Skip to content

JAX-Diffusion is a high-performance image generation model implemented in JAX. It is an optimized diffusion model inspired by minDiffusion

License

Notifications You must be signed in to change notification settings

carrycooldude/JAX-Diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX-Diffusion

JAX-Diffusion on Hugging Face 🤗

JAX-Diffusion is a project that implements diffusion models using JAX, a high-performance numerical computing library. Diffusion models are a class of generative models that have gained popularity for their ability to generate high-quality data samples.

🚀 Live Demo: Try it on Hugging Face Spaces


Screenshot from 2025-04-06 14-32-42


Features

  • Implementation of diffusion models in JAX.
  • High-performance and scalable computations.
  • Modular and extensible codebase.
  • Interactive Gradio app for easy experimentation.

Installation

  1. Clone the repository:

    git clone https://github.com/your-username/JAX-Diffusion.git
    cd JAX-Diffusion
  2. Install dependencies:

    pip install -r requirements.txt

Usage

Train a diffusion model:

python train.py --config configs/default.yaml

Generate samples:

python generate.py --model checkpoints/model.pth

Run the Gradio App:

python app.py

This will launch a Gradio interface where you can generate samples interactively.

Gradio App Preview

Input Output
Enter a text prompt and set diffusion steps Generates an image using a simple JAX diffusion model

Contributing

Contributions are welcome! Please follow these steps:

  1. Fork the repository.
  2. Create a new branch for your feature or bug fix.
  3. Submit a pull request with a clear description of your changes.

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgments

  • JAX for providing the foundation for numerical computing.
  • Gradio for making it easy to build interactive demos.
  • The research community for advancements in diffusion models.

About

JAX-Diffusion is a high-performance image generation model implemented in JAX. It is an optimized diffusion model inspired by minDiffusion

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages