tonic
commited on
Commit
•
7ce9ee7
1
Parent(s):
fa05361
initial commit
Browse files- README.md +2 -2
- app.py +76 -0
- requirements.txt +4 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
+
title: TigerAI-StructLM
|
3 |
+
emoji: 🐯📏
|
4 |
colorFrom: purple
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
import html
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
6 |
+
from threading import Thread
|
7 |
+
import gradio as gr
|
8 |
+
from gradio_rich_textbox import RichTextbox
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
title = """# 🙋🏻♂️Welcome to🌟Tonic's🐯📏TigerAI-StructLM-7B
|
13 |
+
StructLM, is a series of open-source large language models (LLMs) finetuned for structured knowledge grounding (SKG) tasks. You can build with this endpoint using 🐯📏TigerAI-StructLM available here : [TIGER-Lab/StructLM-7B](https://huggingface.co/TIGER-Lab/StructLM-7B).
|
14 |
+
You can also use 🐯📏TigerAI-StructLM by cloning this space. Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/TigerLM?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></h3>
|
15 |
+
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) Math with [introspector](https://huggingface.co/introspector) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [SciTonic](https://github.com/Tonic-AI/scitonic)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
|
16 |
+
"""
|
17 |
+
assistant_message = """Use the information in the following table to solve the problem, choose between the choices if they are provided. table:"""
|
18 |
+
system_message = "You are an AI assistant that specializes in analyzing and reasoning over structured information. You will be given a task, optionally with some structured knowledge input. Your answer must strictly adhere to the output format, if specified."
|
19 |
+
tabular_data = "col : day | kilometers row 1 : tuesday | 0 row 2 : wednesday | 0 row 3 : thursday | 4 row 4 : friday | 0 row 5 : saturday | 0"
|
20 |
+
user_message = "Allie kept track of how many kilometers she walked during the past 5 days. What is the range of the numbers?"
|
21 |
+
model_name = 'TIGER-Lab/StructLM-7B'
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
24 |
+
# model.generation_config = GenerationConfig.from_pretrained(model_name)
|
25 |
+
# model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
@spaces.GPU
|
29 |
+
def predict_math_bot(user_message, system_message="", assistant_message = "", tabular_data = "", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False):
|
30 |
+
prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
|
31 |
+
inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
|
32 |
+
input_ids = inputs["input_ids"].to(model.device)
|
33 |
+
|
34 |
+
output_ids = model.generate(
|
35 |
+
input_ids,
|
36 |
+
max_length=input_ids.shape[1] + max_new_tokens,
|
37 |
+
temperature=temperature,
|
38 |
+
top_p=top_p,
|
39 |
+
repetition_penalty=repetition_penalty,
|
40 |
+
pad_token_id=tokenizer.eos_token_id,
|
41 |
+
do_sample=do_sample
|
42 |
+
)
|
43 |
+
|
44 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
45 |
+
return response
|
46 |
+
|
47 |
+
def main():
|
48 |
+
with gr.Blocks() as demo:
|
49 |
+
gr.Markdown(title)
|
50 |
+
with gr.Row():
|
51 |
+
system_message = gr.Textbox(label="📉System Prompt", placeholder=system_message)
|
52 |
+
assistant_message = gr.Textbox(label="Assistant Message", placeholder=assistant_message)
|
53 |
+
tabular_data = gr.Textbox(label="Tabular Data", placeholder=tabular_data)
|
54 |
+
user_message = gr.Textbox(label="🫡Enter your query here...", placeholder=user_message)
|
55 |
+
|
56 |
+
|
57 |
+
with gr.Accordion("Advanced Settings"):
|
58 |
+
with gr.Row():
|
59 |
+
max_new_tokens = gr.Slider(label="Max new tokens", value=125, minimum=25, maximum=1250)
|
60 |
+
temperature = gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0)
|
61 |
+
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99)
|
62 |
+
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0)
|
63 |
+
do_sample = gr.Checkbox(label="Do sample", value=False)
|
64 |
+
|
65 |
+
output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B", interactive=True)
|
66 |
+
|
67 |
+
gr.Button("Try🫡📉MetaMath").click(
|
68 |
+
predict_math_bot,
|
69 |
+
inputs=[user_message, system_message, assistant_message, tabular_data, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
|
70 |
+
outputs=output_text
|
71 |
+
)
|
72 |
+
|
73 |
+
demo.launch()
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
accelerate
|
4 |
+
bitsandbytes
|