Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from enum import Enum | |
import os | |
import re | |
import aiohttp | |
import requests | |
import json | |
import subprocess | |
import asyncio | |
from io import BytesIO | |
import uuid | |
from math import ceil | |
from tqdm import tqdm | |
from pathlib import Path | |
from huggingface_hub import Repository | |
from PIL import Image, ImageOps | |
from fastapi import FastAPI, BackgroundTasks | |
from fastapi.responses import HTMLResponse | |
from fastapi_utils.tasks import repeat_every | |
from fastapi.middleware.cors import CORSMiddleware | |
import boto3 | |
from datetime import datetime | |
from db import Database | |
AWS_ACCESS_KEY_ID = os.getenv('MY_AWS_ACCESS_KEY_ID') | |
AWS_SECRET_KEY = os.getenv('MY_AWS_SECRET_KEY') | |
AWS_S3_BUCKET_NAME = os.getenv('MY_AWS_S3_BUCKET_NAME') | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
DB_FOLDER = Path("diffusers-gallery-data") | |
CLASSIFIER_URL = "https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference" | |
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/" | |
s3 = boto3.client(service_name='s3', | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_KEY) | |
repo = Repository( | |
local_dir=DB_FOLDER, | |
repo_type="dataset", | |
clone_from="huggingface-projects/diffusers-gallery-data", | |
use_auth_token=True, | |
) | |
repo.git_pull() | |
database = Database(DB_FOLDER) | |
async def upload_resize_image_url(session, image_url): | |
print(f"Uploading image {image_url}") | |
try: | |
async with session.get(image_url) as response: | |
if response.status == 200 and (response.headers['content-type'].startswith('image') or response.headers['content-type'].startswith('application')): | |
image = Image.open(BytesIO(await response.read())).convert('RGB') | |
# resize image proportional | |
image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
image_bytes = BytesIO() | |
image.save(image_bytes, format="JPEG") | |
image_bytes.seek(0) | |
fname = f'{uuid.uuid4()}.jpg' | |
s3.upload_fileobj(Fileobj=image_bytes, Bucket=AWS_S3_BUCKET_NAME, Key="diffusers-gallery/" + fname, | |
ExtraArgs={"ContentType": "image/jpeg", "CacheControl": "max-age=31536000"}) | |
return fname | |
except Exception as e: | |
print(f"Error uploading image {image_url}: {e}") | |
return None | |
def fetch_models(page=0): | |
response = requests.get( | |
f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}') | |
data = response.json() | |
return { | |
"models": [model for model in data['models'] if not model['private']], | |
"numItemsPerPage": data['numItemsPerPage'], | |
"numTotalItems": data['numTotalItems'], | |
"pageIndex": data['pageIndex'] | |
} | |
def fetch_model_card(model_id): | |
response = requests.get( | |
f'https://huggingface.co/{model_id}/raw/main/README.md') | |
return response.text | |
async def find_image_in_model_card(text): | |
image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)') | |
urls = re.findall(image_regex, text) | |
if not urls: | |
return [] | |
async with aiohttp.ClientSession() as session: | |
tasks = [asyncio.ensure_future(upload_resize_image_url( | |
session, image_url)) for image_url in urls[0:3]] | |
return await asyncio.gather(*tasks) | |
def run_classifier(images): | |
images = [i for i in images if i is not None] | |
if len(images) > 0: | |
# classifying only the first image | |
images_urls = [ASSETS_URL + images[0]] | |
response = requests.post(CLASSIFIER_URL, json={"data": [ | |
{"urls": images_urls}, # json urls: list of images urls | |
False, # enable/disable gallery image output | |
None, # single image input | |
None, # files input | |
]}).json() | |
# data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery], | |
class_data = response['data'][0][0] | |
class_data_parsed = {row['label']: round( | |
row['score'], 3) for row in class_data} | |
# update row data with classificator data | |
return class_data_parsed | |
else: | |
return {} | |
async def get_all_new_models(): | |
initial = fetch_models(0) | |
num_pages = ceil(initial['numTotalItems'] / initial['numItemsPerPage']) | |
print( | |
f"Total items: {initial['numTotalItems']} - Items per page: {initial['numItemsPerPage']}") | |
print(f"Found {num_pages} pages") | |
# fetch all models | |
new_models = [] | |
for page in tqdm(range(0, num_pages)): | |
print(f"Fetching page {page} of {num_pages}") | |
page_models = fetch_models(page) | |
new_models += page_models['models'] | |
return new_models | |
async def sync_data(): | |
print("Fetching models") | |
repo.git_pull() | |
all_models = await get_all_new_models() | |
print(f"Found {len(all_models)} models") | |
# save list of all models for ids | |
with open(DB_FOLDER / "models.json", "w") as f: | |
json.dump(all_models, f) | |
# with open(DB_FOLDER / "models.json", "r") as f: | |
# new_models = json.load(f) | |
new_models_ids = [model['id'] for model in all_models] | |
# get existing models | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute("SELECT id FROM models") | |
existing_models = [row['id'] for row in cursor.fetchall()] | |
models_ids_to_add = list(set(new_models_ids) - set(existing_models)) | |
# find all models id to add from new_models | |
models = [model for model in all_models if model['id'] in models_ids_to_add] | |
print(f"Found {len(models)} new models") | |
for model in tqdm(models): | |
model_id = model['id'] | |
likes = model['likes'] | |
downloads = model['downloads'] | |
model_card = fetch_model_card(model_id) | |
images = await find_image_in_model_card(model_card) | |
classifier = run_classifier(images) | |
print(images, classifier) | |
# update model row with image and classifier data | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute("INSERT INTO models(id, data, likes, downloads) VALUES (?, ?, ?, ?)", | |
[model_id, | |
json.dumps({ | |
**model, | |
"images": images, | |
"class": classifier | |
}), | |
likes, | |
downloads | |
]) | |
db.commit() | |
print("Try to update images again") | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
"SELECT * from models") | |
all_models = list(cursor.fetchall()) | |
models_no_images = [] | |
for model in all_models: | |
model_data = json.loads(model['data']) | |
images = model_data['images'] | |
filtered_images = [x for x in images if x is not None] | |
if len(filtered_images) == 0: | |
models_no_images.append(model) | |
for model in tqdm(models_no_images): | |
model_id = model['id'] | |
model_data = json.loads(model['data']) | |
print("Updating model", model_id) | |
model_card = fetch_model_card(model_id) | |
images = await find_image_in_model_card(model_card) | |
classifier = run_classifier(images) | |
# update model row with image and classifier data | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute("UPDATE models SET data = ? WHERE id = ?", | |
[json.dumps(model_data), model_id]) | |
db.commit() | |
print("Update likes and downloads") | |
for model in tqdm(all_models): | |
model_id = model['id'] | |
likes = model['likes'] | |
downloads = model['downloads'] | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute("UPDATE models SET likes = ?, downloads = ? WHERE id = ?", | |
[likes, downloads, model_id]) | |
db.commit() | |
print("Updating DB repository") | |
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
cmd = f"git add . && git commit --amend -m 'update at {time}' && git push --force" | |
print(cmd) | |
subprocess.Popen(cmd, cwd=DB_FOLDER, shell=True) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# @ app.get("/sync") | |
# async def sync(background_tasks: BackgroundTasks): | |
# await sync_data() | |
# return "Synced data to huggingface datasets" | |
MAX_PAGE_SIZE = 30 | |
class Sort(str, Enum): | |
trending = "trending" | |
recent = "recent" | |
likes = "likes" | |
class Style(str, Enum): | |
all = "all" | |
anime = "anime" | |
s3D = "3d" | |
realistic = "realistic" | |
nsfw = "nsfw" | |
def get_page(page: int = 1, sort: Sort = Sort.trending, style: Style = Style.all): | |
page = page if page > 0 else 1 | |
if sort == Sort.trending: | |
sort_query = "likes / MYPOWER((JULIANDAY('now') - JULIANDAY(datetime(json_extract(data, '$.lastModified')))) + 2, 2) DESC" | |
elif sort == Sort.recent: | |
sort_query = "datetime(json_extract(data, '$.lastModified')) DESC" | |
elif sort == Sort.likes: | |
sort_query = "likes DESC" | |
if style == Style.all: | |
style_query = "isNFSW = false" | |
elif style == Style.anime: | |
style_query = "json_extract(data, '$.class.anime') > 0.1 AND isNFSW = false" | |
elif style == Style.s3D: | |
style_query = "json_extract(data, '$.class.3d') > 0.1 AND isNFSW = false" | |
elif style == Style.realistic: | |
style_query = "json_extract(data, '$.class.real_life') > 0.1 AND isNFSW = false" | |
elif style == Style.nsfw: | |
style_query = "isNFSW = true" | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute(f""" | |
SELECT *, COUNT(*) OVER() AS total, isNFSW | |
FROM ( | |
SELECT * , | |
json_extract(data, '$.class.explicit') > 0.3 OR json_extract(data, '$.class.suggestive') > 0.3 AS isNFSW | |
FROM models | |
) | |
WHERE likes > 3 AND {style_query} | |
ORDER BY {sort_query} | |
LIMIT {MAX_PAGE_SIZE} OFFSET {(page - 1) * MAX_PAGE_SIZE} | |
""") | |
results = cursor.fetchall() | |
total = results[0]['total'] if results else 0 | |
total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE | |
models_data = [] | |
for result in results: | |
data = json.loads(result['data']) | |
# update downloads and likes from db table | |
data['downloads'] = result['downloads'] | |
data['likes'] = result['likes'] | |
data['isNFSW'] = bool(result['isNFSW']) | |
models_data.append(data) | |
return { | |
"models": models_data, | |
"totalPages": total_pages | |
} | |
def read_root(): | |
# return html page from string | |
return HTMLResponse(""" | |
<p>Just a bot to sync data from diffusers gallery please go to | |
<a href="https://huggingface.co/spaces/huggingface-projects/diffusers-gallery">https://huggingface.co/spaces/huggingface-projects/diffusers-gallery</a> | |
</p>""") | |
async def repeat_sync(): | |
await sync_data() | |
return "Synced data to huggingface datasets" | |