gokaygokay commited on
Commit
b27043c
1 Parent(s): 287902c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
3
+ import spaces
4
+ import torch
5
+
6
+ model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner", device="cuda").eval()
7
+ processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
8
+
9
+ @spaces.GPU
10
+ def create_captions_rich(image):
11
+ prompt = "caption en"
12
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
13
+ input_len = model_inputs["input_ids"].shape[-1]
14
+
15
+ with torch.inference_mode():
16
+ generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
17
+ generation = generation[0][input_len:]
18
+ decoded = processor.decode(generation, skip_special_tokens=True)
19
+ return decoded
20
+
21
+ css = """
22
+ #mkd {
23
+ height: 500px;
24
+ overflow: auto;
25
+ border: 1px solid #ccc;
26
+ }
27
+ """
28
+
29
+ with gr.Blocks(css=css) as demo:
30
+ gr.HTML("<h1><center>PaliGemma Fine-tuned for Long Captioning for Stable Diffusion 3.<center><h1>")
31
+ with gr.Tab(label="PaliGemma Rich Captions"):
32
+ with gr.Row():
33
+ with gr.Column():
34
+ input_img = gr.Image(label="Input Picture")
35
+ submit_btn = gr.Button(value="Submit")
36
+ output = gr.Text(label="Caption")
37
+
38
+ gr.Examples(
39
+ [["assets/image1.png"], ["assets/image2.PNG"], ["assets/image3.jpg"]],
40
+ inputs = [input_img],
41
+ outputs = [output],
42
+ fn=create_captions_rich,
43
+ label='Try captioning on examples'
44
+ )
45
+
46
+ submit_btn.click(create_captions_rich, [input_img], [output])
47
+
48
+ demo.launch(debug=True)