Yavzan's picture
Update README.md
fae38dc verified
---
license: apache-2.0
datasets:
- ylecun/mnist
- uoft-cs/cifar10
- uoft-cs/cifar100
language:
- en
metrics:
- accuracy
pipeline_tag: text-to-image
tags:
- diffusion
- unet
- res
---
<a id="readme-top"></a>
<!-- PROJECT SHIELDS -->
<!-- PROJECT LOGO -->
<br />
<div align="center">
<a href="https://github.com/Yavuzhan-Baykara/Stable-Diffusion">
</a>
<h3 align="center">Diffusion Model Sampler</h3>
<p align="center">
An implementation of a diffusion model sampler using a UNet transformer to generate handwritten digit samples.
<br />
<a href="https://github.com/Yavuzhan-Baykara/Stable-Diffusion"><strong>Explore the docs 禄</strong></a>
<br />
<br />
<a href="https://github.com/Yavuzhan-Baykara/Stable-Diffusion">View Demo</a>
<a href="https://github.com/Yavuzhan-Baykara/Stable-Diffusion/issues/new?labels=bug&template=bug-report---.md">Report Bug</a>
<a href="https://github.com/Yavuzhan-Baykara/Stable-Diffusion/issues/new?labels=enhancement&template=feature-request---.md">Request Feature</a>
</p>
</div>
<!-- TABLE OF CONTENTS -->
<details>
<summary>Table of Contents</summary>
<ol>
<li>
<a href="#about-the-project">About The Project</a>
<ul>
<li><a href="#built-with">Built With</a></li>
</ul>
</li>
<li>
<a href="#getting-started">Getting Started</a>
<ul>
<li><a href="#prerequisites">Prerequisites</a></li>
<li><a href="#installation">Installation</a></li>
</ul>
</li>
<li><a href="#usage">Usage</a></li>
<li><a href="#results">Results</a></li>
<li><a href="#roadmap">Roadmap</a></li>
<li><a href="#contributing">Contributing</a></li>
<li><a href="#license">License</a></li>
<li><a href="#contact">Contact</a></li>
<li><a href="#acknowledgments">Acknowledgments</a></li>
</ol>
</details>
<!-- ABOUT THE PROJECT -->
## About The Project
Diffusion models have shown great promise in generating high-quality samples in various domains. In this project, we utilize a UNet transformer-based diffusion model to generate samples of handwritten digits. The process involves:
1. Setting up the model and loading pre-trained weights.
2. Generating samples for each digit.
3. Creating a GIF to visualize the generated samples.
<div align="center">
<img src="./digit_samples.gif" alt="MNIST GIF" width="200" height="200" style="display:inline-block;">
<img src="./digit_samples_cifar.gif" alt="CIFAR-10 GIF" width="200" height="200" style="display:inline-block;">
</div>
<p align="right">(<a href="#readme-top">back to top</a>)</p>
### Built With
#### AI and Machine Learning Libraries
<div align="center">
<img src="https://icon.icepanel.io/Technology/svg/TensorFlow.svg" alt="Python" width="40" height="40" style="display:inline-block;">
<img src="https://icon.icepanel.io/Technology/svg/PyTorch.svg" alt="PyTorch" width="40" height="40" style="display:inline-block;">
<img src="https://icon.icepanel.io/Technology/svg/NumPy.svg" alt="NumPy" width="40" height="40" style="display:inline-block;">
<img src="https://icon.icepanel.io/Technology/svg/Matplotlib.svg" alt="Matplotlib" width="40" height="40" style="display:inline-block;">
<img src="https://img.shields.io/badge/Pillow-5A9?style=for-the-badge&logo=pillow&logoColor=white" alt="Pillow" width="40" height="40" style="display:inline-block;">
</div>
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- GETTING STARTED -->
## Getting Started
To get a local copy up and running follow these simple example steps.
### Prerequisites
Ensure you have the following prerequisites installed:
* Python 3.8 or higher
* CUDA-enabled GPU (optional but recommended)
* The following Python libraries:
- torch
- torchvision
- numpy
- Pillow
- matplotlib
### Installation
1. Clone the repository:
```sh
git clone https://github.com/Yavuzhan-Baykara/Stable-Diffusion.git
cd Stable-Diffusion
```
2. Install the required Python libraries:
```sh
pip install torch torchvision numpy Pillow matplotlib
```
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- USAGE -->
## Usage
To train the UNet transformer with different datasets and samplers, use the following command:
```sh
python train.py <dataset> <sampler> <epoch> <batch_size>