Anupam251272 commited on
Commit
f95f251
β€’
1 Parent(s): f747521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -47
app.py CHANGED
@@ -1,73 +1,172 @@
1
  import torch
2
  import gradio as gr
3
- from PIL import Image, ImageDraw
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from diffusers import StableDiffusionPipeline
 
 
 
 
 
6
 
7
  class T4OptimizedStorylineGenerator:
8
  def __init__(self):
9
- # Model Selection
10
- text_model_name = "distilgpt2"
11
- image_model_name = "runwayml/stable-diffusion-v1-5"
12
-
13
- # Device Configuration
14
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
- self.dtype = torch.float16 if self.device == "cuda" else torch.float32
16
-
17
- # Load Text Generation Model
18
- self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)
19
- self.text_model = AutoModelForCausalLM.from_pretrained(
20
- text_model_name,
21
- torch_dtype=self.dtype,
22
- device_map='auto'
23
- ).to(self.device)
24
-
25
- # Load Image Generation Pipeline
26
- self.image_pipeline = StableDiffusionPipeline.from_pretrained(
27
- image_model_name,
28
- torch_dtype=self.dtype
29
- )
30
- if self.device == "cuda":
31
- self.image_pipeline.enable_attention_slicing()
32
- self.image_pipeline = self.image_pipeline.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def generate_story_text(self, topic, plot_points, max_length=200):
35
- prompt = f"Write a short story about {topic}. Scenes: {', '.join(plot_points)}"
36
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=100).to(self.device)
37
- outputs = self.text_model.generate(
38
- inputs.input_ids,
39
- max_length=max_length,
40
- num_return_sequences=1,
41
- temperature=0.7,
42
- top_k=50,
43
- top_p=0.95
44
- )
45
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def generate_scene_images(self, scene_prompts, num_images=3):
 
 
 
48
  scene_images = []
49
- for prompt in scene_prompts[:num_images]:
50
- image = self.image_pipeline(prompt, num_inference_steps=25, guidance_scale=6.0, height=512, width=512).images[0]
51
- scene_images.append(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return scene_images
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def generate_visual_storyline(self, topic, plot_points, scene_prompts):
55
- story_text = self.generate_story_text(topic, plot_points)
56
- scene_images = self.generate_scene_images(scene_prompts)
57
- return story_text, scene_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def create_gradio_interface():
 
 
 
60
  generator = T4OptimizedStorylineGenerator()
61
 
62
  def storyline_wrapper(topic, plot_points, scene_prompts):
 
63
  plot_points_list = [p.strip() for p in plot_points.split(',')]
64
  scene_prompts_list = [p.strip() for p in scene_prompts.split(',')]
65
- return generator.generate_visual_storyline(topic, plot_points_list, scene_prompts_list)
 
 
 
 
 
66
 
67
  interface = gr.Interface(
68
  fn=storyline_wrapper,
69
  inputs=[
70
- gr.Textbox(label="Story Topic"),
71
  gr.Textbox(label="Plot Points", placeholder="Enter forest, Meet creatures, Find treasure"),
72
  gr.Textbox(label="Scene Prompts", placeholder="Misty enchanted forest, Magical creatures gathering, Hidden treasure cave")
73
  ],
@@ -75,11 +174,32 @@ def create_gradio_interface():
75
  gr.Textbox(label="Generated Story"),
76
  gr.Gallery(label="Scene Images")
77
  ],
78
- title="T4-Optimized AI Storyline Generator"
 
79
  )
80
 
81
  return interface
82
 
 
 
 
 
 
 
 
 
 
 
 
83
  if __name__ == "__main__":
84
- interface = create_gradio_interface()
85
- interface.launch()
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import gradio as gr
3
+ import logging
4
+ import traceback
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from diffusers import StableDiffusionPipeline
7
+ from PIL import Image, ImageDraw
8
+
9
+ # Logging Configuration
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s')
11
+ logger = logging.getLogger(__name__)
12
 
13
  class T4OptimizedStorylineGenerator:
14
  def __init__(self):
15
+ """
16
+ Optimized initialization for both CPU and GPU environments
17
+ """
18
+ try:
19
+ # Model Selection
20
+ text_model_name = "distilgpt2" # Lighter GPT-2 model
21
+ image_model_name = "runwayml/stable-diffusion-v1-5" # Stable Diffusion model
22
+
23
+ # Device Configuration
24
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
26
+
27
+ logger.info(f"πŸ–₯️ Device: {self.device}")
28
+ logger.info(f"πŸ“Š Precision: {self.dtype}")
29
+
30
+ # Text Generation Model
31
+ logger.info("πŸ”„ Loading Text Model")
32
+ self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)
33
+ self.text_model = AutoModelForCausalLM.from_pretrained(
34
+ text_model_name,
35
+ torch_dtype=self.dtype
36
+ ).to(self.device) # Removed device_map to avoid accelerate dependency for CPU
37
+
38
+ # Image Generation Pipeline
39
+ logger.info("πŸ–ŒοΈ Loading Image Generation Pipeline")
40
+ self.image_pipeline = StableDiffusionPipeline.from_pretrained(
41
+ image_model_name,
42
+ torch_dtype=self.dtype
43
+ )
44
+
45
+ if self.device == "cuda":
46
+ self.image_pipeline.enable_attention_slicing() # Memory optimization for GPU
47
+ self.image_pipeline = self.image_pipeline.to(self.device)
48
+
49
+ except Exception as e:
50
+ logger.error(f"Initialization Error: {e}")
51
+ logger.error(traceback.format_exc())
52
+ raise
53
 
