O S I H commited on
Commit
6fcc1b4
1 Parent(s): d64d25a

uploadfiles

Browse files
Files changed (2) hide show
  1. app.py +116 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ TextIteratorStreamer,
7
+ )
8
+ import os
9
+ from threading import Thread
10
+ import spaces
11
+ import time
12
+ import subprocess
13
+
14
+ subprocess.run(
15
+ "pip install flash-attn --no-build-isolation",
16
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
+ shell=True,
18
+ )
19
+
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "microsoft/Phi-3-medium-4k-instruct",
23
+ torch_dtype="auto",
24
+ trust_remote_code=True,
25
+ )
26
+ tok = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-4k-instruct",trust_remote_code=True,)
27
+ terminators = [
28
+ tok.eos_token_id,
29
+ ]
30
+
31
+ if torch.cuda.is_available():
32
+ device = torch.device("cuda")
33
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
34
+ else:
35
+ device = torch.device("cpu")
36
+ print("Using CPU")
37
+
38
+ model = model.to(device)
39
+
40
+
41
+ @spaces.GPU(duration=60)
42
+ def chat(message, history,system_prompt, temperature, do_sample, max_tokens, top_k, repetition_penalty, top_p):
43
+ chat = [
44
+ {"role": "assistant", "content": system_prompt}
45
+ ]
46
+ for item in history:
47
+ chat.append({"role": "user", "content": item[0]})
48
+ if item[1] is not None:
49
+ chat.append({"role": "assistant", "content": item[1]})
50
+ chat.append({"role": "user", "content": message})
51
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
52
+ model_inputs = tok([messages], return_tensors="pt").to(device)
53
+ streamer = TextIteratorStreamer(
54
+ tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
55
+ )
56
+ generate_kwargs = dict(
57
+ model_inputs,
58
+ streamer=streamer,
59
+ max_new_tokens=max_tokens,
60
+ do_sample=True,
61
+ temperature=temperature,
62
+ eos_token_id=terminators,
63
+ top_k=top_k,
64
+ repetition_penalty=repetition_penalty,
65
+ top_p=top_p
66
+ )
67
+
68
+ if temperature == 0:
69
+ generate_kwargs["do_sample"] = False
70
+
71
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ t.start()
73
+
74
+ partial_text = ""
75
+ for new_text in streamer:
76
+ partial_text += new_text
77
+ yield partial_text
78
+
79
+ yield partial_text
80
+
81
+
82
+ demo = gr.ChatInterface(
83
+ fn=chat,
84
+ examples=[["Write me a poem about Machine Learning."],
85
+ ["write fibonacci sequence in python"],
86
+ ["who won the world cup in 2018?"],
87
+ ["when was the first computer invented?"],
88
+ ],
89
+ additional_inputs_accordion=gr.Accordion(
90
+ label="⚙️ Parameters", open=False, render=False
91
+ ),
92
+ additional_inputs=[
93
+ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
94
+ gr.Slider(
95
+ minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
96
+ ),
97
+ gr.Checkbox(label="Sampling", value=True),
98
+ gr.Slider(
99
+ minimum=128,
100
+ maximum=4096,
101
+ step=1,
102
+ value=512,
103
+ label="Max new tokens",
104
+ render=False,
105
+ ),
106
+ gr.Slider(1, 80, 40, label="Top K sampling"),
107
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
108
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
109
+ ],
110
+ stop_btn="Stop Generation",
111
+ title="Chat With Phi-3-medium-4k-instruct",
112
+ description="[microsoft/Phi-3-medium-4k-instruct](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct)",
113
+ css="footer {visibility: hidden}",
114
+ theme="NoCrypt/miku@1.2.1",
115
+ )
116
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tiktoken
2
+ gradio
3
+ spaces
4
+ torch==2.2.0
5
+ git+https://github.com/huggingface/transformers/
6
+ optimum
7
+ accelerate
8
+ bitsandbytes
9
+ pytest