liujch1998 commited on
Commit
7eae3e8
1 Parent(s): da832fe

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +144 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import transformers
5
+
6
+ def reduce_sum(value, mask, axis=None):
7
+ if axis is None:
8
+ return torch.sum(value * mask)
9
+ return torch.sum(value * mask, axis)
10
+ def reduce_mean(value, mask, axis=None):
11
+ if axis is None:
12
+ return torch.sum(value * mask) / torch.sum(mask)
13
+ return reduce_sum(value, mask, axis) / torch.sum(mask, axis)
14
+
15
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
16
+
17
+ HF_TOKEN_DOWNLOAD = os.environ.get('HF_TOKEN_DOWNLOAD')
18
+
19
+ class Processor:
20
+ def __init__(self, model):
21
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD)
22
+ self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload')
23
+ self.model.eval()
24
+
25
+ def parse_choices(self, s):
26
+ '''
27
+ s: serialized_choices '(A) ... (B) ... (C) ...'
28
+ '''
29
+ choices = []
30
+ key = 'A' if s.find('(A)') != -1 else 'a'
31
+ while True:
32
+ pos = s.find(f'({chr(ord(key) + 1)})')
33
+ if pos == -1:
34
+ break
35
+ choice = s[3:pos]
36
+ s = s[pos:]
37
+ choice = choice.strip(' ')
38
+ choices.append(choice)
39
+ key = chr(ord(key) + 1)
40
+ choice = s[3:]
41
+ choice = choice.strip(' ')
42
+ choices.append(choice)
43
+ return choices
44
+
45
+ def run(self, question, max_question_len, max_knowledge_len, max_answer_len, m, top_p):
46
+ choices = self.parse_choices(question.split('\\n')[1].strip(' '))
47
+ choices = [chr(ord('A') + i) for i, choice in enumerate(choices)]
48
+ choices_ids = self.tokenizer(choices, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_answer_len).input_ids.to(device) # (C, AL)
49
+
50
+ prompt = question + ' \\n Knowledge: '
51
+ prompt_tok = self.tokenizer(prompt, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len).to(device) # (1, QL)
52
+ knowledges_ids = self.model.generate(
53
+ input_ids=prompt_tok.input_ids,
54
+ attention_mask=prompt_tok.attention_mask,
55
+ max_length=max_knowledge_len + 1,
56
+ min_length=3,
57
+ do_sample=True,
58
+ num_return_sequences=m,
59
+ top_p=top_p,
60
+ ) # (K, KL); begins with 0 ([BOS]); ends with 1 ([EOS])
61
+ knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
62
+ knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
63
+ knowledges = list(set(knowledges))
64
+ knowledges = [''] + knowledges
65
+
66
+ prompts = [question + (f' \\n Knowledge: {knowledge} \\n Answer: ' if knowledge != '' else ' \\n Answer:') for knowledge in knowledges]
67
+ prompts_tok = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len + max_knowledge_len).input_ids.to(device) # (1+K, QL+KL)
68
+ output = self.model(
69
+ input_ids=prompts_tok.input_ids,
70
+ attention_mask=prompts_tok.attention_mask,
71
+ # labels=choices_ids[0].unsqueeze(0).expand(len(knowledges), -1),
72
+ )
73
+ logitsss = output.logits # (1+K, AL, V)
74
+ logitss = logitsss[:, 0, :] # (1+K, V)
75
+ choice_ids = choices_ids[:, 0] # (C)
76
+ answer_logitss = logitss.gather(dim=1, index=choice_ids.unsqueeze(0).expand(len(knowledges), -1)) # (1+K, C)
77
+ answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)
78
+
79
+ # Ensemble
80
+ knowless_pred = answer_probss[0, :].argmax(dim=0).item()
81
+ knowless_pred = choices[knowless_pred]
82
+
83
+ answer_probs = answer_probss.max(dim=0).values # (C)
84
+ knowful_pred = answer_probs.argmax(dim=0).item()
85
+ knowful_pred = choices[knowful_pred]
86
+ selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
87
+ selected_knowledge = knowledges[selected_knowledge_ix]
88
+
89
+ return {
90
+ 'question': question,
91
+ 'knowledges': knowledges,
92
+ 'knowless_pred': knowless_pred,
93
+ 'knowful_pred': knowful_pred,
94
+ 'selected_knowledge': selected_knowledge,
95
+ }
96
+
97
+ MODELS = [
98
+ 'liujch1998/crystal-large',
99
+ # 'liujch1998/crystal-3b',
100
+ # 'liujch1998/crystal-11b',
101
+ ]
102
+ processor_by_model = {}
103
+ for model in MODELS:
104
+ processor_by_model[model] = Processor(model)
105
+
106
+ def predict(question, model, max_question_len, max_knowledge_len, max_answer_len, m, top_p):
107
+ result = processor_by_model[model].run(question, max_question_len, max_knowledge_len, max_answer_len, m, top_p)
108
+ return result['knowless_pred'], result['knowful_pred'], '\n'.join(result['knowledges']), result['selected_knowledge']
109
+
110
+ examples = [
111
+ 'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
112
+ 'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
113
+ 'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
114
+ 'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
115
+ 'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
116
+ 'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
117
+ ]
118
+
119
+ input_question = gr.Dropdown(choices=examples, label='Question:',
120
+ info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
121
+ )
122
+ input_model = gr.DropDown(label='Model:', value=MODELS[0], choices=MODELS)
123
+ input_max_question_len = gr.Number(label='Max number of tokens in question:', value=256, precision=0)
124
+ input_max_knowledge_len = gr.Number(label='Max number of tokens in knowledge:', value=32, precision=0)
125
+ input_max_answer_len = gr.Number(label='Max number of tokens in answer:', value=2, precision=0)
126
+ input_m = gr.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1,
127
+ info='The actual number of generated knowledges may be less than this number due to possible duplicates.',
128
+ )
129
+ input_top_p = gr.Slider(label='top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
130
+ output_knowless_answer = gr.Textbox(label='QA model answer without knowledge:', interactive=False)
131
+ output_knowful_answer = gr.Textbox(label='QA model answer with knowledge:', interactive=False)
132
+ output_all_knowledges = gr.Textbox(label='All generated knowledges:', interactive=False)
133
+ output_selected_knowledge = gr.Textbox(label='Knowledge selected to make the prediction:', interactive=False)
134
+
135
+ description = '''This is a demo for the paper, [*Crystal: Introspective Reasoners Reinforced with Self-Feedback*](), presented at EMNLP 2023. [[Code](https://github.com/liujch1998/crystal)] [[Model](https://huggingface.co/liujch1998/crystal-large)] This demo is made & maintained by [Jiacheng (Gary) Liu](https://liujch1998.github.io).
136
+ Crystal is an introspective reasoning model that answers commonsense questions by first generating knowledge and then use knowledge-grounded reasoning to reach a final prediction. To try this model, select an example question, or write your own commonsense question in the suggested format.'''
137
+
138
+ gr.Interface(
139
+ fn=predict,
140
+ inputs=[input_question, input_model, input_max_question_len, input_max_knowledge_len, input_max_answer_len, input_m, input_top_p],
141
+ outputs=[output_knowless_answer, output_knowful_answer, output_all_knowledges, output_selected_knowledge],
142
+ title="Crystal Demo",
143
+ description=description,
144
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tokenizers
4
+ sentencepiece