File size: 6,807 Bytes
234c908
 
 
 
 
 
 
 
 
 
 
39deac0
234c908
 
 
 
5b4c5c1
234c908
 
 
 
5b4c5c1
234c908
 
 
 
 
5b4c5c1
234c908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd751dd
234c908
 
4dc2373
5b4c5c1
234c908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import openai
import requests
import os
import fileinput
from dotenv import load_dotenv
import io
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

title="Sorytelling-AI-1-test"
inputs_label="あなたが入力に応じてストーリーを生成します"
outputs_label="AIが生成したストーリー"
visual_outputs_label="AIが生成したビジュアルイメージ"
description="""
- ストーリーの元になるアイデアを入力してください。エラーが発生した場合や、出力された内容が気に入らない場合は、再度送信するか、違う内容を入力して送信してください。
"""

article = """
<ul>
    <li style="font-size: small;">出力されたタイトルと、選択肢ABのいずれかの出力をコピーして、次のステップに進みます→<a href="https://huggingface.co/spaces/Masa-digital-art/Storytelling-AI-2-test">https://huggingface.co/spaces/Masa-digital-art/Storytelling-AI-2-test</a></li>
</ul>

<h5>リリースノート</h5>
<ul>
    <li style="font-size: small;">2023-08-31 v1.0</li>
    <li style="font-size: small;">2023-09-09 v1.2</li>  
</ul>

<h5>注意事項</h5>
<ul>
    <li style="font-size: small;">当サービスでは、2023/3/14にリリースされたOpenAI社のChatGPT APIのgpt-4と、2022/4/13にリリースされたSability AI社のStable Diffusion XL 'sAPIを使用しております。</li>
    <li style="font-size: small;">当サービスで生成されたテキストは、OpenAI が提供する人工知能によるものであり、当サービスやOpenAI がその正確性や信頼性を保証するものではありません。</li>
    <li style="font-size: small;">当サービスで生成されたイメージは、Stability AI が提供する人工知能によるものであり、当サービスやStabiliy AI がその信頼性を保証するものではありません。</li>
    <li style="font-size: small;"><a href="https://platform.openai.com/docs/usage-policies">OpenAI の利用規約</a>に従い、データ保持しない方針です(ただし諸般の事情によっては変更する可能性はございます)。
    <li style="font-size: small;">当サービスで生成されたコンテンツは事実確認をした上で、コンテンツ生成者およびコンテンツ利用者の責任において利用してください。</li>
    <li style="font-size: small;">当サービスでの使用により発生したいかなる損害についても、当社は一切の責任を負いません。</li>
    <li style="font-size: small;">当サービスはβ版のため、予告なくサービスを終了する場合がございます。</li>
</ul>

"""

load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
stability_api = client.StabilityInference(
    key=os.getenv('STABILITY_KEY'),
    engine="stable-diffusion-xl-1024-v1-0", 
    verbose=True,
)
MODEL = "gpt-4"

def get_filetext(filename, cache={}):
    if filename in cache:
        # キャッシュに保存されている場合は、キャッシュからファイル内容を取得する
        return cache[filename]
    else:
        if not os.path.exists(filename):
            raise ValueError(f"ファイル '{filename}' が見つかりませんでした")
        with open(filename, "r") as f:
            text = f.read()
        # ファイル内容をキャッシュする
        cache[filename] = text
        return text

class OpenAI:
    
    @classmethod
    def chat_completion(cls, prompt, start_with=""):
        constraints = get_filetext(filename = "constraints.md")
        template = get_filetext(filename = "template.md")
        
        # ChatCompletion APIに渡すデータを定義する
        data = {
            "model": "gpt-4",
            "messages": [
                {"role": "system", "content": constraints}
                ,{"role": "system", "content": template}
                ,{"role": "assistant", "content": "Sure!"}
                ,{"role": "user", "content": prompt}
                ,{"role": "assistant", "content": start_with}
                ],
        }

        # ChatCompletion APIを呼び出す
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {openai.api_key}"
            },
            json=data
        )

        # ChatCompletion APIから返された結果を取得する
        result = response.json()
        print(result)
        
        content = result["choices"][0]["message"]["content"].strip()
        
        visualize_prompt = content.split("## Prompt for Visual Expression\n\n")[1]

        answers = stability_api.generate(
            prompt=("high quality illustlation, Stunning detail, crisp images, high-contrast images, cinematic lighting, sharp focus, imaginative concept art, fantastic colors, impressive shading, establishing shot, image board, wide shot, image of the beginning of the story" + visualize_prompt),
            steps=50,
            width=768,
            height=512,
        )


        for resp in answers:
            for artifact in resp.artifacts:
                if artifact.finish_reason == generation.FILTER:
                    print("NSFW")
                if artifact.type == generation.ARTIFACT_IMAGE:
                    img = Image.open(io.BytesIO(artifact.binary))
        return [content, img]
                    
class MasasanAI:
    
    @classmethod
    def generate_vision_prompt(cls, user_message):
        template = get_filetext(filename="template.md")
        prompt = f"""
        {user_message}
        ---
        上記を元に、下記テンプレートを埋めてください。
        ---
        {template}
        """
        return prompt

    @classmethod
    def generate_vision(cls, user_message):
        prompt = MasasanAI.generate_vision_prompt(user_message);
        start_with = ""
        result = OpenAI.chat_completion(prompt=prompt, start_with=start_with)
        return result

def main():
    iface = gr.Interface(fn=MasasanAI.generate_vision,
                         inputs=gr.Textbox(label=inputs_label),
                         outputs=[gr.Textbox(label=inputs_label),
                                  gr.Image(label=visual_outputs_label)],
                         title=title,
                         description=description,
                         article=article,
                         allow_flagging='never'
                        )

    iface.launch()

if __name__ == '__main__':
    main()