thobuiq commited on
Commit
d15855c
1 Parent(s): 0610b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -63
app.py CHANGED
@@ -1,66 +1,35 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
- from threading import Thread
6
-
7
- # Loading the tokenizer and model from Hugging Face's model hub.
8
- tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
9
- model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
10
-
11
- # using CUDA for an optimal experience
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
- model = model.to(device)
14
-
15
-
16
- # Defining a custom stopping criteria class for the model's text generation.
17
- class StopOnTokens(StoppingCriteria):
18
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
19
- stop_ids = [2] # IDs of tokens where the generation should stop.
20
- for stop_id in stop_ids:
21
- if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
22
- return True
23
- return False
24
-
25
-
26
-
27
- # Function to generate model predictions.
28
- def predict(message, history):
29
- history_transformer_format = history + [[message, ""]]
30
- stop = StopOnTokens()
31
-
32
-
33
- # Formatting the input for the model.
34
- messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
35
- for item in history_transformer_format])
36
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
37
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
38
- generate_kwargs = dict(
39
- model_inputs,
40
- streamer=streamer,
41
- max_new_tokens=1024,
42
- do_sample=True,
43
- top_p=0.95,
44
- top_k=50,
45
- temperature=0.7,
46
- num_beams=1,
47
- stopping_criteria=StoppingCriteriaList([stop])
48
  )
49
- t = Thread(target=model.generate, kwargs=generate_kwargs)
50
- t.start() # Starting the generation in a separate thread.
51
- partial_message = ""
52
- for new_token in streamer:
53
- partial_message += new_token
54
- if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
55
- break
56
- yield partial_message
57
-
58
-
59
 
 
 
 
60
 
61
- # Setting up the Gradio chat interface.
62
- gr.ChatInterface(predict,
63
- title="Tinyllama_chatBot",
64
- description="Ask Tiny llama any questions",
65
- examples=['How to cook a fish?', 'Who is the president of US now?']
66
- ).launch() # Launching the web interface.
 
1
+ import os
2
+ import chainlit as cl
3
+ from ctransformers import AutoModelForCausalLM
4
+
5
+ # Runs when the chat starts
6
+ @cl.on_chat_start
7
+ def main():
8
+ # Create the llm
9
+ llm = AutoModelForCausalLM.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
10
+ model_file="mistral-7b-instruct-v0.1.Q4_K_M.gguf",
11
+ model_type="mistral",
12
+ temperature=0.7,
13
+ gpu_layers=0,
14
+ stream=True,
15
+ threads=int(os.cpu_count() / 2),
16
+ max_new_tokens=10000)
17
+
18
+ # Store the llm in the user session
19
+ cl.user_session.set("llm", llm)
20
+
21
+ # Runs when a message is sent
22
+ @cl.on_message
23
+ async def main(message: cl.Message):
24
+ # Retrieve the chain from the user session
25
+ llm = cl.user_session.get("llm")
26
+
27
+ msg = cl.Message(
28
+ content="",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
 
 
 
 
 
 
 
 
 
 
30
 
31
+ prompt = f"[INST]{message.content}[/INST]"
32
+ for text in llm(prompt=prompt):
33
+ await msg.stream_token(text)
34
 
35
+ await msg.send()