Hritik commited on
Commit
3a496ae
1 Parent(s): 7862e49

edit code for nle inference

Browse files
Files changed (3) hide show
  1. app.py +55 -5
  2. data_utils/xgpt3_dataset.py +7 -12
  3. entailment_inference.py +1 -72
app.py CHANGED
@@ -1,13 +1,63 @@
1
- import gradio as gr
 
 
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  print(f"Is CUDA available: {torch.cuda.is_available()}")
4
  # True
5
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
6
  # Tesla T4
7
 
8
- # def greet(name):
9
- # return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
12
- # iface.launch()
 
 
13
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import json
4
  import torch
5
+ import argparse
6
+ import pandas as pd
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+ from transformers.models.llama.tokenization_llama import LlamaTokenizer
11
+ from torch.utils.data import DataLoader
12
+ from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
13
+ from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
14
+ from peft import LoraConfig, get_peft_model
15
+ from data_utils.xgpt3_dataset import MultiModalDataset
16
+ from utils import batchify
17
+
18
+ import gradio as gr
19
+ from entailment_inference import get_scores
20
+
21
  print(f"Is CUDA available: {torch.cuda.is_available()}")
22
  # True
23
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
24
  # Tesla T4
25
 
26
+ tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
27
+ image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
28
+ processor = MplugOwlProcessor(image_processor, tokenizer)
29
+
30
+
31
+ # Instantiate model
32
+ model = MplugOwlForConditionalGeneration.from_pretrained(
33
+ pretrained_ckpt,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map={'':0}
36
+ )
37
+
38
+ for name, param in model.named_parameters():
39
+ param.requires_grad = False
40
+ peft_config = LoraConfig(
41
+ target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
42
+ inference_mode=True,
43
+ r=32,
44
+ lora_alpha=16,
45
+ lora_dropout=0.05
46
+ )
47
+ model = get_peft_model(model, peft_config)
48
+ model.print_trainable_parameters()
49
+ with open(trained_ckpt, 'rb') as f:
50
+ ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
51
+ model.load_state_dict(ckpt)
52
+ model = model.to(torch.bfloat16)
53
+ print('Model Loaded')
54
 
