heaversm commited on
Commit
7f3e695
1 Parent(s): 9222bcf

vertex-imagen simple generation

Browse files
Files changed (2) hide show
  1. app.py +21 -224
  2. env-sample +0 -4
app.py CHANGED
@@ -1,12 +1,6 @@
1
  # file stuff
2
  import os
3
- import sys
4
- import zipfile
5
- import requests
6
- import tempfile
7
  from io import BytesIO
8
- import random
9
- import string
10
 
11
  #image generation stuff
12
  from PIL import Image
@@ -15,18 +9,9 @@ from PIL import Image
15
  import gradio as gr
16
  from dotenv import load_dotenv
17
 
18
- # stats stuff
19
- from pymongo.mongo_client import MongoClient
20
- from pymongo.server_api import ServerApi
21
- import time
22
-
23
- # countdown stuff
24
- from datetime import datetime, timedelta
25
-
26
 
27
  from google.cloud import aiplatform
28
  import vertexai
29
- # from vertexai.preview.generative_models import GenerativeModel
30
  from vertexai.preview.vision_models import ImageGenerationModel
31
  from vertexai import preview
32
  import uuid #for generating unique filenames
@@ -40,239 +25,51 @@ import google.auth
40
 
41
  load_dotenv()
42
 
43
-
44
- pw_key = os.getenv("PW")
45
-
46
- if pw_key == "<YOUR_PW>":
47
- pw_key = ""
48
-
49
- if pw_key == "":
50
- sys.exit("Please Provide A Password in the Environment Variables")
51
-
52
-
53
- # Connect to MongoDB
54
- uri = os.getenv("MONGO_URI")
55
- mongo_client = MongoClient(uri, server_api=ServerApi('1'))
56
-
57
- mongo_db = mongo_client.pdr
58
- mongo_collection = mongo_db["images"]
59
-
60
- image_labels_global = []
61
- image_paths_global = []
62
-
63
- #load challenges
64
- challenges = []
65
- with open('challenges.txt', 'r') as file:
66
- for line in file:
67
- challenges.append(line.strip())
68
-
69
- #get GCP credentials
70
- def get_credentials():
71
- creds_json_str = os.getenv("IMAGEN") # get json credentials stored as a string
72
- # create a temporary file
73
- with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp:
74
- temp.write(creds_json_str) # write in json format
75
- temp_filename = temp.name
76
- return temp_filename
77
- def get_creds_json():
78
- creds_json_str = os.getenv("IMAGEN")
79
- return json.loads(creds_json_str)
80
-
81
- # pass
82
- # service_acct_json = open(pdr-imagen-encoded.json')
83
  service_account_json = pybase64.b64decode(os.getenv("IMAGEN"))
84
  service_account_info = json.loads(service_account_json)
85
  credentials = service_account.Credentials.from_service_account_info(service_account_info)
86
- #os.environ["GOOGLE_APPLICATION_CREDENTIALS"]= get_creds_json()
87
  project="pdr-imagen"
88
  aiplatform.init(project=project, credentials=credentials)
89
 
 
 
 
 
 
 
 
90
 
91
- # pick a random challenge
92
- def get_challenge():
93
- global challenge
94
- challenge = random.choice(challenges)
95
- return challenge
96
-
97
- # set initial challenge
98
- challenge = get_challenge()
99
-
100
- def update_labels(show_labels):
101
- updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
102
- return updated_gallery
103
-
104
- def generate_images_wrapper(prompts, pw, show_labels,model):
105
- global image_paths_global, image_labels_global
106
- image_paths, image_labels = generate_images(prompts, pw,model)
107
- image_paths_global = image_paths
108
-
109
- # store this as a global so we can handle toggle state
110
- image_labels_global = image_labels
111
- image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
112
-
113
- return image_data
114
-
115
- def download_image(url):
116
- response = requests.get(url)
117
- if response.status_code == 200:
118
- return response.content
119
- else:
120
- raise Exception(f"Failed to download image from URL: {url}")
121
-
122
- def zip_images(image_paths_and_labels):
123
- zip_file_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip').name
124
- with zipfile.ZipFile(zip_file_path, 'w') as zipf:
125
- for image_url, _ in image_paths_and_labels:
126
- # image_content = download_image(image_url)
127
- image_content = open(image_url, "rb").read()
128
- random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
129
- # Write the image content to the zip file with the random filename
130
- zipf.writestr(image_url, image_content)
131
- return zip_file_path
132
-
133
-
134
- def download_all_images():
135
- global image_paths_global, image_labels_global
136
- if not image_paths_global:
137
- raise gr.Error("No images to download.")
138
- image_paths_and_labels = list(zip(image_paths_global, image_labels_global))
139
- zip_path = zip_images(image_paths_and_labels)
140
- image_paths_global = [] # Reset the global variable
141
- image_labels_global = [] # Reset the global variable
142
-
143
- # delete all local images
144
- for image_path, _ in image_paths_and_labels:
145
- os.remove(image_path)
146
-
147
-
148
- return zip_path
149
-
150
- def generate_images(prompts, pw,model_name):
151
- # Check for a valid password
152
 
