John6666 commited on
Commit
5fbe98e
β€’
1 Parent(s): f78dcc4

Upload 13 files

Browse files
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: T2i Multi Heavy Demo
3
- emoji: πŸ†
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: Free Multi Models Text-to-Image Heavy-Armed Demo
3
+ emoji: 🌐🌊
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.39.0
8
+ app_file: app.py
9
+ short_description: Text-to-Image
10
+ pinned: true
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from multit2i import (
3
+ load_models,
4
+ infer_multi,
5
+ infer_multi_random,
6
+ save_gallery_images,
7
+ change_model,
8
+ get_model_info_md,
9
+ loaded_models,
10
+ get_positive_prefix,
11
+ get_positive_suffix,
12
+ get_negative_prefix,
13
+ get_negative_suffix,
14
+ get_recom_prompt_type,
15
+ set_recom_prompt_preset,
16
+ )
17
+ from model import models
18
+
19
+ from tagger.tagger import (
20
+ predict_tags_wd,
21
+ remove_specific_prompt,
22
+ convert_danbooru_to_e621_prompt,
23
+ insert_recom_prompt,
24
+ )
25
+ from tagger.fl2sd3longcap import predict_tags_fl2_sd3
26
+ from tagger.v2 import (
27
+ V2_ALL_MODELS,
28
+ v2_random_prompt,
29
+ )
30
+ from tagger.utils import (
31
+ V2_ASPECT_RATIO_OPTIONS,
32
+ V2_RATING_OPTIONS,
33
+ V2_LENGTH_OPTIONS,
34
+ V2_IDENTITY_OPTIONS,
35
+ )
36
+
37
+
38
+ load_models(models, 10)
39
+ #load_models(models, 20) # Fetching 20 models at the same time. default: 5 *This option is not working so far.
40
+
41
+
42
+ css = """
43
+ #model_info { text-align: center; display:block; }
44
+ """
45
+
46
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
47
+ with gr.Column():
48
+ with gr.Accordion("Advanced settings", open=True):
49
+ with gr.Accordion("Recommended Prompt", open=False):
50
+ recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
51
+ with gr.Row():
52
+ positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
53
+ positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
54
+ negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
55
+ negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
56
+ with gr.Accordion("Prompt Transformer", open=False):
57
+ v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
58
+ v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
59
+ v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
60
+ v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
61
+ v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
62
+ v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
63
+ with gr.Accordion("Model", open=True):
64
+ model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0])
65
+ model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_id="model_info")
66
+ with gr.Group():
67
+ with gr.Accordion("Prompt from Image File", open=False):
68
+ tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
69
+ with gr.Accordion(label="Advanced options", open=False):
70
+ tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
71
+ tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
72
+ tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
73
+ tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
74
+ tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
75
+ tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
76
+ tagger_generate_from_image = gr.Button(value="Generate Tags from Image")
77
+ with gr.Row():
78
+ v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
79
+ v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
80
+ random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
81
+ clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", size="sm", scale=1)
82
+ prompt = gr.Text(label="Prompt", lines=1, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
83
+ neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
84
+ with gr.Row():
85
+ run_button = gr.Button("Generate Image", scale=6)
86
+ random_button = gr.Button("Random Model 🎲", scale=3)
87
+ image_num = gr.Number(label="Count", minimum=1, maximum=16, value=1, step=1, interactive=True, scale=1)
88
+ results = gr.Gallery(label="Gallery", interactive=False, show_download_button=True, show_share_button=False,
89
+ container=True, format="png", object_fit="contain")
90
+ image_files = gr.Files(label="Download", interactive=False)
91
+ clear_results = gr.Button("Clear Gallery / Download")
92
+ examples = gr.Examples(
93
+ examples = [
94
+ ["souryuu asuka langley, 1girl, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors"],
95
+ ["sailor moon, magical girl transformation, sparkles and ribbons, soft pastel colors, crescent moon motif, starry night sky background, shoujo manga style"],
96
+ ["kafuu chino, 1girl, solo"],
97
+ ["1girl"],
98
+ ["beautiful sunset"],
99
+ ],
100
+ inputs=[prompt],
101
+ )
102
+ gr.Markdown(
103
+ f"""This demo was created in reference to the following demos.
104
+ - [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood).
105
+ - [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL).
106
+ <br>The first startup takes a mind-boggling amount of time, but not so much after the second.
107
+ This is due to the time it takes for Gradio to generate an example image to cache.
108
+ """
109
+ )
110
+ gr.DuplicateButton(value="Duplicate Space")
111
+
112
+ model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)
113
+ gr.on(
114
+ triggers=[run_button.click, prompt.submit],
115
+ fn=infer_multi,
116
+ inputs=[prompt, neg_prompt, results, image_num, model_name,
117
+ positive_prefix, positive_suffix, negative_prefix, negative_suffix],
118
+ outputs=[results],
119
+ queue=True,
120
+ show_progress="full",
121
+ show_api=True,
122
+ ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
123
+ gr.on(
124
+ triggers=[random_button.click],
125
+ fn=infer_multi_random,
126
+ inputs=[prompt, neg_prompt, results, image_num,
127
+ positive_prefix, positive_suffix, negative_prefix, negative_suffix],
128
+ outputs=[results],
129
+ queue=True,
130
+ show_progress="full",
131
+ show_api=True,
132
+ ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
133
+ clear_prompt.click(lambda: (None, None, None), None, [prompt, v2_series, v2_character], queue=False, show_api=False)
134
+ clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
135
+ recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
136
+ [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
137
+ random_prompt.click(v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
138
+ v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], queue=False, show_api=False)
139
+ tagger_generate_from_image.click(
140
+ predict_tags_wd,
141
+ [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
142
+ [v2_series, v2_character, prompt, gr.Button(visible=False)],
143
+ show_api=False,
144
+ ).success(
145
+ predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
146
+ ).success(
147
+ remove_specific_prompt, [prompt, tagger_keep_tags], [prompt], queue=False, show_api=False,
148
+ ).success(
149
+ convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
150
+ ).success(
151
+ insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
152
+ )
153
+
154
+ demo.queue()
155
+ demo.launch()
model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multit2i import find_model_list
2
+
3
+
4
+ models = [
5
+ 'yodayo-ai/kivotos-xl-2.0',
6
+ 'yodayo-ai/holodayo-xl-2.1',
7
+ 'cagliostrolab/animagine-xl-3.1',
8
+ 'votepurchase/ponyDiffusionV6XL',
9
+ 'eienmojiki/Anything-XL',
10
+ 'eienmojiki/Starry-XL-v5.2',
11
+ 'digiplay/majicMIX_sombre_v2',
12
+ 'digiplay/majicMIX_realistic_v7',
13
+ 'votepurchase/counterfeitV30_v30',
14
+ 'Meina/MeinaMix_V11',
15
+ 'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
16
+ 'kayfahaarukku/UrangDiffusion-1.1',
17
+ 'Raelina/Rae-Diffusion-XL-V2',
18
+ 'Raelina/Raemu-XL-V4',
19
+ ]
20
+
21
+
22
+ models = ['yodayo-ai/kivotos-xl-2.0', 'Raelina/Rae-Diffusion-XL-V2']
23
+
24
+
25
+ # Examples:
26
+ #models = ['yodayo-ai/kivotos-xl-2.0', 'yodayo-ai/holodayo-xl-2.1'] # specific models
27
+ #models = find_model_list("John6666", [], "", "last_modified", 20) # John6666's latest 20 models
28
+ #models = find_model_list("John6666", ["anime"], "", "last_modified", 20) # John6666's latest 20 models with 'anime' tag
29
+ #models = find_model_list("John6666", [], "anime", "last_modified", 20) # John6666's latest 20 models without 'anime' tag
30
+ #models = find_model_list("", [], "", "last_modified", 20) # latest 20 text-to-image models of huggingface
31
+ #models = find_model_list("", [], "", "downloads", 20) # monthly most downloaded 20 text-to-image models of huggingface
32
+
multit2i.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import asyncio
3
+ from threading import RLock, Thread
4
+ from pathlib import Path
5
+
6
+
7
+ lock = RLock()
8
+ loaded_models = {}
9
+ model_info_dict = {}
10
+
11
+
12
+ def to_list(s):
13
+ return [x.strip() for x in s.split(",")]
14
+
15
+
16
+ def list_sub(a, b):
17
+ return [e for e in a if e not in b]
18
+
19
+
20
+ def list_uniq(l):
21
+ return sorted(set(l), key=l.index)
22
+
23
+
24
+ def is_repo_name(s):
25
+ import re
26
+ return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
27
+
28
+
29
+ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30):
30
+ from huggingface_hub import HfApi
31
+ api = HfApi()
32
+ default_tags = ["diffusers"]
33
+ if not sort: sort = "last_modified"
34
+ models = []
35
+ try:
36
+ model_infos = api.list_models(author=author, pipeline_tag="text-to-image",
37
+ tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
38
+ except Exception as e:
39
+ print(f"Error: Failed to list models.")
40
+ print(e)
41
+ return models
42
+ for model in model_infos:
43
+ if not model.private and not model.gated:
44
+ if not_tag and not_tag in model.tags: continue
45
+ models.append(model.id)
46
+ if len(models) == limit: break
47
+ return models
48
+
49
+
50
+ def get_t2i_model_info_dict(repo_id: str):
51
+ from huggingface_hub import HfApi
52
+ api = HfApi()
53
+ info = {"md": "None"}
54
+ try:
55
+ if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
56
+ model = api.model_info(repo_id=repo_id)
57
+ except Exception as e:
58
+ print(f"Error: Failed to get {repo_id}'s info.")
59
+ print(e)
60
+ return info
61
+ if model.private or model.gated: return info
62
+ try:
63
+ tags = model.tags
64
+ except Exception as e:
65
+ print(e)
66
+ return info
67
+ if not 'diffusers' in model.tags: return info
68
+ if 'diffusers:StableDiffusionXLPipeline' in tags: info["ver"] = "SDXL"
69
+ elif 'diffusers:StableDiffusionPipeline' in tags: info["ver"] = "SD1.5"
70
+ elif 'diffusers:StableDiffusion3Pipeline' in tags: info["ver"] = "SD3"
71
+ else: info["ver"] = "Other"
72
+ info["url"] = f"https://huggingface.co/{repo_id}/"
73
+ if model.card_data and model.card_data.tags:
74
+ info["tags"] = model.card_data.tags
75
+ info["downloads"] = model.downloads
76
+ info["likes"] = model.likes
77
+ info["last_modified"] = model.last_modified.strftime("lastmod: %Y-%m-%d")
78
+ un_tags = ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']
79
+ descs = [info["ver"]] + list_sub(info["tags"], un_tags) + [f'DLs: {info["downloads"]}'] + [f'❀: {info["likes"]}'] + [info["last_modified"]]
80
+ info["md"] = f'Model Info: {", ".join(descs)} [Model Repo]({info["url"]})'
81
+ return info
82
+
83
+
84
+ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
85
+ from datetime import datetime, timezone, timedelta
86
+ progress(0, desc="Updating gallery...")
87
+ dt_now = datetime.now(timezone(timedelta(hours=9)))
88
+ basename = dt_now.strftime('%Y%m%d_%H%M%S_')
89
+ i = 1
90
+ if not images: return images
91
+ output_images = []
92
+ output_paths = []
93
+ for image in images:
94
+ filename = f'{image[1]}_{basename}{str(i)}.png'
95
+ i += 1
96
+ oldpath = Path(image[0])
97
+ newpath = oldpath
98
+ try:
99
+ if oldpath.stem == "image" and oldpath.exists():
100
+ newpath = oldpath.resolve().rename(Path(filename).resolve())
101
+ except Exception as e:
102
+ print(e)
103
+ pass
104
+ finally:
105
+ output_paths.append(str(newpath))
106
+ output_images.append((str(newpath), str(filename)))
107
+ progress(1, desc="Gallery updated.")
108
+ return gr.update(value=output_images), gr.update(value=output_paths)
109
+
110
+
111
+ def load_model(model_name: str):
112
+ global loaded_models
113
+ global model_info_dict
114
+ if model_name in loaded_models.keys(): return loaded_models[model_name]
115
+ try:
116
+ with lock:
117
+ loaded_models[model_name] = gr.load(f'models/{model_name}')
118
+ print(f"Loaded: {model_name}")
119
+ except Exception as e:
120
+ with lock:
121
+ if model_name in loaded_models.keys(): del loaded_models[model_name]
122
+ print(f"Failed to load: {model_name}")
123
+ print(e)
124
+ return None
125
+ try:
126
+ with lock:
127
+ model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
128
+ except Exception as e:
129
+ with lock:
130
+ if model_name in model_info_dict.keys(): del model_info_dict[model_name]
131
+ print(e)
132
+ return loaded_models[model_name]
133
+
134
+
135
+ async def async_load_models(models: list, limit: int=5, wait=10):
136
+ sem = asyncio.Semaphore(limit)
137
+ async def async_load_model(model: str):
138
+ async with sem:
139
+ try:
140
+ return await asyncio.to_thread(load_model, model)
141
+ except Exception as e:
142
+ print(e)
143
+ tasks = [asyncio.create_task(async_load_model(model)) for model in models]
144
+ return await asyncio.gather(*tasks, return_exceptions=True)
145
+
146
+
147
+ def load_models(models: list, limit: int=5):
148
+ loop = asyncio.new_event_loop()
149
+ try:
150
+ loop.run_until_complete(async_load_models(models, limit))
151
+ except Exception as e:
152
+ print(e)
153
+ pass
154
+ finally:
155
+ loop.close()
156
+
157
+
158
+ positive_prefix = {
159
+ "Pony": to_list("score_9, score_8_up, score_7_up"),
160
+ "Pony Anime": to_list("source_anime, anime, score_9, score_8_up, score_7_up"),
161
+ }
162
+ positive_suffix = {
163
+ "Common": to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres"),
164
+ "Anime": to_list("anime artwork, anime style, studio anime, highly detailed"),
165
+ }
166
+ negative_prefix = {
167
+ "Pony": to_list("score_6, score_5, score_4"),
168
+ "Pony Anime": to_list("score_6, score_5, score_4, source_pony, source_furry, source_cartoon"),
169
+ "Pony Real": to_list("score_6, score_5, score_4, source_anime, source_pony, source_furry, source_cartoon"),
170
+ }
171
+ negative_suffix = {
172
+ "Common": to_list("lowres, (bad), bad hands, bad feet, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"),
173
+ "Pony Anime": to_list("busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends"),
174
+ "Pony Real": to_list("ugly, airbrushed, simple background, cgi, cartoon, anime"),
175
+ }
176
+ positive_all = negative_all = []
177
+ for k, v in (positive_prefix | positive_suffix).items():
178
+ positive_all = positive_all + v + [s.replace("_", " ") for s in v]
179
+ positive_all = list_uniq(positive_all)
180
+ for k, v in (negative_prefix | negative_suffix).items():
181
+ negative_all = negative_all + v + [s.replace("_", " ") for s in v]
182
+ positive_all = list_uniq(positive_all)
183
+
184
+
185
+ def recom_prompt(prompt: str = "", neg_prompt: str = "", pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
186
+ def flatten(src):
187
+ return [item for row in src for item in row]
188
+ prompts = to_list(prompt)
189
+ neg_prompts = to_list(neg_prompt)
190
+ prompts = list_sub(prompts, positive_all)
191
+ neg_prompts = list_sub(neg_prompts, negative_all)
192
+ last_empty_p = [""] if not prompts and type != "None" else []
193
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
194
+ prefix_ps = flatten([positive_prefix.get(s, []) for s in pos_pre])
195
+ suffix_ps = flatten([positive_suffix.get(s, []) for s in pos_suf])
196
+ prefix_nps = flatten([negative_prefix.get(s, []) for s in neg_pre])
197
+ suffix_nps = flatten([negative_suffix.get(s, []) for s in neg_suf])
198
+ prompt = ", ".join(list_uniq(prefix_ps + prompts + suffix_ps) + last_empty_p)
199
+ neg_prompt = ", ".join(list_uniq(prefix_nps + neg_prompts + suffix_nps) + last_empty_np)
200
+ return prompt, neg_prompt
201
+
202
+
203
+ recom_prompt_type = {
204
+ "None": ([], [], [], []),
205
+ "Auto": ([], [], [], []),
206
+ "Common": ([], ["Common"], [], ["Common"]),
207
+ "Animagine": ([], ["Common", "Anime"], [], ["Common"]),
208
+ "Pony": (["Pony"], ["Common"], ["Pony"], ["Common"]),
209
+ "Pony Anime": (["Pony", "Pony Anime"], ["Common", "Anime"], ["Pony", "Pony Anime"], ["Common", "Pony Anime"]),
210
+ "Pony Real": (["Pony"], ["Common"], ["Pony", "Pony Real"], ["Common", "Pony Real"]),
211
+ }
212
+
213
+
214
+ enable_auto_recom_prompt = False
215
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
216
+ global enable_auto_recom_prompt
217
+ if type == "Auto": enable_auto_recom_prompt = True
218
+ else: enable_auto_recom_prompt = False
219
+ pos_pre, pos_suf, neg_pre, neg_suf = recom_prompt_type.get(type, ([], [], [], []))
220
+ return recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
221
+
222
+
223
+ def set_recom_prompt_preset(type: str = "None"):
224
+ pos_pre, pos_suf, neg_pre, neg_suf = recom_prompt_type.get(type, ([], [], [], []))
225
+ return pos_pre, pos_suf, neg_pre, neg_suf
226
+
227
+
228
+ def get_recom_prompt_type():
229
+ type = list(recom_prompt_type.keys())
230
+ type.remove("Auto")
231
+ return type
232
+
233
+
234
+ def get_positive_prefix():
235
+ return list(positive_prefix.keys())
236
+
237
+
238
+ def get_positive_suffix():
239
+ return list(positive_suffix.keys())
240
+
241
+
242
+ def get_negative_prefix():
243
+ return list(negative_prefix.keys())
244
+
245
+
246
+ def get_negative_suffix():
247
+ return list(negative_suffix.keys())
248
+
249
+
250
+ def get_model_info_md(model_name: str):
251
+ if model_name in model_info_dict.keys(): return model_info_dict[model_name].get("md", "")
252
+
253
+
254
+ def change_model(model_name: str):
255
+ load_model(model_name)
256
+ return get_model_info_md(model_name)
257
+
258
+
259
+ def infer(prompt: str, neg_prompt: str, model_name: str):
260
+ from PIL import Image
261
+ import random
262
+ seed = ""
263
+ rand = random.randint(1, 500)
264
+ for i in range(rand):
265
+ seed += " "
266
+ caption = model_name.split("/")[-1]
267
+ try:
268
+ model = load_model(model_name)
269
+ if not model: return (Image.Image(), None)
270
+ image_path = model(prompt + seed)
271
+ image = Image.open(image_path).convert('RGBA')
272
+ except Exception as e:
273
+ print(e)
274
+ return (Image.Image(), None)
275
+ return (image, caption)
276
+
277
+
278
+ async def infer_multi(prompt: str, neg_prompt: str, results: list, image_num: float, model_name: str,
279
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
280
+ #from tqdm.asyncio import tqdm_asyncio
281
+ image_num = int(image_num)
282
+ images = results if results else []
283
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
284
+ tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for i in range(image_num)]
285
+ results = await asyncio.gather(*tasks, return_exceptions=True)
286
+ #results = await tqdm_asyncio.gather(*tasks)
287
+ if not results: results = []
288
+ for result in results:
289
+ with lock:
290
+ if result and result[1]: images.append(result)
291
+ yield images
292
+
293
+
294
+ async def infer_multi_random(prompt: str, neg_prompt: str, results: list, image_num: float,
295
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
296
+ #from tqdm.asyncio import tqdm_asyncio
297
+ import random
298
+ image_num = int(image_num)
299
+ images = results if results else []
300
+ random.seed()
301
+ model_names = random.choices(list(loaded_models.keys()), k = image_num)
302
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
303
+ tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for model_name in model_names]
304
+ results = await asyncio.gather(*tasks, return_exceptions=True)
305
+ #await tqdm_asyncio.gather(*tasks)
306
+ if not results: results = []
307
+ for result in results:
308
+ with lock:
309
+ if result and result[1]: images.append(result)
310
+ yield images
311
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ torch
3
+ torchvision
4
+ accelerate
5
+ transformers
6
+ optimum[onnxruntime]
7
+ spaces
8
+ dartrs
9
+ httpx==0.13.3
10
+ httpcore
11
+ googletrans==4.0.0rc1
12
+ timm
tagger/character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/fl2sd3longcap.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
+ import spaces
3
+ import re
4
+ from PIL import Image
5
+
6
+ import subprocess
7
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
+
9
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval()
10
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
11
+
12
+
13
+ def fl_modify_caption(caption: str) -> str:
14
+ """
15
+ Removes specific prefixes from captions if present, otherwise returns the original caption.
16
+ Args:
17
+ caption (str): A string containing a caption.
18
+ Returns:
19
+ str: The caption with the prefix removed if it was present, or the original caption.
20
+ """
21
+ # Define the prefixes to remove
22
+ prefix_substrings = [
23
+ ('captured from ', ''),
24
+ ('captured at ', '')
25
+ ]
26
+
27
+ # Create a regex pattern to match any of the prefixes
28
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
29
+ replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
30
+
31
+ # Function to replace matched prefix with its corresponding replacement
32
+ def replace_fn(match):
33
+ return replacers[match.group(0).lower()]
34
+
35
+ # Apply the regex to the caption
36
+ modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
37
+
38
+ # If the caption was modified, return the modified version; otherwise, return the original
39
+ return modified_caption if modified_caption != caption else caption
40
+
41
+
42
+ @spaces.GPU
43
+ def fl_run_example(image):
44
+ task_prompt = "<DESCRIPTION>"
45
+ prompt = task_prompt + "Describe this image in great detail."
46
+
47
+ # Ensure the image is in RGB mode
48
+ if image.mode != "RGB":
49
+ image = image.convert("RGB")
50
+
51
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
52
+ generated_ids = fl_model.generate(
53
+ input_ids=inputs["input_ids"],
54
+ pixel_values=inputs["pixel_values"],
55
+ max_new_tokens=1024,
56
+ num_beams=3
57
+ )
58
+ generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
59
+ parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
60
+ return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
61
+
62
+
63
+ def predict_tags_fl2_sd3(image: Image.Image, input_tags: str, algo: list[str]):
64
+ def to_list(s):
65
+ return [x.strip() for x in s.split(",") if not s == ""]
66
+
67
+ def list_uniq(l):
68
+ return sorted(set(l), key=l.index)
69
+
70
+ if not "Use Florence-2-SD3-Long-Captioner" in algo:
71
+ return input_tags
72
+ tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image) + ", "))
73
+ tag_list.remove("")
74
+ return ", ".join(tag_list)
tagger/output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class UpsamplingOutput:
6
+ upsampled_tags: str
7
+
8
+ copyright_tags: str
9
+ character_tags: str
10
+ general_tags: str
11
+ rating_tag: str
12
+ aspect_ratio_tag: str
13
+ length_tag: str
14
+ identity_tag: str
15
+
16
+ elapsed_time: float = 0.0
tagger/tag_group.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/tagger.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import (
6
+ AutoImageProcessor,
7
+ AutoModelForImageClassification,
8
+ )
9
+ from pathlib import Path
10
+
11
+
12
+ WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
+ WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
+
15
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
16
+ wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
17
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
+
19
+
20
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
21
+ return (
22
+ [f"1{noun}"]
23
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
24
+ + [f"{maximum+1}+{noun}s"]
25
+ )
26
+
27
+
28
+ PEOPLE_TAGS = (
29
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
30
+ )
31
+
32
+
33
+ RATING_MAP = {
34
+ "general": "safe",
35
+ "sensitive": "sensitive",
36
+ "questionable": "nsfw",
37
+ "explicit": "explicit, nsfw",
38
+ }
39
+ DANBOORU_TO_E621_RATING_MAP = {
40
+ "safe": "rating_safe",
41
+ "sensitive": "rating_safe",
42
+ "nsfw": "rating_explicit",
43
+ "explicit, nsfw": "rating_explicit",
44
+ "explicit": "rating_explicit",
45
+ "rating:safe": "rating_safe",
46
+ "rating:general": "rating_safe",
47
+ "rating:sensitive": "rating_safe",
48
+ "rating:questionable, nsfw": "rating_explicit",
49
+ "rating:explicit, nsfw": "rating_explicit",
50
+ }
51
+
52
+
53
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
54
+ kaomojis = [
55
+ "0_0",
56
+ "(o)_(o)",
57
+ "+_+",
58
+ "+_-",
59
+ "._.",
60
+ "<o>_<o>",
61
+ "<|>_<|>",
62
+ "=_=",
63
+ ">_<",
64
+ "3_3",
65
+ "6_9",
66
+ ">_o",
67
+ "@_@",
68
+ "^_^",
69
+ "o_o",
70
+ "u_u",
71
+ "x_x",
72
+ "|_|",
73
+ "||_||",
74
+ ]
75
+
76
+
77
+ def replace_underline(x: str):
78
+ return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
79
+
80
+
81
+ def to_list(s):
82
+ return [x.strip() for x in s.split(",") if not s == ""]
83
+
84
+
85
+ def list_sub(a, b):
86
+ return [e for e in a if e not in b]
87
+
88
+
89
+ def list_uniq(l):
90
+ return sorted(set(l), key=l.index)
91
+
92
+
93
+ def load_dict_from_csv(filename):
94
+ dict = {}
95
+ if not Path(filename).exists():
96
+ if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
97
+ else: return dict
98
+ try:
99
+ with open(filename, 'r', encoding="utf-8") as f:
100
+ lines = f.readlines()
101
+ except Exception:
102
+ print(f"Failed to open dictionary file: {filename}")
103
+ return dict
104
+ for line in lines:
105
+ parts = line.strip().split(',')
106
+ dict[parts[0]] = parts[1]
107
+ return dict
108
+
109
+
110
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
111
+
112
+
113
+ def character_list_to_series_list(character_list):
114
+ output_series_tag = []
115
+ series_tag = ""
116
+ series_dict = anime_series_dict
117
+ for tag in character_list:
118
+ series_tag = series_dict.get(tag, "")
119
+ if tag.endswith(")"):
120
+ tags = tag.split("(")
121
+ character_tag = "(".join(tags[:-1])
122
+ if character_tag.endswith(" "):
123
+ character_tag = character_tag[:-1]
124
+ series_tag = tags[-1].replace(")", "")
125
+
126
+ if series_tag:
127
+ output_series_tag.append(series_tag)
128
+
129
+ return output_series_tag
130
+
131
+
132
+ def select_random_character(series: str, character: str):
133
+ from random import seed, randrange
134
+ seed()
135
+ character_list = list(anime_series_dict.keys())
136
+ character = character_list[randrange(len(character_list) - 1)]
137
+ series = anime_series_dict.get(character.split(",")[0].strip(), "")
138
+ return series, character
139
+
140
+
141
+ def danbooru_to_e621(dtag, e621_dict):
142
+ def d_to_e(match, e621_dict):
143
+ dtag = match.group(0)
144
+ etag = e621_dict.get(replace_underline(dtag), "")
145
+ if etag:
146
+ return etag
147
+ else:
148
+ return dtag
149
+
150
+ import re
151
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
152
+ return tag
153
+
154
+
155
+ danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
156
+
157
+
158
+ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
159
+ if prompt_type == "danbooru": return input_prompt
160
+ tags = input_prompt.split(",") if input_prompt else []
161
+ people_tags: list[str] = []
162
+ other_tags: list[str] = []
163
+ rating_tags: list[str] = []
164
+
165
+ e621_dict = danbooru_to_e621_dict
166
+ for tag in tags:
167
+ tag = replace_underline(tag)
168
+ tag = danbooru_to_e621(tag, e621_dict)
169
+ if tag in PEOPLE_TAGS:
170
+ people_tags.append(tag)
171
+ elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
172
+ rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
173
+ else:
174
+ other_tags.append(tag)
175
+
176
+ rating_tags = sorted(set(rating_tags), key=rating_tags.index)
177
+ rating_tags = [rating_tags[0]] if rating_tags else []
178
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
179
+
180
+ output_prompt = ", ".join(people_tags + other_tags + rating_tags)
181
+
182
+ return output_prompt
183
+
184
+
185
+ def translate_prompt(prompt: str = ""):
186
+ def translate_to_english(prompt):
187
+ import httpcore
188
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
189
+ from googletrans import Translator
190
+ translator = Translator()
191
+ try:
192
+ translated_prompt = translator.translate(prompt, src='auto', dest='en').text
193
+ return translated_prompt
194
+ except Exception as e:
195
+ print(e)
196
+ return prompt
197
+
198
+ def is_japanese(s):
199
+ import unicodedata
200
+ for ch in s:
201
+ name = unicodedata.name(ch, "")
202
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
203
+ return True
204
+ return False
205
+
206
+ def to_list(s):
207
+ return [x.strip() for x in s.split(",")]
208
+
209
+ prompts = to_list(prompt)
210
+ outputs = []
211
+ for p in prompts:
212
+ p = translate_to_english(p) if is_japanese(p) else p
213
+ outputs.append(p)
214
+
215
+ return ", ".join(outputs)
216
+
217
+
218
+ def translate_prompt_to_ja(prompt: str = ""):
219
+ def translate_to_japanese(prompt):
220
+ import httpcore
221
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
222
+ from googletrans import Translator
223
+ translator = Translator()
224
+ try:
225
+ translated_prompt = translator.translate(prompt, src='en', dest='ja').text
226
+ return translated_prompt
227
+ except Exception as e:
228
+ print(e)
229
+ return prompt
230
+
231
+ def is_japanese(s):
232
+ import unicodedata
233
+ for ch in s:
234
+ name = unicodedata.name(ch, "")
235
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
236
+ return True
237
+ return False
238
+
239
+ def to_list(s):
240
+ return [x.strip() for x in s.split(",")]
241
+
242
+ prompts = to_list(prompt)
243
+ outputs = []
244
+ for p in prompts:
245
+ p = translate_to_japanese(p) if not is_japanese(p) else p
246
+ outputs.append(p)
247
+
248
+ return ", ".join(outputs)
249
+
250
+
251
+ def tags_to_ja(itag, dict):
252
+ def t_to_j(match, dict):
253
+ tag = match.group(0)
254
+ ja = dict.get(replace_underline(tag), "")
255
+ if ja:
256
+ return ja
257
+ else:
258
+ return tag
259
+
260
+ import re
261
+ tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
262
+
263
+ return tag
264
+
265
+
266
+ def convert_tags_to_ja(input_prompt: str = ""):
267
+ tags = input_prompt.split(",") if input_prompt else []
268
+ out_tags = []
269
+
270
+ tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
271
+ dict = tags_to_ja_dict
272
+ for tag in tags:
273
+ tag = replace_underline(tag)
274
+ tag = tags_to_ja(tag, dict)
275
+ out_tags.append(tag)
276
+
277
+ return ", ".join(out_tags)
278
+
279
+
280
+ enable_auto_recom_prompt = True
281
+
282
+
283
+ animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
284
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
285
+ pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
286
+ pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
287
+ other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
288
+ other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
289
+ default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
290
+ default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
291
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
292
+ global enable_auto_recom_prompt
293
+ prompts = to_list(prompt)
294
+ neg_prompts = to_list(neg_prompt)
295
+
296
+ prompts = list_sub(prompts, animagine_ps + pony_ps)
297
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
298
+
299
+ last_empty_p = [""] if not prompts and type != "None" else []
300
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
301
+
302
+ if type == "Auto":
303
+ enable_auto_recom_prompt = True
304
+ else:
305
+ enable_auto_recom_prompt = False
306
+ if type == "Animagine":
307
+ prompts = prompts + animagine_ps
308
+ neg_prompts = neg_prompts + animagine_nps
309
+ elif type == "Pony":
310
+ prompts = prompts + pony_ps
311
+ neg_prompts = neg_prompts + pony_nps
312
+
313
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
314
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
315
+
316
+ return prompt, neg_prompt
317
+
318
+
319
+ def load_model_prompt_dict():
320
+ import json
321
+ dict = {}
322
+ path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
323
+ try:
324
+ with open('model_dict.json', encoding='utf-8') as f:
325
+ dict = json.load(f)
326
+ except Exception:
327
+ pass
328
+ return dict
329
+
330
+
331
+ model_prompt_dict = load_model_prompt_dict()
332
+
333
+
334
+ def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
335
+ if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
336
+ prompts = to_list(prompt)
337
+ neg_prompts = to_list(neg_prompt)
338
+ prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
339
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
340
+ last_empty_p = [""] if not prompts and type != "None" else []
341
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
342
+ ps = []
343
+ nps = []
344
+ if model_name in model_prompt_dict.keys():
345
+ ps = to_list(model_prompt_dict[model_name]["prompt"])
346
+ nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
347
+ else:
348
+ ps = default_ps
349
+ nps = default_nps
350
+ prompts = prompts + ps
351
+ neg_prompts = neg_prompts + nps
352
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
353
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
354
+ return prompt, neg_prompt
355
+
356
+
357
+ tag_group_dict = load_dict_from_csv('tag_group.csv')
358
+
359
+
360
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
361
+ def is_dressed(tag):
362
+ import re
363
+ p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
364
+ return p.search(tag)
365
+
366
+ def is_background(tag):
367
+ import re
368
+ p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
369
+ return p.search(tag)
370
+
371
+ un_tags = ['solo']
372
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
373
+ keep_group_dict = {
374
+ "body": ['groups', 'body_parts'],
375
+ "dress": ['groups', 'body_parts', 'attire'],
376
+ "all": group_list,
377
+ }
378
+
379
+ def is_necessary(tag, keep_tags, group_dict):
380
+ if keep_tags == "all":
381
+ return True
382
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
383
+ return False
384
+ elif keep_tags == "body" and is_dressed(tag):
385
+ return False
386
+ elif is_background(tag):
387
+ return False
388
+ else:
389
+ return True
390
+
391
+ if keep_tags == "all": return input_prompt
392
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
393
+ explicit_group = list(set(group_list) ^ set(keep_group))
394
+
395
+ tags = input_prompt.split(",") if input_prompt else []
396
+ people_tags: list[str] = []
397
+ other_tags: list[str] = []
398
+
399
+ group_dict = tag_group_dict
400
+ for tag in tags:
401
+ tag = replace_underline(tag)
402
+ if tag in PEOPLE_TAGS:
403
+ people_tags.append(tag)
404
+ elif is_necessary(tag, keep_tags, group_dict):
405
+ other_tags.append(tag)
406
+
407
+ output_prompt = ", ".join(people_tags + other_tags)
408
+
409
+ return output_prompt
410
+
411
+
412
+ def sort_taglist(tags: list[str]):
413
+ if not tags: return []
414
+ character_tags: list[str] = []
415
+ series_tags: list[str] = []
416
+ people_tags: list[str] = []
417
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
418
+ group_tags = {}
419
+ other_tags: list[str] = []
420
+ rating_tags: list[str] = []
421
+
422
+ group_dict = tag_group_dict
423
+ group_set = set(group_dict.keys())
424
+ character_set = set(anime_series_dict.keys())
425
+ series_set = set(anime_series_dict.values())
426
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
427
+
428
+ for tag in tags:
429
+ tag = replace_underline(tag)
430
+ if tag in PEOPLE_TAGS:
431
+ people_tags.append(tag)
432
+ elif tag in rating_set:
433
+ rating_tags.append(tag)
434
+ elif tag in group_set:
435
+ elem = group_dict[tag]
436
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
437
+ elif tag in character_set:
438
+ character_tags.append(tag)
439
+ elif tag in series_set:
440
+ series_tags.append(tag)
441
+ else:
442
+ other_tags.append(tag)
443
+
444
+ output_group_tags: list[str] = []
445
+ for k in group_list:
446
+ output_group_tags.extend(group_tags.get(k, []))
447
+
448
+ rating_tags = [rating_tags[0]] if rating_tags else []
449
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
450
+
451
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
452
+
453
+ return output_tags
454
+
455
+
456
+ def sort_tags(tags: str):
457
+ if not tags: return ""
458
+ taglist: list[str] = []
459
+ for tag in tags.split(","):
460
+ taglist.append(tag.strip())
461
+ taglist = list(filter(lambda x: x != "", taglist))
462
+ return ", ".join(sort_taglist(taglist))
463
+
464
+
465
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
466
+ results = {
467
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
468
+ }
469
+
470
+ rating = {}
471
+ character = {}
472
+ general = {}
473
+
474
+ for k, v in results.items():
475
+ if k.startswith("rating:"):
476
+ rating[k.replace("rating:", "")] = v
477
+ continue
478
+ elif k.startswith("character:"):
479
+ character[k.replace("character:", "")] = v
480
+ continue
481
+
482
+ general[k] = v
483
+
484
+ character = {k: v for k, v in character.items() if v >= character_threshold}
485
+ general = {k: v for k, v in general.items() if v >= general_threshold}
486
+
487
+ return rating, character, general
488
+
489
+
490
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
491
+ people_tags: list[str] = []
492
+ other_tags: list[str] = []
493
+ rating_tag = RATING_MAP[rating[0]]
494
+
495
+ for tag in general:
496
+ if tag in PEOPLE_TAGS:
497
+ people_tags.append(tag)
498
+ else:
499
+ other_tags.append(tag)
500
+
501
+ all_tags = people_tags + other_tags
502
+
503
+ return ", ".join(all_tags)
504
+
505
+
506
+ @spaces.GPU()
507
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
508
+ inputs = wd_processor.preprocess(image, return_tensors="pt")
509
+
510
+ outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
511
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
512
+
513
+ # get probabilities
514
+ results = {
515
+ wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
516
+ }
517
+ # rating, character, general
518
+ rating, character, general = postprocess_results(
519
+ results, general_threshold, character_threshold
520
+ )
521
+ prompt = gen_prompt(
522
+ list(rating.keys()), list(character.keys()), list(general.keys())
523
+ )
524
+ output_series_tag = ""
525
+ output_series_list = character_list_to_series_list(character.keys())
526
+ if output_series_list:
527
+ output_series_tag = output_series_list[0]
528
+ else:
529
+ output_series_tag = ""
530
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
531
+
532
+
533
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
534
+ character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
535
+ if not "Use WD Tagger" in algo and len(algo) != 0:
536
+ return input_series, input_character, input_tags, gr.update(interactive=True)
537
+ return predict_tags(image, general_threshold, character_threshold)
538
+
539
+
540
+ def compose_prompt_to_copy(character: str, series: str, general: str):
541
+ characters = character.split(",") if character else []
542
+ serieses = series.split(",") if series else []
543
+ generals = general.split(",") if general else []
544
+ tags = characters + serieses + generals
545
+ cprompt = ",".join(tags) if tags else ""
546
+ return cprompt
tagger/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
+
4
+
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
6
+ "ultra_wide",
7
+ "wide",
8
+ "square",
9
+ "tall",
10
+ "ultra_tall",
11
+ ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
+ "sfw",
14
+ "general",
15
+ "sensitive",
16
+ "nsfw",
17
+ "questionable",
18
+ "explicit",
19
+ ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
+ "very_short",
22
+ "short",
23
+ "medium",
24
+ "long",
25
+ "very_long",
26
+ ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
+ "none",
29
+ "lax",
30
+ "strict",
31
+ ]
32
+
33
+
34
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
+ def gradio_copy_text(_text: None):
36
+ gr.Info("Copied!")
37
+
38
+
39
+ COPY_ACTION_JS = """\
40
+ (inputs, _outputs) => {
41
+ // inputs is the string value of the input_text
42
+ if (inputs.trim() !== "") {
43
+ navigator.clipboard.writeText(inputs);
44
+ }
45
+ }"""
tagger/v2.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from typing import Callable
4
+ from pathlib import Path
5
+
6
+ from dartrs.v2 import (
7
+ V2Model,
8
+ MixtralModel,
9
+ MistralModel,
10
+ compose_prompt,
11
+ LengthTag,
12
+ AspectRatioTag,
13
+ RatingTag,
14
+ IdentityTag,
15
+ )
16
+ from dartrs.dartrs import DartTokenizer
17
+ from dartrs.utils import get_generation_config
18
+
19
+
20
+ import gradio as gr
21
+ from gradio.components import Component
22
+
23
+
24
+ try:
25
+ from output import UpsamplingOutput
26
+ except:
27
+ from .output import UpsamplingOutput
28
+
29
+
30
+ V2_ALL_MODELS = {
31
+ "dart-v2-moe-sft": {
32
+ "repo": "p1atdev/dart-v2-moe-sft",
33
+ "type": "sft",
34
+ "class": MixtralModel,
35
+ },
36
+ "dart-v2-sft": {
37
+ "repo": "p1atdev/dart-v2-sft",
38
+ "type": "sft",
39
+ "class": MistralModel,
40
+ },
41
+ }
42
+
43
+
44
+ def prepare_models(model_config: dict):
45
+ model_name = model_config["repo"]
46
+ tokenizer = DartTokenizer.from_pretrained(model_name)
47
+ model = model_config["class"].from_pretrained(model_name)
48
+
49
+ return {
50
+ "tokenizer": tokenizer,
51
+ "model": model,
52
+ }
53
+
54
+
55
+ def normalize_tags(tokenizer: DartTokenizer, tags: str):
56
+ """Just remove unk tokens."""
57
+ return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
58
+
59
+
60
+ @torch.no_grad()
61
+ def generate_tags(
62
+ model: V2Model,
63
+ tokenizer: DartTokenizer,
64
+ prompt: str,
65
+ ban_token_ids: list[int],
66
+ ):
67
+ output = model.generate(
68
+ get_generation_config(
69
+ prompt,
70
+ tokenizer=tokenizer,
71
+ temperature=1,
72
+ top_p=0.9,
73
+ top_k=100,
74
+ max_new_tokens=256,
75
+ ban_token_ids=ban_token_ids,
76
+ ),
77
+ )
78
+
79
+ return output
80
+
81
+
82
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
83
+ return (
84
+ [f"1{noun}"]
85
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
86
+ + [f"{maximum+1}+{noun}s"]
87
+ )
88
+
89
+
90
+ PEOPLE_TAGS = (
91
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
92
+ )
93
+
94
+
95
+ def gen_prompt_text(output: UpsamplingOutput):
96
+ # separate people tags (e.g. 1girl)
97
+ people_tags = []
98
+ other_general_tags = []
99
+
100
+ for tag in output.general_tags.split(","):
101
+ tag = tag.strip()
102
+ if tag in PEOPLE_TAGS:
103
+ people_tags.append(tag)
104
+ else:
105
+ other_general_tags.append(tag)
106
+
107
+ return ", ".join(
108
+ [
109
+ part.strip()
110
+ for part in [
111
+ *people_tags,
112
+ output.character_tags,
113
+ output.copyright_tags,
114
+ *other_general_tags,
115
+ output.upsampled_tags,
116
+ output.rating_tag,
117
+ ]
118
+ if part.strip() != ""
119
+ ]
120
+ )
121
+
122
+
123
+ def elapsed_time_format(elapsed_time: float) -> str:
124
+ return f"Elapsed: {elapsed_time:.2f} seconds"
125
+
126
+
127
+ def parse_upsampling_output(
128
+ upsampler: Callable[..., UpsamplingOutput],
129
+ ):
130
+ def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
131
+ output = upsampler(*args)
132
+
133
+ return (
134
+ gen_prompt_text(output),
135
+ elapsed_time_format(output.elapsed_time),
136
+ gr.update(interactive=True),
137
+ gr.update(interactive=True),
138
+ )
139
+
140
+ return _parse_upsampling_output
141
+
142
+
143
+ class V2UI:
144
+ model_name: str | None = None
145
+ model: V2Model
146
+ tokenizer: DartTokenizer
147
+
148
+ input_components: list[Component] = []
149
+ generate_btn: gr.Button
150
+
151
+ def on_generate(
152
+ self,
153
+ model_name: str,
154
+ copyright_tags: str,
155
+ character_tags: str,
156
+ general_tags: str,
157
+ rating_tag: RatingTag,
158
+ aspect_ratio_tag: AspectRatioTag,
159
+ length_tag: LengthTag,
160
+ identity_tag: IdentityTag,
161
+ ban_tags: str,
162
+ *args,
163
+ ) -> UpsamplingOutput:
164
+ if self.model_name is None or self.model_name != model_name:
165
+ models = prepare_models(V2_ALL_MODELS[model_name])
166
+ self.model = models["model"]
167
+ self.tokenizer = models["tokenizer"]
168
+ self.model_name = model_name
169
+
170
+ # normalize tags
171
+ # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
172
+ # character_tags = normalize_tags(self.tokenizer, character_tags)
173
+ # general_tags = normalize_tags(self.tokenizer, general_tags)
174
+
175
+ ban_token_ids = self.tokenizer.encode(ban_tags.strip())
176
+
177
+ prompt = compose_prompt(
178
+ prompt=general_tags,
179
+ copyright=copyright_tags,
180
+ character=character_tags,
181
+ rating=rating_tag,
182
+ aspect_ratio=aspect_ratio_tag,
183
+ length=length_tag,
184
+ identity=identity_tag,
185
+ )
186
+
187
+ start = time.time()
188
+ upsampled_tags = generate_tags(
189
+ self.model,
190
+ self.tokenizer,
191
+ prompt,
192
+ ban_token_ids,
193
+ )
194
+ elapsed_time = time.time() - start
195
+
196
+ return UpsamplingOutput(
197
+ upsampled_tags=upsampled_tags,
198
+ copyright_tags=copyright_tags,
199
+ character_tags=character_tags,
200
+ general_tags=general_tags,
201
+ rating_tag=rating_tag,
202
+ aspect_ratio_tag=aspect_ratio_tag,
203
+ length_tag=length_tag,
204
+ identity_tag=identity_tag,
205
+ elapsed_time=elapsed_time,
206
+ )
207
+
208
+
209
+ def parse_upsampling_output_simple(upsampler: UpsamplingOutput):
210
+ return gen_prompt_text(upsampler)
211
+
212
+
213
+ v2 = V2UI()
214
+
215
+
216
+ def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "",
217
+ general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square",
218
+ length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"):
219
+ raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags,
220
+ rating, aspect_ratio, length, identity, ban_tags))
221
+ return raw_prompt
222
+
223
+
224
+ def load_dict_from_csv(filename):
225
+ dict = {}
226
+ if not Path(filename).exists():
227
+ if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
228
+ else: return dict
229
+ try:
230
+ with open(filename, 'r', encoding="utf-8") as f:
231
+ lines = f.readlines()
232
+ except Exception:
233
+ print(f"Failed to open dictionary file: {filename}")
234
+ return dict
235
+ for line in lines:
236
+ parts = line.strip().split(',')
237
+ dict[parts[0]] = parts[1]
238
+ return dict
239
+
240
+
241
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
242
+
243
+
244
+ def select_random_character(series: str, character: str):
245
+ from random import seed, randrange
246
+ seed()
247
+ character_list = list(anime_series_dict.keys())
248
+ character = character_list[randrange(len(character_list) - 1)]
249
+ series = anime_series_dict.get(character.split(",")[0].strip(), "")
250
+ return series, character
251
+
252
+
253
+ def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw",
254
+ aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax",
255
+ ban_tags: str = "censored", model: str = "dart-v2-moe-sft"):
256
+ if copyright == "" and character == "":
257
+ copyright, character = select_random_character("", "")
258
+ raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating,
259
+ aspect_ratio, length, identity, ban_tags)
260
+ return raw_prompt, copyright, character