File size: 4,191 Bytes
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ea389
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ea389
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ea389
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ea389
 
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import torch
import time
from threading import Thread
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer
)


# App title
st.set_page_config(page_title="😶‍🌫️ FuseChat Model")

root_path = "FuseAI"
model_name = "FuseChat-7B-VaRM"

@st.cache_resource
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        f"{root_path}/{model_name}",
        trust_remote_code=True,
    )

    if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.pad_token_id = 0

    model = AutoModelForCausalLM.from_pretrained(
        f"{root_path}/{model_name}",
        device_map="auto",
        load_in_8bit=True,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )

    model.eval()
    return model, tokenizer


with st.sidebar:
    st.title('😶‍🌫️ FuseChat')
    st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
    temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
    top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
    top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
    repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1., max_value=2., value=1.2, step=0.05)
    max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=240, step=8)

with st.spinner('loading model..'):
    model, tokenizer = load_model(model_name)

# Store LLM generated responses
if "messages" not in st.session_state.keys():
    st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]

# Display or clear chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

def clear_chat_history():
    st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)


def generate_fusechat_response():
    # string_dialogue = "You are a helpful and harmless assistant."
    string_dialogue = ""
    for dict_message in st.session_state.messages:
        if dict_message["role"] == "user":
            string_dialogue += "GPT4 Correct User: " + dict_message["content"] + "<|end_of_turn|>"
        else:
            string_dialogue += "GPT4 Correct Assistant: " + dict_message["content"] + "<|end_of_turn|>"

    input_ids = tokenizer(f"{string_dialogue}GPT4 Correct Assistant: ", return_tensors="pt").input_ids
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_length,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
    return "".join(outputs)

# User-provided prompt
if prompt := st.chat_input("Hello there! How are you doing?"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)

# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            response = generate_fusechat_response()
            placeholder = st.empty()
            full_response = ''
            for item in response:
                full_response += item
                time.sleep(0.05)
                placeholder.markdown(full_response + "▌")
            placeholder.markdown(full_response)
    message = {"role": "assistant", "content": full_response}
    st.session_state.messages.append(message)