Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import re | |
import aiohttp | |
import requests | |
import json | |
import subprocess | |
import asyncio | |
from io import BytesIO | |
import uuid | |
from math import ceil | |
from tqdm import tqdm | |
from pathlib import Path | |
from huggingface_hub import Repository | |
from PIL import Image, ImageOps | |
from fastapi import FastAPI, BackgroundTasks | |
from fastapi_utils.tasks import repeat_every | |
from fastapi.middleware.cors import CORSMiddleware | |
import boto3 | |
from db import Database | |
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') | |
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') | |
AWS_S3_BUCKET_NAME = os.getenv('AWS_S3_BUCKET_NAME') | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
DB_FOLDER = Path("diffusers-gallery-data") | |
s3 = boto3.client(service_name='s3', | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_KEY) | |
repo = Repository( | |
local_dir=DB_FOLDER, | |
repo_type="dataset", | |
clone_from="huggingface-projects/diffusers-gallery-data", | |
use_auth_token=True, | |
) | |
repo.git_pull() | |
database = Database(DB_FOLDER) | |
async def upload_resize_image_url(session, image_url): | |
print(f"Uploading image {image_url}") | |
async with session.get(image_url) as response: | |
if response.status == 200 and response.headers['content-type'].startswith('image'): | |
image = Image.open(BytesIO(await response.read())).convert('RGB') | |
# resize image proportional | |
image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
image_bytes = BytesIO() | |
image.save(image_bytes, format="JPEG") | |
image_bytes.seek(0) | |
fname = f'{uuid.uuid4()}.jpg' | |
s3.upload_fileobj(Fileobj=image_bytes, Bucket=AWS_S3_BUCKET_NAME, Key="diffusers-gallery/" + fname, | |
ExtraArgs={"ContentType": "image/jpeg", "CacheControl": "max-age=31536000"}) | |
return fname | |
return None | |
def fetch_models(page=0): | |
response = requests.get( | |
f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}') | |
data = response.json() | |
return { | |
"models": [model for model in data['models'] if not model['private']], | |
"numItemsPerPage": data['numItemsPerPage'], | |
"numTotalItems": data['numTotalItems'], | |
"pageIndex": data['pageIndex'] | |
} | |
def fetch_model_card(model): | |
response = requests.get( | |
f'https://huggingface.co/{model["id"]}/raw/main/README.md') | |
return response.text | |
async def find_image_in_model_card(text): | |
image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)') | |
urls = re.findall(image_regex, text) | |
if not urls: | |
return [] | |
async with aiohttp.ClientSession() as session: | |
tasks = [asyncio.ensure_future(upload_resize_image_url( | |
session, image_url)) for image_url in urls[0:3]] | |
return await asyncio.gather(*tasks) | |
def run_inference(endpoint, img): | |
headers = {'Authorization': f'Bearer {HF_TOKEN}', | |
"X-Wait-For-Model": "true", | |
"X-Use-Cache": "true"} | |
response = requests.post(endpoint, headers=headers, data=img) | |
return response.json() if response.ok else [] | |
async def get_all_models(): | |
initial = fetch_models(0) | |
num_pages = ceil(initial['numTotalItems'] / initial['numItemsPerPage']) | |
print( | |
f"Total items: {initial['numTotalItems']} - Items per page: {initial['numItemsPerPage']}") | |
print(f"Found {num_pages} pages") | |
# fetch all models | |
models = [] | |
for page in tqdm(range(0, num_pages)): | |
print(f"Fetching page {page} of {num_pages}") | |
page_models = fetch_models(page) | |
models += page_models['models'] | |
with open(DB_FOLDER / "models_temp.json", "w") as f: | |
json.dump(models, f) | |
# fetch datacards and images | |
print(f"Found {len(models)} models") | |
final_models = [] | |
for model in tqdm(models): | |
print(f"Fetching model {model['id']}") | |
model_card = fetch_model_card(model) | |
images = await find_image_in_model_card(model_card) | |
# style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0]) | |
style = [] | |
# aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0]) | |
aesthetic = [] | |
final_models.append( | |
{**model, "images": images, "style": style, "aesthetic": aesthetic} | |
) | |
return final_models | |
async def sync_data(): | |
print("Fetching models") | |
models = await get_all_models() | |
with open(DB_FOLDER / "models.json", "w") as f: | |
json.dump(models, f) | |
# with open(DB_FOLDER / "models.json", "r") as f: | |
# models = json.load(f) | |
# open temp db | |
print("Updating database") | |
with database.get_db() as db: | |
cursor = db.cursor() | |
for model in models: | |
try: | |
cursor.execute("INSERT INTO models(id, data) VALUES (?, ?)", | |
[model['id'], json.dumps(model)]) | |
except Exception as e: | |
print(model['id'], model) | |
db.commit() | |
print("Updating repository") | |
subprocess.Popen( | |
"git add . && git commit --amend -m 'update' && git push --force", cwd=DB_FOLDER, shell=True) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# @ app.get("/sync") | |
# async def sync(background_tasks: BackgroundTasks): | |
# background_tasks.add_task(sync_data) | |
# return "Synced data to huggingface datasets" | |
MAX_PAGE_SIZE = 30 | |
def get_page(page: int = 1): | |
page = page if page > 0 else 1 | |
with database.get_db() as db: | |
cursor = db.cursor() | |
cursor.execute(""" | |
SELECT *, COUNT(*) OVER() AS total | |
FROM models | |
WHERE json_extract(data, '$.likes') > 5 | |
ORDER BY datetime(json_extract(data, '$.lastModified')) DESC | |
LIMIT ? OFFSET ? | |
""", (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE)) | |
results = cursor.fetchall() | |
total = results[0][3] if results else 0 | |
total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE | |
return { | |
"models": [json.loads(result[1]) for result in results], | |
"totalPages": total_pages | |
} | |
def read_root(): | |
return "Just a bot to sync data from diffusers gallery" | |
# @app.on_event("startup") | |
# @repeat_every(seconds=60 * 60 * 24, wait_first=True) | |
# async def repeat_sync(): | |
# await sync_data() | |
# return "Synced data to huggingface datasets" | |