yuki-imajuku commited on
Commit
1da638d
1 Parent(s): 3fe6f5e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from PIL import Image
4
+ from controlnet_aux import CannyDetector
5
+ import gradio as gr
6
+ import numpy as np
7
+ import spaces
8
+ import torch
9
+
10
+ from evo_nishikie_v1 import load_evo_nishikie
11
+
12
+
13
+ DESCRIPTION = """# 🐟 Evo-Nishikie
14
+ 🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
15
+
16
+ [Evo-Nishikie](https://huggingface.co/SakanaAI/Evo-Nishikie-v1)は[Sakana AI](https://sakana.ai/)が教育目的で開発した浮世絵に特化した画像生成モデルです。
17
+ 入力した画像を日本語プロンプトに沿って浮世絵風に変換した画像を生成することができます。より詳しくは、上記のブログをご参照ください。
18
+ """
19
+ if not torch.cuda.is_available():
20
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ NUM_IMAGES_PER_PROMPT = 1
27
+ SAFETY_CHECKER = True
28
+ if SAFETY_CHECKER:
29
+ from safety_checker import StableDiffusionSafetyChecker
30
+ from transformers import CLIPFeatureExtractor
31
+
32
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
33
+ "CompVis/stable-diffusion-safety-checker"
34
+ ).to(device)
35
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
36
+ "openai/clip-vit-base-patch32"
37
+ )
38
+
39
+ def check_nsfw_images(
40
+ images: list[Image.Image],
41
+ ) -> tuple[list[Image.Image], list[bool]]:
42
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
43
+ has_nsfw_concepts = safety_checker(
44
+ images=[images], clip_input=safety_checker_input.pixel_values.to(device)
45
+ )
46
+
47
+ return images, has_nsfw_concepts
48
+
49
+
50
+ pipe = load_evo_nishikie("cpu").to(device)
51
+ canny_detector = CannyDetector()
52
+
53
+
54
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
55
+ if randomize_seed:
56
+ seed = random.randint(0, MAX_SEED)
57
+ return seed
58
+
59
+
60
+ @spaces.GPU
61
+ @torch.inference_mode()
62
+ def generate(
63
+ prompt: str,
64
+ input_image: Image.Image,
65
+ seed: int = 0,
66
+ randomize_seed: bool = False,
67
+ progress=gr.Progress(track_tqdm=True),
68
+ ):
69
+ pipe.to(device)
70
+ canny_image = canny_detector(input_image, image_resolution=1024)
71
+ seed = int(randomize_seed_fn(seed, randomize_seed))
72
+ generator = torch.Generator().manual_seed(seed)
73
+
74
+ images = pipe(
75
+ prompt=prompt + "最高品質の輻の浮世絵。",
76
+ negative_prompt="暗い。",
77
+ image=canny_image,
78
+ guidance_scale=8.0,
79
+ controlnet_conditioning_scale=0.6,
80
+ num_inference_steps=50,
81
+ generator=generator,
82
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
83
+ output_type="pil",
84
+ ).images
85
+
86
+ if SAFETY_CHECKER:
87
+ images, has_nsfw_concepts = check_nsfw_images(images)
88
+ if any(has_nsfw_concepts):
89
+ gr.Warning("NSFW content detected.")
90
+ return Image.new("RGB", (512, 512), "WHITE"), seed
91
+ return images[0], seed
92
+
93
+
94
+ examples = [
95
+ ["銀杏が色づく。草木が生えた地面と青空の富士山。", "https://sakana.ai/assets/nedo-grant/nedo_grant.jpeg"],
96
+ ]
97
+
98
+ css = """
99
+ .gradio-container{max-width: 690px !important}
100
+ h1{text-align:center}
101
+ """
102
+ with gr.Blocks(css=css) as demo:
103
+ gr.Markdown(DESCRIPTION)
104
+ with gr.Group():
105
+ with gr.Row():
106
+ with gr.Column(scale=8.0):
107
+ prompt = gr.Textbox(placeholder="日本語でプロンプトを入力してください。", show_label=False)
108
+ input_image = gr.Image(image_mode="RGB", type="pil", show_label=False)
109
+ submit = gr.Button(scale=0)
110
+ result = gr.Image(label="Evo-Nishikieからの生成結果", type="pil", show_label=False)
111
+ with gr.Accordion("詳細設定", open=False):
112
+ seed = gr.Slider(label="シード値", minimum=0, maximum=MAX_SEED, step=1, value=0)
113
+ randomize_seed = gr.Checkbox(label="ランダムにシード値を決定", value=True)
114
+ gr.Examples(examples=examples, inputs=[prompt, input_image], outputs=[result, seed], fn=generate)
115
+ gr.on(
116
+ triggers=[
117
+ submit.click,
118
+ ],
119
+ fn=generate,
120
+ inputs=[
121
+ prompt,
122
+ input_image,
123
+ seed,
124
+ randomize_seed,
125
+ ],
126
+ outputs=[result, seed],
127
+ api_name="run",
128
+ )
129
+ gr.Markdown("""⚠️ 本モデルは実験段階のプロトタイプであり、教育および研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。
130
+ 本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。
131
+ Sakana AIは、本モデルの使用によっ��生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
132
+ 利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。
133
+ アップロードされた画像は画像生成のみに使用され、サーバー上に保存されることはありません。""")
134
+
135
+ demo.queue().launch()