File size: 9,260 Bytes
cd75218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b91ee0
83e56ae
 
cd75218
 
37cd4b3
 
cd75218
 
 
 
 
 
 
 
 
 
37cd4b3
cd75218
37cd4b3
 
cd75218
 
 
 
 
37cd4b3
 
cd75218
 
 
 
37cd4b3
 
 
 
 
 
 
 
cd75218
 
 
 
 
37cd4b3
 
 
 
 
 
 
 
cd75218
 
 
 
 
 
 
 
37cd4b3
cd75218
 
 
 
 
 
 
 
 
 
 
37cd4b3
cd75218
37cd4b3
 
cd75218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37cd4b3
cd75218
 
 
 
 
 
 
 
 
 
 
 
 
 
37cd4b3
cd75218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a9298b
 
cd75218
 
 
 
 
 
 
 
9b91ee0
cd75218
 
 
 
 
 
37cd4b3
cd75218
 
 
 
 
 
 
 
 
 
37cd4b3
cd75218
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import re
from time import sleep
from typing import List, Tuple, Optional

import gradio as gr
import requests
import platform

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer,
    BitsAndBytesConfig,
    GenerationConfig
)


if platform.system() == "Windows" or platform.system() == "Darwin":
    from dotenv import load_dotenv
    load_dotenv()

# Load model in int4
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-7b1")
print(f"Successfully loaded the model")


# Define stopping criteria. We do not use it for bloom model family but it can be used for llama model family
stop_tokens = ["\n###"]
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_tokens)

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in stop_token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


# Prompts 
instruction_with_q = """
A chat between a curious USER and an artificial intelligence assistant.
The assistant's job is to answer the given question using only the information provided in the RDF triplet format. The assistant's answer should be in a USER-readable format, with proper sentences and grammar and should be concise and short.
The RDF triplets will be provided in triplets, where triplets are always in the (subject, relation, object) format and are separated by a semicolon. The assistant should understand that if multiple triplets are provided, the answer to the question should use all of the information from triplets and make aggregation. The assistant MUST NOT add any additional information, beside form the one proveded in the triplets.
The assistant should try to reply as short as possible, and perform counting or aggregation operations over triplets by himself when necessary.
"""

instruction_wo_q = """
A chat between a curious USER and an artificial intelligence assistant.
The assistant's job is convert the provided input in RDF triplet format into USER-readable text format, with proper sentences and grammar. The triplets are always in the (subject, relation, object) format, where each triplet is separated by a semicolon. The assistant should understand that if multiple triplets are provided, the generated USER-readable text should use all of the information from input. The assistant MUST NOT add any additional information, beside form the one proveded in the input.
"""


