Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
import transformers | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
HF_TOKEN_DOWNLOAD = os.environ.get('HF_TOKEN_DOWNLOAD') | |
class Processor: | |
def __init__(self, model): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD) | |
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload') | |
self.model.eval() | |
def parse_choices(self, s): | |
''' | |
s: serialized_choices '(A) ... (B) ... (C) ...' | |
''' | |
choices = [] | |
key = 'A' if s.find('(A)') != -1 else 'a' | |
while True: | |
pos = s.find(f'({chr(ord(key) + 1)})') | |
if pos == -1: | |
break | |
choice = s[3:pos] | |
s = s[pos:] | |
choice = choice.strip(' ') | |
choices.append(choice) | |
key = chr(ord(key) + 1) | |
choice = s[3:] | |
choice = choice.strip(' ') | |
choices.append(choice) | |
return choices | |
def run(self, question, max_question_len, max_knowledge_len, max_answer_len, m, top_p): | |
choices = self.parse_choices(question.split('\\n')[1].strip(' ')) | |
choices = [chr(ord('A') + i) for i, choice in enumerate(choices)] | |
choices_ids = self.tokenizer(choices, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_answer_len).input_ids.to(device) # (C, AL) | |
prompt = question + ' \\n Knowledge: ' | |
prompt_tok = self.tokenizer(prompt, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len).to(device) # (1, QL) | |
knowledges_ids = self.model.generate( | |
input_ids=prompt_tok.input_ids, | |
attention_mask=prompt_tok.attention_mask, | |
max_length=max_knowledge_len + 1, | |
min_length=3, | |
do_sample=True, | |
num_return_sequences=m, | |
top_p=top_p, | |
) # (K, KL); begins with 0 ([BOS]); ends with 1 ([EOS]) | |
knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS]) | |
knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
knowledges = list(set(knowledges)) | |
knowledges = [''] + knowledges | |
prompts = [question + (f' \\n Knowledge: {knowledge} \\n Answer: ' if knowledge != '' else ' \\n Answer:') for knowledge in knowledges] | |
prompts_tok = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len + max_knowledge_len).to(device) # (1+K, QL+KL) | |
output = self.model( | |
input_ids=prompts_tok.input_ids, | |
attention_mask=prompts_tok.attention_mask, | |
labels=choices_ids[0].unsqueeze(0).repeat(len(knowledges), 1), | |
) | |
logitsss = output.logits # (1+K, AL, V) | |
logitss = logitsss[:, 0, :] # (1+K, V) | |
choice_ids = choices_ids[:, 0] # (C) | |
answer_logitss = logitss.gather(dim=1, index=choice_ids.unsqueeze(0).expand(len(knowledges), -1)) # (1+K, C) | |
answer_probss = answer_logitss.softmax(dim=1) # (1+K, C) | |
# Ensemble | |
knowless_pred = answer_probss[0, :].argmax(dim=0).item() | |
knowless_pred = choices[knowless_pred] | |
answer_probs = answer_probss.max(dim=0).values # (C) | |
knowful_pred = answer_probs.argmax(dim=0).item() | |
knowful_pred = choices[knowful_pred] | |
selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item() | |
selected_knowledge = knowledges[selected_knowledge_ix] | |
return { | |
'question': question, | |
'knowledges': knowledges, | |
'knowless_pred': knowless_pred, | |
'knowful_pred': knowful_pred, | |
'selected_knowledge': selected_knowledge, | |
} | |
MODELS = [ | |
'liujch1998/crystal-large', | |
# 'liujch1998/crystal-3b', | |
# 'liujch1998/crystal-11b', | |
] | |
processor_by_model = {} | |
for model in MODELS: | |
processor_by_model[model] = Processor(model) | |
def predict(question, model, max_question_len, max_knowledge_len, max_answer_len, m, top_p): | |
result = processor_by_model[model].run(question, max_question_len, max_knowledge_len, max_answer_len, m, top_p) | |
return result['knowless_pred'], result['knowful_pred'], '\n'.join(result['knowledges']), result['selected_knowledge'] | |
examples = [ | |
'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller', | |
'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper', | |
'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones', | |
'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded', | |
'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter', | |
'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs', | |
] | |
input_question = gr.Dropdown(choices=examples, label='Question:', | |
info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."', | |
) | |
input_model = gr.Dropdown(label='Model:', value=MODELS[0], choices=MODELS) | |
input_max_question_len = gr.Number(label='Max number of tokens in question:', value=256, precision=0) | |
input_max_knowledge_len = gr.Number(label='Max number of tokens in knowledge:', value=32, precision=0) | |
input_max_answer_len = gr.Number(label='Max number of tokens in answer:', value=2, precision=0) | |
input_m = gr.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1, | |
info='The actual number of generated knowledges may be less than this number due to possible duplicates.', | |
) | |
input_top_p = gr.Slider(label='top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05) | |
output_knowless_answer = gr.Textbox(label='QA model answer without knowledge:', interactive=False) | |
output_knowful_answer = gr.Textbox(label='QA model answer with knowledge:', interactive=False) | |
output_all_knowledges = gr.Textbox(label='All generated knowledges:', interactive=False) | |
output_selected_knowledge = gr.Textbox(label='Knowledge selected to make the prediction:', interactive=False) | |
description = '''This is a demo for the paper, [*Crystal: Introspective Reasoners Reinforced with Self-Feedback*](https://arxiv.org/abs/2310.04921), presented at EMNLP 2023. [[Code](https://github.com/liujch1998/crystal)] [[Model](https://huggingface.co/liujch1998/crystal-11b)] This demo is made & maintained by [Jiacheng (Gary) Liu](https://liujch1998.github.io). | |
Crystal is an introspective reasoning model that answers commonsense questions by first generating knowledge and then use knowledge-grounded reasoning to reach a final prediction. To try this model, select an example question, or write your own commonsense question in the suggested format.''' | |
gr.Interface( | |
fn=predict, | |
inputs=[input_question, input_model, input_max_question_len, input_max_knowledge_len, input_max_answer_len, input_m, input_top_p], | |
outputs=[output_knowless_answer, output_knowful_answer, output_all_knowledges, output_selected_knowledge], | |
title="Crystal Demo", | |
description=description, | |
).launch() | |