Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm, trange
|
|
15 |
import skimage.io as io
|
16 |
import PIL.Image
|
17 |
import gradio as gr
|
|
|
|
|
18 |
N = type(None)
|
19 |
V = np.array
|
20 |
ARRAY = np.ndarray
|
@@ -228,47 +230,47 @@ clip_model, preprocess = clip.load("ViT-B/16", device=device, jit=False)
|
|
228 |
from transformers import AutoTokenizer
|
229 |
tokenizer = AutoTokenizer.from_pretrained("imthanhlv/gpt2news")
|
230 |
|
231 |
-
def inference(img, text, is_translate):
|
232 |
-
prefix_length = 10
|
233 |
-
model = ClipCaptionModel(prefix_length)
|
234 |
-
model_path = 'sat_019.pt'
|
235 |
-
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
236 |
-
model = model.eval()
|
237 |
-
device = CUDA(0) if is_gpu else "cpu"
|
238 |
-
model = model.to(device)
|
239 |
-
use_beam_search = True
|
240 |
-
if is_translate:
|
241 |
-
# encode text
|
242 |
-
if text is None:
|
243 |
-
return "No text provided"
|
244 |
-
text = clip.tokenize([text]).to(device)
|
245 |
-
with torch.no_grad():
|
246 |
-
prefix = clip_model.encode_text(text).to(device, dtype=torch.float32)
|
247 |
-
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
248 |
-
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
-
|
263 |
|
264 |
title = "CLIP Dual encoder"
|
265 |
-
description = "You can translate English
|
266 |
examples=[["drug.jpg","", False], ["", "What is your name?", True]]
|
267 |
|
268 |
inputs = [
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
]
|
273 |
|
274 |
gr.Interface(
|
|
|
15 |
import skimage.io as io
|
16 |
import PIL.Image
|
17 |
import gradio as gr
|
18 |
+
|
19 |
+
|
20 |
N = type(None)
|
21 |
V = np.array
|
22 |
ARRAY = np.ndarray
|
|
|
230 |
from transformers import AutoTokenizer
|
231 |
tokenizer = AutoTokenizer.from_pretrained("imthanhlv/gpt2news")
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
def inference(img, text, is_translation):
|
235 |
+
prefix_length = 10
|
236 |
+
model = ClipCaptionModel(prefix_length)
|
237 |
+
model_path = 'sat_019.pt'
|
238 |
+
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
239 |
+
model = model.eval()
|
240 |
+
device = CUDA(0) if is_gpu else "cpu"
|
241 |
+
model = model.to(device)
|
242 |
+
if is_translation:
|
243 |
+
# encode text
|
244 |
+
if text is None:
|
245 |
+
return "No text provided"
|
246 |
+
text = clip.tokenize([text]).to(device)
|
247 |
+
with torch.no_grad():
|
248 |
+
prefix = clip_model.encode_text(text).to(device, dtype=torch.float32)
|
249 |
+
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
250 |
+
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
|
251 |
|
252 |
+
else:
|
253 |
+
if img is None:
|
254 |
+
return "No image"
|
255 |
+
image = io.imread(img.name)
|
256 |
+
pil_image = PIL.Image.fromarray(image)
|
257 |
+
image = preprocess(pil_image).unsqueeze(0).to(device)
|
258 |
+
|
259 |
+
with torch.no_grad():
|
260 |
+
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
261 |
+
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
262 |
+
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed, prompt="Một bức ảnh về")[0]
|
263 |
|
264 |
+
return generated_text_prefix
|
265 |
|
266 |
title = "CLIP Dual encoder"
|
267 |
+
description = "You can translate English to Vietnamese or generate Vietnamese caption from image"
|
268 |
examples=[["drug.jpg","", False], ["", "What is your name?", True]]
|
269 |
|
270 |
inputs = [
|
271 |
+
gr.inputs.Image(type="file", label="Image to generate Vietnamese caption", optional=True),
|
272 |
+
gr.inputs.Textbox(lines=2, placeholder="English sentence for translation"),
|
273 |
+
gr.inputs.Checkbox()
|
274 |
]
|
275 |
|
276 |
gr.Interface(
|