File size: 7,634 Bytes
c59606b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
from dotenv import load_dotenv
import streamlit as st
import requests
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import easyocr
import numpy as np
import cv2

# Load environment variables
load_dotenv()

# Set up logging
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Hugging Face API setup
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"

HF_TOKEN = os.getenv("HF_TOKEN")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}

# Initialize EasyOCR reader
reader = easyocr.Reader(['en'])

def query(payload):
    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()
        
        logger.debug(f"API response status code: {response.status_code}")
        logger.debug(f"API response headers: {response.headers}")
        
        content_type = response.headers.get('Content-Type', '')
        if 'application/json' in content_type:
            return response.json()
        elif 'image' in content_type:
            return response.content
        else:
            logger.error(f"Unexpected content type: {content_type}")
            st.error(f"Unexpected content type: {content_type}")
            return None
    except requests.exceptions.RequestException as e:
        logger.error(f"Request failed: {str(e)}")
        st.error(f"Request failed: {str(e)}")
        return None

def increase_image_quality(image, scale_factor):
    width, height = image.size
    new_size = (width * scale_factor, height * scale_factor)
    return image.resize(new_size, Image.LANCZOS)

def extract_text_from_image(image):
    img_array = np.array(image)
    results = reader.readtext(img_array)
    return ' '.join([result[1] for result in results])

def remove_text_from_image(image, text_to_remove):
    img_array = np.array(image)
    results = reader.readtext(img_array)
    
    for (bbox, text, prob) in results:
        if text_to_remove.lower() in text.lower():
            top_left = tuple(map(int, bbox[0]))
            bottom_right = tuple(map(int, bbox[2]))
            
            # Convert image to OpenCV format
            img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
            
            # Create a mask for inpainting
            mask = np.zeros(img_cv.shape[:2], dtype=np.uint8)
            cv2.rectangle(mask, top_left, bottom_right, (255, 255, 255), -1)
            
            # Perform inpainting
            inpainted = cv2.inpaint(img_cv, mask, 3, cv2.INPAINT_TELEA)
            
            # Convert back to PIL Image
            image = Image.fromarray(cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB))
            
            return image, top_left, (bottom_right[0] - top_left[0], bottom_right[1] - top_left[1])
    
    logger.warning(f"Text '{text_to_remove}' not found in the image.")
    return image, None, None

def add_text_to_image(image, text, font_size=40, font_color="#FFFFFF", position=None, size=None):
    draw = ImageDraw.Draw(image)
    try:
        font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
    except IOError:
        logger.warning("Roboto-Bold font not found, using default font")
        font = ImageFont.load_default()

    img_width, img_height = image.size
    if position is None or size is None:
        # Calculate the center position if no position is provided
        bbox = font.getbbox(text)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        position = ((img_width - text_width) // 2, (img_height - text_height) // 2)
        size = (text_width, text_height)
    
    # Adjust font size to fit within the given size
    while font.getbbox(text)[2] - font.getbbox(text)[0] > size[0] or font.getbbox(text)[3] - font.getbbox(text)[1] > size[1]:
        font_size -= 1
        font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
    
    # Use the exact position of the removed text
    logger.debug(f"Adding text at position: {position}")
    draw.text(position, text, font=font, fill=font_color)
    return image

def main():
    st.title("Poster Generator and Editor")

    # Image Generation
    st.header("Generate Poster")
    poster_type = st.selectbox("Poster Type", ["Fashion", "Movie", "Event", "Advertisement", "Other"])
    prompt = st.text_area("Prompt")
    num_images = st.number_input("Number of Images", min_value=1, max_value=5, value=1)
    quality_factor = st.number_input("Quality Factor", min_value=1, max_value=4, value=1)

    if st.button("Generate Images"):
        if poster_type == "Other":
            full_prompt = f"A colorful poster with the following elements: {prompt}"
        else:
            full_prompt = f"A colorful {poster_type.lower()} poster with the following elements: {prompt}"

        generated_images = []
        for i in range(num_images):
            with st.spinner(f"Generating image {i+1}..."):
                logger.info(f"Generating image {i+1} with prompt: {full_prompt}")
                response = query({"inputs": full_prompt})
                
                if isinstance(response, bytes):
                    image = Image.open(io.BytesIO(response))
                    if quality_factor > 1:
                        image = increase_image_quality(image, quality_factor)
                    generated_images.append(image)
                else:
                    st.error("Failed to generate image")

        # Display generated images
        for i, img in enumerate(generated_images):
            st.image(img, caption=f"Generated Poster {i+1}", use_column_width=True)
            
            # Save image to session state for editing
            img_byte_arr = io.BytesIO()
            img.save(img_byte_arr, format='PNG')
            img_byte_arr = img_byte_arr.getvalue()
            st.session_state[f'image_{i}'] = img_byte_arr

    # Image Editing
    st.header("Edit Poster")
    image_to_edit = st.selectbox("Select Image to Edit", [f"Generated Poster {i+1}" for i in range(len(st.session_state.keys()))])
    
    if image_to_edit:
        image_index = int(image_to_edit.split()[-1]) - 1
        img_bytes = st.session_state[f'image_{image_index}']
        img = Image.open(io.BytesIO(img_bytes))
        st.image(img, caption="Current Image", use_column_width=True)

        text_to_remove = st.text_input("Text to Remove")
        new_text = st.text_input("New Text")
        font_size = st.number_input("Font Size", min_value=1, max_value=100, value=40)
        font_color = st.color_picker("Font Color", "#FFFFFF")

        if st.button("Apply Changes"):
            position = None
            size = None
            if text_to_remove:
                img, position, size = remove_text_from_image(img, text_to_remove)

            if new_text:
                img = add_text_to_image(img, new_text, font_size, font_color, position, size)

            st.image(img, caption="Edited Image", use_column_width=True)
            
            # Save edited image for download
            img_byte_arr = io.BytesIO()
            img.save(img_byte_arr, format='PNG')
            img_byte_arr = img_byte_arr.getvalue()
            st.download_button(
                label="Download Edited Image",
                data=img_byte_arr,
                file_name="edited_poster.png",
                mime="image/png"
            )

if __name__ == "__main__":
    main()