File size: 3,051 Bytes
1673a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
from PIL import Image

from model import BlipBaseModel, GitBaseCocoModel

MODELS = {
	"Git-Base-COCO": GitBaseCocoModel,
	"Blip Base": BlipBaseModel,
}

# examples = [["Image1.png"], ["Image2.png"], ["Image3.png"]]

def generate_captions(
	image,
	num_captions,
	model_name,
	max_length,
	temperature,
	top_k,
	top_p,
	repetition_penalty,
	diversity_penalty,
	):
	"""
	Generates captions for the given image.

	-----
	Parameters:
	image: PIL.Image
		The image to generate captions for.
	num_captions: int
		The number of captions to generate.
	** Rest of the parameters are the same as in the model.generate method. **
	-----
	Returns:
	list[str]
	"""
	# Convert the numerical values to their corresponding types.
	# Gradio Slider returns values as floats: except when the value is a whole number, in which case it returns an int.
	# Only float values suffer from this issue.
	temperature = float(temperature)
	top_p = float(top_p)
	repetition_penalty = float(repetition_penalty)
	diversity_penalty = float(diversity_penalty)

	device = "cuda" if torch.cuda.is_available() else "cpu"

	model = MODELS[model_name](device)

	captions = model.generate(
		image=image,
		max_length=max_length,
		num_captions=num_captions,
		temperature=temperature,
		top_k=top_k,
		top_p=top_p,
		repetition_penalty=repetition_penalty,
		diversity_penalty=diversity_penalty,
	)

	# Convert list to a single string separated by newlines.
	captions = "\n".join(captions)
	return captions

title = "AI tool for generating captions for images"
description = "This tool uses pretrained models to generate captions for images."

interface = gr.Interface(
	fn=generate_captions,
	inputs=[
		gr.components.Image(type="pil", label="Image"),
		gr.components.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Captions to Generate"),
		gr.components.Dropdown(MODELS.keys(), label="Model", value=list(MODELS.keys())[1]), # Default to Blip Base
		gr.components.Slider(minimum=20, maximum=100, step=5, value=50, label="Maximum Caption Length"),
		gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Temperature"),
		gr.components.Slider(minimum=1, maximum=100, step=1, value=50, label="Top K"),
		gr.components.Slider(minimum=0.1, maximum=5.0, step=0.1, value=1.0, label="Top P"),
		gr.components.Slider(minimum=1.0, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"),
		gr.components.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Diversity Penalty"),
	],
	outputs=[
		gr.components.Textbox(label="Caption"),
	],
	# Set image examples to be displayed in the interface.
	examples = [
		["Image1.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
		["Image2.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
		["Image3.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
	],
	title=title,
	description=description,
	allow_flagging="never",
)


if __name__ == "__main__":
    # Launch the interface.
	interface.launch(
		enable_queue=True,
		debug=True,
	)