Anupam251272's picture
Update app.py
f95f251 verified
import torch
import gradio as gr
import logging
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline
from PIL import Image, ImageDraw
# Logging Configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class T4OptimizedStorylineGenerator:
def __init__(self):
"""
Optimized initialization for both CPU and GPU environments
"""
try:
# Model Selection
text_model_name = "distilgpt2" # Lighter GPT-2 model
image_model_name = "runwayml/stable-diffusion-v1-5" # Stable Diffusion model
# Device Configuration
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
logger.info(f"πŸ–₯️ Device: {self.device}")
logger.info(f"πŸ“Š Precision: {self.dtype}")
# Text Generation Model
logger.info("πŸ”„ Loading Text Model")
self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)
self.text_model = AutoModelForCausalLM.from_pretrained(
text_model_name,
torch_dtype=self.dtype
).to(self.device) # Removed device_map to avoid accelerate dependency for CPU
# Image Generation Pipeline
logger.info("πŸ–ŒοΈ Loading Image Generation Pipeline")
self.image_pipeline = StableDiffusionPipeline.from_pretrained(
image_model_name,
torch_dtype=self.dtype
)
if self.device == "cuda":
self.image_pipeline.enable_attention_slicing() # Memory optimization for GPU
self.image_pipeline = self.image_pipeline.to(self.device)
except Exception as e:
logger.error(f"Initialization Error: {e}")
logger.error(traceback.format_exc())
raise
def generate_story_text(self, topic, plot_points, max_length=200):
"""
Optimized text generation for T4 with reduced complexity
"""
try:
prompt = f"Write a short story about {topic}. Scenes: {', '.join(plot_points)}"
# Tokenization with memory efficiency
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=100
).to(self.device)
# Generation with reduced complexity
with torch.no_grad():
outputs = self.text_model.generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_k=50,
top_p=0.95
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
logger.error(f"Text Generation Error: {e}")
return f"Story generation failed: {e}"
def generate_scene_images(self, scene_prompts, num_images=3):
"""
T4-Optimized image generation with reduced steps and memory usage
"""
scene_images = []
try:
for prompt in scene_prompts[:num_images]:
with torch.inference_mode():
# Reduced inference steps for T4
image = self.image_pipeline(
prompt,
num_inference_steps=25, # Reduced from standard 50
guidance_scale=6.0,
height=512, # Standard resolution
width=512
).images[0]
scene_images.append(image)
except Exception as e:
logger.error(f"Image Generation Error: {e}")
# Fallback error image generation
scene_images.append(self._create_error_image())
return scene_images
def _create_error_image(self, size=(600, 400)):
"""
Create a standardized error visualization image
"""
img = Image.new('RGB', size, color=(200, 50, 50))
draw = ImageDraw.Draw(img)
# Simple error message
draw.text(
(50, 180),
"Image Generation Failed",
fill=(255, 255, 255)
)
return img
def generate_visual_storyline(self, topic, plot_points, scene_prompts):
"""
Comprehensive storyline generation with T4 optimization
"""
try:
# Generate story text
story_text = self.generate_story_text(topic, plot_points)
# Generate scene images with T4 constraints
scene_images = self.generate_scene_images(scene_prompts)
return story_text, scene_images
except Exception as e:
logger.error(f"Visual Storyline Generation Error: {e}")
error_text = f"Storyline generation failed: {e}"
error_image = self._create_error_image()
return error_text, [error_image]
def create_gradio_interface():
"""
Create a user-friendly Gradio interface with T4 optimizations
"""
generator = T4OptimizedStorylineGenerator()
def storyline_wrapper(topic, plot_points, scene_prompts):
# Split input strings
plot_points_list = [p.strip() for p in plot_points.split(',')]
scene_prompts_list = [p.strip() for p in scene_prompts.split(',')]
return generator.generate_visual_storyline(
topic,
plot_points_list,
scene_prompts_list
)
interface = gr.Interface(
fn=storyline_wrapper,
inputs=[
gr.Textbox(label="Story Topic", placeholder="A magical adventure"),
gr.Textbox(label="Plot Points", placeholder="Enter forest, Meet creatures, Find treasure"),
gr.Textbox(label="Scene Prompts", placeholder="Misty enchanted forest, Magical creatures gathering, Hidden treasure cave")
],
outputs=[
gr.Textbox(label="Generated Story"),
gr.Gallery(label="Scene Images")
],
title="🌈 T4-Optimized AI Storyline Generator",
description="Create magical stories with AI-powered text and image generation"
)
return interface
def main():
"""Main execution with error handling"""
try:
# Launch Gradio interface
interface = create_gradio_interface()
interface.launch(debug=True)
except Exception as e:
logger.critical(f"Critical Failure: {e}")
logger.critical(traceback.format_exc())
if __name__ == "__main__":
main()
# T4 GPU Optimization Notes
"""
Optimization Strategies:
1. Use distilgpt2 (lighter model)
2. Reduced inference steps (25 vs 50)
3. float16 precision when using GPU
4. Attention slicing when using GPU
5. Smaller image resolutions
"""