text-to-image / app.py
Pratap2002's picture
Upload app.py
c59606b verified
raw
history blame
No virus
7.63 kB
import os
from dotenv import load_dotenv
import streamlit as st
import requests
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import easyocr
import numpy as np
import cv2
# Load environment variables
load_dotenv()
# Set up logging
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Hugging Face API setup
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
HF_TOKEN = os.getenv("HF_TOKEN")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
# Initialize EasyOCR reader
reader = easyocr.Reader(['en'])
def query(payload):
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status()
logger.debug(f"API response status code: {response.status_code}")
logger.debug(f"API response headers: {response.headers}")
content_type = response.headers.get('Content-Type', '')
if 'application/json' in content_type:
return response.json()
elif 'image' in content_type:
return response.content
else:
logger.error(f"Unexpected content type: {content_type}")
st.error(f"Unexpected content type: {content_type}")
return None
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {str(e)}")
st.error(f"Request failed: {str(e)}")
return None
def increase_image_quality(image, scale_factor):
width, height = image.size
new_size = (width * scale_factor, height * scale_factor)
return image.resize(new_size, Image.LANCZOS)
def extract_text_from_image(image):
img_array = np.array(image)
results = reader.readtext(img_array)
return ' '.join([result[1] for result in results])
def remove_text_from_image(image, text_to_remove):
img_array = np.array(image)
results = reader.readtext(img_array)
for (bbox, text, prob) in results:
if text_to_remove.lower() in text.lower():
top_left = tuple(map(int, bbox[0]))
bottom_right = tuple(map(int, bbox[2]))
# Convert image to OpenCV format
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Create a mask for inpainting
mask = np.zeros(img_cv.shape[:2], dtype=np.uint8)
cv2.rectangle(mask, top_left, bottom_right, (255, 255, 255), -1)
# Perform inpainting
inpainted = cv2.inpaint(img_cv, mask, 3, cv2.INPAINT_TELEA)
# Convert back to PIL Image
image = Image.fromarray(cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB))
return image, top_left, (bottom_right[0] - top_left[0], bottom_right[1] - top_left[1])
logger.warning(f"Text '{text_to_remove}' not found in the image.")
return image, None, None
def add_text_to_image(image, text, font_size=40, font_color="#FFFFFF", position=None, size=None):
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
except IOError:
logger.warning("Roboto-Bold font not found, using default font")
font = ImageFont.load_default()
img_width, img_height = image.size
if position is None or size is None:
# Calculate the center position if no position is provided
bbox = font.getbbox(text)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
position = ((img_width - text_width) // 2, (img_height - text_height) // 2)
size = (text_width, text_height)
# Adjust font size to fit within the given size
while font.getbbox(text)[2] - font.getbbox(text)[0] > size[0] or font.getbbox(text)[3] - font.getbbox(text)[1] > size[1]:
font_size -= 1
font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
# Use the exact position of the removed text
logger.debug(f"Adding text at position: {position}")
draw.text(position, text, font=font, fill=font_color)
return image
def main():
st.title("Poster Generator and Editor")
# Image Generation
st.header("Generate Poster")
poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"])
prompt = st.text_area("Prompt")
num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1)
quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1)
if st.button("Generate Images"):
if poster_type == "Other":
full_prompt = f"A colorful poster with the following elements: {prompt}"
else:
full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}"
generated_images = []
for i in range(num_images):
with st.spinner(f"Generating image {i+1}..."):
logger.info(f"Generating image {i+1} with prompt: {full_prompt}")
response = query({"inputs": full_prompt})
if isinstance(response, bytes):
image = Image.open(io.BytesIO(response))
if quality_factor > 1:
image = increase_image_quality(image, quality_factor)
generated_images.append(image)
else:
st.error("Failed to generate image")
# Display generated images
for i, img in enumerate(generated_images):
st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True)
# Save image to session state for editing
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
st.session_state[f'image_{i}'] = img_byte_arr
# Image Editing
st.header("Edit Poster")
image_to_edit = st.selectbox("Select Image to Edit", [f"Generated Poster {i+1}" for i in range(len(st.session_state.keys()))])
if image_to_edit:
image_index = int(image_to_edit.split()[-1]) - 1
img_bytes = st.session_state[f'image_{image_index}']
img = Image.open(io.BytesIO(img_bytes))
st.image(img, caption="Current Image", use_column_width=True)
text_to_remove = st.text_input("Text to Remove")
new_text = st.text_input("New Text")
font_size = st.number_input("Font Size", min_value=1, max_value=100, value=40)
font_color = st.color_picker("Font Color", "#FFFFFF")
if st.button("Apply Changes"):
position = None
size = None
if text_to_remove:
img, position, size = remove_text_from_image(img, text_to_remove)
if new_text:
img = add_text_to_image(img, new_text, font_size, font_color, position, size)
st.image(img, caption="Edited Image", use_column_width=True)
# Save edited image for download
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
st.download_button(
label="Download Edited Image",
data=img_byte_arr,
file_name="edited_poster.png",
mime="image/png"
)
if __name__ == "__main__":
main()