history_with_q = [
        ("USER", "Question: Is Essex the Ceremonial County of West Tilbury? Triplets: ('West Tilbury', 'Ceremonial County', 'Essex');"),
        ("ASSISTANT", "Essex is the Ceremonial County of West Tributary"),
        ("USER", "Question: What nation is Hornito located in, where Jamie Bateman Cayn died too? Triplets: ('Jaime Bateman Cay贸n', 'death place', 'Panama'); ('Hornito, Chiriqu铆', 'country', 'Panama');"),
        ("ASSISTANT", "Hornito, Chiriqu铆 is located in Panama, where Jaime Bateman Cay贸n died."),
        ("USER", "Question: Who are the shareholder of the soccer club for whom Steve Holland plays? Triplets: ('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
        ("ASSISTANT", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
        ("USER", "Question: Who is the chancellor of Falmouth University? Triplets: ('Falmouth University', 'chancellor', 'Dawn French');"),
        ("ASSISTANT", "The chancellor of the Falmouth University is Dawn French.")

    ]


history_wo_q = [
        ("USER", "('West Tilbury', 'Ceremonial County', 'Essex');"),
        ("ASSISTANT", "Essex is the Ceremonial County of West Tributary"),
        ("USER", "('Jaime Bateman Cay贸n', 'death place', 'Panama'); ('Hornito, Chiriqu铆', 'country', 'Panama');"),
        ("ASSISTANT", "Hornito, Chiriqu铆 is located in Panama, where Jaime Bateman Cay贸n died."),
        ("USER", "('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
        ("ASSISTANT", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
        ("USER", "('Falmouth University', 'chancellor', 'Dawn French');"),
        ("ASSISTANT", "The chancellor of the Falmouth University is Dawn French.")

    ]


# Helper finctions to conert input into prompt format
def prepare_input(linearized_triplets, question=None) ->  str:
    if question and "List all" in question:
        question = question.replace("List all ", "Which are ")
    if question:
        input_text = f"Question: {question.strip()} Triplets: {linearized_triplets}"
    else:
        input_text = linearized_triplets
    return input_text


def make_prompt(
    curr_input: str,
    instruction: str,
    history: List[Tuple[str, str]]=None,
) -> str:
    ret = f"{instruction}\n"
    for i, (role, message) in enumerate(history):
        ret += f"{role}: {message}\n"
    ret += f"USER: {curr_input}\nASSISTANT: "
    return ret


def generate_output(
    triplets: str,
    question: str = None,
    temperature=0.6,
    top_p=0.5,
    top_k=0,
    repetition_penalty=1.08
) -> str:
    curr_input = prepare_input(triplets, question)
    if question:
        instruction = make_prompt(curr_input, instruction_with_q, history_with_q)
    else:
        instruction = make_prompt(curr_input, instruction_wo_q, history_wo_q)

    stop = StopOnTokens()
    input_ids = tokenizer(instruction, return_tensors="pt").input_ids
    input_ids = input_ids.to(model.device)

    generate_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=100,
        temperature=temperature,
        do_sample=temperature>0.0,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
    )

    with torch.no_grad():
        outputs = model.generate(**generate_kwargs, return_dict_in_generate=True, output_scores=True)

    response = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    for tok in tokenizer.additional_special_tokens+[tokenizer.eos_token]:
        instruction = instruction.replace(tok, '')
    response = response[len(instruction):]
    return response


# Gradio UI Code
with gr.Blocks(theme='gradio/soft') as demo:
    # Elements stack vertically by default just define elements in order you want them to stack
    header = gr.HTML("""
        <h1 style="text-align: center">RDF to text Vicuna Demo</h1>
        <h3 style="text-align: center"> Generate natural language verbalizations from RDF triplets </h3>
        <br>
        <p style="font-size: 12px; text-align: center">鈿狅笍 Takes approximately 15-30s to generate.</p>
    """)


    triplets = gr.Textbox(lines=3, placeholder="('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');", label='Triplets')
    question = gr.Textbox(lines=4, placeholder='Write a question here, if you want to generate answer based on question.', label='Question')
    
    with gr.Row():
        run_button = gr.Button("Generate", variant="primary")
        clear_button = gr.ClearButton(variant="secondary")
    
    output_box = gr.Textbox(lines=2, interactive=False, label="Generated Text")

    with gr.Accordion("Options", open=False):
        temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
        top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
        top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
        repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
    
    info = gr.HTML(f"""
        <p>馃寪 Leveraging the <a href='https://huggingface.co/bigscience/bloomz-7b1'><strong>Vicuna model</strong></a> with int4 quantization.</p>
    """)


    examples = gr.Examples([
        ['("Google Videos", "developer", "Google"), ("Google Web Toolkit", "author", "Google")', ""],
        ['("Katyayana", "religion", "Buddhism")', "What is the relegious affiliations of Katyayana?"],
    ], inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], fn=generate_output, cache_examples=False if platform.system() == "Windows" or platform.system() == "Darwin" else True, outputs=output_box)


    #readme_content = requests.get(f"https://huggingface.co/HF_MODEL_PATH/raw/main/README.md").text
    #readme_content = re.sub('---.*?---', '', readme_content, flags=re.DOTALL) #Remove YAML front matter

    #with gr.Accordion("馃摉 Model Readme", open=True):
    #    readme = gr.Markdown(
    #        readme_content,
    #    )
    
    run_button.click(fn=generate_output, inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="rdf2text")
    clear_button.add([triplets, question, output_box])

demo.queue(concurrency_count=1, max_size=10).launch(debug=True)