File size: 4,679 Bytes
bd81242
1cd7e2c
bd81242
88c3112
 
 
1cd7e2c
15b44ba
61b9df0
1cd7e2c
 
1e53245
b92287c
 
 
 
74a4765
1609134
 
9222bcf
 
 
1609134
1cd7e2c
 
9222bcf
 
 
 
 
 
81e44ac
 
 
 
 
 
 
dc1d5f9
 
81e44ac
dc1d5f9
 
7f3e695
 
 
 
 
 
1609134
7f3e695
 
7b994e6
7f3e695
 
81e44ac
7f3e695
dfa75b2
81e44ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593deb8
7f3e695
dc1d5f9
81e44ac
 
dc1d5f9
bae56f1
 
 
 
 
 
7f3e695
318ea3e
de6d38a
 
 
318ea3e
81e44ac
de6d38a
bae56f1
7f3e695
 
bae56f1
81e44ac
 
 
dc1d5f9
 
37e5f7e
81e44ac
 
 
bd81242
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# file stuff
import os
from io import BytesIO

#image generation stuff
from PIL import Image

# gradio / hf / image gen stuff
import gradio as gr
from dotenv import load_dotenv


from google.cloud import aiplatform
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel
from vertexai import preview

# GCP credentials stuff
import json
import pybase64
from google.oauth2 import service_account
import google.auth

load_dotenv()

service_account_json = pybase64.b64decode(os.getenv("IMAGEN"))
service_account_info = json.loads(service_account_json)
credentials = service_account.Credentials.from_service_account_info(service_account_info)
project="pdr-imagen"
aiplatform.init(project=project, credentials=credentials)

# enforce password is True if DO_ENFORCE_PW is set to "true"
DO_ENFORCE_PW = os.getenv("DO_ENFORCE_PW")


def trigger_max_gens():
    gr.Warning("🖼️ Max Image Generations Reached! 🖼️")

def generate_image(pw,prompt,model_name):

    if pw != os.getenv("PW") and DO_ENFORCE_PW == "true":
        raise gr.Error("Invalid password. Please try again.")

    try:
        model = ImageGenerationModel.from_pretrained(model_name)
        response = model.generate_images(
            prompt=prompt,
            number_of_images=1,
        )

        image_bytes = response[0]._image_bytes
        image_url = Image.open(BytesIO(image_bytes))

    except Exception as e:
        print(e)
        raise gr.Error(f"An error occurred while generating the image")
    return image_url

custom_js = """
function customJS() {
    //Limit Image Generation
    const MAX_GENERATIONS = 10;
    const DO_ENFORCE_MAX_GENERATIONS = true;

    disableGenerateButton = function() {
        const btn = document.getElementById('btn_generate-images');
        btn.disabled = true;
        btn.classList.add('not-visible');
    }

    triggerMaxGenerationsToast = function() {
        const trigger_max_gens_btn = document.getElementById('trigger-max-gens-btn');
        trigger_max_gens_btn.click();
    }

    setCurrentGenerations = function() {
        if (!DO_ENFORCE_MAX_GENERATIONS) {
            return;
        }
        const curGenerations = localStorage.getItem('currentGenerations');
        console.log(`${curGenerations} / ${MAX_GENERATIONS}`)
        if (curGenerations) {
            if (curGenerations >= MAX_GENERATIONS) {
                triggerMaxGenerationsToast();
                disableGenerateButton();
            } else {
                localStorage.setItem('currentGenerations', parseInt(curGenerations) + 1);
            }
        } else {
            localStorage.setItem('currentGenerations', 1);
        }
    }

    setCurrentGenerations();

    document.getElementById('btn_generate-images').addEventListener('click', function() {
        setCurrentGenerations();
    });

}
"""

with gr.Blocks(js=custom_js) as demo:

    gr.Markdown("# <center>Google Vertex Imagen Generator</center>")
    #password
    pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service",visible=False if DO_ENFORCE_PW == "false" else True)
    gr.Markdown("Need access? Send a DM to @HeaversMike on Twitter or send me an email / Slack msg.",visible=False if DO_ENFORCE_PW == "false" else True)

    #instructions
    with gr.Accordion("Instructions & Tips",label="instructions",open=False):
        with gr.Row():
            gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ")

    #prompts
    with gr.Accordion("Prompt",label="Prompt",open=True):
        text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")

    model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")

    with gr.Row():
        btn = gr.Button("Generate Images", variant="primary", elem_id="btn_generate-images")

    #output
    with gr.Accordion("Image Output",label="Image Output",open=True):
        output_image = gr.Image(label="Image")

    with gr.Row():
        trigger_max_gens_btn = gr.Button(value="Show Max Gens Reached",visible=False,elem_id="trigger-max-gens-btn")

    btn.click(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name=False)
    text.submit(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF

    #js-triggered functionality
    trigger_max_gens_btn.click(trigger_max_gens, None, None)

demo.launch(share=False)