File size: 4,659 Bytes
e8ea372
 
a101d9b
 
 
e8ea372
a101d9b
e8ea372
a101d9b
 
 
e8ea372
a101d9b
 
 
e8ea372
a101d9b
e8ea372
a101d9b
e8ea372
 
a101d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8ea372
a101d9b
 
e8ea372
a101d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8ea372
 
a101d9b
 
 
e8ea372
 
 
a101d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import re
import asyncio
import aiohttp
import requests
import json
from tqdm import tqdm

from huggingface_hub import Repository

from fastapi import FastAPI, BackgroundTasks
from fastapi_utils.tasks import repeat_every
from fastapi.staticfiles import StaticFiles

from db import Database

HF_TOKEN = os.environ.get("HF_TOKEN")

database = Database()


async def check_image_url(url):
    async with aiohttp.ClientSession() as session:
        async with session.head(url) as resp:
            if resp.status == 200 and resp.content_type.startswith("image/"):
                return url


def fetch_models(page=0):
    response = requests.get(
        f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}&sort=likes')
    data = response.json()
    return {
        "models": [model for model in data['models'] if not model['private']],
        "numItemsPerPage": data['numItemsPerPage'],
        "numTotalItems": data['numTotalItems'],
        "pageIndex": data['pageIndex']
    }


def fetch_model_card(model):
    response = requests.get(
        f'https://huggingface.co/{model["id"]}/raw/main/README.md')
    return response.text


def find_image_in_model_card(text):
    image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)')
    urls = re.findall(image_regex, text)
    # python check if arrya is not empty
    # if urls:
    # tasks = []
    # for url in urls:
    #     tasks.append(check_image_url(url))

    # results = await asyncio.gather(*tasks)
    # return [result for result in results if result]
    return urls


def run_inference(endpoint, img):
    headers = {'Authorization': f'Bearer {HF_TOKEN}',
               "X-Wait-For-Model": "true",
               "X-Use-Cache": "true"}

    response = requests.post(endpoint, headers=headers, data=img)
    return response.json() if response.ok else []


def get_all_models():
    initial = fetch_models()
    num_pages = initial['numTotalItems'] // initial['numItemsPerPage']

    print(f"Found {num_pages} pages")

    # fetch all models
    models = []
    for page in tqdm(range(0, num_pages)):
        print(f"Fetching page {page} of {num_pages}")
        page_models = fetch_models(page)
        models += page_models['models']

    # fetch datacards and images
    print(f"Found {len(models)} models")
    final_models = []
    for model in tqdm(models):
        print(f"Fetching model {model['id']}")
        model_card = fetch_model_card(model)
        images = find_image_in_model_card(model_card)
        # style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
        style = []
        # aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
        aesthetic = []
        final_models.append(
            {**model, "images": images, "style": style, "aesthetic": aesthetic}
        )
    return final_models


async def sync_data():
    models = get_all_models()

    with open("data/models.json", "w") as f:
        json.dump(models, f)

    with database.get_db() as db:
        cursor = db.cursor()
        for model in models:
            try:
                cursor.execute("INSERT INTO models (data) VALUES (?)",
                               [json.dumps(model)])
            except Exception as e:
                print(model['id'], model)
        db.commit()


app = FastAPI()


@ app.get("/sync")
async def sync(background_tasks: BackgroundTasks):
    background_tasks.add_task(sync_data)
    return "Synced data to huggingface datasets"


MAX_PAGE_SIZE = 30


@app.get("/api/models")
def get_page(page: int = 1):
    page = page if page > 0 else 1
    with database.get_db() as db:
        cursor = db.cursor()
        cursor.execute("""
            SELECT *
            FROM (
                SELECT *, COUNT(*) OVER() AS total
                FROM models
                GROUP BY json_extract(data, '$.id')
                HAVING COUNT(json_extract(data, '$.id')) = 1
            )
            ORDER BY json_extract(data, '$.likes') DESC
            LIMIT ? OFFSET ?
        """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
        results = cursor.fetchall()
        total = results[0][3] if results else 0
        total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE

    return {
        "models": [json.loads(result[1]) for result in results],
        "totalPages": total_pages
    }


app.mount("/", StaticFiles(directory="static", html=True), name="static")

# @app.on_event("startup")
# @repeat_every(seconds=1800)
# def repeat_sync():
#     sync_rooms_to_dataset()
#     return "Synced data to huggingface datasets"