File size: 1,565 Bytes
b27043c
 
 
 
 
8eabd54
b27043c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import torch

model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cuda").eval()
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")

@spaces.GPU
def create_captions_rich(image):   
    prompt = "caption en"
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
    input_len = model_inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    return decoded

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  gr.HTML("<h1><center>PaliGemma Fine-tuned for Long Captioning for Stable Diffusion 3.<center><h1>")
  with gr.Tab(label="PaliGemma Rich Captions"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Picture")
        submit_btn = gr.Button(value="Submit")
      output = gr.Text(label="Caption")

    gr.Examples(
    [["assets/image1.png"], ["assets/image2.PNG"], ["assets/image3.jpg"]],
    inputs = [input_img],
    outputs = [output],
    fn=create_captions_rich,
    label='Try captioning on examples'
    )

    submit_btn.click(create_captions_rich, [input_img], [output])

demo.launch(debug=True)