John6666 commited on
Commit
f519c86
β€’
1 Parent(s): 0d35d04

Upload 13 files

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +68 -65
  3. multit2i.py +31 -14
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Free Multi Models Text-to-Image Heavy-Armed Demo V2
3
  emoji: 🌐🌊
4
  colorFrom: blue
5
  colorTo: purple
 
1
  ---
2
+ title: Free Multi Models Text-to-Image Heavy-Armed Demo V3
3
  emoji: 🌐🌊
4
  colorFrom: blue
5
  colorTo: purple
app.py CHANGED
@@ -11,81 +11,84 @@ from tagger.v2 import V2_ALL_MODELS, v2_random_prompt
11
  from tagger.utils import (V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
12
  V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
13
 
14
- max_images = 8
15
  MAX_SEED = 2**32-1
16
  load_models(models)
17
 
18
  css = """
19
  .model_info { text-align: center; }
20
- .output { width=112px; height=112px; !important; }
21
- .gallery { width=100%; min_height=768px; !important; }
22
  """
23
 
24
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
25
- with gr.Column():
26
- with gr.Group():
27
- model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
28
- model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
29
- with gr.Group():
30
- with gr.Accordion("Prompt from Image File", open=False):
31
- tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
32
- with gr.Accordion(label="Advanced options", open=False):
33
- tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
34
- tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
35
- tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
36
- tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
37
- tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
38
- tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
39
- tagger_generate_from_image = gr.Button(value="Generate Tags from Image")
40
- with gr.Row():
41
- v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
42
- v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
43
- random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
44
- clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", size="sm", scale=1)
45
- prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
46
- neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="")
47
- with gr.Accordion("Advanced options", open=False):
48
- with gr.Row():
49
- width = gr.Slider(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
50
- height = gr.Slider(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
51
- with gr.Row():
52
- steps = gr.Slider(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=0)
53
- cfg = gr.Slider(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=0)
54
- seed = gr.Slider(label="Seed", info="Randomize Seed if -1.", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
55
- with gr.Accordion("Recommended Prompt", open=False):
56
- recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
57
- with gr.Row():
58
- positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
59
- positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
60
- negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
61
- negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
62
- with gr.Accordion("Prompt Transformer", open=False):
63
  with gr.Row():
64
- v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
65
- v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
66
- v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
67
- with gr.Row():
68
- v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
69
- v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
70
- v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
71
- v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
72
- v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
73
- image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=1)
74
- with gr.Row():
75
- run_button = gr.Button("Generate Image", scale=6)
76
- random_button = gr.Button("Random Model 🎲", scale=3)
77
- stop_button = gr.Button('Stop', interactive=False, scale=1)
78
- with gr.Column():
79
- with gr.Group():
 
 
 
 
 
 
80
  with gr.Row():
81
- output = [gr.Image(label='', elem_classes="output", type="filepath", format="png",
82
- show_download_button=True, show_share_button=False, show_label=False,
83
- interactive=False, min_width=80, visible=True) for _ in range(max_images)]
84
- with gr.Group():
85
- results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
86
- container=True, format="png", object_fit="cover", columns=2, rows=2)
87
- image_files = gr.Files(label="Download", interactive=False)
88
- clear_results = gr.Button("Clear Gallery / Download πŸ—‘οΈ")
 
 
 
 
 
 
 
 
 
89
  with gr.Column():
90
  examples = gr.Examples(
91
  examples = [
 
11
  from tagger.utils import (V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
12
  V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
13
 
14
+ max_images = 6
15
  MAX_SEED = 2**32-1
16
  load_models(models)
17
 
18
  css = """
19
  .model_info { text-align: center; }
20
+ .output { width=112px; height=112px; max_width=112px; max_height=112px; !important; }
21
+ .gallery { min_width=512px; min_height=512px; max_height=1024px; !important; }
22
  """
23
 
24
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
25
+ with gr.Row():
26
+ with gr.Column(scale=10):
27
+ with gr.Group():
28
+ with gr.Accordion("Prompt from Image File", open=False):
29
+ tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
30
+ with gr.Accordion(label="Advanced options", open=False):
31
+ with gr.Row():
32
+ tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
33
+ tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
34
+ tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
35
+ with gr.Row():
36
+ tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
37
+ tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
38
+ tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
39
+ tagger_generate_from_image = gr.Button(value="Generate Tags from Image", variant="secondary")
40
+ with gr.Accordion("Prompt Transformer", open=False):
41
+ with gr.Row():
42
+ v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
43
+ v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
44
+ v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
45
+ with gr.Row():
46
+ v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
47
+ v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
48
+ v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
49
+ v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
50
+ v2_copy = gr.Button(value="Copy to clipboard", variant="secondary", size="sm", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
51
  with gr.Row():
52
+ v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
53
+ v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
54
+ random_prompt = gr.Button(value="Extend Prompt 🎲", variant="secondary", size="sm", scale=1)
55
+ clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", variant="secondary", size="sm", scale=1)
56
+ prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
57
+ with gr.Accordion("Advanced options", open=False):
58
+ neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="")
59
+ with gr.Row():
60
+ width = gr.Slider(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
61
+ height = gr.Slider(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
62
+ with gr.Row():
63
+ steps = gr.Slider(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=0)
64
+ cfg = gr.Slider(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=0)
65
+ seed = gr.Slider(label="Seed", info="Randomize Seed if -1.", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
66
+ recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
67
+ with gr.Row():
68
+ positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
69
+ positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
70
+ negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
71
+ negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
72
+
73
+ image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=1)
74
  with gr.Row():
75
+ run_button = gr.Button("Generate Image", scale=6)
76
+ random_button = gr.Button("Random Model 🎲", variant="secondary", scale=3)
77
+ stop_button = gr.Button('Stop', variant="secondary", interactive=False, scale=1)
78
+ with gr.Group():
79
+ model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
80
+ model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
81
+ with gr.Column(scale=10):
82
+ with gr.Group():
83
+ with gr.Row():
84
+ output = [gr.Image(label='', elem_classes="output", type="filepath", format="png",
85
+ show_download_button=True, show_share_button=False, show_label=False, container=False,
86
+ interactive=False, min_width=80, visible=True) for _ in range(max_images)]
87
+ with gr.Group():
88
+ results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
89
+ container=True, format="png", object_fit="cover", columns=2, rows=2)
90
+ image_files = gr.Files(label="Download", interactive=False)
91
+ clear_results = gr.Button("Clear Gallery / Download πŸ—‘οΈ", variant="secondary")
92
  with gr.Column():
93
  examples = gr.Examples(
94
  examples = [
multit2i.py CHANGED
@@ -35,7 +35,7 @@ def is_repo_name(s):
35
 
36
  def get_status(model_name: str):
37
  from huggingface_hub import InferenceClient
38
- client = InferenceClient(timeout=10)
39
  return client.get_model_status(model_name)
40
 
41
 
@@ -54,7 +54,7 @@ def is_loadable(model_name: str, force_gpu: bool = False):
54
 
55
  def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False):
56
  from huggingface_hub import HfApi
57
- api = HfApi()
58
  default_tags = ["diffusers"]
59
  if not sort: sort = "last_modified"
60
  limit = limit * 20 if check_status and force_gpu else limit * 5
@@ -67,7 +67,7 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
67
  print(e)
68
  return models
69
  for model in model_infos:
70
- if not model.private and not model.gated:
71
  loadable = is_loadable(model.id, force_gpu) if check_status else True
72
  if not_tag and not_tag in model.tags or not loadable: continue
73
  models.append(model.id)
@@ -77,7 +77,7 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
77
 
78
  def get_t2i_model_info_dict(repo_id: str):
79
  from huggingface_hub import HfApi
80
- api = HfApi()
81
  info = {"md": "None"}
82
  try:
83
  if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
@@ -86,14 +86,15 @@ def get_t2i_model_info_dict(repo_id: str):
86
  print(f"Error: Failed to get {repo_id}'s info.")
87
  print(e)
88
  return info
89
- if model.private or model.gated: return info
90
  try:
91
  tags = model.tags
92
  except Exception as e:
93
  print(e)
94
  return info
95
  if not 'diffusers' in model.tags: return info
96
- if 'diffusers:StableDiffusionXLPipeline' in tags: info["ver"] = "SDXL"
 
97
  elif 'diffusers:StableDiffusionPipeline' in tags: info["ver"] = "SD1.5"
98
  elif 'diffusers:StableDiffusion3Pipeline' in tags: info["ver"] = "SD3"
99
  else: info["ver"] = "Other"
@@ -109,7 +110,8 @@ def get_t2i_model_info_dict(repo_id: str):
109
 
110
 
111
  def rename_image(image_path: str | None, model_name: str, save_path: str | None = None):
112
- from PIL import Image
 
113
  from datetime import datetime, timezone, timedelta
114
  if image_path is None: return None
115
  dt_now = datetime.now(timezone(timedelta(hours=9)))
@@ -352,7 +354,7 @@ def warm_model(model_name: str):
352
 
353
  # https://huggingface.co/docs/api-inference/detailed_parameters
354
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
355
- def infer_body(client: InferenceClient | gr.Interface, prompt: str, neg_prompt: str | None = None,
356
  height: int | None = None, width: int | None = None,
357
  steps: int | None = None, cfg: int | None = None, seed: int = -1):
358
  png_path = "image.png"
@@ -372,7 +374,7 @@ def infer_body(client: InferenceClient | gr.Interface, prompt: str, neg_prompt:
372
  return str(Path(png_path).resolve())
373
  except Exception as e:
374
  print(e)
375
- return None
376
 
377
 
378
  async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
@@ -392,11 +394,17 @@ async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
392
  await asyncio.sleep(0)
393
  try:
394
  result = await asyncio.wait_for(task, timeout=timeout)
395
- except (Exception, asyncio.TimeoutError) as e:
396
  print(e)
397
  print(f"Task timed out: {model_name}")
398
  if not task.done(): task.cancel()
399
  result = None
 
 
 
 
 
 
400
  if task.done() and result is not None:
401
  with lock:
402
  image = rename_image(result, model_name, save_path)
@@ -404,20 +412,25 @@ async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
404
  return None
405
 
406
 
 
407
  def infer_fn(model_name: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
408
  width: int | None = None, steps: int | None = None, cfg: int | None = None, seed: int = -1,
409
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
410
  if model_name == 'NA':
411
  return None
412
  try:
413
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
 
414
  loop = asyncio.new_event_loop()
 
 
415
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
416
  steps, cfg, seed, save_path, inference_timeout))
417
  except (Exception, asyncio.CancelledError) as e:
418
  print(e)
419
- print(f"Task aborted: {model_name}")
420
  result = None
 
421
  finally:
422
  loop.close()
423
  return result
@@ -432,14 +445,18 @@ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str | None = N
432
  random.seed()
433
  model_name = random.choice(list(loaded_models.keys()))
434
  try:
435
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
 
436
  loop = asyncio.new_event_loop()
 
 
437
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
438
  steps, cfg, seed, save_path, inference_timeout))
439
  except (Exception, asyncio.CancelledError) as e:
440
  print(e)
441
- print(f"Task aborted: {model_name}")
442
  result = None
 
443
  finally:
444
  loop.close()
445
  return result
 
35
 
36
  def get_status(model_name: str):
37
  from huggingface_hub import InferenceClient
38
+ client = InferenceClient(token=HF_TOKEN, timeout=10)
39
  return client.get_model_status(model_name)
40
 
41
 
 
54
 
55
  def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False):
56
  from huggingface_hub import HfApi
57
+ api = HfApi(token=HF_TOKEN)
58
  default_tags = ["diffusers"]
59
  if not sort: sort = "last_modified"
60
  limit = limit * 20 if check_status and force_gpu else limit * 5
 
67
  print(e)
68
  return models
69
  for model in model_infos:
70
+ if not model.private and not model.gated or HF_TOKEN is not None:
71
  loadable = is_loadable(model.id, force_gpu) if check_status else True
72
  if not_tag and not_tag in model.tags or not loadable: continue
73
  models.append(model.id)
 
77
 
78
  def get_t2i_model_info_dict(repo_id: str):
79
  from huggingface_hub import HfApi
80
+ api = HfApi(token=HF_TOKEN)
81
  info = {"md": "None"}
82
  try:
83
  if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
 
86
  print(f"Error: Failed to get {repo_id}'s info.")
87
  print(e)
88
  return info
89
+ if model.private or model.gated and HF_TOKEN is None: return info
90
  try:
91
  tags = model.tags
92
  except Exception as e:
93
  print(e)
94
  return info
95
  if not 'diffusers' in model.tags: return info
96
+ if 'diffusers:FluxPipeline' in tags: info["ver"] = "FLUX.1"
97
+ elif 'diffusers:StableDiffusionXLPipeline' in tags: info["ver"] = "SDXL"
98
  elif 'diffusers:StableDiffusionPipeline' in tags: info["ver"] = "SD1.5"
99
  elif 'diffusers:StableDiffusion3Pipeline' in tags: info["ver"] = "SD3"
100
  else: info["ver"] = "Other"
 
110
 
111
 
112
  def rename_image(image_path: str | None, model_name: str, save_path: str | None = None):
113
+ from PIL import Image, ImageFile
114
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
115
  from datetime import datetime, timezone, timedelta
116
  if image_path is None: return None
117
  dt_now = datetime.now(timezone(timedelta(hours=9)))
 
354
 
355
  # https://huggingface.co/docs/api-inference/detailed_parameters
356
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
357
+ def infer_body(client: InferenceClient | gr.Interface | object, prompt: str, neg_prompt: str | None = None,
358
  height: int | None = None, width: int | None = None,
359
  steps: int | None = None, cfg: int | None = None, seed: int = -1):
360
  png_path = "image.png"
 
374
  return str(Path(png_path).resolve())
375
  except Exception as e:
376
  print(e)
377
+ raise Exception(e)
378
 
379
 
380
  async def infer(model_name: str, prompt: str, neg_prompt: str | None = None,
 
394
  await asyncio.sleep(0)
395
  try:
396
  result = await asyncio.wait_for(task, timeout=timeout)
397
+ except asyncio.TimeoutError as e:
398
  print(e)
399
  print(f"Task timed out: {model_name}")
400
  if not task.done(): task.cancel()
401
  result = None
402
+ raise Exception(f"Task timed out: {model_name}")
403
+ except Exception as e:
404
+ print(e)
405
+ if not task.done(): task.cancel()
406
+ result = None
407
+ raise Exception(e)
408
  if task.done() and result is not None:
409
  with lock:
410
  image = rename_image(result, model_name, save_path)
 
412
  return None
413
 
414
 
415
+ # https://github.com/aio-libs/pytest-aiohttp/issues/8 # also AsyncInferenceClient is buggy.
416
  def infer_fn(model_name: str, prompt: str, neg_prompt: str | None = None, height: int | None = None,
417
  width: int | None = None, steps: int | None = None, cfg: int | None = None, seed: int = -1,
418
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
419
  if model_name == 'NA':
420
  return None
421
  try:
422
+ loop = asyncio.get_running_loop()
423
+ except Exception:
424
  loop = asyncio.new_event_loop()
425
+ try:
426
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
427
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
428
  steps, cfg, seed, save_path, inference_timeout))
429
  except (Exception, asyncio.CancelledError) as e:
430
  print(e)
431
+ print(f"Task aborted: {model_name}, Error: {e}")
432
  result = None
433
+ raise gr.Error(f"Task aborted: {model_name}, Error: {e}")
434
  finally:
435
  loop.close()
436
  return result
 
445
  random.seed()
446
  model_name = random.choice(list(loaded_models.keys()))
447
  try:
448
+ loop = asyncio.get_running_loop()
449
+ except Exception:
450
  loop = asyncio.new_event_loop()
451
+ try:
452
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
453
  result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
454
  steps, cfg, seed, save_path, inference_timeout))
455
  except (Exception, asyncio.CancelledError) as e:
456
  print(e)
457
+ print(f"Task aborted: {model_name}, Error: {e}")
458
  result = None
459
+ raise gr.Error(f"Task aborted: {model_name}, Error: {e}")
460
  finally:
461
  loop.close()
462
  return result