metadata
license: apache-2.0
datasets:
- ylecun/mnist
- uoft-cs/cifar10
- uoft-cs/cifar100
language:
- en
metrics:
- accuracy
pipeline_tag: text-to-image
tags:
- diffusion
- unet
- res
Diffusion Model Sampler
An implementation of a diffusion model sampler using a UNet transformer to generate handwritten digit samples.
Explore the docs »
View Demo
·
Report Bug
·
Request Feature
Table of Contents
About The Project
Diffusion models have shown great promise in generating high-quality samples in various domains. In this project, we utilize a UNet transformer-based diffusion model to generate samples of handwritten digits. The process involves:
- Setting up the model and loading pre-trained weights.
- Generating samples for each digit.
- Creating a GIF to visualize the generated samples.
Built With
AI and Machine Learning Libraries
Getting Started
To get a local copy up and running follow these simple example steps.
Prerequisites
Ensure you have the following prerequisites installed:
- Python 3.8 or higher
- CUDA-enabled GPU (optional but recommended)
- The following Python libraries:
- torch
- torchvision
- numpy
- Pillow
- matplotlib
Installation
- Clone the repository:
git clone https://github.com/Yavuzhan-Baykara/Stable-Diffusion.git cd Stable-Diffusion
- Install the required Python libraries:
pip install torch torchvision numpy Pillow matplotlib
Usage
To train the UNet transformer with different datasets and samplers, use the following command:
python train.py <dataset> <sampler> <epoch> <batch_size>