File size: 4,944 Bytes
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac7b6a
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e5383
7261d63
 
 
 
 
 
 
 
 
53be4fc
7261d63
53be4fc
 
 
 
 
7261d63
 
13ea389
7261d63
 
 
53be4fc
 
7261d63
 
 
 
 
 
cf8bf4d
 
 
53be4fc
cf8bf4d
 
 
 
7261d63
 
 
 
 
 
53be4fc
 
7261d63
 
53be4fc
7261d63
53be4fc
 
 
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53be4fc
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import torch
import time
from threading import Thread
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer
)


# App title
st.set_page_config(page_title="😶‍🌫️ FuseChat Model")

root_path = "FuseAI"
model_name = "FuseChat-Qwen-2.5-7B-Instruct"

@st.cache_resource
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        f"{root_path}/{model_name}",
        trust_remote_code=True,
    )

    if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.pad_token_id = 0

    model = AutoModelForCausalLM.from_pretrained(
        f"{root_path}/{model_name}",
        device_map="auto",
        load_in_4bit=True,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )

    model.eval()
    return model, tokenizer


with st.sidebar:
    st.title('😶‍🌫️ FuseChat-3.0')
    st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
    temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=1.0, value=0.7, step=0.01)
    top_p = st.sidebar.slider('top_p', min_value=0.1, max_value=1.0, value=0.8, step=0.05)
    top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=20, step=1)
    repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1.0, max_value=2.0, value=1.05, step=0.05)
    max_length = st.sidebar.slider('max_length', min_value=32, max_value=4096, value=2048, step=8)

with st.spinner('loading model..'):
    model, tokenizer = load_model(model_name)

# Store LLM generated responses
if "messages" not in st.session_state.keys():
    # st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
    st.session_state.messages = []

# Display or clear chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

def set_query(query):
    st.session_state.messages.append({"role": "user", "content": query})
# Create a list of candidate questions
candidate_questions = ["Is boiling water (100°C) an obtuse angle (larger than 90 degrees)?", "Write a quicksort code in Python.", "笼子里有好几只鸡和兔子。笼子里有72个头,200只腿。里面有多少只鸡和兔子"]
# Display the chat interface with a list of clickable question buttons
for question in candidate_questions:
    st.sidebar.button(label=question, on_click=set_query, args=[question])

def clear_chat_history():
    st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)


def generate_fusechat_response():
    conversations=[]
    conversations.append({"role": "system", "content": "You are FuseChat-3.0, created by Sun Yat-sen University. You are a helpful assistant."})
    for dict_message in st.session_state.messages:
        if dict_message["role"] == "user":
            conversations.append({"role": "user", "content": dict_message["content"]})
        else:
            conversations.append({"role": "assistant", "content": dict_message["content"]})
    string_dialogue = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer(string_dialogue, return_tensors="pt").input_ids.to('cuda')
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_length,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
    return "".join(outputs)

# User-provided prompt
if prompt := st.chat_input("Do androids dream of electric sheep?"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)

# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            response = generate_fusechat_response()
            placeholder = st.empty()
            full_response = ''
            for item in response:
                full_response += item
                time.sleep(0.05)
                placeholder.markdown(full_response + "▌")
            placeholder.markdown(full_response)
    message = {"role": "assistant", "content": full_response}
    st.session_state.messages.append(message)