Dokdo / app (28).py
ginipick's picture
Upload 2 files
5e60b44 verified
raw
history blame
14.6 kB
import spaces
import argparse
import os
import time
from os import path
import shutil
from datetime import datetime
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from PIL import Image
from transformers import pipeline
import replicate
import logging
import requests
from pathlib import Path
import cv2
import numpy as np
import sys
import io
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Setup and initialization code
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
# API 설정
CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
REPLICATE_API_TOKEN = os.getenv("API_KEY")
# 환경 변수 설정
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
# CUDA 설정
torch.backends.cuda.matmul.allow_tf32 = True
# 번역기 초기화
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
def check_api_key():
"""API 키 확인 및 설정"""
if not REPLICATE_API_TOKEN:
logger.error("Replicate API key not found")
return False
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
logger.info("Replicate API token set successfully")
return True
def translate_if_korean(text):
"""한글이 포함된 경우 영어로 번역"""
if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
translation = translator(text)[0]['translation_text']
return translation
return text
def filter_prompt(prompt):
inappropriate_keywords = [
"nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
"erotic", "sensual", "seductive", "provocative", "intimate",
"violence", "gore", "blood", "death", "kill", "murder", "torture",
"drug", "suicide", "abuse", "hate", "discrimination"
]
prompt_lower = prompt.lower()
for keyword in inappropriate_keywords:
if keyword in prompt_lower:
return False, "부적절한 내용이 포함된 프롬프트입니다."
return True, prompt
def process_prompt(prompt):
"""프롬프트 전처리 (번역 및 필터링)"""
translated_prompt = translate_if_korean(prompt)
is_safe, filtered_prompt = filter_prompt(translated_prompt)
return is_safe, filtered_prompt
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
# Model initialization
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
def upload_to_catbox(image_path):
"""catbox.moe API를 사용하여 이미지 업로드"""
try:
logger.info(f"Preparing to upload image: {image_path}")
url = "https://catbox.moe/user/api.php"
file_extension = Path(image_path).suffix.lower()
if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']:
logger.error(f"Unsupported file type: {file_extension}")
return None
files = {
'fileToUpload': (
os.path.basename(image_path),
open(image_path, 'rb'),
'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png'
)
}
data = {
'reqtype': 'fileupload',
'userhash': CATBOX_USER_HASH
}
response = requests.post(url, files=files, data=data)
if response.status_code == 200 and response.text.startswith('http'):
image_url = response.text
logger.info(f"Image uploaded successfully: {image_url}")
return image_url
else:
raise Exception(f"Upload failed: {response.text}")
except Exception as e:
logger.error(f"Image upload error: {str(e)}")
return None
def add_watermark(video_path):
"""OpenCV를 사용하여 비디오에 워터마크 추가"""
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
text = "GiniGEN.AI"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = height * 0.05 / 30
thickness = 2
color = (255, 255, 255)
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
margin = int(height * 0.02)
x_pos = width - text_width - margin
y_pos = height - margin
output_path = "watermarked_output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
out.write(frame)
cap.release()
out.release()
return output_path
except Exception as e:
logger.error(f"Error adding watermark: {str(e)}")
return video_path
def generate_video(image, prompt):
logger.info("Starting video generation")
try:
if not check_api_key():
return "Replicate API key not properly configured"
if not image:
logger.error("No image provided")
return "Please upload an image"
image_url = upload_to_catbox(image)
if not image_url:
return "Failed to upload image"
input_data = {
"prompt": prompt,
"first_frame_image": image_url
}
try:
replicate.Client(api_token=REPLICATE_API_TOKEN)
output = replicate.run(
"minimax/video-01-live",
input=input_data
)
temp_file = "temp_output.mp4"
if hasattr(output, 'read'):
with open(temp_file, "wb") as file:
file.write(output.read())
elif isinstance(output, str):
response = requests.get(output)
with open(temp_file, "wb") as file:
file.write(response.content)
final_video = add_watermark(temp_file)
return final_video
except Exception as api_error:
logger.error(f"API call failed: {str(api_error)}")
return f"API call failed: {str(api_error)}"
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return f"Unexpected error: {str(e)}"
def save_image(image):
"""Save the generated image temporarily"""
try:
# 임시 디렉토리에 저장
temp_dir = "temp"
if not os.path.exists(temp_dir):
os.makedirs(temp_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(temp_dir, f"temp_{timestamp}.png")
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(filepath, format='PNG', optimize=True, quality=100)
return filepath
except Exception as e:
logger.error(f"Error in save_image: {str(e)}")
return None
css = """
footer {
visibility: hidden;
}
"""
# Gradio 인터페이스 생성
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
gr.HTML('<div class="title">AI Image & Video Generator</div>')
with gr.Tabs():
with gr.Tab("Image Generation"):
with gr.Row():
with gr.Column(scale=3):
img_prompt = gr.Textbox(
label="Image Description",
placeholder="이미지 설명을 입력하세요... (한글 입력 가능)",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
height = gr.Slider(
label="Height",
minimum=256,
maximum=1152,
step=64,
value=1024
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1152,
step=64,
value=1024
)
with gr.Row():
steps = gr.Slider(
label="Inference Steps",
minimum=6,
maximum=25,
step=1,
value=8
)
scales = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=5.0,
step=0.1,
value=3.5
)
def get_random_seed():
return torch.randint(0, 1000000, (1,)).item()
seed = gr.Number(
label="Seed",
value=get_random_seed(),
precision=0
)
randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"])
generate_btn = gr.Button(
"✨ Generate Image",
elem_classes=["generate-btn"]
)
with gr.Column(scale=4):
img_output = gr.Image(
label="Generated Image",
type="pil",
format="png"
)
with gr.Tab("Amazing Video Generation"):
with gr.Row():
with gr.Column(scale=3):
video_prompt = gr.Textbox(
label="Video Description",
placeholder="비디오 설명을 입력하세요... (한글 입력 가능)",
lines=3
)
upload_image = gr.Image(
type="filepath",
label="Upload First Frame Image"
)
video_generate_btn = gr.Button(
"🎬 Generate Video",
elem_classes=["generate-btn"]
)
with gr.Column(scale=4):
video_output = gr.Video(label="Generated Video")
@spaces.GPU
def process_and_save_image(height, width, steps, scales, prompt, seed):
is_safe, translated_prompt = process_prompt(prompt)
if not is_safe:
gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
return None
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
try:
generated_image = pipe(
prompt=[translated_prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
if not isinstance(generated_image, Image.Image):
generated_image = Image.fromarray(generated_image)
if generated_image.mode != 'RGB':
generated_image = generated_image.convert('RGB')
img_byte_arr = io.BytesIO()
generated_image.save(img_byte_arr, format='PNG')
return Image.open(io.BytesIO(img_byte_arr.getvalue()))
except Exception as e:
logger.error(f"Error in image generation: {str(e)}")
return None
def process_and_generate_video(image, prompt):
is_safe, translated_prompt = process_prompt(prompt)
if not is_safe:
gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
return None
return generate_video(image, translated_prompt)
def update_seed():
return get_random_seed()
generate_btn.click(
process_and_save_image,
inputs=[height, width, steps, scales, img_prompt, seed],
outputs=img_output
)
video_generate_btn.click(
process_and_generate_video,
inputs=[upload_image, video_prompt],
outputs=video_output
)
randomize_seed.click(
update_seed,
outputs=[seed]
)
generate_btn.click(
update_seed,
outputs=[seed]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)