54
  def generate_story_text(self, topic, plot_points, max_length=200):
55
+ """
56
+ Optimized text generation for T4 with reduced complexity
57
+ """
58
+ try:
59
+ prompt = f"Write a short story about {topic}. Scenes: {', '.join(plot_points)}"
60
+
61
+ # Tokenization with memory efficiency
62
+ inputs = self.tokenizer(
63
+ prompt,
64
+ return_tensors="pt",
65
+ truncation=True,
66
+ max_length=100
67
+ ).to(self.device)
68
+
69
+ # Generation with reduced complexity
70
+ with torch.no_grad():
71
+ outputs = self.text_model.generate(
72
+ inputs.input_ids,
73
+ max_length=max_length,
74
+ num_return_sequences=1,
75
+ temperature=0.7,
76
+ top_k=50,
77
+ top_p=0.95
78
+ )
79
+
80
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+
82
+ except Exception as e:
83
+ logger.error(f"Text Generation Error: {e}")
84
+ return f"Story generation failed: {e}"
85
 
86
  def generate_scene_images(self, scene_prompts, num_images=3):
87
+ """
88
+ T4-Optimized image generation with reduced steps and memory usage
89
+ """
90
  scene_images = []
91
+
92
+ try:
93
+ for prompt in scene_prompts[:num_images]:
94
+ with torch.inference_mode():
95
+ # Reduced inference steps for T4
96
+ image = self.image_pipeline(
97
+ prompt,
98
+ num_inference_steps=25, # Reduced from standard 50
99
+ guidance_scale=6.0,
100
+ height=512, # Standard resolution
101
+ width=512
102
+ ).images[0]
103
+
104
+ scene_images.append(image)
105
+
106
+ except Exception as e:
107
+ logger.error(f"Image Generation Error: {e}")
108
+ # Fallback error image generation
109
+ scene_images.append(self._create_error_image())
110
+
111
  return scene_images
112
 
113
+ def _create_error_image(self, size=(600, 400)):
114
+ """
115
+ Create a standardized error visualization image
116
+ """
117
+ img = Image.new('RGB', size, color=(200, 50, 50))
118
+ draw = ImageDraw.Draw(img)
119
+
120
+ # Simple error message
121
+ draw.text(
122
+ (50, 180),
123
+ "Image Generation Failed",
124
+ fill=(255, 255, 255)
125
+ )
126
+
127
+ return img
128
+
129
  def generate_visual_storyline(self, topic, plot_points, scene_prompts):
130
+ """
131
+ Comprehensive storyline generation with T4 optimization
132
+ """
133
+ try:
134
+ # Generate story text
135
+ story_text = self.generate_story_text(topic, plot_points)
136
+
137
+ # Generate scene images with T4 constraints
138
+ scene_images = self.generate_scene_images(scene_prompts)
139
+
140
+ return story_text, scene_images
141
+
142
+ except Exception as e:
143
+ logger.error(f"Visual Storyline Generation Error: {e}")
144
+ error_text = f"Storyline generation failed: {e}"
145
+ error_image = self._create_error_image()
146
+
147
+ return error_text, [error_image]
148
 
149
  def create_gradio_interface():
150
+ """
151
+ Create a user-friendly Gradio interface with T4 optimizations
152
+ """
153
  generator = T4OptimizedStorylineGenerator()
154
 
155
  def storyline_wrapper(topic, plot_points, scene_prompts):
156
+ # Split input strings
157
  plot_points_list = [p.strip() for p in plot_points.split(',')]
158
  scene_prompts_list = [p.strip() for p in scene_prompts.split(',')]
159
+
160
+ return generator.generate_visual_storyline(
161
+ topic,
162
+ plot_points_list,
163
+ scene_prompts_list
164
+ )
165
 
166
  interface = gr.Interface(
167
  fn=storyline_wrapper,
168
  inputs=[
169
+ gr.Textbox(label="Story Topic", placeholder="A magical adventure"),
170
  gr.Textbox(label="Plot Points", placeholder="Enter forest, Meet creatures, Find treasure"),
171
  gr.Textbox(label="Scene Prompts", placeholder="Misty enchanted forest, Magical creatures gathering, Hidden treasure cave")
172
  ],
 
174
  gr.Textbox(label="Generated Story"),
175
  gr.Gallery(label="Scene Images")
176
  ],
177
+ title="🌈 T4-Optimized AI Storyline Generator",
178
+ description="Create magical stories with AI-powered text and image generation"
179
  )
180
 
181
  return interface
182
 
183
+ def main():
184
+ """Main execution with error handling"""
185
+ try:
186
+ # Launch Gradio interface
187
+ interface = create_gradio_interface()
188
+ interface.launch(debug=True)
189
+
190
+ except Exception as e:
191
+ logger.critical(f"Critical Failure: {e}")
192
+ logger.critical(traceback.format_exc())
193
+
194
  if __name__ == "__main__":
195
+ main()
196
+
197
+ # T4 GPU Optimization Notes
198
+ """
199
+ Optimization Strategies:
200
+ 1. Use distilgpt2 (lighter model)
201
+ 2. Reduced inference steps (25 vs 50)
202
+ 3. float16 precision when using GPU
203
+ 4. Attention slicing when using GPU
204
+ 5. Smaller image resolutions
205
+ """