Spaces:
Runtime error
Runtime error
vertex-imagen simple generation
Browse files- app.py +21 -224
- env-sample +0 -4
app.py
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
# file stuff
|
2 |
import os
|
3 |
-
import sys
|
4 |
-
import zipfile
|
5 |
-
import requests
|
6 |
-
import tempfile
|
7 |
from io import BytesIO
|
8 |
-
import random
|
9 |
-
import string
|
10 |
|
11 |
#image generation stuff
|
12 |
from PIL import Image
|
@@ -15,18 +9,9 @@ from PIL import Image
|
|
15 |
import gradio as gr
|
16 |
from dotenv import load_dotenv
|
17 |
|
18 |
-
# stats stuff
|
19 |
-
from pymongo.mongo_client import MongoClient
|
20 |
-
from pymongo.server_api import ServerApi
|
21 |
-
import time
|
22 |
-
|
23 |
-
# countdown stuff
|
24 |
-
from datetime import datetime, timedelta
|
25 |
-
|
26 |
|
27 |
from google.cloud import aiplatform
|
28 |
import vertexai
|
29 |
-
# from vertexai.preview.generative_models import GenerativeModel
|
30 |
from vertexai.preview.vision_models import ImageGenerationModel
|
31 |
from vertexai import preview
|
32 |
import uuid #for generating unique filenames
|
@@ -40,239 +25,51 @@ import google.auth
|
|
40 |
|
41 |
load_dotenv()
|
42 |
|
43 |
-
|
44 |
-
pw_key = os.getenv("PW")
|
45 |
-
|
46 |
-
if pw_key == "<YOUR_PW>":
|
47 |
-
pw_key = ""
|
48 |
-
|
49 |
-
if pw_key == "":
|
50 |
-
sys.exit("Please Provide A Password in the Environment Variables")
|
51 |
-
|
52 |
-
|
53 |
-
# Connect to MongoDB
|
54 |
-
uri = os.getenv("MONGO_URI")
|
55 |
-
mongo_client = MongoClient(uri, server_api=ServerApi('1'))
|
56 |
-
|
57 |
-
mongo_db = mongo_client.pdr
|
58 |
-
mongo_collection = mongo_db["images"]
|
59 |
-
|
60 |
-
image_labels_global = []
|
61 |
-
image_paths_global = []
|
62 |
-
|
63 |
-
#load challenges
|
64 |
-
challenges = []
|
65 |
-
with open('challenges.txt', 'r') as file:
|
66 |
-
for line in file:
|
67 |
-
challenges.append(line.strip())
|
68 |
-
|
69 |
-
#get GCP credentials
|
70 |
-
def get_credentials():
|
71 |
-
creds_json_str = os.getenv("IMAGEN") # get json credentials stored as a string
|
72 |
-
# create a temporary file
|
73 |
-
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp:
|
74 |
-
temp.write(creds_json_str) # write in json format
|
75 |
-
temp_filename = temp.name
|
76 |
-
return temp_filename
|
77 |
-
def get_creds_json():
|
78 |
-
creds_json_str = os.getenv("IMAGEN")
|
79 |
-
return json.loads(creds_json_str)
|
80 |
-
|
81 |
-
# pass
|
82 |
-
# service_acct_json = open(pdr-imagen-encoded.json')
|
83 |
service_account_json = pybase64.b64decode(os.getenv("IMAGEN"))
|
84 |
service_account_info = json.loads(service_account_json)
|
85 |
credentials = service_account.Credentials.from_service_account_info(service_account_info)
|
86 |
-
#os.environ["GOOGLE_APPLICATION_CREDENTIALS"]= get_creds_json()
|
87 |
project="pdr-imagen"
|
88 |
aiplatform.init(project=project, credentials=credentials)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
global challenge
|
94 |
-
challenge = random.choice(challenges)
|
95 |
-
return challenge
|
96 |
-
|
97 |
-
# set initial challenge
|
98 |
-
challenge = get_challenge()
|
99 |
-
|
100 |
-
def update_labels(show_labels):
|
101 |
-
updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
|
102 |
-
return updated_gallery
|
103 |
-
|
104 |
-
def generate_images_wrapper(prompts, pw, show_labels,model):
|
105 |
-
global image_paths_global, image_labels_global
|
106 |
-
image_paths, image_labels = generate_images(prompts, pw,model)
|
107 |
-
image_paths_global = image_paths
|
108 |
-
|
109 |
-
# store this as a global so we can handle toggle state
|
110 |
-
image_labels_global = image_labels
|
111 |
-
image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
|
112 |
-
|
113 |
-
return image_data
|
114 |
-
|
115 |
-
def download_image(url):
|
116 |
-
response = requests.get(url)
|
117 |
-
if response.status_code == 200:
|
118 |
-
return response.content
|
119 |
-
else:
|
120 |
-
raise Exception(f"Failed to download image from URL: {url}")
|
121 |
-
|
122 |
-
def zip_images(image_paths_and_labels):
|
123 |
-
zip_file_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip').name
|
124 |
-
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
125 |
-
for image_url, _ in image_paths_and_labels:
|
126 |
-
# image_content = download_image(image_url)
|
127 |
-
image_content = open(image_url, "rb").read()
|
128 |
-
random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
|
129 |
-
# Write the image content to the zip file with the random filename
|
130 |
-
zipf.writestr(image_url, image_content)
|
131 |
-
return zip_file_path
|
132 |
-
|
133 |
-
|
134 |
-
def download_all_images():
|
135 |
-
global image_paths_global, image_labels_global
|
136 |
-
if not image_paths_global:
|
137 |
-
raise gr.Error("No images to download.")
|
138 |
-
image_paths_and_labels = list(zip(image_paths_global, image_labels_global))
|
139 |
-
zip_path = zip_images(image_paths_and_labels)
|
140 |
-
image_paths_global = [] # Reset the global variable
|
141 |
-
image_labels_global = [] # Reset the global variable
|
142 |
-
|
143 |
-
# delete all local images
|
144 |
-
for image_path, _ in image_paths_and_labels:
|
145 |
-
os.remove(image_path)
|
146 |
-
|
147 |
-
|
148 |
-
return zip_path
|
149 |
-
|
150 |
-
def generate_images(prompts, pw,model_name):
|
151 |
-
# Check for a valid password
|
152 |
|
153 |
-
|
154 |
-
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
image_labels = [] # shows the prompt in the gallery above the image
|
158 |
-
users = [] # adds the user to the label
|
159 |
|
160 |
-
#
|
161 |
-
prompts_list = [prompt for prompt in prompts.split(';') if prompt]
|
162 |
-
|
163 |
-
# model = "claude-3-opus-20240229"
|
164 |
-
|
165 |
-
for i, entry in enumerate(prompts_list):
|
166 |
-
entry_parts = entry.split('-', 1) # Split by the first dash found
|
167 |
-
if len(entry_parts) == 2:
|
168 |
-
#raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
|
169 |
-
user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
|
170 |
-
else:
|
171 |
-
text = entry.strip() # If no initials are provided, use the entire prompt as the text
|
172 |
-
user_initials = ""
|
173 |
-
|
174 |
-
users.append(user_initials) # Append user initials to the list
|
175 |
-
|
176 |
-
prompt_w_challenge = f"{challenge}: {text}"
|
177 |
-
print(prompt_w_challenge)
|
178 |
-
|
179 |
-
start_time = time.time()
|
180 |
-
|
181 |
-
try:
|
182 |
-
#what model to use?
|
183 |
-
model = ImageGenerationModel.from_pretrained(model_name)
|
184 |
-
response = model.generate_images(
|
185 |
-
prompt=prompt_w_challenge,
|
186 |
-
number_of_images=1,
|
187 |
-
)
|
188 |
-
|
189 |
-
end_time = time.time()
|
190 |
-
gen_time = end_time - start_time # total generation time
|
191 |
-
|
192 |
-
image_bytes = response[0]._image_bytes
|
193 |
-
image_url = Image.open(BytesIO(image_bytes))
|
194 |
-
|
195 |
-
#generate random filename using uuid
|
196 |
-
#filename = f"{uuid.uuid4()}.png"
|
197 |
-
|
198 |
-
# Save the image to a temporary file, and return this
|
199 |
-
#image_url = filename
|
200 |
-
#response[0].save(filename)
|
201 |
-
image_label = f"{i+1}: {text}"
|
202 |
-
|
203 |
-
model_for_db = f"imagen-{model_name}"
|
204 |
-
|
205 |
-
try:
|
206 |
-
# Save the prompt, model, image URL, generation time and creation timestamp to the database
|
207 |
-
mongo_collection.insert_one({"user": user_initials, "text": text, "model": model_for_db, "image_url": "bytes", "gen_time": gen_time, "timestamp": time.time(), "challenge": challenge})
|
208 |
-
except Exception as e:
|
209 |
-
print(e)
|
210 |
-
raise gr.Error("An error occurred while saving the prompt to the database.")
|
211 |
-
|
212 |
-
# Append the image URL and label to their respective lists
|
213 |
-
image_paths.append(image_url)
|
214 |
-
image_labels.append(image_label)
|
215 |
-
except Exception as e:
|
216 |
-
print(e)
|
217 |
-
raise gr.Error(f"An error occurred while generating the image for: {entry}")
|
218 |
-
return image_paths, image_labels
|
219 |
-
|
220 |
-
#custom css
|
221 |
-
css = """
|
222 |
-
#gallery-images .caption-label {
|
223 |
-
white-space: normal !important;
|
224 |
-
}
|
225 |
-
"""
|
226 |
-
|
227 |
-
|
228 |
-
with gr.Blocks(css=css) as demo:
|
229 |
-
|
230 |
-
gr.Markdown("# <center>Prompt de Resistance Vertex Imagen</center>")
|
231 |
-
|
232 |
-
pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service")
|
233 |
|
234 |
#instructions
|
235 |
with gr.Accordion("Instructions & Tips",label="instructions",open=False):
|
236 |
with gr.Row():
|
237 |
-
gr.Markdown("**Instructions**: To use this service, please enter the password. Then generate an image from the prompt field below in response to the challenge, then click the download arrow from the top right of the image to save it.")
|
238 |
gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ")
|
239 |
|
240 |
-
#challenge
|
241 |
-
challenge_display = gr.Textbox(label="Challenge", value=get_challenge())
|
242 |
-
challenge_display.disabled = True
|
243 |
-
regenerate_btn = gr.Button("New Challenge")
|
244 |
-
|
245 |
-
|
246 |
#prompts
|
247 |
-
with gr.Accordion("
|
248 |
text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
|
249 |
|
250 |
model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
|
251 |
|
252 |
with gr.Row():
|
253 |
-
|
254 |
-
|
255 |
|
256 |
#output
|
257 |
-
with gr.Accordion("Image
|
258 |
-
|
259 |
-
show_labels = gr.Checkbox(label="Show Labels", value=False)
|
260 |
-
|
261 |
-
|
262 |
-
with gr.Accordion("Downloads",label="download",open=True):
|
263 |
-
download_all_btn = gr.Button("Download All")
|
264 |
-
download_link = gr.File(label="Download Zip")
|
265 |
-
|
266 |
-
# generate new challenge
|
267 |
-
regenerate_btn.click(fn=get_challenge, inputs=[], outputs=[challenge_display])
|
268 |
-
|
269 |
-
#submissions
|
270 |
-
#trigger generation either through hitting enter in the text field, or clicking the button.
|
271 |
-
btn.click(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model ], outputs=output_images, api_name=False)
|
272 |
-
text.submit(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
|
273 |
-
show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
|
274 |
|
275 |
-
|
276 |
-
|
277 |
|
278 |
demo.launch(share=False)
|
|
|
1 |
# file stuff
|
2 |
import os
|
|
|
|
|
|
|
|
|
3 |
from io import BytesIO
|
|
|
|
|
4 |
|
5 |
#image generation stuff
|
6 |
from PIL import Image
|
|
|
9 |
import gradio as gr
|
10 |
from dotenv import load_dotenv
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
from google.cloud import aiplatform
|
14 |
import vertexai
|
|
|
15 |
from vertexai.preview.vision_models import ImageGenerationModel
|
16 |
from vertexai import preview
|
17 |
import uuid #for generating unique filenames
|
|
|
25 |
|
26 |
load_dotenv()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
service_account_json = pybase64.b64decode(os.getenv("IMAGEN"))
|
29 |
service_account_info = json.loads(service_account_json)
|
30 |
credentials = service_account.Credentials.from_service_account_info(service_account_info)
|
|
|
31 |
project="pdr-imagen"
|
32 |
aiplatform.init(project=project, credentials=credentials)
|
33 |
|
34 |
+
def generate_image(prompt,model_name):
|
35 |
+
try:
|
36 |
+
model = ImageGenerationModel.from_pretrained(model_name)
|
37 |
+
response = model.generate_images(
|
38 |
+
prompt=prompt,
|
39 |
+
number_of_images=1,
|
40 |
+
)
|
41 |
|
42 |
+
image_bytes = response[0]._image_bytes
|
43 |
+
image_url = Image.open(BytesIO(image_bytes))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
except Exception as e:
|
46 |
+
print(e)
|
47 |
+
raise gr.Error(f"An error occurred while generating the image for: {entry}")
|
48 |
+
return image_url
|
49 |
|
50 |
+
with gr.Blocks() as demo:
|
|
|
|
|
51 |
|
52 |
+
gr.Markdown("# <center>Google Vertex Imagen Generator</center>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
#instructions
|
55 |
with gr.Accordion("Instructions & Tips",label="instructions",open=False):
|
56 |
with gr.Row():
|
|
|
57 |
gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ")
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
#prompts
|
60 |
+
with gr.Accordion("Prompt",label="Prompt",open=True):
|
61 |
text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
|
62 |
|
63 |
model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
|
64 |
|
65 |
with gr.Row():
|
66 |
+
btn = gr.Button("Generate Images")
|
|
|
67 |
|
68 |
#output
|
69 |
+
with gr.Accordion("Image Output",label="Image Output",open=True):
|
70 |
+
output_image = gr.Image(label="Image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
btn.click(fn=generate_image, inputs=[text, model ], outputs=output_image, api_name=False)
|
73 |
+
text.submit(fn=generate_image, inputs=[text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF
|
74 |
|
75 |
demo.launch(share=False)
|
env-sample
CHANGED
@@ -1,5 +1 @@
|
|
1 |
-
OPENAI_API_KEY = <YOUR_OPENAI_API_KEY>
|
2 |
-
PW = <YOUR_PW>
|
3 |
-
MONGO_URI=<YOUR_MONGO_URI>
|
4 |
-
MODE=dev
|
5 |
GOOGLE_APPLICATION_CREDENTIALS= <YOUR_GOOGLE_APPLICATION_CREDENTIALS>
|
|
|
|
|
|
|
|
|
|
|
1 |
GOOGLE_APPLICATION_CREDENTIALS= <YOUR_GOOGLE_APPLICATION_CREDENTIALS>
|