Gurgen-Blbulyan commited on
Commit
08cc25a
1 Parent(s): 769849b

adding files for app

Browse files
Files changed (3) hide show
  1. app.py +19 -0
  2. inference.py +29 -0
  3. utils.py +42 -0
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from inference import Inference
4
+
5
+
6
+ encoder_model_name='google/vit-large-patch32-224-in21k'
7
+ decoder_model_name='gpt2-large'
8
+ inference = Inference(
9
+ decoder_model_name=decoder_model_name,
10
+
11
+ )
12
+
13
+ def generate_text(video):
14
+ generated_text = inference.generate_text(video, encoder_model_name)
15
+
16
+ return generated_text
17
+
18
+ app = gr.Interface(fn=generate_text, inputs='video', outputs='text')
19
+ app.launch(share=True)
inference.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, VisionEncoderDecoderModel
3
+
4
+ import utils
5
+
6
+ class Inference:
7
+ def __init__(self, decoder_model_name, max_length=32):
8
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
11
+ self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained('armgabrielyan/video-summarization')
12
+ self.encoder_decoder_model.to(self.device)
13
+
14
+ self.max_length = max_length
15
+
16
+ def generate_text(self, video, encoder_model_name):
17
+ if isinstance(video, str):
18
+ pixel_values = utils.video2image_from_path(video, encoder_model_name)
19
+ else:
20
+ pixel_values = video
21
+
22
+ if not self.tokenizer.pad_token:
23
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
24
+ self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
25
+
26
+ generated_ids = self.encoder_decoder_model.generate(pixel_values.unsqueeze(0).to(self.device), max_length=self.max_length)
27
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
+
29
+ return generated_text
utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTFeatureExtractor
2
+ import torchvision
3
+ import torchvision.transforms.functional as fn
4
+ import torch as th
5
+
6
+
7
+ def video2image_from_path(video_path, feature_extractor_name):
8
+ video = torchvision.io.read_video(video_path)
9
+
10
+ return video2image(video[0], feature_extractor_name)
11
+
12
+
13
+ def video2image(video, feature_extractor_name):
14
+ feature_extractor = ViTFeatureExtractor.from_pretrained(
15
+ feature_extractor_name
16
+ )
17
+
18
+ vid = th.permute(video, (3, 0, 1, 2))
19
+ samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long)
20
+ vid = vid[:, samp, :, :]
21
+
22
+ im_l = list()
23
+ for i in range(vid.shape[1]):
24
+ im_l.append(vid[:, i, :, :])
25
+
26
+ inputs = feature_extractor(im_l, return_tensors="pt")
27
+
28
+ inputs = inputs['pixel_values']
29
+
30
+ im_h = list()
31
+ for i in range(7):
32
+ im_v = th.cat((inputs[0+i*7, :, :, :],
33
+ inputs[1+i*7, :, :, :],
34
+ inputs[2+i*7, :, :, :],
35
+ inputs[3+i*7, :, :, :],
36
+ inputs[4+i*7, :, :, :],
37
+ inputs[5+i*7, :, :, :],
38
+ inputs[6+i*7, :, :, :]), 2)
39
+ im_h.append(im_v)
40
+ resize = fn.resize(th.cat(im_h, 1), size=[224])
41
+
42
+ return resize