hayas commited on
Commit
ec4e41c
1 Parent(s): 2bb99ea
Files changed (6) hide show
  1. .pre-commit-config.yaml +55 -0
  2. .vscode/settings.json +21 -0
  3. README.md +6 -4
  4. app.py +144 -0
  5. requirements.txt +9 -0
  6. style.css +16 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.12.0
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.6.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ ["types-python-slugify", "types-requests", "types-PyYAML"]
33
+ - repo: https://github.com/psf/black
34
+ rev: 23.10.1
35
+ hooks:
36
+ - id: black
37
+ language_version: python3.10
38
+ args: ["--line-length", "119"]
39
+ - repo: https://github.com/kynan/nbstripout
40
+ rev: 0.6.1
41
+ hooks:
42
+ - id: nbstripout
43
+ args:
44
+ [
45
+ "--extra-keys",
46
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
47
+ ]
48
+ - repo: https://github.com/nbQA-dev/nbQA
49
+ rev: 1.7.0
50
+ hooks:
51
+ - id: nbqa-black
52
+ - id: nbqa-pyupgrade
53
+ args: ["--py37-plus"]
54
+ - id: nbqa-isort
55
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
+ "editor.formatOnType": true,
5
+ "editor.codeActionsOnSave": {
6
+ "source.organizeImports": true
7
+ }
8
+ },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
+ "editor.formatOnSave": true,
20
+ "files.insertFinalNewline": true
21
+ }
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Rinna Youri 7b
3
- emoji: 🦀
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: rinna Youri-7B
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ suggested-hardware: t4-small
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# Youri-7B"
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "rinna/youri-7b-chat"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+
26
+
27
+ def apply_chat_template(conversation: list[dict[str, str]]) -> str:
28
+ prompt = "\n".join([f"{c['role']}: {c['content']}" for c in conversation])
29
+ prompt = f"{prompt}\nシステム: "
30
+ return prompt
31
+
32
+
33
+ @spaces.GPU
34
+ @torch.inference_mode()
35
+ def generate(
36
+ message: str,
37
+ chat_history: list[tuple[str, str]],
38
+ system_prompt: str = "",
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.7,
41
+ top_p: float = 0.95,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.0,
44
+ ) -> Iterator[str]:
45
+ conversation = []
46
+ if system_prompt:
47
+ conversation.append({"role": "設定", "content": system_prompt})
48
+ for user, assistant in chat_history:
49
+ conversation.extend([{"role": "ユーザー", "content": user}, {"role": "システム", "content": assistant}])
50
+ conversation.append({"role": "ユーザー", "content": message})
51
+
52
+ prompt = apply_chat_template(conversation)
53
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
54
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
55
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
56
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
57
+ input_ids = input_ids.to(model.device)
58
+
59
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
60
+ generate_kwargs = dict(
61
+ {"input_ids": input_ids},
62
+ streamer=streamer,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ temperature=temperature,
68
+ num_beams=1,
69
+ repetition_penalty=repetition_penalty,
70
+ )
71
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ t.start()
73
+
74
+ outputs = []
75
+ for text in streamer:
76
+ outputs.append(text)
77
+ yield "".join(outputs)
78
+
79
+
80
+ chat_interface = gr.ChatInterface(
81
+ fn=generate,
82
+ chatbot=gr.Chatbot(show_label=False, layout="panel", height=600),
83
+ additional_inputs_accordion_name="詳細設定",
84
+ additional_inputs=[
85
+ gr.Textbox(
86
+ label="System prompt",
87
+ lines=6,
88
+ ),
89
+ gr.Slider(
90
+ label="Max new tokens",
91
+ minimum=1,
92
+ maximum=MAX_MAX_NEW_TOKENS,
93
+ step=1,
94
+ value=DEFAULT_MAX_NEW_TOKENS,
95
+ ),
96
+ gr.Slider(
97
+ label="Temperature",
98
+ minimum=0.1,
99
+ maximum=4.0,
100
+ step=0.1,
101
+ value=0.7,
102
+ ),
103
+ gr.Slider(
104
+ label="Top-p (nucleus sampling)",
105
+ minimum=0.05,
106
+ maximum=1.0,
107
+ step=0.05,
108
+ value=0.95,
109
+ ),
110
+ gr.Slider(
111
+ label="Top-k",
112
+ minimum=1,
113
+ maximum=1000,
114
+ step=1,
115
+ value=50,
116
+ ),
117
+ gr.Slider(
118
+ label="Repetition penalty",
119
+ minimum=1.0,
120
+ maximum=2.0,
121
+ step=0.05,
122
+ value=1.0,
123
+ ),
124
+ ],
125
+ stop_btn=None,
126
+ examples=[
127
+ ["東京の観光名所を教えて。"],
128
+ ["落武者って何?"],
129
+ ["暴れん坊将軍って誰のこと?"],
130
+ ["人がヘリを食べるのにかかる時間は?"],
131
+ ],
132
+ )
133
+
134
+ with gr.Blocks(css="style.css") as demo:
135
+ gr.Markdown(DESCRIPTION)
136
+ gr.DuplicateButton(
137
+ value="Duplicate Space for private use",
138
+ elem_id="duplicate-button",
139
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
140
+ )
141
+ chat_interface.render()
142
+
143
+ if __name__ == "__main__":
144
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.24.1
2
+ bitsandbytes==0.41.1
3
+ gradio==3.50.2
4
+ protobuf==3.20.3
5
+ scipy==1.11.3
6
+ sentencepiece==0.1.99
7
+ spaces==0.18.0
8
+ torch==2.0.0
9
+ transformers==4.34.1
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ .contain {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }