radames's picture
first
a101d9b
raw
history blame
4.66 kB
import os
import re
import asyncio
import aiohttp
import requests
import json
from tqdm import tqdm
from huggingface_hub import Repository
from fastapi import FastAPI, BackgroundTasks
from fastapi_utils.tasks import repeat_every
from fastapi.staticfiles import StaticFiles
from db import Database
HF_TOKEN = os.environ.get("HF_TOKEN")
database = Database()
async def check_image_url(url):
async with aiohttp.ClientSession() as session:
async with session.head(url) as resp:
if resp.status == 200 and resp.content_type.startswith("image/"):
return url
def fetch_models(page=0):
response = requests.get(
f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}&sort=likes')
data = response.json()
return {
"models": [model for model in data['models'] if not model['private']],
"numItemsPerPage": data['numItemsPerPage'],
"numTotalItems": data['numTotalItems'],
"pageIndex": data['pageIndex']
}
def fetch_model_card(model):
response = requests.get(
f'https://huggingface.co/{model["id"]}/raw/main/README.md')
return response.text
def find_image_in_model_card(text):
image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)')
urls = re.findall(image_regex, text)
# python check if arrya is not empty
# if urls:
# tasks = []
# for url in urls:
# tasks.append(check_image_url(url))
# results = await asyncio.gather(*tasks)
# return [result for result in results if result]
return urls
def run_inference(endpoint, img):
headers = {'Authorization': f'Bearer {HF_TOKEN}',
"X-Wait-For-Model": "true",
"X-Use-Cache": "true"}
response = requests.post(endpoint, headers=headers, data=img)
return response.json() if response.ok else []
def get_all_models():
initial = fetch_models()
num_pages = initial['numTotalItems'] // initial['numItemsPerPage']
print(f"Found {num_pages} pages")
# fetch all models
models = []
for page in tqdm(range(0, num_pages)):
print(f"Fetching page {page} of {num_pages}")
page_models = fetch_models(page)
models += page_models['models']
# fetch datacards and images
print(f"Found {len(models)} models")
final_models = []
for model in tqdm(models):
print(f"Fetching model {model['id']}")
model_card = fetch_model_card(model)
images = find_image_in_model_card(model_card)
# style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
style = []
# aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
aesthetic = []
final_models.append(
{**model, "images": images, "style": style, "aesthetic": aesthetic}
)
return final_models
async def sync_data():
models = get_all_models()
with open("data/models.json", "w") as f:
json.dump(models, f)
with database.get_db() as db:
cursor = db.cursor()
for model in models:
try:
cursor.execute("INSERT INTO models (data) VALUES (?)",
[json.dumps(model)])
except Exception as e:
print(model['id'], model)
db.commit()
app = FastAPI()
@ app.get("/sync")
async def sync(background_tasks: BackgroundTasks):
background_tasks.add_task(sync_data)
return "Synced data to huggingface datasets"
MAX_PAGE_SIZE = 30
@app.get("/api/models")
def get_page(page: int = 1):
page = page if page > 0 else 1
with database.get_db() as db:
cursor = db.cursor()
cursor.execute("""
SELECT *
FROM (
SELECT *, COUNT(*) OVER() AS total
FROM models
GROUP BY json_extract(data, '$.id')
HAVING COUNT(json_extract(data, '$.id')) = 1
)
ORDER BY json_extract(data, '$.likes') DESC
LIMIT ? OFFSET ?
""", (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
results = cursor.fetchall()
total = results[0][3] if results else 0
total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE
return {
"models": [json.loads(result[1]) for result in results],
"totalPages": total_pages
}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# @app.on_event("startup")
# @repeat_every(seconds=1800)
# def repeat_sync():
# sync_rooms_to_dataset()
# return "Synced data to huggingface datasets"