diff --git "a/notebooks/populate_dataset.ipynb" "b/notebooks/populate_dataset.ipynb"
new file mode 100644--- /dev/null
+++ "b/notebooks/populate_dataset.ipynb"
@@ -0,0 +1,446 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "3b5fbb8f-5789-45db-a551-f0e6633b4f46",
+ "metadata": {},
+ "source": [
+ "# Populate a HDF5 dataset with base64 Pokémon images keyed by energy type"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a9234c78-1ac5-4c71-b11b-3a81be57f3f3",
+ "metadata": {},
+ "source": [
+ "Used in [**This Pokémon Does Not Exist**](https://huggingface.co/spaces/ronvolutional/ai-pokemon-card)\n",
+ "\n",
+ "Model fine-tuned by [**Max Woolf**](https://huggingface.co/minimaxir/ai-generated-pokemon-rudalle)\n",
+ "\n",
+ "ruDALL-E by [**Sber**](https://rudalle.ru/en)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "09850949-235d-4997-b8e3-c5b2aeffe109",
+ "metadata": {},
+ "source": [
+ "## Initialise datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "cead6fa8-e9ef-4672-bbfb-beadcaf5f3a0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import h5py\n",
+ "\n",
+ "datasets_dir = './datasets'\n",
+ "datasets_file = 'pregenerated_pokemon.h5'\n",
+ "h5_file = os.path.join(datasets_dir, datasets_file)\n",
+ "\n",
+ "energy_types = ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']"
+ ]
+ },
+ {
+ "cell_type": "raw",
+ "id": "2df90e94-15c0-4eb6-914e-875ec80b7c24",
+ "metadata": {},
+ "source": [
+ "# Only run if the datasets file does not exist\n",
+ "\n",
+ "with h5py.File(h5_file, 'x') as datasets:\n",
+ " for energy in energy_types:\n",
+ " datasets.create_dataset(energy, (0,1), h5py.string_dtype(encoding='utf-8'), maxshape=(None,1))\n",
+ "\n",
+ " print(datasets.keys())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cdd3eb59-bbf5-4b85-b6bc-35f591317b47",
+ "metadata": {},
+ "source": [
+ "### Dataset functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "fca1947f-8a66-4636-8049-99d6ff0ace93",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import math\n",
+ "from time import gmtime, strftime, time\n",
+ "from random import choices, randint\n",
+ "from IPython import display\n",
+ "\n",
+ "def get_stats(h5_file=h5_file):\n",
+ " with h5py.File(h5_file, 'r') as datasets:\n",
+ " return {\n",
+ " \"size_counts\": {key: datasets[key].size.item() for key in datasets.keys()},\n",
+ " \"size_total\": sum(list(datasets[energy].size.item() for energy in datasets.keys())),\n",
+ " \"size_mb\": round(os.path.getsize(h5_file) / 1024**2, 1)\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def add_row(energy, image):\n",
+ " with h5py.File(h5_file, 'r+') as datasets:\n",
+ " dataset = datasets[energy]\n",
+ " dataset.resize(dataset.size + 1, 0)\n",
+ " dataset[-1] = image\n",
+ "\n",
+ "\n",
+ "def get_image(energy=None, row=None):\n",
+ " if not energy:\n",
+ " energy = choices(energy_types)[0]\n",
+ "\n",
+ " with h5py.File(h5_file, 'r') as datasets:\n",
+ " if not row:\n",
+ " row = randint(0, datasets[energy].size - 1)\n",
+ "\n",
+ " return datasets[energy].asstr()[row][0]\n",
+ "\n",
+ "def pretty_time(seconds):\n",
+ " m, s = divmod(seconds, 60)\n",
+ " h, m = divmod(m, 60)\n",
+ " return f\"{f'{math.floor(h)}h ' if h else ''}{f'{math.floor(m)}m ' if m else ''}{f'{math.floor(s)}s' if s else ''}\"\n",
+ " \n",
+ "def populate_dataset(batches=1, batch_size=1, image_cap=100_000, filesize_cap=4_000):\n",
+ " initial_stats = get_stats()\n",
+ "\n",
+ " iterations = 0\n",
+ " start_time = time()\n",
+ "\n",
+ " while iterations < batches and get_stats()['size_total'] < image_cap and get_stats()['size_mb'] < filesize_cap:\n",
+ " for energy in energy_types:\n",
+ " current = get_stats()\n",
+ " new_images_count = (current['size_total'] - initial_stats['size_total'])\n",
+ " new_mb_count = round(current['size_mb'] - initial_stats['size_mb'], 1)\n",
+ " elapsed = time() - start_time\n",
+ " eta_total = elapsed / (new_images_count or 1) * batches * batch_size * len(energy_types)\n",
+ "\n",
+ " display.clear_output(wait=True)\n",
+ " if new_images_count:\n",
+ " print(f\"ETA: {pretty_time(eta_total - elapsed)} left of {pretty_time(eta_total)}\")\n",
+ " print(f\"Images in dataset: {current['size_total']}{f' (+{new_images_count})' if new_images_count else ''}\")\n",
+ " print(f\"Size of dataset: {current['size_mb']}MB{f' (+{new_mb_count}MB)' if new_mb_count else ''}\")\n",
+ " print(f\"Batch {iterations + 1} of {batches}:\")\n",
+ " print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Generating {batch_size} {energy} Pokémon...\")\n",
+ "\n",
+ " generate_pokemon(energy, batch_size)\n",
+ "\n",
+ " iterations += 1\n",
+ "\n",
+ " new_stats = get_stats()\n",
+ " elapsed = time() - start_time\n",
+ "\n",
+ " display.clear_output(wait=True)\n",
+ " print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Finished populating dataset with {batches} {'batches' if batches > 1 else 'batch'} after {pretty_time(elapsed)}\")\n",
+ " print(f\"Images in dataset: {new_stats['size_total']} (+{new_stats['size_total'] - initial_stats['size_total']})\")\n",
+ " print(f\"Size of dataset: {new_stats['size_mb']}MB (+{round(new_stats['size_mb'] - initial_stats['size_mb'], 1)}MB)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b0da582a-2b29-4df6-8a3f-ddd56377af16",
+ "metadata": {},
+ "source": [
+ "## Load Pokémon model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ad435db5-bd35-4440-87b2-5a108f4ae385",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from rudalle import get_rudalle_model, get_tokenizer, get_vae\n",
+ "from huggingface_hub import cached_download, hf_hub_url\n",
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "5a2a4e6a-2086-4f98-b2b5-65a3631be61e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GPUs available: 1\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"GPUs available: {torch.cuda.device_count()}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "720df30d-f42c-406a-92ba-4a465f6ff1d3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n",
+ "vae --> ready\n",
+ "tokenizer --> ready\n",
+ "GPU[0] memory: 11263Mib\n",
+ "GPU[0] memory reserved: 5144Mib\n",
+ "GPU[0] memory allocated: 2767Mib\n"
+ ]
+ }
+ ],
+ "source": [
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "fp16 = torch.cuda.is_available()\n",
+ "map_location = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
+ "\n",
+ "file_dir = \"./models\"\n",
+ "file_name = \"pytorch_model.bin\"\n",
+ "config_file_url = hf_hub_url(repo_id=\"minimaxir/ai-generated-pokemon-rudalle\", filename=file_name)\n",
+ "cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)\n",
+ "\n",
+ "model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)\n",
+ "model.load_state_dict(torch.load(f\"{file_dir}/{file_name}\", map_location=map_location))\n",
+ "\n",
+ "vae = get_vae().to(device)\n",
+ "tokenizer = get_tokenizer()\n",
+ "\n",
+ "print(f\"GPU[0] memory: {int(torch.cuda.get_device_properties(0).total_memory / 1024**2)}Mib\")\n",
+ "print(f\"GPU[0] memory reserved: {int(torch.cuda.memory_reserved(0) / 1024**2)}Mib\")\n",
+ "print(f\"GPU[0] memory allocated: {int(torch.cuda.memory_allocated(0) / 1024**2)}Mib\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "88d413a4-a8c9-401e-9cb5-32a1ae34c179",
+ "metadata": {},
+ "source": [
+ "### Model functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "0624c686-c75f-46c6-afae-bbe6c455caa1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import base64\n",
+ "from io import BytesIO\n",
+ "from time import gmtime, strftime, time\n",
+ "from rudalle.pipelines import generate_images\n",
+ "\n",
+ "def english_to_russian(english):\n",
+ " word_map = {\n",
+ " \"colorless\": \"Покемон нормального типа\",\n",
+ " \"dragon\": \"Покемон типа дракона\",\n",
+ " \"darkness\": \"Покемон темного типа\",\n",
+ " \"fairy\": \"Покемон фея\",\n",
+ " \"fighting\": \"Покемон боевого типа\",\n",
+ " \"fire\": \"Покемон огня\",\n",
+ " \"grass\": \"Покемон трава\",\n",
+ " \"lightning\": \"Покемон электрического типа\",\n",
+ " \"metal\": \"Покемон из стали типа\",\n",
+ " \"psychic\": \"Покемон психического типа\",\n",
+ " \"water\": \"Покемон в воду\"\n",
+ " }\n",
+ "\n",
+ " return word_map[english.lower()]\n",
+ "\n",
+ "\n",
+ "def generate_pokemon(energy, num=1):\n",
+ " if energy in energy_types:\n",
+ " russian_prompt = english_to_russian(energy)\n",
+ " \n",
+ " images, _ = generate_images(russian_prompt, tokenizer, model, vae, top_k=2048, images_num=num, top_p=0.995)\n",
+ " \n",
+ " for image in images:\n",
+ " buffer = BytesIO()\n",
+ " image.save(buffer, format=\"JPEG\", quality=100, optimize=True)\n",
+ " base64_bytes = base64.b64encode(buffer.getvalue())\n",
+ " base64_string = base64_bytes.decode(\"UTF-8\")\n",
+ " base64_image = \"data:image/jpeg;base64,\" + base64_string\n",
+ " add_row(energy, base64_image)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7b309b8e-0c34-4a80-a411-8f093a105494",
+ "metadata": {},
+ "source": [
+ "## Populate dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a96d026f-e300-40c7-88a6-20be0012c584",
+ "metadata": {},
+ "source": [
+ "Total number of images per population = `batches` × `len(energy_types)` (11) × `batch_size`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "74d2c50a-93ae-4040-89a0-87f818187bbb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2022-03-16 05:07:48 Finished populating dataset with 1 batch after 10m 8s\n",
+ "Images in dataset: 5082 (+66)\n",
+ "Size of dataset: 199.8MB (+2.5MB)\n"
+ ]
+ }
+ ],
+ "source": [
+ "batches = 1\n",
+ "batch_size = 6\n",
+ "image_cap = 100_000\n",
+ "filesize_cap = 4_000 # MB\n",
+ "\n",
+ "populate_dataset(batches, batch_size, image_cap, filesize_cap)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4eb6f750-aa2d-42b9-89dc-2cb103e8869e",
+ "metadata": {},
+ "source": [
+ "## Getting images"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9365ac7b-36e4-42b9-8380-a039d556356b",
+ "metadata": {},
+ "source": [
+ "### Random image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "dca3bc8e-1f65-4566-8385-bf6b9f20eaaf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "display.HTML(f'')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f3b9278e-b6dc-4bea-9ef6-3aa312739718",
+ "metadata": {},
+ "source": [
+ "### Random image of specific energy type"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "b6e98817-19f7-44f3-8970-0a11b01cf37b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "display.HTML(f'')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34ca6f96-625b-459a-ba90-2c58a1d0ea47",
+ "metadata": {},
+ "source": [
+ "### Specific image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "5a60fc60-4198-4318-a5b4-26095cb2c0bb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "display.HTML(f'')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}