import torch import gradio as gr import logging import traceback from transformers import AutoTokenizer, AutoModelForCausalLM from diffusers import StableDiffusionPipeline from PIL import Image, ImageDraw # Logging Configuration logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s') logger = logging.getLogger(__name__) class T4OptimizedStorylineGenerator: def __init__(self): """ Optimized initialization for both CPU and GPU environments """ try: # Model Selection text_model_name = "distilgpt2" # Lighter GPT-2 model image_model_name = "runwayml/stable-diffusion-v1-5" # Stable Diffusion model # Device Configuration self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 if self.device == "cuda" else torch.float32 logger.info(f"🖥️ Device: {self.device}") logger.info(f"📊 Precision: {self.dtype}") # Text Generation Model logger.info("🔄 Loading Text Model") self.tokenizer = AutoTokenizer.from_pretrained(text_model_name) self.text_model = AutoModelForCausalLM.from_pretrained( text_model_name, torch_dtype=self.dtype ).to(self.device) # Removed device_map to avoid accelerate dependency for CPU # Image Generation Pipeline logger.info("🖌️ Loading Image Generation Pipeline") self.image_pipeline = StableDiffusionPipeline.from_pretrained( image_model_name, torch_dtype=self.dtype ) if self.device == "cuda": self.image_pipeline.enable_attention_slicing() # Memory optimization for GPU self.image_pipeline = self.image_pipeline.to(self.device) except Exception as e: logger.error(f"Initialization Error: {e}") logger.error(traceback.format_exc()) raise def generate_story_text(self, topic, plot_points, max_length=200): """ Optimized text generation for T4 with reduced complexity """ try: prompt = f"Write a short story about {topic}. Scenes: {', '.join(plot_points)}" # Tokenization with memory efficiency inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=100 ).to(self.device) # Generation with reduced complexity with torch.no_grad(): outputs = self.text_model.generate( inputs.input_ids, max_length=max_length, num_return_sequences=1, temperature=0.7, top_k=50, top_p=0.95 ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: logger.error(f"Text Generation Error: {e}") return f"Story generation failed: {e}" def generate_scene_images(self, scene_prompts, num_images=3): """ T4-Optimized image generation with reduced steps and memory usage """ scene_images = [] try: for prompt in scene_prompts[:num_images]: with torch.inference_mode(): # Reduced inference steps for T4 image = self.image_pipeline( prompt, num_inference_steps=25, # Reduced from standard 50 guidance_scale=6.0, height=512, # Standard resolution width=512 ).images[0] scene_images.append(image) except Exception as e: logger.error(f"Image Generation Error: {e}") # Fallback error image generation scene_images.append(self._create_error_image()) return scene_images def _create_error_image(self, size=(600, 400)): """ Create a standardized error visualization image """ img = Image.new('RGB', size, color=(200, 50, 50)) draw = ImageDraw.Draw(img) # Simple error message draw.text( (50, 180), "Image Generation Failed", fill=(255, 255, 255) ) return img def generate_visual_storyline(self, topic, plot_points, scene_prompts): """ Comprehensive storyline generation with T4 optimization """ try: # Generate story text story_text = self.generate_story_text(topic, plot_points) # Generate scene images with T4 constraints scene_images = self.generate_scene_images(scene_prompts) return story_text, scene_images except Exception as e: logger.error(f"Visual Storyline Generation Error: {e}") error_text = f"Storyline generation failed: {e}" error_image = self._create_error_image() return error_text, [error_image] def create_gradio_interface(): """ Create a user-friendly Gradio interface with T4 optimizations """ generator = T4OptimizedStorylineGenerator() def storyline_wrapper(topic, plot_points, scene_prompts): # Split input strings plot_points_list = [p.strip() for p in plot_points.split(',')] scene_prompts_list = [p.strip() for p in scene_prompts.split(',')] return generator.generate_visual_storyline( topic, plot_points_list, scene_prompts_list ) interface = gr.Interface( fn=storyline_wrapper, inputs=[ gr.Textbox(label="Story Topic", placeholder="A magical adventure"), gr.Textbox(label="Plot Points", placeholder="Enter forest, Meet creatures, Find treasure"), gr.Textbox(label="Scene Prompts", placeholder="Misty enchanted forest, Magical creatures gathering, Hidden treasure cave") ], outputs=[ gr.Textbox(label="Generated Story"), gr.Gallery(label="Scene Images") ], title="🌈 T4-Optimized AI Storyline Generator", description="Create magical stories with AI-powered text and image generation" ) return interface def main(): """Main execution with error handling""" try: # Launch Gradio interface interface = create_gradio_interface() interface.launch(debug=True) except Exception as e: logger.critical(f"Critical Failure: {e}") logger.critical(traceback.format_exc()) if __name__ == "__main__": main() # T4 GPU Optimization Notes """ Optimization Strategies: 1. Use distilgpt2 (lighter model) 2. Reduced inference steps (25 vs 50) 3. float16 precision when using GPU 4. Attention slicing when using GPU 5. Smaller image resolutions """