Spaces:
Runtime error
Runtime error
Commit
•
cd75218
1
Parent(s):
b17a523
init commit
Browse files- README.md +7 -1
- app.py +202 -0
- requirements.txt +6 -0
README.md
CHANGED
@@ -10,4 +10,10 @@ pinned: false
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
|
14 |
+
This demo space is for generating text from RDF triplets. Being able to generate good quality text from RDF data would permit e.g., making this data more accessible to lay users, enriching existing text with information drawn from knowledge bases such as DBpedia or describing, comparing and relating entities present in these knowledge bases. Along with simple RDF to text generation, the repository can be used to generate long form answers to questions based on triplets as context.
|
15 |
+
|
16 |
+
You can input triplets either with a question or without. The input triplets should be in the following format:
|
17 |
+
```
|
18 |
+
("subj", "pred", "obj");("subj", "pred", "obj");
|
19 |
+
```
|
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from time import sleep
|
4 |
+
from typing import List, Tuple, Optional
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
import platform
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers import (
|
12 |
+
AutoModelForCausalLM,
|
13 |
+
AutoTokenizer,
|
14 |
+
StoppingCriteria,
|
15 |
+
StoppingCriteriaList,
|
16 |
+
TextIteratorStreamer,
|
17 |
+
BitsAndBytesConfig,
|
18 |
+
GenerationConfig
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
if platform.system() == "Windows" or platform.system() == "Darwin":
|
23 |
+
from dotenv import load_dotenv
|
24 |
+
load_dotenv()
|
25 |
+
|
26 |
+
# Load model in int4
|
27 |
+
bnb_config = BitsAndBytesConfig(
|
28 |
+
load_in_4bit=True,
|
29 |
+
bnb_4bit_use_double_quant=True,
|
30 |
+
bnb_4bit_quant_type="nf4",
|
31 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
32 |
+
)
|
33 |
+
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b-v1.3")
|
35 |
+
model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-13b-v1.3", quantization_config=bnb_config, trust_remote_code=True, device_map="auto")
|
36 |
+
if torch.__version__ >= "2":
|
37 |
+
model = torch.compile(model)
|
38 |
+
print(f"Successfully loaded the model {model_name} into memory")
|
39 |
+
|
40 |
+
|
41 |
+
# Define stopping criteria
|
42 |
+
stop_tokens = ["###", "Human", "\n###"]
|
43 |
+
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_tokens)
|
44 |
+
|
45 |
+
class StopOnTokens(StoppingCriteria):
|
46 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
47 |
+
for stop_id in stop_token_ids:
|
48 |
+
if input_ids[0][-1] == stop_id:
|
49 |
+
return True
|
50 |
+
return False
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
#Prompts
|
55 |
+
instruction_with_q = """
|
56 |
+
A chat between a curious human and an artificial intelligence assistant.
|
57 |
+
The assistant's job is to answer the given question using only the information provided in the RDF triplet format. The assistant's answer should be in a human-readable format, with proper sentences and grammar and should be concise and short.
|
58 |
+
The RDF triplets will be provided in triplets, where triplets are always in the (subject, relation, object) format and are separated by a semicolon. The assistant should understand that if multiple triplets are provided, the answer to the question should use all of the information from triplets and make aggregation. The assistant MUST NOT add any additional information, beside form the one proveded in the triplets.
|
59 |
+
The assistant should try to reply as short as possible, and perform counting or aggregation operations over triplets by himself when necessary.
|
60 |
+
"""
|
61 |
+
|
62 |
+
instruction_wo_q = """
|
63 |
+
A chat between a curious human and an artificial intelligence assistant.
|
64 |
+
The assistant's job is convert the provided input in RDF triplet format into human-readable text format, with proper sentences and grammar. The triplets are always in the (subject, relation, object) format, where each triplet is separated by a semicolon. The assistant should understand that if multiple triplets are provided, the generated human-readable text should use all of the information from input. The assistant MUST NOT add any additional information, beside form the one proveded in the input.
|
65 |
+
"""
|
66 |
+
|
67 |
+
|
68 |
+
history_with_q = [
|
69 |
+
("Human", "Question: Is Essex the Ceremonial County of West Tilbury? Triplets: ('West Tilbury', 'Ceremonial County', 'Essex');"),
|
70 |
+
("Assistant", "Essex is the Ceremonial County of West Tributary"),
|
71 |
+
("Human", "Question: What nation is Hornito located in, where Jamie Bateman Cayn died too? Triplets: ('Jaime Bateman Cayón', 'death place', 'Panama'); ('Hornito, Chiriquí', 'country', 'Panama');"),
|
72 |
+
("Assistant", "Hornito, Chiriquí is located in Panama, where Jaime Bateman Cayón died."),
|
73 |
+
("Human", "Question: Who are the shareholder of the soccer club for whom Steve Holland plays? Triplets: ('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
|
74 |
+
("Assistant", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
|
75 |
+
("Human", "Question: Who is the chancellor of Falmouth University? Triplets: ('Falmouth University', 'chancellor', 'Dawn French');"),
|
76 |
+
("Assistant", "The chancellor of the Falmouth University is Dawn French.")
|
77 |
+
|
78 |
+
]
|
79 |
+
|
80 |
+
|
81 |
+
history_wo_q = [
|
82 |
+
("Human", "('West Tilbury', 'Ceremonial County', 'Essex');"),
|
83 |
+
("Assistant", "Essex is the Ceremonial County of West Tributary"),
|
84 |
+
("Human", "('Jaime Bateman Cayón', 'death place', 'Panama'); ('Hornito, Chiriquí', 'country', 'Panama');"),
|
85 |
+
("Assistant", "Hornito, Chiriquí is located in Panama, where Jaime Bateman Cayón died."),
|
86 |
+
("Human", "('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
|
87 |
+
("Assistant", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
|
88 |
+
("Human", "('Falmouth University', 'chancellor', 'Dawn French');"),
|
89 |
+
("Assistant", "The chancellor of the Falmouth University is Dawn French.")
|
90 |
+
|
91 |
+
]
|
92 |
+
|
93 |
+
|
94 |
+
# Helper finctions to conert input into prompt format
|
95 |
+
def prepare_input(linearized_triplets, question=None) -> str:
|
96 |
+
if question and "List all" in question:
|
97 |
+
question = question.replace("List all ", "Which are ")
|
98 |
+
if "question" in style:
|
99 |
+
input_text = f"Question: {question.strip()} Triplets: {linearized_triplets}"
|
100 |
+
else:
|
101 |
+
input_text = linearized_triplets
|
102 |
+
return input_text
|
103 |
+
|
104 |
+
|
105 |
+
def make_prompt(
|
106 |
+
curr_input: str,
|
107 |
+
instruction: str,
|
108 |
+
history: List[Tuple[str, str]]=None,
|
109 |
+
) -> str:
|
110 |
+
ret = f"{instruction}\n###"
|
111 |
+
for i, (role, message) in enumerate(history):
|
112 |
+
ret += f"{role}: {message}\n###"
|
113 |
+
ret += f"Human: {curr_input}\n###Assistant: \n"
|
114 |
+
return ret
|
115 |
+
|
116 |
+
|
117 |
+
def generate_output(
|
118 |
+
triplets: str,
|
119 |
+
question: str = None,
|
120 |
+
temperature=0.6,
|
121 |
+
top_p=0.5,
|
122 |
+
top_k=0,
|
123 |
+
repetition_penalty=1.08
|
124 |
+
) -> str:
|
125 |
+
curr_input = prepare_input(triplets, question)
|
126 |
+
if question:
|
127 |
+
instruction = make_prompt(curr_input, instruction_with_q, history_with_q)
|
128 |
+
else:
|
129 |
+
instruction = make_prompt(curr_input, instruction_wo_q, history_wo_q)
|
130 |
+
|
131 |
+
input_ids = tokenizer(instruction, return_tensors="pt").input_ids
|
132 |
+
input_ids = input_ids.to(model.device)
|
133 |
+
|
134 |
+
generate_kwargs = dict(
|
135 |
+
input_ids=input_ids,
|
136 |
+
max_new_tokens=100,
|
137 |
+
temperature=temperature,
|
138 |
+
do_sample=temperature>0.0,
|
139 |
+
top_p=top_p,
|
140 |
+
top_k=top_k,
|
141 |
+
repetition_penalty=repetition_penalty,
|
142 |
+
stopping_criteria=StoppingCriteriaList([stop]),
|
143 |
+
)
|
144 |
+
|
145 |
+
with torch.no_grad():
|
146 |
+
outputs = model.generate(generate_kwargs)
|
147 |
+
|
148 |
+
response = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
149 |
+
for tok in tokenizer.additional_special_tokens+[tokenizer.eos_token]:
|
150 |
+
instruction = instruction.replace(tok, '')
|
151 |
+
response = response[len(instruction):]
|
152 |
+
return response
|
153 |
+
|
154 |
+
|
155 |
+
# Gradio UI Code
|
156 |
+
with gr.Blocks(theme='gradio/soft') as demo:
|
157 |
+
# Elements stack vertically by default just define elements in order you want them to stack
|
158 |
+
header = gr.HTML("""
|
159 |
+
<h1 style="text-align: center">RDF to text Vicuna Demo</h1>
|
160 |
+
<h3 style="text-align: center"> Generate natural language verbalizations from RDF triplets </h3>
|
161 |
+
<br>
|
162 |
+
<p style="font-size: 12px; text-align: center">⚠️ Takes approximately 15-30s to generate.</p>
|
163 |
+
""")
|
164 |
+
|
165 |
+
output_box = gr.Textbox(lines=2, interactive=False)
|
166 |
+
|
167 |
+
triplets = gr.Textbox(lines=3, placeholder="('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');", label='Triplets')
|
168 |
+
question = gr.Textbox(lines=4, placeholder='Write a question here, if you want to generate answer based on question.', label='Question')
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
run_button = gr.Button("Generate", variant="primary")
|
172 |
+
clear_button = gr.ClearButton(variant="secondary")
|
173 |
+
|
174 |
+
with gr.Accordion("Options", open=False):
|
175 |
+
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
|
176 |
+
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
|
177 |
+
top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
|
178 |
+
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
|
179 |
+
|
180 |
+
info = gr.HTML(f"""
|
181 |
+
<p>🌐 Leveraging the <a href='https://huggingface.co/lmsys/vicuna-13b-v1.3'><strong>Vicuna model</strong></a> with int4 quantization.</p>
|
182 |
+
""")
|
183 |
+
|
184 |
+
|
185 |
+
examples = gr.Examples([
|
186 |
+
['("Google Videos", "developer", "Google"), ("Google Web Toolkit", "author", "Google")', ""],
|
187 |
+
['("Katyayana", "religion", "Buddhism")', "What is the relegious affiliations of Katyayana?"],
|
188 |
+
], inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], fn=generate, cache_examples=False if platform.system() == "Windows" or platform.system() == "Darwin" else True, outputs=output_box)
|
189 |
+
|
190 |
+
|
191 |
+
#readme_content = requests.get(f"https://huggingface.co/HF_MODEL_PATH/raw/main/README.md").text
|
192 |
+
#readme_content = re.sub('---.*?---', '', readme_content, flags=re.DOTALL) #Remove YAML front matter
|
193 |
+
|
194 |
+
#with gr.Accordion("📖 Model Readme", open=True):
|
195 |
+
# readme = gr.Markdown(
|
196 |
+
# readme_content,
|
197 |
+
# )
|
198 |
+
|
199 |
+
run_button.click(fn=generate, inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="rdf2text")
|
200 |
+
clear_button.add([triplets, question, output_box])
|
201 |
+
|
202 |
+
demo.queue(concurrency_count=1, max_size=10).launch(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.20.3
|
2 |
+
bitsandbytes==0.39.0
|
3 |
+
sentencepiece==0.1.99
|
4 |
+
torch==2.0.1
|
5 |
+
torchaudio==0.12.1+cu113
|
6 |
+
torchvision==0.13.1+cu113
|