Spaces:
Running
Running
Upload multit2i.py
Browse files- multit2i.py +10 -8
multit2i.py
CHANGED
@@ -3,8 +3,10 @@ import asyncio
|
|
3 |
from threading import RLock
|
4 |
from pathlib import Path
|
5 |
from huggingface_hub import InferenceClient
|
|
|
6 |
|
7 |
|
|
|
8 |
server_timeout = 600
|
9 |
inference_timeout = 300
|
10 |
|
@@ -38,14 +40,14 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
|
|
38 |
if not sort: sort = "last_modified"
|
39 |
models = []
|
40 |
try:
|
41 |
-
model_infos = api.list_models(author=author, pipeline_tag="text-to-image",
|
42 |
tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
|
43 |
except Exception as e:
|
44 |
print(f"Error: Failed to list models.")
|
45 |
print(e)
|
46 |
return models
|
47 |
for model in model_infos:
|
48 |
-
if not model.private and not model.gated:
|
49 |
if not_tag and not_tag in model.tags: continue
|
50 |
models.append(model.id)
|
51 |
if len(models) == limit: break
|
@@ -58,7 +60,7 @@ def get_t2i_model_info_dict(repo_id: str):
|
|
58 |
info = {"md": "None"}
|
59 |
try:
|
60 |
if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
|
61 |
-
model = api.model_info(repo_id=repo_id)
|
62 |
except Exception as e:
|
63 |
print(f"Error: Failed to get {repo_id}'s info.")
|
64 |
print(e)
|
@@ -156,7 +158,7 @@ def load_model(model_name: str):
|
|
156 |
global model_info_dict
|
157 |
if model_name in loaded_models.keys(): return loaded_models[model_name]
|
158 |
try:
|
159 |
-
loaded_models[model_name] = load_from_model(model_name)
|
160 |
print(f"Loaded: {model_name}")
|
161 |
except Exception as e:
|
162 |
if model_name in loaded_models.keys(): del loaded_models[model_name]
|
@@ -179,12 +181,12 @@ def load_model_api(model_name: str):
|
|
179 |
if model_name in loaded_models.keys(): return loaded_models[model_name]
|
180 |
try:
|
181 |
client = InferenceClient(timeout=5)
|
182 |
-
status = client.get_model_status(model_name)
|
183 |
if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
|
184 |
print(f"Failed to load by API: {model_name}")
|
185 |
return None
|
186 |
else:
|
187 |
-
loaded_models[model_name] = InferenceClient(model_name, timeout=server_timeout)
|
188 |
print(f"Loaded by API: {model_name}")
|
189 |
except Exception as e:
|
190 |
if model_name in loaded_models.keys(): del loaded_models[model_name]
|
@@ -340,9 +342,9 @@ def infer_body(client: InferenceClient | gr.Interface, prompt: str, neg_prompt:
|
|
340 |
if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
|
341 |
try:
|
342 |
if isinstance(client, InferenceClient):
|
343 |
-
image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
|
344 |
elif isinstance(client, gr.Interface):
|
345 |
-
image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
|
346 |
else: return None
|
347 |
image.save(png_path)
|
348 |
return str(Path(png_path).resolve())
|
|
|
3 |
from threading import RLock
|
4 |
from pathlib import Path
|
5 |
from huggingface_hub import InferenceClient
|
6 |
+
import os
|
7 |
|
8 |
|
9 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None
|
10 |
server_timeout = 600
|
11 |
inference_timeout = 300
|
12 |
|
|
|
40 |
if not sort: sort = "last_modified"
|
41 |
models = []
|
42 |
try:
|
43 |
+
model_infos = api.list_models(author=author, pipeline_tag="text-to-image", token=HF_TOKEN,
|
44 |
tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
|
45 |
except Exception as e:
|
46 |
print(f"Error: Failed to list models.")
|
47 |
print(e)
|
48 |
return models
|
49 |
for model in model_infos:
|
50 |
+
if not model.private and not model.gated and HF_TOKEN is None:
|
51 |
if not_tag and not_tag in model.tags: continue
|
52 |
models.append(model.id)
|
53 |
if len(models) == limit: break
|
|
|
60 |
info = {"md": "None"}
|
61 |
try:
|
62 |
if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
|
63 |
+
model = api.model_info(repo_id=repo_id, token=HF_TOKEN)
|
64 |
except Exception as e:
|
65 |
print(f"Error: Failed to get {repo_id}'s info.")
|
66 |
print(e)
|
|
|
158 |
global model_info_dict
|
159 |
if model_name in loaded_models.keys(): return loaded_models[model_name]
|
160 |
try:
|
161 |
+
loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN)
|
162 |
print(f"Loaded: {model_name}")
|
163 |
except Exception as e:
|
164 |
if model_name in loaded_models.keys(): del loaded_models[model_name]
|
|
|
181 |
if model_name in loaded_models.keys(): return loaded_models[model_name]
|
182 |
try:
|
183 |
client = InferenceClient(timeout=5)
|
184 |
+
status = client.get_model_status(model_name, token=HF_TOKEN)
|
185 |
if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
|
186 |
print(f"Failed to load by API: {model_name}")
|
187 |
return None
|
188 |
else:
|
189 |
+
loaded_models[model_name] = InferenceClient(model_name, token=HF_TOKEN, timeout=server_timeout)
|
190 |
print(f"Loaded by API: {model_name}")
|
191 |
except Exception as e:
|
192 |
if model_name in loaded_models.keys(): del loaded_models[model_name]
|
|
|
342 |
if cfg is not None and cfg > 0: cfg = kwargs["guidance_scale"] = cfg
|
343 |
try:
|
344 |
if isinstance(client, InferenceClient):
|
345 |
+
image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
|
346 |
elif isinstance(client, gr.Interface):
|
347 |
+
image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
|
348 |
else: return None
|
349 |
image.save(png_path)
|
350 |
return str(Path(png_path).resolve())
|