55
+ PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
56
+ Human: <|video|>
57
+ Human: Does this video entail the description: ""A basketball team walking off the field while the audience claps.""?
58
+ AI: """
59
 
60
+ valid_data = MultiModalDataset("examples/y5xuvHpDPZQ_000005_000015.mp4", PROMPT, tokenizer, processor, max_length = 256, loss_objective = 'sequential')
61
+ dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify)
62
+ score = get_scores(model, tokenizer, dataloader)
63
+ print(score)
data_utils/xgpt3_dataset.py CHANGED
@@ -36,37 +36,32 @@ def load_jsonl(filename):
36
  class MultiModalDataset(Dataset):
37
  """MultiModal dataset"""
38
 
39
- def __init__(self, input_file, tokenizer, processor,
40
  max_length=2048,
41
  media_tokens=['<image>', '<|video|>'], loss_objective = 'sequential'):
42
 
43
  args = get_args()
44
 
45
  self.loss_objective = loss_objective
46
- if 'sequential' in self.loss_objective:
47
- self.dataset = pd.read_csv(input_file)
48
- self.dataset = self.dataset.dropna()
49
- else:
50
  raise NotImplementedError('dataset loader not implemented for other loss objectives')
51
 
52
- self.dataset = pd.read_csv(input_file)
 
53
  self.tokenizer = tokenizer
54
  self.max_length = max_length
55
  self.processor = processor
56
  self.media_tokens = {k: -int(i+1) for i, k in enumerate(media_tokens)}
57
  self.media_lengths = {'<image>': 1+64,'<|video|>': 1+64}
58
  print("num_media_token: ", self.media_lengths)
59
- print(len(self.dataset))
60
  self.bucket = {}
61
 
62
  def __len__(self):
63
- return len(self.dataset)
64
 
65
  def __getitem__(self, index):
66
-
67
- data = self.dataset.iloc[index]
68
- videopath = data['videopath']
69
- caption = data['caption']
70
  video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt') # video_pixel_values
71
  text_input = self._extract_text_token_from_conversation(caption, self.max_length, index)
72
  item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
 
36
  class MultiModalDataset(Dataset):
37
  """MultiModal dataset"""
38
 
39
+ def __init__(self, videopath, text, tokenizer, processor,
40
  max_length=2048,
41
  media_tokens=['<image>', '<|video|>'], loss_objective = 'sequential'):
42
 
43
  args = get_args()
44
 
45
  self.loss_objective = loss_objective
46
+ if 'sequential' not in self.loss_objective:
 
 
 
47
  raise NotImplementedError('dataset loader not implemented for other loss objectives')
48
 
49
+ self.videopath = videopath
50
+ self.text = text
51
  self.tokenizer = tokenizer
52
  self.max_length = max_length
53
  self.processor = processor
54
  self.media_tokens = {k: -int(i+1) for i, k in enumerate(media_tokens)}
55
  self.media_lengths = {'<image>': 1+64,'<|video|>': 1+64}
56
  print("num_media_token: ", self.media_lengths)
 
57
  self.bucket = {}
58
 
59
  def __len__(self):
60
+ return 1
61
 
62
  def __getitem__(self, index):
63
+ videopath = self.videopath
64
+ caption = self.text
 
 
65
  video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt') # video_pixel_values
66
  text_input = self._extract_text_token_from_conversation(caption, self.max_length, index)
67
  item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
entailment_inference.py CHANGED
@@ -15,18 +15,7 @@ from peft import LoraConfig, get_peft_model
15
  from data_utils.xgpt3_dataset import MultiModalDataset
16
  from utils import batchify
17
 
18
- parser = argparse.ArgumentParser()
19
 
20
- parser.add_argument('--input_csv', type = str, required = True, help = 'input json file')
21
- parser.add_argument('--output_csv', type = str, help = 'output csv with scores')
22
- parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt')
23
- parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt')
24
- parser.add_argument('--lora_r', type = int, default = 32)
25
- parser.add_argument('--use_lora', action = 'store_true', help = 'lora model')
26
- parser.add_argument('--all-params', action = 'store_true', help = 'use all params of the model')
27
- parser.add_argument('--batch_size', type = int, default = 32)
28
-
29
- args = parser.parse_args()
30
  softmax = nn.Softmax(dim=2)
31
 
32
  def get_entail(logits, input_ids, tokenizer):
@@ -47,7 +36,6 @@ def get_entail(logits, input_ids, tokenizer):
47
  return entailment
48
 
49
  def get_scores(model, tokenizer, dataloader):
50
-
51
  with torch.no_grad():
52
  for index, inputs in tqdm(enumerate(dataloader)):
53
  for k, v in inputs.items():
@@ -60,63 +48,4 @@ def get_scores(model, tokenizer, dataloader):
60
  non_media_mask = inputs['non_media_mask'], prompt_mask = inputs['prompt_mask'])
61
  logits = outputs['logits']
62
  entail_scores = get_entail(logits, inputs['input_ids'], tokenizer)
63
- for m in range(len(entail_scores)):
64
- with open(args.output_csv, 'a') as f:
65
- writer = csv.writer(f)
66
- writer.writerow([inputs['videopaths'][m], inputs['captions'][m], entail_scores[m].item()])
67
- print(f"Batch {index} Done")
68
-
69
- def main():
70
-
71
- pretrained_ckpt = args.pretrained_ckpt
72
-
73
- # Processors
74
- tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
75
- image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
76
- processor = MplugOwlProcessor(image_processor, tokenizer)
77
-
78
- valid_data = MultiModalDataset(args.input_csv, tokenizer, processor, max_length = 256, loss_objective = 'sequential')
79
- dataloader = DataLoader(valid_data, batch_size=args.batch_size, pin_memory=True, collate_fn=batchify)
80
-
81
- # Instantiate model
82
- model = MplugOwlForConditionalGeneration.from_pretrained(
83
- pretrained_ckpt,
84
- torch_dtype=torch.bfloat16,
85
- device_map={'':0}
86
- )
87
-
88
- if args.use_lora:
89
- for name, param in model.named_parameters():
90
- param.requires_grad = False
91
- if args.all_params:
92
- peft_config = LoraConfig(
93
- target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
94
- inference_mode=True,
95
- r=args.lora_r,
96
- lora_alpha=16,
97
- lora_dropout=0.05
98
- )
99
- else:
100
- peft_config = LoraConfig(
101
- target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
102
- inference_mode=True,
103
- r=args.lora_r,
104
- lora_alpha=16,
105
- lora_dropout=0.05
106
- )
107
-
108
- model = get_peft_model(model, peft_config)
109
- model.print_trainable_parameters()
110
-
111
- with open(args.trained_ckpt, 'rb') as f:
112
- ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
113
- model.load_state_dict(ckpt)
114
- model = model.to(torch.bfloat16)
115
- print('Model Loaded')
116
-
117
- model.eval()
118
-
119
- get_scores(model, tokenizer, dataloader)
120
-
121
- if __name__ == "__main__":
122
- main()
 
15
  from data_utils.xgpt3_dataset import MultiModalDataset
16
  from utils import batchify
17
 
 
18
 
 
 
 
 
 
 
 
 
 
 
19
  softmax = nn.Softmax(dim=2)
20
 
21
  def get_entail(logits, input_ids, tokenizer):
 
36
  return entailment
37
 
38
  def get_scores(model, tokenizer, dataloader):
 
39
  with torch.no_grad():
40
  for index, inputs in tqdm(enumerate(dataloader)):
41
  for k, v in inputs.items():
 
48
  non_media_mask = inputs['non_media_mask'], prompt_mask = inputs['prompt_mask'])
49
  logits = outputs['logits']
50
  entail_scores = get_entail(logits, inputs['input_ids'], tokenizer)
51
+ return entail_scores[0].item()