153
- if pw != os.getenv("PW"):
154
- raise gr.Error("Invalid password. Please try again.")
 
 
155
 
156
- image_paths = [] # holds urls of images
157
- image_labels = [] # shows the prompt in the gallery above the image
158
- users = [] # adds the user to the label
159
 
160
- # Split the prompts string into individual prompts based on semicolon separation
161
- prompts_list = [prompt for prompt in prompts.split(';') if prompt]
162
-
163
- # model = "claude-3-opus-20240229"
164
-
165
- for i, entry in enumerate(prompts_list):
166
- entry_parts = entry.split('-', 1) # Split by the first dash found
167
- if len(entry_parts) == 2:
168
- #raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
169
- user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
170
- else:
171
- text = entry.strip() # If no initials are provided, use the entire prompt as the text
172
- user_initials = ""
173
-
174
- users.append(user_initials) # Append user initials to the list
175
-
176
- prompt_w_challenge = f"{challenge}: {text}"
177
- print(prompt_w_challenge)
178
-
179
- start_time = time.time()
180
-
181
- try:
182
- #what model to use?
183
- model = ImageGenerationModel.from_pretrained(model_name)
184
- response = model.generate_images(
185
- prompt=prompt_w_challenge,
186
- number_of_images=1,
187
- )
188
-
189
- end_time = time.time()
190
- gen_time = end_time - start_time # total generation time
191
-
192
- image_bytes = response[0]._image_bytes
193
- image_url = Image.open(BytesIO(image_bytes))
194
-
195
- #generate random filename using uuid
196
- #filename = f"{uuid.uuid4()}.png"
197
-
198
- # Save the image to a temporary file, and return this
199
- #image_url = filename
200
- #response[0].save(filename)
201
- image_label = f"{i+1}: {text}"
202
-
203
- model_for_db = f"imagen-{model_name}"
204
-
205
- try:
206
- # Save the prompt, model, image URL, generation time and creation timestamp to the database
207
- mongo_collection.insert_one({"user": user_initials, "text": text, "model": model_for_db, "image_url": "bytes", "gen_time": gen_time, "timestamp": time.time(), "challenge": challenge})
208
- except Exception as e:
209
- print(e)
210
- raise gr.Error("An error occurred while saving the prompt to the database.")
211
-
212
- # Append the image URL and label to their respective lists
213
- image_paths.append(image_url)
214
- image_labels.append(image_label)
215
- except Exception as e:
216
- print(e)
217
- raise gr.Error(f"An error occurred while generating the image for: {entry}")
218
- return image_paths, image_labels
219
-
220
- #custom css
221
- css = """
222
- #gallery-images .caption-label {
223
- white-space: normal !important;
224
- }
225
- """
226
-
227
-
228
- with gr.Blocks(css=css) as demo:
229
-
230
- gr.Markdown("# <center>Prompt de Resistance Vertex Imagen</center>")
231
-
232
- pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service")
233
 
234
  #instructions
235
  with gr.Accordion("Instructions & Tips",label="instructions",open=False):
236
  with gr.Row():
237
- gr.Markdown("**Instructions**: To use this service, please enter the password. Then generate an image from the prompt field below in response to the challenge, then click the download arrow from the top right of the image to save it.")
238
  gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ")
239
 
240
- #challenge
241
- challenge_display = gr.Textbox(label="Challenge", value=get_challenge())
242
- challenge_display.disabled = True
243
- regenerate_btn = gr.Button("New Challenge")
244
-
245
-
246
  #prompts
