jdoexbox360 commited on
Commit
abcb0d0
0 Parent(s):

Duplicate from jdoexbox360/chagpt-2-convo-finally

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +100 -0
  4. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stablelm Tuned Alpha Chat
3
+ emoji: 👀
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: jdoexbox360/chagpt-2-convo-finally
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import tensorflow as tf
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, GPT2LMHeadModel, GPT2Tokenizer
5
+ import time
6
+ import numpy as np
7
+ from torch.nn import functional as F
8
+ import os
9
+ from threading import Thread
10
+
11
+ print(f"Starting to load the model to memory")
12
+
13
+ tok = GPT2Tokenizer.from_pretrained("ethzanalytics/ai-msgbot-gpt2-XL-dialogue")
14
+ m = GPT2LMHeadModel.from_pretrained("ethzanalytics/ai-msgbot-gpt2-XL-dialogue", pad_token_id=tok.eos_token_id)
15
+ generator = pipeline('text-generation', model=m, tokenizer=tok)
16
+ print(f"Sucessfully loaded the model to the memory")
17
+
18
+ start_message = """You are an AI called assistant."""
19
+
20
+
21
+ class StopOnTokens(StoppingCriteria):
22
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
23
+ stop_ids = [50278, 50279, 50277, 1, 0]
24
+ for stop_id in stop_ids:
25
+ if input_ids[0][-1] == stop_id:
26
+ return True
27
+ return False
28
+
29
+
30
+ def user(message, history):
31
+ # Append the user's message to the conversation history
32
+ return "", history + [[message, ""]]
33
+
34
+
35
+ def chat(curr_system_message, history):
36
+ # Initialize a StopOnTokens object
37
+ stop = StopOnTokens()
38
+
39
+ # Construct the input message string for the model by concatenating the current system message and conversation history
40
+ messages = curr_system_message + \
41
+ "".join(["".join(["\nperson alpha:"+item[0], "\nperson beta:"+item[1]])
42
+ for item in history])
43
+
44
+ # Tokenize the messages string
45
+ model_inputs = tok([messages], return_tensors="pt")
46
+ streamer = TextIteratorStreamer(
47
+ tok, skip_prompt=True, skip_special_tokens=True)
48
+ generate_kwargs = dict(
49
+ model_inputs,
50
+ streamer=streamer,
51
+ max_new_tokens=1024,
52
+ do_sample=True,
53
+ top_p=0.95,
54
+ top_k=1000,
55
+ temperature=1.0,
56
+ num_beams=1,
57
+ stopping_criteria=StoppingCriteriaList([stop])
58
+ )
59
+ t = Thread(target=m.generate, kwargs=generate_kwargs)
60
+ t.start()
61
+
62
+ # print(history)
63
+ # Initialize an empty string to store the generated text
64
+ partial_text = ""
65
+ for new_text in streamer:
66
+ # print(new_text)
67
+ partial_text += new_text
68
+ history[-1][1] = partial_text
69
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
70
+ yield history
71
+ return partial_text
72
+
73
+
74
+ with gr.Blocks() as demo:
75
+ # history = gr.State([])
76
+ gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
77
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
78
+ chatbot = gr.Chatbot().style(height=500)
79
+ with gr.Row():
80
+ with gr.Column():
81
+ msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box",
82
+ show_label=False).style(container=False)
83
+ with gr.Column():
84
+ with gr.Row():
85
+ submit = gr.Button("Submit")
86
+ stop = gr.Button("Stop")
87
+ clear = gr.Button("Clear")
88
+ system_msg = gr.Textbox(
89
+ start_message, label="System Message", interactive=False, visible=False)
90
+
91
+ submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
92
+ fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True)
93
+ submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
94
+ fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True)
95
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[
96
+ submit_event, submit_click_event], queue=False)
97
+ clear.click(lambda: None, None, [chatbot], queue=False)
98
+
99
+ demo.queue(max_size=32, concurrency_count=2)
100
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tensorflow
3
+ torch
4
+ torchvision
5
+ torchaudio
6
+ transformers
7
+ numpy
8
+