arborvitae commited on
Commit
dec5052
1 Parent(s): 22d90fb

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +63 -0
model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+ from llama_cpp import Llama
3
+ from huggingface_hub import hf_hub_download
4
+
5
+
6
+ def download_model():
7
+ # See https://github.com/OpenAccess-AI-Collective/ggml-webui/blob/main/tabbed.py
8
+ # https://huggingface.co/spaces/kat33/llama.cpp/blob/main/app.py
9
+ print(f"Downloading model: {model_repo}/{model_filename}")
10
+ file = hf_hub_download(
11
+ repo_id=model_repo, filename=model_filename
12
+ )
13
+ print("Downloaded " + file)
14
+ return file
15
+
16
+ model_repo = "TheBloke/CodeLlama-7B-Instruct-GGUF"
17
+ model_filename = "codellama-7b-instruct.Q4_K_S.gguf"
18
+
19
+ model_path = download_model()
20
+
21
+ # load Llama-2
22
+ llm = Llama(model_path=model_path, n_ctx=4000, verbose=False)
23
+
24
+
25
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
26
+ system_prompt: str) -> str:
27
+ texts = [f'[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
28
+ for user_input, response in chat_history:
29
+ texts.append(f'{user_input.strip()} [/INST] {response.strip()} </s><s> [INST] ')
30
+ texts.append(f'{message.strip()} [/INST]')
31
+ return ''.join(texts)
32
+
33
+ def generate(prompt, max_new_tokens, temperature, top_p, top_k):
34
+ return llm(prompt,
35
+ max_tokens=max_new_tokens,
36
+ stop=["</s>"],
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ top_k=top_k,
40
+ stream=False)
41
+
42
+
43
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
44
+ prompt = get_prompt(message, chat_history, system_prompt)
45
+ input_ids = llm.tokenize(prompt.encode('utf-8'))
46
+ return len(input_ids)
47
+
48
+
49
+ def run(message: str,
50
+ chat_history: list[tuple[str, str]],
51
+ system_prompt: str,
52
+ max_new_tokens: int = 1024,
53
+ temperature: float = 0.8,
54
+ top_p: float = 0.95,
55
+ top_k: int = 50) -> Iterator[str]:
56
+ prompt = get_prompt(message, chat_history, system_prompt)
57
+ output = generate(prompt, max_new_tokens, temperature, top_p, top_k)
58
+ yield output['choices'][0]['text']
59
+
60
+ # outputs = []
61
+ # for resp in streamer:
62
+ # outputs.append(resp['choices'][0]['text'])
63
+ # yield ''.join(outputs)