Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from enum import Enum | |
import os | |
import re | |
import aiohttp | |
import requests | |
import json | |
import subprocess | |
import asyncio | |
from io import BytesIO | |
import uuid | |
import yaml | |
from math import ceil | |
from tqdm import tqdm | |
from pathlib import Path | |
from huggingface_hub import Repository | |
from PIL import Image, ImageOps | |
from fastapi import FastAPI, BackgroundTasks | |
from fastapi.responses import HTMLResponse | |
from fastapi_utils.tasks import repeat_every | |
from fastapi.middleware.cors import CORSMiddleware | |
import boto3 | |
from datetime import datetime | |
from db import Database | |
AWS_ACCESS_KEY_ID = os.getenv("MY_AWS_ACCESS_KEY_ID") | |
AWS_SECRET_KEY = os.getenv("MY_AWS_SECRET_KEY") | |
AWS_S3_BUCKET_NAME = os.getenv("MY_AWS_S3_BUCKET_NAME") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
DB_FOLDER = Path("diffusers-gallery-data") | |
CLASSIFIER_URL = ( | |
"https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference" | |
) | |
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/" | |
s3 = boto3.client( | |
service_name="s3", | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_KEY, | |
) | |
repo = Repository( | |
local_dir=DB_FOLDER, | |
repo_type="dataset", | |
clone_from="huggingface-projects/diffusers-gallery-data", | |
use_auth_token=True, | |
) | |
repo.git_pull() | |
database = Database(DB_FOLDER) | |
async def upload_resize_image_url(session, image_url): | |
print(f"Uploading image {image_url}") | |
try: | |
async with session.get(image_url) as response: | |
if response.status == 200 and ( | |
response.headers["content-type"].startswith("image") | |
or response.headers["content-type"].startswith("application") | |
): | |
image = Image.open(BytesIO(await response.read())).convert("RGB") | |
# resize image proportional | |
image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
image_bytes = BytesIO() | |
image.save(image_bytes, format="JPEG") | |
image_bytes.seek(0) | |
fname = f"{uuid.uuid4()}.jpg" | |
s3.upload_fileobj( | |
Fileobj=image_bytes, | |
Bucket=AWS_S3_BUCKET_NAME, | |
Key="diffusers-gallery/" + fname, | |
ExtraArgs={ | |
"ContentType": "image/jpeg", | |
"CacheControl": "max-age=31536000", | |
}, | |
) | |
return fname | |
except Exception as e: | |
print(f"Error uploading image {image_url}: {e}") | |
return None | |
def fetch_models(page=0): | |
response = requests.get( | |
f"https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}" | |
) | |
data = response.json() | |
return { | |
"models": [model for model in data["models"] if not model["private"]], | |
"numItemsPerPage": data["numItemsPerPage"], | |
"numTotalItems": data["numTotalItems"], | |
"pageIndex": data["pageIndex"], | |
} | |
def fetch_model_card(model_id): | |
response = requests.get(f"https://huggingface.co/{model_id}/raw/main/README.md") | |
return response.text | |
REGEX = re.compile(r'---(.*?)---', re.DOTALL) | |
def get_yaml_data(text_content): | |
matches = REGEX.findall(text_content) | |
yaml_block = matches[0].strip() if matches else None | |
if yaml_block: | |
try: | |
data_dict = yaml.safe_load(yaml_block) | |
return data_dict | |
except yaml.YAMLError as exc: | |
print(exc) | |
return {} | |
async def find_image_in_model_card(text): | |
image_regex = re.compile(r"https?://\S+(?:png|jpg|jpeg|webp)") | |
urls = re.findall(image_regex, text) | |
if not urls: | |
return [] | |
async with aiohttp.ClientSession() as session: | |
tasks = [ | |
asyncio.ensure_future(upload_resize_image_url(session, image_url)) | |
for image_url in urls[0:3] | |
] | |
return await asyncio.gather(*tasks) | |
def run_classifier(images): | |
images = [i for i in images if i is not None] | |
if len(images) > 0: | |
# classifying only the first image | |
images_urls = [ASSETS_URL + images[0]] | |
response = requests.post( | |
CLASSIFIER_URL, | |
json={ | |
"data": [ | |
{"urls": images_urls}, # json urls: list of images urls | |
False, # enable/disable gallery image output | |
None, # single image input | |
None, # files input | |
] | |
}, | |
).json() | |
# data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery], | |
class_data = response["data"][0][0] | |
class_data_parsed = {row["label"]: round(row["score"], 3) for row in class_data} | |
# update row data with classificator data | |
return class_data_parsed | |
else: | |
return {} | |
async def get_all_new_models(): | |
initial = fetch_models(0) | |
num_pages = ceil(initial["numTotalItems"] / initial["numItemsPerPage"]) | |
print( | |
f"Total items: {initial['numTotalItems']} - Items per page: {initial['numItemsPerPage']}" | |
) | |
print(f"Found {num_pages} pages") | |
# fetch all models | |
new_models = [] | |
for page in tqdm(range(0, num_pages)): | |
print(f"Fetching page {page} of {num_pages}") | |
page_models = fetch_models(page) | |
new_models += page_models["models"] | |
return new_models | |
async def sync_data(): | |
print("Fetching models") | |
repo.git_pull() | |
all_models = await get_all_new_models() | |
print(f"Found {len(all_models)} models") | |
# save list of all models for ids | |
with open(DB_FOLDER / "models.json", "w") as f: | |
json.dump(all_models, f) | |
# with open(DB_FOLDER / "models.json", "r") as f: | |
# new_models = json.load(f) | |
new_models_ids = [model["id"] for model in all_models] | |
# get existing models | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute("SELECT id FROM models") | |
existing_models = [row["id"] for row in cursor.fetchall()] | |
models_ids_to_add = list(set(new_models_ids) - set(existing_models)) | |
# find all models id to add from new_models | |
models = [model for model in all_models if model["id"] in models_ids_to_add] | |
print(f"Found {len(models)} new models") | |
for model in tqdm(models): | |
model_id = model["id"] | |
print(f"\n\nFetching model {model_id}") | |
likes = model["likes"] | |
downloads = model["downloads"] | |
print("Fetching model card") | |
model_card = fetch_model_card(model_id) | |
print("Parsing model card") | |
model_card_data = get_yaml_data(model_card) | |
print("Finding images in model card") | |
images = await find_image_in_model_card(model_card) | |
classifier = run_classifier(images) | |
print(images, classifier) | |
# update model row with image and classifier data | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
"INSERT INTO models(id, data, likes, downloads) VALUES (?, ?, ?, ?)", | |
[ | |
model_id, | |
json.dumps( | |
{ | |
**model, | |
"meta": model_card_data, | |
"images": images, | |
"class": classifier, | |
} | |
), | |
likes, | |
downloads, | |
], | |
) | |
db.commit() | |
print("\n\n\n\nTry to update images again\n\n\n") | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute("SELECT * from models") | |
to_all_models = list(cursor.fetchall()) | |
models_no_images = [] | |
for model in to_all_models: | |
model_data = json.loads(model["data"]) | |
images = model_data["images"] | |
filtered_images = [x for x in images if x is not None] | |
if len(filtered_images) == 0: | |
models_no_images.append(model) | |
for model in tqdm(models_no_images): | |
model_id = model["id"] | |
model_data = json.loads(model["data"]) | |
print(f"\n\nFetching model {model_id}") | |
model_card = fetch_model_card(model_id) | |
print("Parsing model card") | |
model_card_data = get_yaml_data(model_card) | |
print("Finding images in model card") | |
images = await find_image_in_model_card(model_card) | |
classifier = run_classifier(images) | |
model_data["images"] = images | |
model_data["class"] = classifier | |
model_data["meta"] = model_card_data | |
# update model row with image and classifier data | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
"UPDATE models SET data = ? WHERE id = ?", | |
[json.dumps(model_data), model_id], | |
) | |
db.commit() | |
print("Update likes and downloads") | |
for model in tqdm(all_models): | |
model_id = model["id"] | |
likes = model["likes"] | |
downloads = model["downloads"] | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
"UPDATE models SET likes = ?, downloads = ? WHERE id = ?", | |
[likes, downloads, model_id], | |
) | |
db.commit() | |
print("Updating DB repository") | |
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
cmd = f"git add . && git commit --amend -m 'update at {time}' && git push --force" | |
print(cmd) | |
subprocess.Popen(cmd, cwd=DB_FOLDER, shell=True) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# @ app.get("/sync") | |
# async def sync(background_tasks: BackgroundTasks): | |
# await sync_data() | |
# return "Synced data to huggingface datasets" | |
MAX_PAGE_SIZE = 30 | |
class Sort(str, Enum): | |
trending = "trending" | |
recent = "recent" | |
likes = "likes" | |
class Style(str, Enum): | |
all = "all" | |
anime = "anime" | |
s3D = "3d" | |
realistic = "realistic" | |
nsfw = "nsfw" | |
def get_page( | |
page: int = 1, sort: Sort = Sort.trending, style: Style = Style.all, tag: str = None | |
): | |
page = page if page > 0 else 1 | |
if sort == Sort.trending: | |
sort_query = "likes / MYPOWER((JULIANDAY('now') - JULIANDAY(datetime(json_extract(data, '$.lastModified')))) + 2, 2) DESC" | |
elif sort == Sort.recent: | |
sort_query = "datetime(json_extract(data, '$.lastModified')) DESC" | |
elif sort == Sort.likes: | |
sort_query = "likes DESC" | |
if style == Style.all: | |
style_query = "isNFSW = false" | |
elif style == Style.anime: | |
style_query = "json_extract(data, '$.class.anime') > 0.1 AND isNFSW = false" | |
elif style == Style.s3D: | |
style_query = "json_extract(data, '$.class.3d') > 0.1 AND isNFSW = false" | |
elif style == Style.realistic: | |
style_query = "json_extract(data, '$.class.real_life') > 0.1 AND isNFSW = false" | |
elif style == Style.nsfw: | |
style_query = "isNFSW = true" | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
f""" | |
SELECT *, | |
COUNT(*) OVER() AS total, | |
isNFSW | |
FROM ( | |
SELECT *, | |
json_extract(data, '$.class.explicit') > 0.3 OR json_extract(data, '$.class.suggestive') > 0.3 AS isNFSW | |
FROM models | |
) AS subquery | |
WHERE (? IS NULL AND likes > 3 OR ? IS NOT NULL) | |
AND {style_query} | |
AND (? IS NULL OR EXISTS ( | |
SELECT 1 | |
FROM json_each(json_extract(data, '$.meta.tags')) | |
WHERE json_each.value = ? | |
)) | |
ORDER BY {sort_query} | |
LIMIT {MAX_PAGE_SIZE} OFFSET {(page - 1) * MAX_PAGE_SIZE}; | |
""", | |
(tag, tag, tag, tag), | |
) | |
results = cursor.fetchall() | |
total = results[0]["total"] if results else 0 | |
total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE | |
models_data = [] | |
for result in results: | |
data = json.loads(result["data"]) | |
images = data["images"] | |
filtered_images = [x for x in images if x is not None] | |
# clean nulls | |
data["images"] = filtered_images | |
# update downloads and likes from db table | |
data["downloads"] = result["downloads"] | |
data["likes"] = result["likes"] | |
data["isNFSW"] = bool(result["isNFSW"]) | |
models_data.append(data) | |
return {"models": models_data, "totalPages": total_pages} | |
def read_root(): | |
# return html page from string | |
return HTMLResponse( | |
""" | |
<p>Just a bot to sync data from diffusers gallery please go to | |
<a href="https://huggingface.co/spaces/huggingface-projects/diffusers-gallery" target="_blank" rel="noopener noreferrer">https://huggingface.co/spaces/huggingface-projects/diffusers-gallery</a> | |
</p>""" | |
) | |
async def repeat_sync(): | |
await sync_data() | |
return "Synced data to huggingface datasets" | |