247
- with gr.Accordion("Prompts",label="Prompts",open=True):
248
  text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
249
 
250
  model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
251
 
252
  with gr.Row():
253
- btn = gr.Button("Generate Images")
254
-
255
 
256
  #output
257
- with gr.Accordion("Image Outputs",label="Image Outputs",open=True):
258
- output_images = gr.Gallery(label="Image Outputs", elem_id="gallery-images", show_label=True, columns=[3], rows=[1], object_fit="contain", height="auto", allow_preview=False)
259
- show_labels = gr.Checkbox(label="Show Labels", value=False)
260
-
261
-
262
- with gr.Accordion("Downloads",label="download",open=True):
263
- download_all_btn = gr.Button("Download All")
264
- download_link = gr.File(label="Download Zip")
265
-
266
- # generate new challenge
267
- regenerate_btn.click(fn=get_challenge, inputs=[], outputs=[challenge_display])
268
-
269
- #submissions
270
- #trigger generation either through hitting enter in the text field, or clicking the button.
271
- btn.click(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model ], outputs=output_images, api_name=False)
272
- text.submit(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
273
- show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
274
 
275
- #downloads
276
- download_all_btn.click(fn=download_all_images, inputs=[], outputs=download_link)
277
 
278
  demo.launch(share=False)
 
1
  # file stuff
2
  import os
 
 
 
 
3
  from io import BytesIO
 
 
4
 
5
  #image generation stuff
6
  from PIL import Image
 
9
  import gradio as gr
10
  from dotenv import load_dotenv
11
 
 
 
 
 
 
 
 
 
12
 
13
  from google.cloud import aiplatform
14
  import vertexai
 
15
  from vertexai.preview.vision_models import ImageGenerationModel
16
  from vertexai import preview
17
  import uuid #for generating unique filenames
 
25
 
26
  load_dotenv()
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  service_account_json = pybase64.b64decode(os.getenv("IMAGEN"))
29
  service_account_info = json.loads(service_account_json)
30
  credentials = service_account.Credentials.from_service_account_info(service_account_info)
 
31
  project="pdr-imagen"
32
  aiplatform.init(project=project, credentials=credentials)
33
 
34
+ def generate_image(prompt,model_name):
35
+ try:
36
+ model = ImageGenerationModel.from_pretrained(model_name)
37
+ response = model.generate_images(
38
+ prompt=prompt,
39
+ number_of_images=1,
40
+ )
41
 
42
+ image_bytes = response[0]._image_bytes
43
+ image_url = Image.open(BytesIO(image_bytes))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ except Exception as e:
46
+ print(e)
47
+ raise gr.Error(f"An error occurred while generating the image for: {entry}")
48
+ return image_url
49
 
50
+ with gr.Blocks() as demo:
 
 
51
 
52
+ gr.Markdown("# <center>Google Vertex Imagen Generator</center>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  #instructions
55
  with gr.Accordion("Instructions & Tips",label="instructions",open=False):
56
  with gr.Row():
 
57
  gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ")
58
 
 
 
 
 
 
 
59
  #prompts
60
+ with gr.Accordion("Prompt",label="Prompt",open=True):
61
  text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
62
 
63
  model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
64
 
65
  with gr.Row():
66
+ btn = gr.Button("Generate Images")
 
67
 
68
  #output
69
+ with gr.Accordion("Image Output",label="Image Output",open=True):
70
+ output_image = gr.Image(label="Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ btn.click(fn=generate_image, inputs=[text, model ], outputs=output_image, api_name=False)
73
+ text.submit(fn=generate_image, inputs=[text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF
74
 
75
  demo.launch(share=False)
env-sample CHANGED
@@ -1,5 +1 @@
1
- OPENAI_API_KEY = <YOUR_OPENAI_API_KEY>
2
- PW = <YOUR_PW>
3
- MONGO_URI=<YOUR_MONGO_URI>
4
- MODE=dev
5
  GOOGLE_APPLICATION_CREDENTIALS= <YOUR_GOOGLE_APPLICATION_CREDENTIALS>
 
 
 
 
 
1
  GOOGLE_APPLICATION_CREDENTIALS= <YOUR_GOOGLE_APPLICATION_CREDENTIALS>