Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import logging | |
import traceback | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image, ImageDraw | |
# Logging Configuration | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s') | |
logger = logging.getLogger(__name__) | |
class T4OptimizedStorylineGenerator: | |
def __init__(self): | |
""" | |
Optimized initialization for both CPU and GPU environments | |
""" | |
try: | |
# Model Selection | |
text_model_name = "distilgpt2" # Lighter GPT-2 model | |
image_model_name = "runwayml/stable-diffusion-v1-5" # Stable Diffusion model | |
# Device Configuration | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
logger.info(f"π₯οΈ Device: {self.device}") | |
logger.info(f"π Precision: {self.dtype}") | |
# Text Generation Model | |
logger.info("π Loading Text Model") | |
self.tokenizer = AutoTokenizer.from_pretrained(text_model_name) | |
self.text_model = AutoModelForCausalLM.from_pretrained( | |
text_model_name, | |
torch_dtype=self.dtype | |
).to(self.device) # Removed device_map to avoid accelerate dependency for CPU | |
# Image Generation Pipeline | |
logger.info("ποΈ Loading Image Generation Pipeline") | |
self.image_pipeline = StableDiffusionPipeline.from_pretrained( | |
image_model_name, | |
torch_dtype=self.dtype | |
) | |
if self.device == "cuda": | |
self.image_pipeline.enable_attention_slicing() # Memory optimization for GPU | |
self.image_pipeline = self.image_pipeline.to(self.device) | |
except Exception as e: | |
logger.error(f"Initialization Error: {e}") | |
logger.error(traceback.format_exc()) | |
raise | |
def generate_story_text(self, topic, plot_points, max_length=200): | |
""" | |
Optimized text generation for T4 with reduced complexity | |
""" | |
try: | |
prompt = f"Write a short story about {topic}. Scenes: {', '.join(plot_points)}" | |
# Tokenization with memory efficiency | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=100 | |
).to(self.device) | |
# Generation with reduced complexity | |
with torch.no_grad(): | |
outputs = self.text_model.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95 | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
logger.error(f"Text Generation Error: {e}") | |
return f"Story generation failed: {e}" | |
def generate_scene_images(self, scene_prompts, num_images=3): | |
""" | |
T4-Optimized image generation with reduced steps and memory usage | |
""" | |
scene_images = [] | |
try: | |
for prompt in scene_prompts[:num_images]: | |
with torch.inference_mode(): | |
# Reduced inference steps for T4 | |
image = self.image_pipeline( | |
prompt, | |
num_inference_steps=25, # Reduced from standard 50 | |
guidance_scale=6.0, | |
height=512, # Standard resolution | |
width=512 | |
).images[0] | |
scene_images.append(image) | |
except Exception as e: | |
logger.error(f"Image Generation Error: {e}") | |
# Fallback error image generation | |
scene_images.append(self._create_error_image()) | |
return scene_images | |
def _create_error_image(self, size=(600, 400)): | |
""" | |
Create a standardized error visualization image | |
""" | |
img = Image.new('RGB', size, color=(200, 50, 50)) | |
draw = ImageDraw.Draw(img) | |
# Simple error message | |
draw.text( | |
(50, 180), | |
"Image Generation Failed", | |
fill=(255, 255, 255) | |
) | |
return img | |
def generate_visual_storyline(self, topic, plot_points, scene_prompts): | |
""" | |
Comprehensive storyline generation with T4 optimization | |
""" | |
try: | |
# Generate story text | |
story_text = self.generate_story_text(topic, plot_points) | |
# Generate scene images with T4 constraints | |
scene_images = self.generate_scene_images(scene_prompts) | |
return story_text, scene_images | |
except Exception as e: | |
logger.error(f"Visual Storyline Generation Error: {e}") | |
error_text = f"Storyline generation failed: {e}" | |
error_image = self._create_error_image() | |
return error_text, [error_image] | |
def create_gradio_interface(): | |
""" | |
Create a user-friendly Gradio interface with T4 optimizations | |
""" | |
generator = T4OptimizedStorylineGenerator() | |
def storyline_wrapper(topic, plot_points, scene_prompts): | |
# Split input strings | |
plot_points_list = [p.strip() for p in plot_points.split(',')] | |
scene_prompts_list = [p.strip() for p in scene_prompts.split(',')] | |
return generator.generate_visual_storyline( | |
topic, | |
plot_points_list, | |
scene_prompts_list | |
) | |
interface = gr.Interface( | |
fn=storyline_wrapper, | |
inputs=[ | |
gr.Textbox(label="Story Topic", placeholder="A magical adventure"), | |
gr.Textbox(label="Plot Points", placeholder="Enter forest, Meet creatures, Find treasure"), | |
gr.Textbox(label="Scene Prompts", placeholder="Misty enchanted forest, Magical creatures gathering, Hidden treasure cave") | |
], | |
outputs=[ | |
gr.Textbox(label="Generated Story"), | |
gr.Gallery(label="Scene Images") | |
], | |
title="π T4-Optimized AI Storyline Generator", | |
description="Create magical stories with AI-powered text and image generation" | |
) | |
return interface | |
def main(): | |
"""Main execution with error handling""" | |
try: | |
# Launch Gradio interface | |
interface = create_gradio_interface() | |
interface.launch(debug=True) | |
except Exception as e: | |
logger.critical(f"Critical Failure: {e}") | |
logger.critical(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() | |
# T4 GPU Optimization Notes | |
""" | |
Optimization Strategies: | |
1. Use distilgpt2 (lighter model) | |
2. Reduced inference steps (25 vs 50) | |
3. float16 precision when using GPU | |
4. Attention slicing when using GPU | |
5. Smaller image resolutions | |
""" | |