Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .env +19 -0
- .gitattributes +1 -0
- .github/workflows/branch.yml +60 -0
- .github/workflows/release.yml +30 -0
- .gitignore +10 -0
- CONTRIBUTING.md +90 -0
- LICENSE +21 -0
- README.md +379 -8
- app.py +433 -0
- benchmark.py +145 -0
- code_completion.py +216 -0
- colab/Llama_2_7b_Chat_GPTQ.ipynb +0 -0
- colab/ggmlv3_q4_0.ipynb +109 -0
- colab/webui_CodeLlama_7B_Instruct_GPTQ.ipynb +514 -0
- docs/issues.md +0 -0
- docs/news.md +38 -0
- docs/performance.md +32 -0
- docs/pypi.md +187 -0
- env_examples/.env.13b_example +13 -0
- env_examples/.env.7b_8bit_example +13 -0
- env_examples/.env.7b_ggmlv3_q4_0_example +18 -0
- env_examples/.env.7b_gptq_example +18 -0
- llama2_cu_python/Makefile +9 -0
- llama2_cu_python/__init__.py +3 -0
- llama2_cu_python/libllama2.so +3 -0
- llama2_cu_python/llama2.cu +1394 -0
- llama2_cu_python/llama2.h +23 -0
- llama2_cu_python/llama2_cu.py +151 -0
- llama2_wrapper/__init__.py +1 -0
- llama2_wrapper/download/__init__.py +0 -0
- llama2_wrapper/download/__main__.py +59 -0
- llama2_wrapper/model.py +839 -0
- llama2_wrapper/server/__init__.py +0 -0
- llama2_wrapper/server/__main__.py +46 -0
- llama2_wrapper/server/app.py +526 -0
- llama2_wrapper/types.py +115 -0
- poetry.lock +0 -0
- prompts/prompts_en.csv +0 -0
- prompts/utils.py +48 -0
- pyproject.toml +47 -0
- requirements.txt +21 -0
- static/screenshot.png +0 -0
- tests/__init__.py +0 -0
- tests/test_get_prompt.py +59 -0
.env
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_PATH = ""
|
2 |
+
# if MODEL_PATH is "", default llama.cpp/gptq models
|
3 |
+
# will be downloaded to: ./models
|
4 |
+
|
5 |
+
# Example ggml path:
|
6 |
+
MODEL_PATH = "models/llama2_7b_chat.bin"
|
7 |
+
|
8 |
+
# options: llama.cpp, gptq, transformers
|
9 |
+
#BACKEND_TYPE = "llama.cpp"
|
10 |
+
BACKEND_TYPE = "llama2.cu"
|
11 |
+
|
12 |
+
# only for transformers bitsandbytes 8 bit
|
13 |
+
LOAD_IN_8BIT = False
|
14 |
+
|
15 |
+
MAX_MAX_NEW_TOKENS = 2048
|
16 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
17 |
+
MAX_INPUT_TOKEN_LENGTH = 4000
|
18 |
+
|
19 |
+
DEFAULT_SYSTEM_PROMPT = ""
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
llama2_cu_python/libllama2.so filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/branch.yml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Push
|
2 |
+
on: [push]
|
3 |
+
|
4 |
+
jobs:
|
5 |
+
test:
|
6 |
+
strategy:
|
7 |
+
fail-fast: false
|
8 |
+
matrix:
|
9 |
+
python-version: ['3.10']
|
10 |
+
poetry-version: ['1.5.1']
|
11 |
+
os: [ubuntu-latest]
|
12 |
+
runs-on: ${{ matrix.os }}
|
13 |
+
steps:
|
14 |
+
- uses: actions/checkout@v3
|
15 |
+
- uses: actions/setup-python@v3
|
16 |
+
with:
|
17 |
+
python-version: ${{ matrix.python-version }}
|
18 |
+
- name: Run image
|
19 |
+
uses: abatilo/actions-poetry@v2.1.4
|
20 |
+
with:
|
21 |
+
poetry-version: ${{ matrix.poetry-version }}
|
22 |
+
- name: Install dependencies
|
23 |
+
run: poetry install
|
24 |
+
- name: Run tests
|
25 |
+
run: poetry run pytest
|
26 |
+
- name: Upload coverage reports to Codecov
|
27 |
+
uses: codecov/codecov-action@v3
|
28 |
+
env:
|
29 |
+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
30 |
+
# - name: Upload coverage to Codecov
|
31 |
+
# uses: codecov/codecov-action@v2
|
32 |
+
code-quality:
|
33 |
+
strategy:
|
34 |
+
fail-fast: false
|
35 |
+
matrix:
|
36 |
+
python-version: ['3.10']
|
37 |
+
poetry-version: ['1.5.1']
|
38 |
+
os: [ubuntu-latest]
|
39 |
+
runs-on: ${{ matrix.os }}
|
40 |
+
steps:
|
41 |
+
- uses: actions/checkout@v3
|
42 |
+
- uses: actions/setup-python@v3
|
43 |
+
with:
|
44 |
+
python-version: ${{ matrix.python-version }}
|
45 |
+
- name: Python Poetry Action
|
46 |
+
uses: abatilo/actions-poetry@v2.1.6
|
47 |
+
with:
|
48 |
+
poetry-version: ${{ matrix.poetry-version }}
|
49 |
+
- name: Install dependencies
|
50 |
+
run: poetry install
|
51 |
+
- name: Run black
|
52 |
+
run: poetry run black . --check
|
53 |
+
# - name: Run isort
|
54 |
+
# run: poetry run isort . --check-only --profile black
|
55 |
+
# - name: Run flake8
|
56 |
+
# run: poetry run flake8 .
|
57 |
+
# - name: Run bandit
|
58 |
+
# run: poetry run bandit .
|
59 |
+
# - name: Run saftey
|
60 |
+
# run: poetry run safety check
|
.github/workflows/release.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release
|
2 |
+
on:
|
3 |
+
release:
|
4 |
+
types:
|
5 |
+
- created
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
publish:
|
9 |
+
strategy:
|
10 |
+
fail-fast: false
|
11 |
+
matrix:
|
12 |
+
python-version: ['3.10']
|
13 |
+
poetry-version: ['1.5.1']
|
14 |
+
os: [ubuntu-latest]
|
15 |
+
runs-on: ${{ matrix.os }}
|
16 |
+
steps:
|
17 |
+
- uses: actions/checkout@v3
|
18 |
+
- uses: actions/setup-python@v3
|
19 |
+
with:
|
20 |
+
python-version: ${{ matrix.python-version }}
|
21 |
+
- name: Run image
|
22 |
+
uses: abatilo/actions-poetry@v2.1.4
|
23 |
+
with:
|
24 |
+
poetry-version: ${{ matrix.poetry-version }}
|
25 |
+
- name: Publish
|
26 |
+
env:
|
27 |
+
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
28 |
+
run: |
|
29 |
+
poetry config pypi-token.pypi $PYPI_TOKEN
|
30 |
+
poetry publish --build
|
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models
|
2 |
+
dist
|
3 |
+
|
4 |
+
.DS_Store
|
5 |
+
.vscode
|
6 |
+
|
7 |
+
__pycache__
|
8 |
+
gradio_cached_examples
|
9 |
+
|
10 |
+
.pytest_cache
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to [llama2-webui](https://github.com/liltom-eth/llama2-webui)
|
2 |
+
|
3 |
+
We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's:
|
4 |
+
|
5 |
+
- Reporting a bug
|
6 |
+
- Proposing new features
|
7 |
+
- Discussing the current state of the code
|
8 |
+
- Update README.md
|
9 |
+
- Submitting a PR
|
10 |
+
|
11 |
+
## Using GitHub's [issues](https://github.com/liltom-eth/llama2-webui/issues)
|
12 |
+
|
13 |
+
We use GitHub issues to track public bugs. Report a bug by [opening a new issue](https://github.com/liltom-eth/llama2-webui/issues). It's that easy!
|
14 |
+
|
15 |
+
Thanks for **[jlb1504](https://github.com/jlb1504)** for reporting the [first issue](https://github.com/liltom-eth/llama2-webui/issues/1)!
|
16 |
+
|
17 |
+
**Great Bug Reports** tend to have:
|
18 |
+
|
19 |
+
- A quick summary and/or background
|
20 |
+
- Steps to reproduce
|
21 |
+
- Be specific!
|
22 |
+
- Give a sample code if you can.
|
23 |
+
- What you expected would happen
|
24 |
+
- What actually happens
|
25 |
+
- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
|
26 |
+
|
27 |
+
Proposing new features are also welcome.
|
28 |
+
|
29 |
+
## Pull Request
|
30 |
+
|
31 |
+
All pull requests are welcome. For example, you update the `README.md` to help users to better understand the usage.
|
32 |
+
|
33 |
+
### Clone the repository
|
34 |
+
|
35 |
+
1. Create a user account on GitHub if you do not already have one.
|
36 |
+
|
37 |
+
2. Fork the project [repository](https://github.com/liltom-eth/llama2-webui): click on the *Fork* button near the top of the page. This creates a copy of the code under your account on GitHub.
|
38 |
+
|
39 |
+
3. Clone this copy to your local disk:
|
40 |
+
|
41 |
+
```
|
42 |
+
git clone git@github.com:liltom-eth/llama2-webui.git
|
43 |
+
cd llama2-webui
|
44 |
+
```
|
45 |
+
|
46 |
+
### Implement your changes
|
47 |
+
|
48 |
+
1. Create a branch to hold your changes:
|
49 |
+
|
50 |
+
```
|
51 |
+
git checkout -b my-feature
|
52 |
+
```
|
53 |
+
|
54 |
+
and start making changes. Never work on the main branch!
|
55 |
+
|
56 |
+
2. Start your work on this branch.
|
57 |
+
|
58 |
+
3. When you’re done editing, do:
|
59 |
+
|
60 |
+
```
|
61 |
+
git add <MODIFIED FILES>
|
62 |
+
git commit
|
63 |
+
```
|
64 |
+
|
65 |
+
to record your changes in [git](https://git-scm.com/).
|
66 |
+
|
67 |
+
### Submit your contribution
|
68 |
+
|
69 |
+
1. If everything works fine, push your local branch to the remote server with:
|
70 |
+
|
71 |
+
```
|
72 |
+
git push -u origin my-feature
|
73 |
+
```
|
74 |
+
|
75 |
+
2. Go to the web page of your fork and click "Create pull request" to send your changes for review.
|
76 |
+
|
77 |
+
```{todo}
|
78 |
+
Find more detailed information in [creating a PR]. You might also want to open
|
79 |
+
the PR as a draft first and mark it as ready for review after the feedbacks
|
80 |
+
from the continuous integration (CI) system or any required fixes.
|
81 |
+
```
|
82 |
+
|
83 |
+
## License
|
84 |
+
|
85 |
+
By contributing, you agree that your contributions will be licensed under its MIT License.
|
86 |
+
|
87 |
+
## Questions?
|
88 |
+
|
89 |
+
Email us at [liltom.eth@gmail.com](mailto:liltom.eth@gmail.com)
|
90 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Tom
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,383 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 😻
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.11.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: llama2-webui
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 3.37.0
|
6 |
---
|
7 |
+
# llama2-webui
|
8 |
+
|
9 |
+
Running Llama 2 with gradio web UI on GPU or CPU from anywhere (Linux/Windows/Mac).
|
10 |
+
- Supporting all Llama 2 models (7B, 13B, 70B, GPTQ, GGML, GGUF, [CodeLlama](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ)) with 8-bit, 4-bit mode.
|
11 |
+
- Use [llama2-wrapper](https://pypi.org/project/llama2-wrapper/) as your local llama2 backend for Generative Agents/Apps; [colab example](./colab/Llama_2_7b_Chat_GPTQ.ipynb).
|
12 |
+
- [Run OpenAI Compatible API](#start-openai-compatible-api) on Llama2 models.
|
13 |
+
|
14 |
+
![screenshot](./static/screenshot.png)
|
15 |
+
|
16 |
+
![code_llama_playground](https://i.imgur.com/FgMUiT6.gif)
|
17 |
+
|
18 |
+
## Features
|
19 |
+
|
20 |
+
- Supporting models: [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)/[13b](https://huggingface.co/llamaste/Llama-2-13b-chat-hf)/[70b](https://huggingface.co/llamaste/Llama-2-70b-chat-hf), [Llama-2-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ), [Llama-2-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML), [Llama-2-GGUF](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF), [CodeLlama](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) ...
|
21 |
+
- Supporting model backends: [tranformers](https://github.com/huggingface/transformers), [bitsandbytes(8-bit inference)](https://github.com/TimDettmers/bitsandbytes), [AutoGPTQ(4-bit inference)](https://github.com/PanQiWei/AutoGPTQ), [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
22 |
+
- Demos: [Run Llama2 on MacBook Air](https://twitter.com/liltom_eth/status/1682791729207070720?s=20); [Run Llama2 on free Colab T4 GPU](./colab/Llama_2_7b_Chat_GPTQ.ipynb)
|
23 |
+
- Use [llama2-wrapper](https://pypi.org/project/llama2-wrapper/) as your local llama2 backend for Generative Agents/Apps; [colab example](./colab/Llama_2_7b_Chat_GPTQ.ipynb).
|
24 |
+
- [Run OpenAI Compatible API](#start-openai-compatible-api) on Llama2 models.
|
25 |
+
- [News](./docs/news.md), [Benchmark](./docs/performance.md), [Issue Solutions](./docs/issues.md)
|
26 |
+
|
27 |
+
## Contents
|
28 |
+
|
29 |
+
- [Install](#install)
|
30 |
+
- [Usage](#usage)
|
31 |
+
- [Start Chat UI](#start-chat-ui)
|
32 |
+
- [Start Code Llama UI](#start-code-llama-ui)
|
33 |
+
- [Use llama2-wrapper for Your App](#use-llama2-wrapper-for-your-app)
|
34 |
+
- [Start OpenAI Compatible API](#start-openai-compatible-api)
|
35 |
+
- [Benchmark](#benchmark)
|
36 |
+
- [Download Llama-2 Models](#download-llama-2-models)
|
37 |
+
- [Model List](#model-list)
|
38 |
+
- [Download Script](#download-script)
|
39 |
+
- [Tips](#tips)
|
40 |
+
- [Env Examples](#env-examples)
|
41 |
+
- [Run on Nvidia GPU](#run-on-nvidia-gpu)
|
42 |
+
- [Run bitsandbytes 8 bit](#run-bitsandbytes-8-bit)
|
43 |
+
- [Run GPTQ 4 bit](#run-gptq-4-bit)
|
44 |
+
- [Run on CPU](#run-on-cpu)
|
45 |
+
- [Mac Metal Acceleration](#mac-metal-acceleration)
|
46 |
+
- [AMD/Nvidia GPU Acceleration](#amdnvidia-gpu-acceleration)
|
47 |
+
- [License](#license)
|
48 |
+
- [Contributing](#contributing)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
## Install
|
53 |
+
### Method 1: From [PyPI](https://pypi.org/project/llama2-wrapper/)
|
54 |
+
```
|
55 |
+
pip install llama2-wrapper
|
56 |
+
```
|
57 |
+
The newest `llama2-wrapper>=0.1.14` supports llama.cpp's `gguf` models.
|
58 |
+
|
59 |
+
If you would like to use old `ggml` models, install `llama2-wrapper<=0.1.13` or manually install `llama-cpp-python==0.1.77`.
|
60 |
+
|
61 |
+
### Method 2: From Source:
|
62 |
+
|
63 |
+
```
|
64 |
+
git clone https://github.com/liltom-eth/llama2-webui.git
|
65 |
+
cd llama2-webui
|
66 |
+
pip install -r requirements.txt
|
67 |
+
```
|
68 |
+
### Install Issues:
|
69 |
+
`bitsandbytes >= 0.39` may not work on older NVIDIA GPUs. In that case, to use `LOAD_IN_8BIT`, you may have to downgrade like this:
|
70 |
+
|
71 |
+
- `pip install bitsandbytes==0.38.1`
|
72 |
+
|
73 |
+
`bitsandbytes` also need a special install for Windows:
|
74 |
+
|
75 |
+
```
|
76 |
+
pip uninstall bitsandbytes
|
77 |
+
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.0-py3-none-win_amd64.whl
|
78 |
+
```
|
79 |
+
|
80 |
+
## Usage
|
81 |
+
|
82 |
+
### Start Chat UI
|
83 |
+
|
84 |
+
Run chatbot simply with web UI:
|
85 |
+
|
86 |
+
```bash
|
87 |
+
python app.py
|
88 |
+
```
|
89 |
+
|
90 |
+
`app.py` will load the default config `.env` which uses `llama.cpp` as the backend to run `llama-2-7b-chat.ggmlv3.q4_0.bin` model for inference. The model `llama-2-7b-chat.ggmlv3.q4_0.bin` will be automatically downloaded.
|
91 |
+
|
92 |
+
```bash
|
93 |
+
Running on backend llama.cpp.
|
94 |
+
Use default model path: ./models/llama-2-7b-chat.Q4_0.gguf
|
95 |
+
Start downloading model to: ./models/llama-2-7b-chat.Q4_0.gguf
|
96 |
+
```
|
97 |
+
|
98 |
+
You can also customize your `MODEL_PATH`, `BACKEND_TYPE,` and model configs in `.env` file to run different llama2 models on different backends (llama.cpp, transformers, gptq).
|
99 |
+
|
100 |
+
### Start Code Llama UI
|
101 |
+
|
102 |
+
We provide a code completion / filling UI for Code Llama.
|
103 |
+
|
104 |
+
Base model **Code Llama** and extend model **Code Llama — Python** are not fine-tuned to follow instructions. They should be prompted so that the expected answer is the natural continuation of the prompt. That means these two models focus on code filling and code completion.
|
105 |
+
|
106 |
+
Here is an example run CodeLlama code completion on llama.cpp backend:
|
107 |
+
|
108 |
+
```
|
109 |
+
python code_completion.py --model_path ./models/codellama-7b.Q4_0.gguf
|
110 |
+
```
|
111 |
+
|
112 |
+
![code_llama_playground](https://i.imgur.com/FgMUiT6.gif)
|
113 |
+
|
114 |
+
`codellama-7b.Q4_0.gguf` can be downloaded from [TheBloke/CodeLlama-7B-GGUF](https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/blob/main/codellama-7b.Q4_0.gguf).
|
115 |
+
|
116 |
+
**Code Llama — Instruct** trained with “natural language instruction” inputs paired with anticipated outputs. This strategic methodology enhances the model’s capacity to grasp human expectations in prompts. That means instruct models can be used in a chatbot-like app.
|
117 |
+
|
118 |
+
Example run CodeLlama chat on gptq backend:
|
119 |
+
|
120 |
+
```
|
121 |
+
python app.py --backend_type gptq --model_path ./models/CodeLlama-7B-Instruct-GPTQ/ --share True
|
122 |
+
```
|
123 |
+
|
124 |
+
![code_llama_chat](https://i.imgur.com/lQLfemB.gif)
|
125 |
+
|
126 |
+
`CodeLlama-7B-Instruct-GPTQ` can be downloaded from [TheBloke/CodeLlama-7B-Instruct-GPTQ](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ)
|
127 |
+
|
128 |
+
### Use llama2-wrapper for Your App
|
129 |
+
|
130 |
+
🔥 For developers, we released `llama2-wrapper` as a llama2 backend wrapper in [PYPI](https://pypi.org/project/llama2-wrapper/).
|
131 |
+
|
132 |
+
Use `llama2-wrapper` as your local llama2 backend to answer questions and more, [colab example](./colab/ggmlv3_q4_0.ipynb):
|
133 |
+
|
134 |
+
```python
|
135 |
+
# pip install llama2-wrapper
|
136 |
+
from llama2_wrapper import LLAMA2_WRAPPER, get_prompt
|
137 |
+
llama2_wrapper = LLAMA2_WRAPPER()
|
138 |
+
# Default running on backend llama.cpp.
|
139 |
+
# Automatically downloading model to: ./models/llama-2-7b-chat.ggmlv3.q4_0.bin
|
140 |
+
prompt = "Do you know Pytorch"
|
141 |
+
answer = llama2_wrapper(get_prompt(prompt), temperature=0.9)
|
142 |
+
```
|
143 |
+
|
144 |
+
Run gptq llama2 model on Nvidia GPU, [colab example](./colab/Llama_2_7b_Chat_GPTQ.ipynb):
|
145 |
+
|
146 |
+
```python
|
147 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
148 |
+
llama2_wrapper = LLAMA2_WRAPPER(backend_type="gptq")
|
149 |
+
# Automatically downloading model to: ./models/Llama-2-7b-Chat-GPTQ
|
150 |
+
```
|
151 |
+
|
152 |
+
Run llama2 7b with bitsandbytes 8 bit with a `model_path`:
|
153 |
+
|
154 |
+
```python
|
155 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
156 |
+
llama2_wrapper = LLAMA2_WRAPPER(
|
157 |
+
model_path = "./models/Llama-2-7b-chat-hf",
|
158 |
+
backend_type = "transformers",
|
159 |
+
load_in_8bit = True
|
160 |
+
)
|
161 |
+
```
|
162 |
+
Check [API Document](https://pypi.org/project/llama2-wrapper/) for more usages.
|
163 |
+
|
164 |
+
### Start OpenAI Compatible API
|
165 |
+
|
166 |
+
`llama2-wrapper` offers a web server that acts as a drop-in replacement for the OpenAI API. This allows you to use Llama2 models with any OpenAI compatible clients, libraries or services, etc.
|
167 |
+
|
168 |
+
Start Fast API:
|
169 |
+
|
170 |
+
```
|
171 |
+
python -m llama2_wrapper.server
|
172 |
+
```
|
173 |
+
|
174 |
+
it will use `llama.cpp` as the backend by default to run `llama-2-7b-chat.ggmlv3.q4_0.bin` model.
|
175 |
+
|
176 |
+
Start Fast API for `gptq` backend:
|
177 |
+
|
178 |
+
```
|
179 |
+
python -m llama2_wrapper.server --backend_type gptq
|
180 |
+
```
|
181 |
+
|
182 |
+
Navigate to http://localhost:8000/docs to see the OpenAPI documentation.
|
183 |
+
|
184 |
+
#### Basic settings
|
185 |
+
|
186 |
+
| Flag | Description |
|
187 |
+
| ---------------- | ------------------------------------------------------------ |
|
188 |
+
| `-h`, `--help` | Show this help message. |
|
189 |
+
| `--model_path` | The path to the model to use for generating completions. |
|
190 |
+
| `--backend_type` | Backend for llama2, options: llama.cpp, gptq, transformers |
|
191 |
+
| `--max_tokens` | Maximum context size. |
|
192 |
+
| `--load_in_8bit` | Whether to use bitsandbytes to run model in 8 bit mode (only for transformers models). |
|
193 |
+
| `--verbose` | Whether to print verbose output to stderr. |
|
194 |
+
| `--host` | API address |
|
195 |
+
| `--port` | API port |
|
196 |
+
|
197 |
+
## Benchmark
|
198 |
+
|
199 |
+
Run benchmark script to compute performance on your device, `benchmark.py` will load the same `.env` as `app.py`.:
|
200 |
+
|
201 |
+
```bash
|
202 |
+
python benchmark.py
|
203 |
+
```
|
204 |
+
|
205 |
+
You can also select the `iter`, `backend_type` and `model_path` the benchmark will be run (overwrite .env args) :
|
206 |
+
|
207 |
+
```bash
|
208 |
+
python benchmark.py --iter NB_OF_ITERATIONS --backend_type gptq
|
209 |
+
```
|
210 |
+
|
211 |
+
By default, the number of iterations is 5, but if you want a faster result or a more accurate one
|
212 |
+
you can set it to whatever value you want, but please only report results with at least 5 iterations.
|
213 |
+
|
214 |
+
This [colab example](./colab/Llama_2_7b_Chat_GPTQ.ipynb) also show you how to benchmark gptq model on free Google Colab T4 GPU.
|
215 |
+
|
216 |
+
Some benchmark performance:
|
217 |
+
|
218 |
+
| Model | Precision | Device | RAM / GPU VRAM | Speed (tokens/sec) | load time (s) |
|
219 |
+
| --------------------------- | --------- | ------------------ | -------------- | ------------------ | ------------- |
|
220 |
+
| Llama-2-7b-chat-hf | 8 bit | NVIDIA RTX 2080 Ti | 7.7 GB VRAM | 3.76 | 641.36 |
|
221 |
+
| Llama-2-7b-Chat-GPTQ | 4 bit | NVIDIA RTX 2080 Ti | 5.8 GB VRAM | 18.85 | 192.91 |
|
222 |
+
| Llama-2-7b-Chat-GPTQ | 4 bit | Google Colab T4 | 5.8 GB VRAM | 18.19 | 37.44 |
|
223 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Apple M1 Pro CPU | 5.4 GB RAM | 17.90 | 0.18 |
|
224 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Apple M2 CPU | 5.4 GB RAM | 13.70 | 0.13 |
|
225 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Apple M2 Metal | 5.4 GB RAM | 12.60 | 0.10 |
|
226 |
+
| llama-2-7b-chat.ggmlv3.q2_K | 2 bit | Intel i7-8700 | 4.5 GB RAM | 7.88 | 31.90 |
|
227 |
+
|
228 |
+
Check/contribute the performance of your device in the full [performance doc](./docs/performance.md).
|
229 |
+
|
230 |
+
## Download Llama-2 Models
|
231 |
+
|
232 |
+
Llama 2 is a collection of pre-trained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters.
|
233 |
+
|
234 |
+
Llama-2-7b-Chat-GPTQ is the GPTQ model files for [Meta's Llama 2 7b Chat](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). GPTQ 4-bit Llama-2 model require less GPU VRAM to run it.
|
235 |
+
|
236 |
+
### Model List
|
237 |
+
|
238 |
+
| Model Name | set MODEL_PATH in .env | Download URL |
|
239 |
+
| ----------------------------------- | ---------------------------------------- | ------------------------------------------------------------ |
|
240 |
+
| meta-llama/Llama-2-7b-chat-hf | /path-to/Llama-2-7b-chat-hf | [Link](https://huggingface.co/llamaste/Llama-2-7b-chat-hf) |
|
241 |
+
| meta-llama/Llama-2-13b-chat-hf | /path-to/Llama-2-13b-chat-hf | [Link](https://huggingface.co/llamaste/Llama-2-13b-chat-hf) |
|
242 |
+
| meta-llama/Llama-2-70b-chat-hf | /path-to/Llama-2-70b-chat-hf | [Link](https://huggingface.co/llamaste/Llama-2-70b-chat-hf) |
|
243 |
+
| meta-llama/Llama-2-7b-hf | /path-to/Llama-2-7b-hf | [Link](https://huggingface.co/meta-llama/Llama-2-7b-hf) |
|
244 |
+
| meta-llama/Llama-2-13b-hf | /path-to/Llama-2-13b-hf | [Link](https://huggingface.co/meta-llama/Llama-2-13b-hf) |
|
245 |
+
| meta-llama/Llama-2-70b-hf | /path-to/Llama-2-70b-hf | [Link](https://huggingface.co/meta-llama/Llama-2-70b-hf) |
|
246 |
+
| TheBloke/Llama-2-7b-Chat-GPTQ | /path-to/Llama-2-7b-Chat-GPTQ | [Link](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ) |
|
247 |
+
| TheBloke/Llama-2-7b-Chat-GGUF | /path-to/llama-2-7b-chat.Q4_0.gguf | [Link](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/blob/main/llama-2-7b-chat.Q4_0.gguf) |
|
248 |
+
| TheBloke/Llama-2-7B-Chat-GGML | /path-to/llama-2-7b-chat.ggmlv3.q4_0.bin | [Link](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML) |
|
249 |
+
| TheBloke/CodeLlama-7B-Instruct-GPTQ | TheBloke/CodeLlama-7B-Instruct-GPTQ | [Link](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) |
|
250 |
+
| ... | ... | ... |
|
251 |
+
|
252 |
+
Running 4-bit model `Llama-2-7b-Chat-GPTQ` needs GPU with 6GB VRAM.
|
253 |
+
|
254 |
+
Running 4-bit model `llama-2-7b-chat.ggmlv3.q4_0.bin` needs CPU with 6GB RAM. There is also a list of other 2, 3, 4, 5, 6, 8-bit GGML models that can be used from [TheBloke/Llama-2-7B-Chat-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML).
|
255 |
+
|
256 |
+
### Download Script
|
257 |
+
|
258 |
+
These models can be downloaded through:
|
259 |
+
|
260 |
+
```bash
|
261 |
+
python -m llama2_wrapper.download --repo_id TheBloke/CodeLlama-7B-Python-GPTQ
|
262 |
+
|
263 |
+
python -m llama2_wrapper.download --repo_id TheBloke/Llama-2-7b-Chat-GGUF --filename llama-2-7b-chat.Q4_0.gguf --save_dir ./models
|
264 |
+
```
|
265 |
+
|
266 |
+
Or use CMD like:
|
267 |
+
|
268 |
+
```bash
|
269 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
270 |
+
git lfs install
|
271 |
+
git clone git@hf.co:meta-llama/Llama-2-7b-chat-hf
|
272 |
+
```
|
273 |
+
|
274 |
+
To download Llama 2 models, you need to request access from [https://ai.meta.com/llama/](https://ai.meta.com/llama/) and also enable access on repos like [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main). Requests will be processed in hours.
|
275 |
+
|
276 |
+
For GPTQ models like [TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ), you can directly download without requesting access.
|
277 |
+
|
278 |
+
For GGML models like [TheBloke/Llama-2-7B-Chat-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML), you can directly download without requesting access.
|
279 |
+
|
280 |
+
## Tips
|
281 |
+
|
282 |
+
### Env Examples
|
283 |
+
|
284 |
+
There are some examples in `./env_examples/` folder.
|
285 |
+
|
286 |
+
| Model Setup | Example .env |
|
287 |
+
| ------------------------------------------------------ | --------------------------- |
|
288 |
+
| Llama-2-7b-chat-hf 8-bit (transformers backend) | .env.7b_8bit_example |
|
289 |
+
| Llama-2-7b-Chat-GPTQ 4-bit (gptq transformers backend) | .env.7b_gptq_example |
|
290 |
+
| Llama-2-7B-Chat-GGML 4bit (llama.cpp backend) | .env.7b_ggmlv3_q4_0_example |
|
291 |
+
| Llama-2-13b-chat-hf (transformers backend) | .env.13b_example |
|
292 |
+
| ... | ... |
|
293 |
+
|
294 |
+
### Run on Nvidia GPU
|
295 |
+
|
296 |
+
The running requires around 14GB of GPU VRAM for Llama-2-7b and 28GB of GPU VRAM for Llama-2-13b.
|
297 |
+
|
298 |
+
If you are running on multiple GPUs, the model will be loaded automatically on GPUs and split the VRAM usage. That allows you to run Llama-2-7b (requires 14GB of GPU VRAM) on a setup like 2 GPUs (11GB VRAM each).
|
299 |
+
|
300 |
+
#### Run bitsandbytes 8 bit
|
301 |
+
|
302 |
+
If you do not have enough memory, you can set up your `LOAD_IN_8BIT` as `True` in `.env`. This can reduce memory usage by around half with slightly degraded model quality. It is compatible with the CPU, GPU, and Metal backend.
|
303 |
+
|
304 |
+
Llama-2-7b with 8-bit compression can run on a single GPU with 8 GB of VRAM, like an Nvidia RTX 2080Ti, RTX 4080, T4, V100 (16GB).
|
305 |
+
|
306 |
+
#### Run GPTQ 4 bit
|
307 |
+
|
308 |
+
If you want to run 4 bit Llama-2 model like `Llama-2-7b-Chat-GPTQ`, you can set up your `BACKEND_TYPE` as `gptq` in `.env` like example `.env.7b_gptq_example`.
|
309 |
+
|
310 |
+
Make sure you have downloaded the 4-bit model from `Llama-2-7b-Chat-GPTQ` and set the `MODEL_PATH` and arguments in `.env` file.
|
311 |
+
|
312 |
+
`Llama-2-7b-Chat-GPTQ` can run on a single GPU with 6 GB of VRAM.
|
313 |
+
|
314 |
+
If you encounter issue like `NameError: name 'autogptq_cuda_256' is not defined`, please refer to [here](https://huggingface.co/TheBloke/open-llama-13b-open-instruct-GPTQ/discussions/1)
|
315 |
+
> pip install https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.3.0/auto_gptq-0.3.0+cu117-cp310-cp310-linux_x86_64.whl
|
316 |
+
|
317 |
+
### Run on CPU
|
318 |
+
|
319 |
+
Run Llama-2 model on CPU requires [llama.cpp](https://github.com/ggerganov/llama.cpp) dependency and [llama.cpp Python Bindings](https://github.com/abetlen/llama-cpp-python), which are already installed.
|
320 |
+
|
321 |
+
|
322 |
+
Download GGML models like `llama-2-7b-chat.ggmlv3.q4_0.bin` following [Download Llama-2 Models](#download-llama-2-models) section. `llama-2-7b-chat.ggmlv3.q4_0.bin` model requires at least 6 GB RAM to run on CPU.
|
323 |
+
|
324 |
+
Set up configs like `.env.7b_ggmlv3_q4_0_example` from `env_examples` as `.env`.
|
325 |
+
|
326 |
+
Run web UI `python app.py` .
|
327 |
+
|
328 |
+
#### Mac Metal Acceleration
|
329 |
+
|
330 |
+
For Mac users, you can also set up Mac Metal for acceleration, try install this dependencies:
|
331 |
+
|
332 |
+
```bash
|
333 |
+
pip uninstall llama-cpp-python -y
|
334 |
+
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install -U llama-cpp-python --no-cache-dir
|
335 |
+
pip install 'llama-cpp-python[server]'
|
336 |
+
```
|
337 |
+
|
338 |
+
or check details:
|
339 |
+
|
340 |
+
- [MacOS Install with Metal GPU](https://github.com/abetlen/llama-cpp-python/blob/main/docs/install/macos.md)
|
341 |
+
|
342 |
+
#### AMD/Nvidia GPU Acceleration
|
343 |
+
|
344 |
+
If you would like to use AMD/Nvidia GPU for acceleration, check this:
|
345 |
+
|
346 |
+
- [Installation with OpenBLAS / cuBLAS / CLBlast / Metal](https://github.com/abetlen/llama-cpp-python#installation-with-openblas--cublas--clblast--metal)
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
## License
|
353 |
+
|
354 |
+
MIT - see [MIT License](LICENSE)
|
355 |
+
|
356 |
+
This project enables users to adapt it freely for proprietary purposes without any restrictions.
|
357 |
+
|
358 |
+
## Contributing
|
359 |
+
|
360 |
+
Kindly read our [Contributing Guide](CONTRIBUTING.md) to learn and understand our development process.
|
361 |
+
|
362 |
+
### All Contributors
|
363 |
+
|
364 |
+
<a href="https://github.com/liltom-eth/llama2-webui/graphs/contributors">
|
365 |
+
<img src="https://contrib.rocks/image?repo=liltom-eth/llama2-webui" />
|
366 |
+
</a>
|
367 |
+
|
368 |
+
### Review
|
369 |
+
<a href='https://github.com/repo-reviews/repo-reviews.github.io/blob/main/create.md' target="_blank"><img alt='Github' src='https://img.shields.io/badge/review-100000?style=flat&logo=Github&logoColor=white&labelColor=888888&color=555555'/></a>
|
370 |
+
|
371 |
+
### Star History
|
372 |
+
|
373 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=liltom-eth/llama2-webui&type=Date)](https://star-history.com/#liltom-eth/llama2-webui&Date)
|
374 |
+
|
375 |
+
## Credits
|
376 |
|
377 |
+
- https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
378 |
+
- https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat
|
379 |
+
- https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ
|
380 |
+
- [https://github.com/ggerganov/llama.cpp](https://github.com/ggerganov/llama.cpp)
|
381 |
+
- [https://github.com/TimDettmers/bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
382 |
+
- [https://github.com/PanQiWei/AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)
|
383 |
+
- [https://github.com/abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python)
|
app.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import argparse
|
4 |
+
from typing import Iterator
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from distutils.util import strtobool
|
9 |
+
|
10 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
11 |
+
|
12 |
+
import logging
|
13 |
+
|
14 |
+
from prompts.utils import PromtsContainer
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--model_path", type=str, default="", help="model path")
|
19 |
+
parser.add_argument(
|
20 |
+
"--backend_type",
|
21 |
+
type=str,
|
22 |
+
default="",
|
23 |
+
help="Backend options: llama.cpp, gptq, transformers, llama2.cu",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--load_in_8bit",
|
27 |
+
type=bool,
|
28 |
+
default=False,
|
29 |
+
help="Whether to use bitsandbytes 8 bit.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--share",
|
33 |
+
type=bool,
|
34 |
+
default=False,
|
35 |
+
help="Whether to share public for gradio.",
|
36 |
+
)
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
load_dotenv()
|
40 |
+
|
41 |
+
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
|
42 |
+
MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
|
43 |
+
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
|
44 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))
|
45 |
+
|
46 |
+
MODEL_PATH = os.getenv("MODEL_PATH")
|
47 |
+
assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
|
48 |
+
BACKEND_TYPE = os.getenv("BACKEND_TYPE")
|
49 |
+
assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"
|
50 |
+
|
51 |
+
LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))
|
52 |
+
|
53 |
+
if args.model_path != "":
|
54 |
+
MODEL_PATH = args.model_path
|
55 |
+
if args.backend_type != "":
|
56 |
+
BACKEND_TYPE = args.backend_type
|
57 |
+
if args.load_in_8bit:
|
58 |
+
LOAD_IN_8BIT = True
|
59 |
+
|
60 |
+
llama2_wrapper = LLAMA2_WRAPPER(
|
61 |
+
model_path=MODEL_PATH,
|
62 |
+
backend_type=BACKEND_TYPE,
|
63 |
+
max_tokens=MAX_INPUT_TOKEN_LENGTH,
|
64 |
+
load_in_8bit=LOAD_IN_8BIT,
|
65 |
+
verbose=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
DESCRIPTION = """
|
69 |
+
# llama2-webui
|
70 |
+
"""
|
71 |
+
DESCRIPTION2 = """
|
72 |
+
- Supporting models: [Llama-2-7b](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML)/[13b](https://huggingface.co/llamaste/Llama-2-13b-chat-hf)/[70b](https://huggingface.co/llamaste/Llama-2-70b-chat-hf), [Llama-2-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ), [Llama-2-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML), [CodeLlama](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) ...
|
73 |
+
- Supporting model backends: [tranformers](https://github.com/huggingface/transformers), [bitsandbytes(8-bit inference)](https://github.com/TimDettmers/bitsandbytes), [AutoGPTQ(4-bit inference)](https://github.com/PanQiWei/AutoGPTQ), [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
74 |
+
"""
|
75 |
+
|
76 |
+
def clear_and_save_textbox(message: str) -> tuple[str, str]:
|
77 |
+
return "", message
|
78 |
+
|
79 |
+
def save_textbox_for_prompt(message: str) -> str:
|
80 |
+
logging.info("start save_textbox_from_prompt")
|
81 |
+
message = convert_summary_to_prompt(message)
|
82 |
+
return message
|
83 |
+
|
84 |
+
def display_input(
|
85 |
+
message: str, history: list[tuple[str, str]]
|
86 |
+
) -> list[tuple[str, str]]:
|
87 |
+
history.append((message, ""))
|
88 |
+
return history
|
89 |
+
|
90 |
+
def delete_prev_fn(
|
91 |
+
history: list[tuple[str, str]]
|
92 |
+
) -> tuple[list[tuple[str, str]], str]:
|
93 |
+
try:
|
94 |
+
message, _ = history.pop()
|
95 |
+
except IndexError:
|
96 |
+
message = ""
|
97 |
+
return history, message or ""
|
98 |
+
|
99 |
+
def generate(
|
100 |
+
message: str,
|
101 |
+
history_with_input: list[tuple[str, str]],
|
102 |
+
system_prompt: str,
|
103 |
+
max_new_tokens: int,
|
104 |
+
temperature: float,
|
105 |
+
top_p: float,
|
106 |
+
top_k: int,
|
107 |
+
platform: str,
|
108 |
+
) -> tuple[Iterator[list[tuple[str, str]]], str]:
|
109 |
+
if max_new_tokens > MAX_MAX_NEW_TOKENS:
|
110 |
+
raise ValueError
|
111 |
+
try:
|
112 |
+
history = history_with_input[:-1]
|
113 |
+
yield history + [(message, "")], "## processing prompt"
|
114 |
+
generator = llama2_wrapper.run(
|
115 |
+
message,
|
116 |
+
history,
|
117 |
+
system_prompt,
|
118 |
+
max_new_tokens,
|
119 |
+
temperature,
|
120 |
+
top_p,
|
121 |
+
top_k,
|
122 |
+
)
|
123 |
+
t = -time.perf_counter()
|
124 |
+
try:
|
125 |
+
first_response = next(generator)
|
126 |
+
t += time.perf_counter()
|
127 |
+
yield history + [(message, first_response)], "## generating"
|
128 |
+
t -= time.perf_counter()
|
129 |
+
except StopIteration:
|
130 |
+
yield history + [(message, "")], "## terminated"
|
131 |
+
num_tokens = 1
|
132 |
+
t = -time.perf_counter()
|
133 |
+
for response in generator:
|
134 |
+
num_tokens += 1
|
135 |
+
t += time.perf_counter()
|
136 |
+
yield history + [(message, response)], "## generating"
|
137 |
+
t -= time.perf_counter()
|
138 |
+
t += time.perf_counter()
|
139 |
+
if platform == None:
|
140 |
+
platform = "CUDA by default"
|
141 |
+
yield history + [(message, response)], f"### num tok: {num_tokens}<br>time(sec): {t:.2f}<br>tok/sec: {num_tokens / t:.2f}<br>{BACKEND_TYPE}({platform})"
|
142 |
+
except Exception as e:
|
143 |
+
logging.exception(e)
|
144 |
+
|
145 |
+
def check_input_token_length(
|
146 |
+
message: str, chat_history: list[tuple[str, str]], system_prompt: str
|
147 |
+
) -> None:
|
148 |
+
input_token_length = llama2_wrapper.get_input_token_length(
|
149 |
+
message, chat_history, system_prompt
|
150 |
+
)
|
151 |
+
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
152 |
+
raise gr.Error(
|
153 |
+
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
|
154 |
+
)
|
155 |
+
|
156 |
+
prompts_container = PromtsContainer()
|
157 |
+
prompts = prompts_container.get_prompts_tab_dict()
|
158 |
+
default_prompts_checkbox = False
|
159 |
+
default_advanced_checkbox = False
|
160 |
+
|
161 |
+
def convert_summary_to_prompt(summary):
|
162 |
+
return prompts_container.get_prompt_by_summary(summary)
|
163 |
+
|
164 |
+
def tab_list(tab_data, chatbot, perf, platform):
|
165 |
+
for item in tab_data:
|
166 |
+
with gr.Group():
|
167 |
+
gr.HTML(
|
168 |
+
f'<p style="color: black; font-weight: bold;">{item["act"]}</p>'
|
169 |
+
)
|
170 |
+
prompt_text = gr.Button(
|
171 |
+
value=f"{item['summary']}",
|
172 |
+
size="sm",
|
173 |
+
elem_classes="text-left-aligned",
|
174 |
+
)
|
175 |
+
prompt_text.click(
|
176 |
+
fn=save_textbox_for_prompt,
|
177 |
+
inputs=prompt_text,
|
178 |
+
outputs=saved_input,
|
179 |
+
api_name=False,
|
180 |
+
queue=True,
|
181 |
+
).then(
|
182 |
+
fn=display_input,
|
183 |
+
inputs=[saved_input, chatbot],
|
184 |
+
outputs=chatbot,
|
185 |
+
api_name=False,
|
186 |
+
queue=True,
|
187 |
+
).then(
|
188 |
+
fn=check_input_token_length,
|
189 |
+
inputs=[saved_input, chatbot, system_prompt],
|
190 |
+
api_name=False,
|
191 |
+
queue=False,
|
192 |
+
).success(
|
193 |
+
fn=generate,
|
194 |
+
inputs=[
|
195 |
+
saved_input,
|
196 |
+
chatbot,
|
197 |
+
system_prompt,
|
198 |
+
max_new_tokens,
|
199 |
+
temperature,
|
200 |
+
top_p,
|
201 |
+
top_k,
|
202 |
+
platform,
|
203 |
+
],
|
204 |
+
outputs=[
|
205 |
+
chatbot,
|
206 |
+
perf
|
207 |
+
],
|
208 |
+
api_name=False,
|
209 |
+
)
|
210 |
+
|
211 |
+
CSS = """
|
212 |
+
.contain { display: flex; flex-direction: column;}
|
213 |
+
.text-left-aligned {text-align: left !important; font-size: 16px;}
|
214 |
+
"""
|
215 |
+
with gr.Blocks(css=CSS, title="Gradio") as demo:
|
216 |
+
with gr.Row():
|
217 |
+
with gr.Column(visible=default_advanced_checkbox, variant="combat") as advanced_column:
|
218 |
+
system_prompt = gr.Textbox(
|
219 |
+
label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6
|
220 |
+
)
|
221 |
+
max_new_tokens = gr.Slider(
|
222 |
+
label="Max new tokens",
|
223 |
+
minimum=1,
|
224 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
225 |
+
step=1,
|
226 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
227 |
+
)
|
228 |
+
temperature = gr.Slider(
|
229 |
+
label="Temperature",
|
230 |
+
minimum=0.1,
|
231 |
+
maximum=4.0,
|
232 |
+
step=0.1,
|
233 |
+
value=1.0,
|
234 |
+
)
|
235 |
+
top_p = gr.Slider(
|
236 |
+
label="Top-p (nucleus sampling)",
|
237 |
+
minimum=0.05,
|
238 |
+
maximum=1.0,
|
239 |
+
step=0.05,
|
240 |
+
value=0.95,
|
241 |
+
)
|
242 |
+
top_k = gr.Slider(
|
243 |
+
label="Top-k",
|
244 |
+
minimum=1,
|
245 |
+
maximum=1000,
|
246 |
+
step=1,
|
247 |
+
value=50,
|
248 |
+
)
|
249 |
+
with gr.Column(scale=2):
|
250 |
+
with gr.Row():
|
251 |
+
gr.Markdown("# llama2-webui")
|
252 |
+
perf = gr.Markdown(value=f"## performance<br>Current Backend: {BACKEND_TYPE}", rtl=True)
|
253 |
+
with gr.Group():
|
254 |
+
chatbot = gr.Chatbot(label="Chatbot")
|
255 |
+
with gr.Row():
|
256 |
+
textbox = gr.Textbox(
|
257 |
+
container=False,
|
258 |
+
show_label=False,
|
259 |
+
placeholder="Type a message...",
|
260 |
+
scale=10,
|
261 |
+
)
|
262 |
+
submit_button = gr.Button(
|
263 |
+
"Submit", variant="primary",
|
264 |
+
)
|
265 |
+
with gr.Row():
|
266 |
+
retry_button = gr.Button("🔄 Retry", variant="secondary")
|
267 |
+
undo_button = gr.Button("↩️ Undo", variant="secondary")
|
268 |
+
clear_button = gr.Button("🗑️ Clear", variant="secondary")
|
269 |
+
|
270 |
+
saved_input = gr.State()
|
271 |
+
with gr.Row():
|
272 |
+
advanced_checkbox = gr.Checkbox(
|
273 |
+
label="Advanced",
|
274 |
+
value=default_advanced_checkbox,
|
275 |
+
container=False,
|
276 |
+
elem_classes="min_check",
|
277 |
+
)
|
278 |
+
prompts_checkbox = gr.Checkbox(
|
279 |
+
label="Prompts",
|
280 |
+
value=default_prompts_checkbox,
|
281 |
+
container=False,
|
282 |
+
elem_classes="min_check",
|
283 |
+
)
|
284 |
+
with gr.Row():
|
285 |
+
platform = gr.Radio(["CUDA", "platform2"], label="Choose hardware platform", info="CUDA by default if no choosen")
|
286 |
+
with gr.Column(visible=default_prompts_checkbox) as prompt_column:
|
287 |
+
for k, v in prompts.items():
|
288 |
+
with gr.Tab(k):
|
289 |
+
tab_list(v, chatbot, perf, platform)
|
290 |
+
|
291 |
+
prompts_checkbox.change(
|
292 |
+
lambda x: gr.update(visible=x),
|
293 |
+
prompts_checkbox,
|
294 |
+
prompt_column,
|
295 |
+
queue=False,
|
296 |
+
)
|
297 |
+
|
298 |
+
advanced_checkbox.change(
|
299 |
+
lambda x: gr.update(visible=x),
|
300 |
+
advanced_checkbox,
|
301 |
+
advanced_column,
|
302 |
+
queue=False,
|
303 |
+
)
|
304 |
+
|
305 |
+
textbox.submit(
|
306 |
+
fn=clear_and_save_textbox,
|
307 |
+
inputs=textbox,
|
308 |
+
outputs=[textbox, saved_input],
|
309 |
+
api_name=False,
|
310 |
+
queue=False,
|
311 |
+
).then(
|
312 |
+
fn=display_input,
|
313 |
+
inputs=[saved_input, chatbot],
|
314 |
+
outputs=chatbot,
|
315 |
+
api_name=False,
|
316 |
+
queue=False,
|
317 |
+
).then(
|
318 |
+
fn=check_input_token_length,
|
319 |
+
inputs=[saved_input, chatbot, system_prompt],
|
320 |
+
api_name=False,
|
321 |
+
queue=False,
|
322 |
+
).success(
|
323 |
+
fn=generate,
|
324 |
+
inputs=[
|
325 |
+
saved_input,
|
326 |
+
chatbot,
|
327 |
+
system_prompt,
|
328 |
+
max_new_tokens,
|
329 |
+
temperature,
|
330 |
+
top_p,
|
331 |
+
top_k,
|
332 |
+
platform,
|
333 |
+
],
|
334 |
+
outputs=[
|
335 |
+
chatbot,
|
336 |
+
perf
|
337 |
+
],
|
338 |
+
api_name=False,
|
339 |
+
)
|
340 |
+
|
341 |
+
submit_button.click(
|
342 |
+
fn=clear_and_save_textbox,
|
343 |
+
inputs=textbox,
|
344 |
+
outputs=[textbox, saved_input],
|
345 |
+
api_name=False,
|
346 |
+
queue=False,
|
347 |
+
).then(
|
348 |
+
fn=display_input,
|
349 |
+
inputs=[saved_input, chatbot],
|
350 |
+
outputs=chatbot,
|
351 |
+
api_name=False,
|
352 |
+
queue=False,
|
353 |
+
).then(
|
354 |
+
fn=check_input_token_length,
|
355 |
+
inputs=[saved_input, chatbot, system_prompt],
|
356 |
+
api_name=False,
|
357 |
+
queue=False,
|
358 |
+
).success(
|
359 |
+
fn=generate,
|
360 |
+
inputs=[
|
361 |
+
saved_input,
|
362 |
+
chatbot,
|
363 |
+
system_prompt,
|
364 |
+
max_new_tokens,
|
365 |
+
temperature,
|
366 |
+
top_p,
|
367 |
+
top_k,
|
368 |
+
platform,
|
369 |
+
],
|
370 |
+
outputs=[
|
371 |
+
chatbot,
|
372 |
+
perf
|
373 |
+
],
|
374 |
+
api_name=False,
|
375 |
+
)
|
376 |
+
|
377 |
+
retry_button.click(
|
378 |
+
fn=delete_prev_fn,
|
379 |
+
inputs=chatbot,
|
380 |
+
outputs=[chatbot, saved_input],
|
381 |
+
api_name=False,
|
382 |
+
queue=False,
|
383 |
+
).then(
|
384 |
+
fn=display_input,
|
385 |
+
inputs=[saved_input, chatbot],
|
386 |
+
outputs=chatbot,
|
387 |
+
api_name=False,
|
388 |
+
queue=False,
|
389 |
+
).then(
|
390 |
+
fn=generate,
|
391 |
+
inputs=[
|
392 |
+
saved_input,
|
393 |
+
chatbot,
|
394 |
+
system_prompt,
|
395 |
+
max_new_tokens,
|
396 |
+
temperature,
|
397 |
+
top_p,
|
398 |
+
top_k,
|
399 |
+
platform,
|
400 |
+
],
|
401 |
+
outputs=[
|
402 |
+
chatbot,
|
403 |
+
perf
|
404 |
+
],
|
405 |
+
api_name=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
undo_button.click(
|
409 |
+
fn=delete_prev_fn,
|
410 |
+
inputs=chatbot,
|
411 |
+
outputs=[chatbot, saved_input],
|
412 |
+
api_name=False,
|
413 |
+
queue=False,
|
414 |
+
).then(
|
415 |
+
fn=lambda x: x,
|
416 |
+
inputs=[saved_input],
|
417 |
+
outputs=textbox,
|
418 |
+
api_name=False,
|
419 |
+
queue=False,
|
420 |
+
)
|
421 |
+
|
422 |
+
clear_button.click(
|
423 |
+
fn=lambda: ([], ""),
|
424 |
+
outputs=[chatbot, saved_input],
|
425 |
+
queue=False,
|
426 |
+
api_name=False,
|
427 |
+
)
|
428 |
+
|
429 |
+
demo.queue(max_size=20).launch(share=args.share)
|
430 |
+
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
main()
|
benchmark.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from distutils.util import strtobool
|
7 |
+
from memory_profiler import memory_usage
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
11 |
+
|
12 |
+
|
13 |
+
def run_iteration(
|
14 |
+
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
|
15 |
+
):
|
16 |
+
def generation():
|
17 |
+
generator = llama2_wrapper.run(
|
18 |
+
prompt_example,
|
19 |
+
[],
|
20 |
+
DEFAULT_SYSTEM_PROMPT,
|
21 |
+
DEFAULT_MAX_NEW_TOKENS,
|
22 |
+
1,
|
23 |
+
0.95,
|
24 |
+
50,
|
25 |
+
)
|
26 |
+
model_response = None
|
27 |
+
try:
|
28 |
+
first_model_response = next(generator)
|
29 |
+
except StopIteration:
|
30 |
+
pass
|
31 |
+
for model_response in generator:
|
32 |
+
pass
|
33 |
+
return llama2_wrapper.get_token_length(model_response), model_response
|
34 |
+
|
35 |
+
tic = time.perf_counter()
|
36 |
+
mem_usage, (output_token_length, model_response) = memory_usage(
|
37 |
+
(generation,), max_usage=True, retval=True
|
38 |
+
)
|
39 |
+
toc = time.perf_counter()
|
40 |
+
|
41 |
+
generation_time = toc - tic
|
42 |
+
tokens_per_second = output_token_length / generation_time
|
43 |
+
|
44 |
+
return generation_time, tokens_per_second, mem_usage, model_response
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument("--iter", type=int, default=5, help="Number of iterations")
|
50 |
+
parser.add_argument("--model_path", type=str, default="", help="model path")
|
51 |
+
parser.add_argument(
|
52 |
+
"--backend_type",
|
53 |
+
type=str,
|
54 |
+
default="",
|
55 |
+
help="Backend options: llama.cpp, gptq, transformers",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--load_in_8bit",
|
59 |
+
type=bool,
|
60 |
+
default=False,
|
61 |
+
help="Whether to use bitsandbytes 8 bit.",
|
62 |
+
)
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
load_dotenv()
|
67 |
+
|
68 |
+
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
|
69 |
+
MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
|
70 |
+
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
|
71 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))
|
72 |
+
|
73 |
+
MODEL_PATH = os.getenv("MODEL_PATH")
|
74 |
+
assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
|
75 |
+
BACKEND_TYPE = os.getenv("BACKEND_TYPE")
|
76 |
+
assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"
|
77 |
+
|
78 |
+
LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))
|
79 |
+
|
80 |
+
if args.model_path != "":
|
81 |
+
MODEL_PATH = args.model_path
|
82 |
+
if args.backend_type != "":
|
83 |
+
BACKEND_TYPE = args.backend_type
|
84 |
+
if args.load_in_8bit:
|
85 |
+
LOAD_IN_8BIT = True
|
86 |
+
|
87 |
+
# Initialization
|
88 |
+
init_tic = time.perf_counter()
|
89 |
+
llama2_wrapper = LLAMA2_WRAPPER(
|
90 |
+
model_path=MODEL_PATH,
|
91 |
+
backend_type=BACKEND_TYPE,
|
92 |
+
max_tokens=MAX_INPUT_TOKEN_LENGTH,
|
93 |
+
load_in_8bit=LOAD_IN_8BIT,
|
94 |
+
# verbose=True,
|
95 |
+
)
|
96 |
+
|
97 |
+
init_toc = time.perf_counter()
|
98 |
+
initialization_time = init_toc - init_tic
|
99 |
+
|
100 |
+
total_time = 0
|
101 |
+
total_tokens_per_second = 0
|
102 |
+
total_memory_gen = 0
|
103 |
+
|
104 |
+
prompt_example = (
|
105 |
+
"Can you explain briefly to me what is the Python programming language?"
|
106 |
+
)
|
107 |
+
|
108 |
+
# Cold run
|
109 |
+
print("Performing cold run...")
|
110 |
+
run_iteration(
|
111 |
+
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
|
112 |
+
)
|
113 |
+
|
114 |
+
# Timed runs
|
115 |
+
print(f"Performing {args.iter} timed runs...")
|
116 |
+
for i in tqdm(range(args.iter)):
|
117 |
+
try:
|
118 |
+
gen_time, tokens_per_sec, mem_gen, model_response = run_iteration(
|
119 |
+
llama2_wrapper,
|
120 |
+
prompt_example,
|
121 |
+
DEFAULT_SYSTEM_PROMPT,
|
122 |
+
DEFAULT_MAX_NEW_TOKENS,
|
123 |
+
)
|
124 |
+
total_time += gen_time
|
125 |
+
total_tokens_per_second += tokens_per_sec
|
126 |
+
total_memory_gen += mem_gen
|
127 |
+
except:
|
128 |
+
break
|
129 |
+
avg_time = total_time / (i + 1)
|
130 |
+
avg_tokens_per_second = total_tokens_per_second / (i + 1)
|
131 |
+
avg_memory_gen = total_memory_gen / (i + 1)
|
132 |
+
|
133 |
+
print(f"Last model response: {model_response}")
|
134 |
+
print(f"Initialization time: {initialization_time:0.4f} seconds.")
|
135 |
+
print(
|
136 |
+
f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds."
|
137 |
+
)
|
138 |
+
print(
|
139 |
+
f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec."
|
140 |
+
)
|
141 |
+
print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB")
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
main()
|
code_completion.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
5 |
+
|
6 |
+
FIM_PREFIX = "<PRE> "
|
7 |
+
FIM_MIDDLE = " <MID>"
|
8 |
+
FIM_SUFFIX = " <SUF>"
|
9 |
+
|
10 |
+
FIM_INDICATOR = "<FILL_ME>"
|
11 |
+
|
12 |
+
EOS_STRING = "</s>"
|
13 |
+
EOT_STRING = "<EOT>"
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument(
|
19 |
+
"--model_path",
|
20 |
+
type=str,
|
21 |
+
default="./models/codellama-7b-instruct.ggmlv3.Q4_0.bin",
|
22 |
+
help="model path",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--backend_type",
|
26 |
+
type=str,
|
27 |
+
default="llama.cpp",
|
28 |
+
help="Backend options: llama.cpp, gptq, transformers",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--max_tokens",
|
32 |
+
type=int,
|
33 |
+
default=4000,
|
34 |
+
help="Maximum context size.",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--load_in_8bit",
|
38 |
+
type=bool,
|
39 |
+
default=False,
|
40 |
+
help="Whether to use bitsandbytes 8 bit.",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--share",
|
44 |
+
type=bool,
|
45 |
+
default=False,
|
46 |
+
help="Whether to share public for gradio.",
|
47 |
+
)
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
llama2_wrapper = LLAMA2_WRAPPER(
|
51 |
+
model_path=args.model_path,
|
52 |
+
backend_type=args.backend_type,
|
53 |
+
max_tokens=args.max_tokens,
|
54 |
+
load_in_8bit=args.load_in_8bit,
|
55 |
+
)
|
56 |
+
|
57 |
+
def generate(
|
58 |
+
prompt,
|
59 |
+
temperature=0.9,
|
60 |
+
max_new_tokens=256,
|
61 |
+
top_p=0.95,
|
62 |
+
repetition_penalty=1.0,
|
63 |
+
):
|
64 |
+
temperature = float(temperature)
|
65 |
+
if temperature < 1e-2:
|
66 |
+
temperature = 1e-2
|
67 |
+
top_p = float(top_p)
|
68 |
+
fim_mode = False
|
69 |
+
|
70 |
+
generate_kwargs = dict(
|
71 |
+
temperature=temperature,
|
72 |
+
max_new_tokens=max_new_tokens,
|
73 |
+
top_p=top_p,
|
74 |
+
repetition_penalty=repetition_penalty,
|
75 |
+
stream=True,
|
76 |
+
)
|
77 |
+
|
78 |
+
if FIM_INDICATOR in prompt:
|
79 |
+
fim_mode = True
|
80 |
+
try:
|
81 |
+
prefix, suffix = prompt.split(FIM_INDICATOR)
|
82 |
+
except:
|
83 |
+
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
|
84 |
+
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
85 |
+
|
86 |
+
stream = llama2_wrapper.__call__(prompt, **generate_kwargs)
|
87 |
+
|
88 |
+
if fim_mode:
|
89 |
+
output = prefix
|
90 |
+
else:
|
91 |
+
output = prompt
|
92 |
+
|
93 |
+
# for response in stream:
|
94 |
+
# output += response
|
95 |
+
# yield output
|
96 |
+
# return output
|
97 |
+
|
98 |
+
previous_token = ""
|
99 |
+
for response in stream:
|
100 |
+
if any([end_token in response for end_token in [EOS_STRING, EOT_STRING]]):
|
101 |
+
if fim_mode:
|
102 |
+
output += suffix
|
103 |
+
yield output
|
104 |
+
return output
|
105 |
+
print("output", output)
|
106 |
+
else:
|
107 |
+
return output
|
108 |
+
else:
|
109 |
+
output += response
|
110 |
+
previous_token = response
|
111 |
+
yield output
|
112 |
+
return output
|
113 |
+
|
114 |
+
examples = [
|
115 |
+
'def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\nprint(remove_non_ascii(\'afkdj$$(\'))',
|
116 |
+
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
|
117 |
+
"// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
|
118 |
+
"Poor English: She no went to the market. Corrected English:",
|
119 |
+
"def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_ME>\n else:\n results.extend(list2[i+1:])\n return results",
|
120 |
+
]
|
121 |
+
|
122 |
+
def process_example(args):
|
123 |
+
for x in generate(args):
|
124 |
+
pass
|
125 |
+
return x
|
126 |
+
|
127 |
+
description = """
|
128 |
+
<div style="text-align: center;">
|
129 |
+
<h1>Code Llama Playground</h1>
|
130 |
+
|
131 |
+
</div>
|
132 |
+
<div style="text-align: center;">
|
133 |
+
<p>This is a demo to complete code with Code Llama. For instruction purposes, please use llama2-webui app.py with CodeLlama-Instruct models. </p>
|
134 |
+
</div>
|
135 |
+
"""
|
136 |
+
with gr.Blocks() as demo:
|
137 |
+
with gr.Column():
|
138 |
+
gr.Markdown(description)
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
instruction = gr.Textbox(
|
142 |
+
placeholder="Enter your code here",
|
143 |
+
lines=5,
|
144 |
+
label="Input",
|
145 |
+
elem_id="q-input",
|
146 |
+
)
|
147 |
+
submit = gr.Button("Generate", variant="primary")
|
148 |
+
output = gr.Code(elem_id="q-output", lines=30, label="Output")
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
with gr.Accordion("Advanced settings", open=False):
|
152 |
+
with gr.Row():
|
153 |
+
column_1, column_2 = gr.Column(), gr.Column()
|
154 |
+
with column_1:
|
155 |
+
temperature = gr.Slider(
|
156 |
+
label="Temperature",
|
157 |
+
value=0.1,
|
158 |
+
minimum=0.0,
|
159 |
+
maximum=1.0,
|
160 |
+
step=0.05,
|
161 |
+
interactive=True,
|
162 |
+
info="Higher values produce more diverse outputs",
|
163 |
+
)
|
164 |
+
max_new_tokens = gr.Slider(
|
165 |
+
label="Max new tokens",
|
166 |
+
value=256,
|
167 |
+
minimum=0,
|
168 |
+
maximum=8192,
|
169 |
+
step=64,
|
170 |
+
interactive=True,
|
171 |
+
info="The maximum numbers of new tokens",
|
172 |
+
)
|
173 |
+
with column_2:
|
174 |
+
top_p = gr.Slider(
|
175 |
+
label="Top-p (nucleus sampling)",
|
176 |
+
value=0.90,
|
177 |
+
minimum=0.0,
|
178 |
+
maximum=1,
|
179 |
+
step=0.05,
|
180 |
+
interactive=True,
|
181 |
+
info="Higher values sample more low-probability tokens",
|
182 |
+
)
|
183 |
+
repetition_penalty = gr.Slider(
|
184 |
+
label="Repetition penalty",
|
185 |
+
value=1.05,
|
186 |
+
minimum=1.0,
|
187 |
+
maximum=2.0,
|
188 |
+
step=0.05,
|
189 |
+
interactive=True,
|
190 |
+
info="Penalize repeated tokens",
|
191 |
+
)
|
192 |
+
|
193 |
+
gr.Examples(
|
194 |
+
examples=examples,
|
195 |
+
inputs=[instruction],
|
196 |
+
cache_examples=False,
|
197 |
+
fn=process_example,
|
198 |
+
outputs=[output],
|
199 |
+
)
|
200 |
+
|
201 |
+
submit.click(
|
202 |
+
generate,
|
203 |
+
inputs=[
|
204 |
+
instruction,
|
205 |
+
temperature,
|
206 |
+
max_new_tokens,
|
207 |
+
top_p,
|
208 |
+
repetition_penalty,
|
209 |
+
],
|
210 |
+
outputs=[output],
|
211 |
+
)
|
212 |
+
demo.queue(concurrency_count=16).launch(share=args.share)
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
main()
|
colab/Llama_2_7b_Chat_GPTQ.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
colab/ggmlv3_q4_0.ipynb
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"toc_visible": true,
|
8 |
+
"authorship_tag": "ABX9TyM9WbudQYrVFksXUrt4Opt3",
|
9 |
+
"include_colab_link": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"language_info": {
|
16 |
+
"name": "python"
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"cells": [
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {
|
23 |
+
"id": "view-in-github",
|
24 |
+
"colab_type": "text"
|
25 |
+
},
|
26 |
+
"source": [
|
27 |
+
"<a href=\"https://colab.research.google.com/github/liltom-eth/llama2-webui/blob/main/colab/ggmlv3_q4_0.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": null,
|
33 |
+
"metadata": {
|
34 |
+
"id": "7O5JSosg5-rx"
|
35 |
+
},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"%cd /content\n",
|
39 |
+
"!pip install llama2-wrapper\n"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"source": [
|
45 |
+
"from llama2_wrapper import LLAMA2_WRAPPER, get_prompt\n",
|
46 |
+
"\n",
|
47 |
+
"llama2_wrapper = LLAMA2_WRAPPER()"
|
48 |
+
],
|
49 |
+
"metadata": {
|
50 |
+
"colab": {
|
51 |
+
"base_uri": "https://localhost:8080/"
|
52 |
+
},
|
53 |
+
"id": "8rgb1ckl72wC",
|
54 |
+
"outputId": "d9ca2e20-26a5-490b-86f2-1a182e533b20"
|
55 |
+
},
|
56 |
+
"execution_count": 5,
|
57 |
+
"outputs": [
|
58 |
+
{
|
59 |
+
"output_type": "stream",
|
60 |
+
"name": "stdout",
|
61 |
+
"text": [
|
62 |
+
"Running on backend llama.cpp.\n",
|
63 |
+
"Use default model path: ./models/llama-2-7b-chat.ggmlv3.q4_0.bin\n",
|
64 |
+
"Start downloading model to: ./models/llama-2-7b-chat.ggmlv3.q4_0.bin\n"
|
65 |
+
]
|
66 |
+
}
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"source": [
|
72 |
+
"prompt = get_prompt(\"Hi do you know Pytorch?\")\n",
|
73 |
+
"print(llama2_wrapper(prompt))"
|
74 |
+
],
|
75 |
+
"metadata": {
|
76 |
+
"id": "Qz2xAqozTIf6",
|
77 |
+
"colab": {
|
78 |
+
"base_uri": "https://localhost:8080/"
|
79 |
+
},
|
80 |
+
"outputId": "1380fa52-3d4a-4ac5-ed02-7faefe7ec2f6"
|
81 |
+
},
|
82 |
+
"execution_count": 3,
|
83 |
+
"outputs": [
|
84 |
+
{
|
85 |
+
"output_type": "stream",
|
86 |
+
"name": "stdout",
|
87 |
+
"text": [
|
88 |
+
" Yes, I'm familiar with PyTorch! PyTorch is an open-source deep learning framework that is widely used for building and training neural networks. It was originally developed by Facebook and is now maintained by the PyTorch Foundation.\n",
|
89 |
+
"\n",
|
90 |
+
"Here are some key features and capabilities of PyTorch:\n",
|
91 |
+
"\n",
|
92 |
+
"1. **Tensor Computation**: PyTorch provides a powerful tensor computation engine that allows for complex mathematical operations on large datasets.\n",
|
93 |
+
"2. **Autograd**: PyTorch's autograd system automatically computes gradients, which can save a lot of time and effort during training.\n",
|
94 |
+
"3. **Dynamic Compute**: PyTorch's dynamic compute system allows for more efficient computation by only computing the necessary computations at runtime.\n",
|
95 |
+
"4. **Memory-efficient**: PyTorch is designed to be memory-efficient, which is important for training large models that require a lot of memory.\n",
|
96 |
+
"5. **Accelerators**: PyTorch supports a wide range of accelerators, including GPUs, TPUs, and FPGAs, which can significantly speed up training times.\n",
|
97 |
+
"6. **Modules**: PyTorch provides a wide range of pre-built modules for common tasks, such as convolutional layers, recurrent neural networks, and more.\n",
|
98 |
+
"7. **Extensive Community**: PyTorch has a large and active community of developers and users, which can be helpful for getting support and staying up-to-date with the latest developments.\n",
|
99 |
+
"8. **Easy Integration**: PyTorch can be easily integrated with other popular deep learning frameworks, such as TensorFlow and Keras.\n",
|
100 |
+
"9. **Pythonic**: PyTorch is written in Python, which is a popular and easy-to-learn programming language.\n",
|
101 |
+
"10. **Flexible**: PyTorch allows for a wide range of customization options, which can be useful for building and training unique models.\n",
|
102 |
+
"\n",
|
103 |
+
"Overall, PyTorch is a powerful and flexible deep learning framework that can be used for a wide range of applications, including computer vision, natural language processing, and more.\n"
|
104 |
+
]
|
105 |
+
}
|
106 |
+
]
|
107 |
+
}
|
108 |
+
]
|
109 |
+
}
|
colab/webui_CodeLlama_7B_Instruct_GPTQ.ipynb
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"gpuType": "T4",
|
8 |
+
"authorship_tag": "ABX9TyOZhPcZe61RhDjhEFQv0vrl",
|
9 |
+
"include_colab_link": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"language_info": {
|
16 |
+
"name": "python"
|
17 |
+
},
|
18 |
+
"accelerator": "GPU"
|
19 |
+
},
|
20 |
+
"cells": [
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"metadata": {
|
24 |
+
"id": "view-in-github",
|
25 |
+
"colab_type": "text"
|
26 |
+
},
|
27 |
+
"source": [
|
28 |
+
"<a href=\"https://colab.research.google.com/github/liltom-eth/llama2-webui/blob/main/colab/webui_CodeLlama_7B_Instruct_GPTQ.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": null,
|
34 |
+
"metadata": {
|
35 |
+
"id": "7O5JSosg5-rx"
|
36 |
+
},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"!pip install -U llama2-wrapper==0.1.12"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"source": [
|
45 |
+
"%cd /content\n",
|
46 |
+
"!git clone https://github.com/liltom-eth/llama2-webui\n",
|
47 |
+
"\n",
|
48 |
+
"%cd /content/llama2-webui\n",
|
49 |
+
"!python -m llama2_wrapper.download --repo_id TheBloke/CodeLlama-7B-Instruct-GPTQ\n",
|
50 |
+
"\n",
|
51 |
+
"%cd /content/llama2-webui\n",
|
52 |
+
"!python app.py --backend_type gptq --model_path ./models/CodeLlama-7B-Instruct-GPTQ/ --share True"
|
53 |
+
],
|
54 |
+
"metadata": {
|
55 |
+
"colab": {
|
56 |
+
"base_uri": "https://localhost:8080/"
|
57 |
+
},
|
58 |
+
"id": "Y6A7bJdkmzY8",
|
59 |
+
"outputId": "0d702a7d-68ab-4747-f012-246d4dee3718"
|
60 |
+
},
|
61 |
+
"execution_count": 4,
|
62 |
+
"outputs": [
|
63 |
+
{
|
64 |
+
"output_type": "stream",
|
65 |
+
"name": "stdout",
|
66 |
+
"text": [
|
67 |
+
"/content\n",
|
68 |
+
"fatal: destination path 'llama2-webui' already exists and is not an empty directory.\n",
|
69 |
+
"/content/llama2-webui\n",
|
70 |
+
"Start downloading model TheBloke/CodeLlama-7B-Instruct-GPTQ to: ./models/CodeLlama-7B-Instruct-GPTQ\n",
|
71 |
+
"Fetching 15 files: 0% 0/15 [00:00<?, ?it/s]\n",
|
72 |
+
"Downloading (…)d0d05/.gitattributes: 100% 1.52k/1.52k [00:00<00:00, 7.94MB/s]\n",
|
73 |
+
"Fetching 15 files: 7% 1/15 [00:01<00:16, 1.15s/it]\n",
|
74 |
+
"Downloading (…)478d0d05/LICENSE.txt: 100% 7.02k/7.02k [00:00<00:00, 31.6MB/s]\n",
|
75 |
+
"\n",
|
76 |
+
"Downloading (…)478d0d05/config.json: 100% 1.25k/1.25k [00:00<00:00, 7.95MB/s]\n",
|
77 |
+
"\n",
|
78 |
+
"Downloading (…)nfiguration_llama.py: 100% 8.56k/8.56k [00:00<00:00, 41.7MB/s]\n",
|
79 |
+
"\n",
|
80 |
+
"Downloading (…)81b84478d0d05/Notice: 100% 112/112 [00:00<00:00, 750kB/s]\n",
|
81 |
+
"\n",
|
82 |
+
"Downloading (…)neration_config.json: 100% 132/132 [00:00<00:00, 836kB/s]\n",
|
83 |
+
"\n",
|
84 |
+
"Downloading (…)8d0d05/USE_POLICY.md: 100% 105/105 [00:00<00:00, 686kB/s]\n",
|
85 |
+
"\n",
|
86 |
+
"Downloading (…)84478d0d05/README.md: 100% 22.0k/22.0k [00:00<00:00, 59.5MB/s]\n",
|
87 |
+
"\n",
|
88 |
+
"Downloading (…)05/modeling_llama.py: 100% 45.9k/45.9k [00:00<00:00, 27.5MB/s]\n",
|
89 |
+
"\n",
|
90 |
+
"Downloading (…)quantize_config.json: 100% 187/187 [00:00<00:00, 1.34MB/s]\n",
|
91 |
+
"\n",
|
92 |
+
"Downloading (…)cial_tokens_map.json: 100% 411/411 [00:00<00:00, 2.82MB/s]\n",
|
93 |
+
"\n",
|
94 |
+
"Downloading (…)d0d05/tokenizer.json: 0% 0.00/1.84M [00:00<?, ?B/s]\u001b[A\n",
|
95 |
+
"\n",
|
96 |
+
"Downloading (…)okenizer_config.json: 100% 824/824 [00:00<00:00, 5.75MB/s]\n",
|
97 |
+
"\n",
|
98 |
+
"\n",
|
99 |
+
"Downloading model.safetensors: 0% 0.00/3.90G [00:00<?, ?B/s]\u001b[A\u001b[A\n",
|
100 |
+
"\n",
|
101 |
+
"\n",
|
102 |
+
"Downloading tokenizer.model: 100% 500k/500k [00:00<00:00, 16.3MB/s]\n",
|
103 |
+
"\n",
|
104 |
+
"Downloading (…)d0d05/tokenizer.json: 100% 1.84M/1.84M [00:00<00:00, 5.47MB/s]\n",
|
105 |
+
"\n",
|
106 |
+
"\n",
|
107 |
+
"Downloading model.safetensors: 0% 10.5M/3.90G [00:00<01:08, 56.4MB/s]\u001b[A\u001b[A\n",
|
108 |
+
"\n",
|
109 |
+
"Downloading model.safetensors: 1% 21.0M/3.90G [00:00<00:57, 67.1MB/s]\u001b[A\u001b[A\n",
|
110 |
+
"\n",
|
111 |
+
"Downloading model.safetensors: 1% 31.5M/3.90G [00:00<00:51, 75.5MB/s]\u001b[A\u001b[A\n",
|
112 |
+
"\n",
|
113 |
+
"Downloading model.safetensors: 1% 52.4M/3.90G [00:00<00:40, 94.5MB/s]\u001b[A\u001b[A\n",
|
114 |
+
"\n",
|
115 |
+
"Downloading model.safetensors: 2% 73.4M/3.90G [00:00<00:33, 113MB/s] \u001b[A\u001b[A\n",
|
116 |
+
"\n",
|
117 |
+
"Downloading model.safetensors: 2% 94.4M/3.90G [00:00<00:28, 133MB/s]\u001b[A\u001b[A\n",
|
118 |
+
"\n",
|
119 |
+
"Downloading model.safetensors: 3% 115M/3.90G [00:00<00:25, 148MB/s] \u001b[A\u001b[A\n",
|
120 |
+
"\n",
|
121 |
+
"Downloading model.safetensors: 3% 136M/3.90G [00:01<00:24, 156MB/s]\u001b[A\u001b[A\n",
|
122 |
+
"\n",
|
123 |
+
"Downloading model.safetensors: 4% 157M/3.90G [00:01<00:22, 167MB/s]\u001b[A\u001b[A\n",
|
124 |
+
"\n",
|
125 |
+
"Downloading model.safetensors: 5% 178M/3.90G [00:01<00:22, 168MB/s]\u001b[A\u001b[A\n",
|
126 |
+
"\n",
|
127 |
+
"Downloading model.safetensors: 5% 199M/3.90G [00:01<00:21, 169MB/s]\u001b[A\u001b[A\n",
|
128 |
+
"\n",
|
129 |
+
"Downloading model.safetensors: 6% 220M/3.90G [00:01<00:21, 170MB/s]\u001b[A\u001b[A\n",
|
130 |
+
"\n",
|
131 |
+
"Downloading model.safetensors: 6% 241M/3.90G [00:01<00:21, 174MB/s]\u001b[A\u001b[A\n",
|
132 |
+
"\n",
|
133 |
+
"Downloading model.safetensors: 7% 262M/3.90G [00:01<00:20, 177MB/s]\u001b[A\u001b[A\n",
|
134 |
+
"\n",
|
135 |
+
"Downloading model.safetensors: 7% 283M/3.90G [00:02<01:08, 52.9MB/s]\u001b[A\u001b[A\n",
|
136 |
+
"\n",
|
137 |
+
"Downloading model.safetensors: 8% 315M/3.90G [00:02<00:47, 75.6MB/s]\u001b[A\u001b[A\n",
|
138 |
+
"\n",
|
139 |
+
"Downloading model.safetensors: 9% 346M/3.90G [00:03<00:36, 97.8MB/s]\u001b[A\u001b[A\n",
|
140 |
+
"\n",
|
141 |
+
"Downloading model.safetensors: 9% 367M/3.90G [00:03<00:31, 111MB/s] \u001b[A\u001b[A\n",
|
142 |
+
"\n",
|
143 |
+
"Downloading model.safetensors: 10% 388M/3.90G [00:03<00:28, 122MB/s]\u001b[A\u001b[A\n",
|
144 |
+
"\n",
|
145 |
+
"Downloading model.safetensors: 10% 409M/3.90G [00:03<00:26, 134MB/s]\u001b[A\u001b[A\n",
|
146 |
+
"\n",
|
147 |
+
"Downloading model.safetensors: 11% 430M/3.90G [00:03<00:24, 141MB/s]\u001b[A\u001b[A\n",
|
148 |
+
"\n",
|
149 |
+
"Downloading model.safetensors: 12% 461M/3.90G [00:03<00:21, 160MB/s]\u001b[A\u001b[A\n",
|
150 |
+
"\n",
|
151 |
+
"Downloading model.safetensors: 12% 482M/3.90G [00:03<00:20, 165MB/s]\u001b[A\u001b[A\n",
|
152 |
+
"\n",
|
153 |
+
"Downloading model.safetensors: 13% 503M/3.90G [00:04<00:20, 166MB/s]\u001b[A\u001b[A\n",
|
154 |
+
"\n",
|
155 |
+
"Downloading model.safetensors: 13% 524M/3.90G [00:04<00:19, 170MB/s]\u001b[A\u001b[A\n",
|
156 |
+
"\n",
|
157 |
+
"Downloading model.safetensors: 14% 556M/3.90G [00:04<00:18, 181MB/s]\u001b[A\u001b[A\n",
|
158 |
+
"\n",
|
159 |
+
"Downloading model.safetensors: 15% 577M/3.90G [00:04<00:18, 182MB/s]\u001b[A\u001b[A\n",
|
160 |
+
"\n",
|
161 |
+
"Downloading model.safetensors: 15% 598M/3.90G [00:04<00:18, 183MB/s]\u001b[A\u001b[A\n",
|
162 |
+
"\n",
|
163 |
+
"Downloading model.safetensors: 16% 619M/3.90G [00:04<00:17, 184MB/s]\u001b[A\u001b[A\n",
|
164 |
+
"\n",
|
165 |
+
"Downloading model.safetensors: 16% 640M/3.90G [00:04<00:17, 184MB/s]\u001b[A\u001b[A\n",
|
166 |
+
"\n",
|
167 |
+
"Downloading model.safetensors: 17% 661M/3.90G [00:04<00:18, 178MB/s]\u001b[A\u001b[A\n",
|
168 |
+
"\n",
|
169 |
+
"Downloading model.safetensors: 17% 682M/3.90G [00:04<00:17, 180MB/s]\u001b[A\u001b[A\n",
|
170 |
+
"\n",
|
171 |
+
"Downloading model.safetensors: 18% 703M/3.90G [00:05<00:17, 180MB/s]\u001b[A\u001b[A\n",
|
172 |
+
"\n",
|
173 |
+
"Downloading model.safetensors: 19% 724M/3.90G [00:05<00:17, 181MB/s]\u001b[A\u001b[A\n",
|
174 |
+
"\n",
|
175 |
+
"Downloading model.safetensors: 19% 744M/3.90G [00:05<00:18, 171MB/s]\u001b[A\u001b[A\n",
|
176 |
+
"\n",
|
177 |
+
"Downloading model.safetensors: 20% 765M/3.90G [00:05<00:18, 173MB/s]\u001b[A\u001b[A\n",
|
178 |
+
"\n",
|
179 |
+
"Downloading model.safetensors: 20% 786M/3.90G [00:05<00:17, 175MB/s]\u001b[A\u001b[A\n",
|
180 |
+
"\n",
|
181 |
+
"Downloading model.safetensors: 21% 807M/3.90G [00:05<00:17, 178MB/s]\u001b[A\u001b[A\n",
|
182 |
+
"\n",
|
183 |
+
"Downloading model.safetensors: 21% 828M/3.90G [00:05<00:17, 180MB/s]\u001b[A\u001b[A\n",
|
184 |
+
"\n",
|
185 |
+
"Downloading model.safetensors: 22% 849M/3.90G [00:05<00:16, 182MB/s]\u001b[A\u001b[A\n",
|
186 |
+
"\n",
|
187 |
+
"Downloading model.safetensors: 22% 870M/3.90G [00:07<01:37, 30.9MB/s]\u001b[A\u001b[A\n",
|
188 |
+
"\n",
|
189 |
+
"Downloading model.safetensors: 23% 891M/3.90G [00:08<01:13, 40.8MB/s]\u001b[A\u001b[A\n",
|
190 |
+
"\n",
|
191 |
+
"Downloading model.safetensors: 24% 923M/3.90G [00:08<00:50, 59.3MB/s]\u001b[A\u001b[A\n",
|
192 |
+
"\n",
|
193 |
+
"Downloading model.safetensors: 24% 944M/3.90G [00:08<00:42, 70.2MB/s]\u001b[A\u001b[A\n",
|
194 |
+
"\n",
|
195 |
+
"Downloading model.safetensors: 25% 975M/3.90G [00:08<00:30, 94.3MB/s]\u001b[A\u001b[A\n",
|
196 |
+
"\n",
|
197 |
+
"Downloading model.safetensors: 26% 996M/3.90G [00:08<00:27, 107MB/s] \u001b[A\u001b[A\n",
|
198 |
+
"\n",
|
199 |
+
"Downloading model.safetensors: 26% 1.02G/3.90G [00:08<00:23, 121MB/s]\u001b[A\u001b[A\n",
|
200 |
+
"\n",
|
201 |
+
"Downloading model.safetensors: 27% 1.04G/3.90G [00:08<00:21, 134MB/s]\u001b[A\u001b[A\n",
|
202 |
+
"\n",
|
203 |
+
"Downloading model.safetensors: 27% 1.06G/3.90G [00:08<00:20, 141MB/s]\u001b[A\u001b[A\n",
|
204 |
+
"\n",
|
205 |
+
"Downloading model.safetensors: 28% 1.08G/3.90G [00:09<00:18, 151MB/s]\u001b[A\u001b[A\n",
|
206 |
+
"\n",
|
207 |
+
"Downloading model.safetensors: 28% 1.10G/3.90G [00:09<00:17, 160MB/s]\u001b[A\u001b[A\n",
|
208 |
+
"\n",
|
209 |
+
"Downloading model.safetensors: 29% 1.12G/3.90G [00:09<00:16, 166MB/s]\u001b[A\u001b[A\n",
|
210 |
+
"\n",
|
211 |
+
"Downloading model.safetensors: 29% 1.14G/3.90G [00:09<00:16, 171MB/s]\u001b[A\u001b[A\n",
|
212 |
+
"\n",
|
213 |
+
"Downloading model.safetensors: 30% 1.16G/3.90G [00:09<00:15, 175MB/s]\u001b[A\u001b[A\n",
|
214 |
+
"\n",
|
215 |
+
"Downloading model.safetensors: 30% 1.18G/3.90G [00:09<00:15, 178MB/s]\u001b[A\u001b[A\n",
|
216 |
+
"\n",
|
217 |
+
"Downloading model.safetensors: 31% 1.21G/3.90G [00:09<00:15, 179MB/s]\u001b[A\u001b[A\n",
|
218 |
+
"\n",
|
219 |
+
"Downloading model.safetensors: 31% 1.23G/3.90G [00:09<00:14, 181MB/s]\u001b[A\u001b[A\n",
|
220 |
+
"\n",
|
221 |
+
"Downloading model.safetensors: 32% 1.25G/3.90G [00:09<00:14, 182MB/s]\u001b[A\u001b[A\n",
|
222 |
+
"\n",
|
223 |
+
"Downloading model.safetensors: 33% 1.27G/3.90G [00:10<00:23, 113MB/s]\u001b[A\u001b[A\n",
|
224 |
+
"\n",
|
225 |
+
"Downloading model.safetensors: 33% 1.29G/3.90G [00:10<00:20, 128MB/s]\u001b[A\u001b[A\n",
|
226 |
+
"\n",
|
227 |
+
"Downloading model.safetensors: 34% 1.31G/3.90G [00:10<00:18, 139MB/s]\u001b[A\u001b[A\n",
|
228 |
+
"\n",
|
229 |
+
"Downloading model.safetensors: 34% 1.33G/3.90G [00:10<00:17, 150MB/s]\u001b[A\u001b[A\n",
|
230 |
+
"\n",
|
231 |
+
"Downloading model.safetensors: 35% 1.35G/3.90G [00:10<00:16, 158MB/s]\u001b[A\u001b[A\n",
|
232 |
+
"\n",
|
233 |
+
"Downloading model.safetensors: 35% 1.37G/3.90G [00:12<01:24, 29.9MB/s]\u001b[A\u001b[A\n",
|
234 |
+
"\n",
|
235 |
+
"Downloading model.safetensors: 36% 1.41G/3.90G [00:12<00:55, 45.3MB/s]\u001b[A\u001b[A\n",
|
236 |
+
"\n",
|
237 |
+
"Downloading model.safetensors: 37% 1.44G/3.90G [00:13<00:39, 63.0MB/s]\u001b[A\u001b[A\n",
|
238 |
+
"\n",
|
239 |
+
"Downloading model.safetensors: 37% 1.46G/3.90G [00:13<00:33, 72.6MB/s]\u001b[A\u001b[A\n",
|
240 |
+
"\n",
|
241 |
+
"Downloading model.safetensors: 38% 1.48G/3.90G [00:13<00:29, 82.0MB/s]\u001b[A\u001b[A\n",
|
242 |
+
"\n",
|
243 |
+
"Downloading model.safetensors: 38% 1.50G/3.90G [00:13<00:24, 98.6MB/s]\u001b[A\u001b[A\n",
|
244 |
+
"\n",
|
245 |
+
"Downloading model.safetensors: 39% 1.53G/3.90G [00:13<00:19, 124MB/s] \u001b[A\u001b[A\n",
|
246 |
+
"\n",
|
247 |
+
"Downloading model.safetensors: 40% 1.55G/3.90G [00:13<00:17, 132MB/s]\u001b[A\u001b[A\n",
|
248 |
+
"\n",
|
249 |
+
"Downloading model.safetensors: 40% 1.57G/3.90G [00:13<00:16, 143MB/s]\u001b[A\u001b[A\n",
|
250 |
+
"\n",
|
251 |
+
"Downloading model.safetensors: 41% 1.59G/3.90G [00:14<00:15, 153MB/s]\u001b[A\u001b[A\n",
|
252 |
+
"\n",
|
253 |
+
"Downloading model.safetensors: 41% 1.61G/3.90G [00:14<00:14, 160MB/s]\u001b[A\u001b[A\n",
|
254 |
+
"\n",
|
255 |
+
"Downloading model.safetensors: 42% 1.64G/3.90G [00:14<00:13, 167MB/s]\u001b[A\u001b[A\n",
|
256 |
+
"\n",
|
257 |
+
"Downloading model.safetensors: 43% 1.66G/3.90G [00:14<00:13, 171MB/s]\u001b[A\u001b[A\n",
|
258 |
+
"\n",
|
259 |
+
"Downloading model.safetensors: 43% 1.68G/3.90G [00:14<00:12, 177MB/s]\u001b[A\u001b[A\n",
|
260 |
+
"\n",
|
261 |
+
"Downloading model.safetensors: 44% 1.70G/3.90G [00:14<00:12, 174MB/s]\u001b[A\u001b[A\n",
|
262 |
+
"\n",
|
263 |
+
"Downloading model.safetensors: 44% 1.72G/3.90G [00:14<00:12, 173MB/s]\u001b[A\u001b[A\n",
|
264 |
+
"\n",
|
265 |
+
"Downloading model.safetensors: 45% 1.74G/3.90G [00:14<00:12, 175MB/s]\u001b[A\u001b[A\n",
|
266 |
+
"\n",
|
267 |
+
"Downloading model.safetensors: 45% 1.76G/3.90G [00:14<00:11, 179MB/s]\u001b[A\u001b[A\n",
|
268 |
+
"\n",
|
269 |
+
"Downloading model.safetensors: 46% 1.78G/3.90G [00:15<00:12, 172MB/s]\u001b[A\u001b[A\n",
|
270 |
+
"\n",
|
271 |
+
"Downloading model.safetensors: 46% 1.80G/3.90G [00:15<00:12, 174MB/s]\u001b[A\u001b[A\n",
|
272 |
+
"\n",
|
273 |
+
"Downloading model.safetensors: 47% 1.82G/3.90G [00:15<00:11, 177MB/s]\u001b[A\u001b[A\n",
|
274 |
+
"\n",
|
275 |
+
"Downloading model.safetensors: 47% 1.85G/3.90G [00:16<00:28, 71.9MB/s]\u001b[A\u001b[A\n",
|
276 |
+
"\n",
|
277 |
+
"Downloading model.safetensors: 48% 1.87G/3.90G [00:16<00:23, 87.4MB/s]\u001b[A\u001b[A\n",
|
278 |
+
"\n",
|
279 |
+
"Downloading model.safetensors: 49% 1.90G/3.90G [00:16<00:16, 118MB/s] \u001b[A\u001b[A\n",
|
280 |
+
"\n",
|
281 |
+
"Downloading model.safetensors: 49% 1.92G/3.90G [00:16<00:14, 132MB/s]\u001b[A\u001b[A\n",
|
282 |
+
"\n",
|
283 |
+
"Downloading model.safetensors: 50% 1.94G/3.90G [00:16<00:13, 143MB/s]\u001b[A\u001b[A\n",
|
284 |
+
"\n",
|
285 |
+
"Downloading model.safetensors: 50% 1.96G/3.90G [00:16<00:12, 152MB/s]\u001b[A\u001b[A\n",
|
286 |
+
"\n",
|
287 |
+
"Downloading model.safetensors: 51% 1.98G/3.90G [00:16<00:13, 142MB/s]\u001b[A\u001b[A\n",
|
288 |
+
"\n",
|
289 |
+
"Downloading model.safetensors: 51% 2.00G/3.90G [00:16<00:13, 144MB/s]\u001b[A\u001b[A\n",
|
290 |
+
"\n",
|
291 |
+
"Downloading model.safetensors: 52% 2.02G/3.90G [00:17<00:12, 144MB/s]\u001b[A\u001b[A\n",
|
292 |
+
"\n",
|
293 |
+
"Downloading model.safetensors: 52% 2.04G/3.90G [00:17<00:12, 148MB/s]\u001b[A\u001b[A\n",
|
294 |
+
"\n",
|
295 |
+
"Downloading model.safetensors: 53% 2.07G/3.90G [00:17<00:12, 152MB/s]\u001b[A\u001b[A\n",
|
296 |
+
"\n",
|
297 |
+
"Downloading model.safetensors: 54% 2.09G/3.90G [00:17<00:22, 81.2MB/s]\u001b[A\u001b[A\n",
|
298 |
+
"\n",
|
299 |
+
"Downloading model.safetensors: 54% 2.12G/3.90G [00:18<00:16, 107MB/s] \u001b[A\u001b[A\n",
|
300 |
+
"\n",
|
301 |
+
"Downloading model.safetensors: 55% 2.14G/3.90G [00:18<00:14, 119MB/s]\u001b[A\u001b[A\n",
|
302 |
+
"\n",
|
303 |
+
"Downloading model.safetensors: 55% 2.16G/3.90G [00:18<00:14, 123MB/s]\u001b[A\u001b[A\n",
|
304 |
+
"\n",
|
305 |
+
"Downloading model.safetensors: 56% 2.18G/3.90G [00:18<00:13, 131MB/s]\u001b[A\u001b[A\n",
|
306 |
+
"\n",
|
307 |
+
"Downloading model.safetensors: 57% 2.21G/3.90G [00:18<00:10, 156MB/s]\u001b[A\u001b[A\n",
|
308 |
+
"\n",
|
309 |
+
"Downloading model.safetensors: 57% 2.23G/3.90G [00:18<00:10, 162MB/s]\u001b[A\u001b[A\n",
|
310 |
+
"\n",
|
311 |
+
"Downloading model.safetensors: 58% 2.25G/3.90G [00:18<00:10, 160MB/s]\u001b[A\u001b[A\n",
|
312 |
+
"\n",
|
313 |
+
"Downloading model.safetensors: 59% 2.29G/3.90G [00:18<00:09, 174MB/s]\u001b[A\u001b[A\n",
|
314 |
+
"\n",
|
315 |
+
"Downloading model.safetensors: 59% 2.31G/3.90G [00:19<00:08, 178MB/s]\u001b[A\u001b[A\n",
|
316 |
+
"\n",
|
317 |
+
"Downloading model.safetensors: 60% 2.33G/3.90G [00:19<00:08, 180MB/s]\u001b[A\u001b[A\n",
|
318 |
+
"\n",
|
319 |
+
"Downloading model.safetensors: 60% 2.35G/3.90G [00:19<00:08, 181MB/s]\u001b[A\u001b[A\n",
|
320 |
+
"\n",
|
321 |
+
"Downloading model.safetensors: 61% 2.37G/3.90G [00:19<00:08, 181MB/s]\u001b[A\u001b[A\n",
|
322 |
+
"\n",
|
323 |
+
"Downloading model.safetensors: 61% 2.39G/3.90G [00:19<00:08, 181MB/s]\u001b[A\u001b[A\n",
|
324 |
+
"\n",
|
325 |
+
"Downloading model.safetensors: 62% 2.41G/3.90G [00:19<00:08, 182MB/s]\u001b[A\u001b[A\n",
|
326 |
+
"\n",
|
327 |
+
"Downloading model.safetensors: 62% 2.43G/3.90G [00:19<00:08, 182MB/s]\u001b[A\u001b[A\n",
|
328 |
+
"\n",
|
329 |
+
"Downloading model.safetensors: 63% 2.45G/3.90G [00:19<00:08, 177MB/s]\u001b[A\u001b[A\n",
|
330 |
+
"\n",
|
331 |
+
"Downloading model.safetensors: 64% 2.47G/3.90G [00:20<00:11, 124MB/s]\u001b[A\u001b[A\n",
|
332 |
+
"\n",
|
333 |
+
"Downloading model.safetensors: 64% 2.51G/3.90G [00:20<00:09, 149MB/s]\u001b[A\u001b[A\n",
|
334 |
+
"\n",
|
335 |
+
"Downloading model.safetensors: 65% 2.53G/3.90G [00:22<00:40, 34.2MB/s]\u001b[A\u001b[A\n",
|
336 |
+
"\n",
|
337 |
+
"Downloading model.safetensors: 66% 2.56G/3.90G [00:22<00:26, 50.1MB/s]\u001b[A\u001b[A\n",
|
338 |
+
"\n",
|
339 |
+
"Downloading model.safetensors: 66% 2.58G/3.90G [00:22<00:21, 60.1MB/s]\u001b[A\u001b[A\n",
|
340 |
+
"\n",
|
341 |
+
"Downloading model.safetensors: 67% 2.60G/3.90G [00:22<00:18, 69.4MB/s]\u001b[A\u001b[A\n",
|
342 |
+
"\n",
|
343 |
+
"Downloading model.safetensors: 67% 2.62G/3.90G [00:22<00:15, 84.0MB/s]\u001b[A\u001b[A\n",
|
344 |
+
"\n",
|
345 |
+
"Downloading model.safetensors: 68% 2.64G/3.90G [00:22<00:12, 99.4MB/s]\u001b[A\u001b[A\n",
|
346 |
+
"\n",
|
347 |
+
"Downloading model.safetensors: 68% 2.66G/3.90G [00:23<00:12, 96.0MB/s]\u001b[A\u001b[A\n",
|
348 |
+
"\n",
|
349 |
+
"Downloading model.safetensors: 69% 2.68G/3.90G [00:23<00:12, 95.4MB/s]\u001b[A\u001b[A\n",
|
350 |
+
"\n",
|
351 |
+
"Downloading model.safetensors: 69% 2.71G/3.90G [00:23<00:14, 84.2MB/s]\u001b[A\u001b[A\n",
|
352 |
+
"\n",
|
353 |
+
"Downloading model.safetensors: 70% 2.73G/3.90G [00:23<00:14, 82.0MB/s]\u001b[A\u001b[A\n",
|
354 |
+
"\n",
|
355 |
+
"Downloading model.safetensors: 70% 2.74G/3.90G [00:24<00:14, 80.9MB/s]\u001b[A\u001b[A\n",
|
356 |
+
"\n",
|
357 |
+
"Downloading model.safetensors: 70% 2.75G/3.90G [00:24<00:15, 75.8MB/s]\u001b[A\u001b[A\n",
|
358 |
+
"\n",
|
359 |
+
"Downloading model.safetensors: 71% 2.76G/3.90G [00:24<00:15, 75.3MB/s]\u001b[A\u001b[A\n",
|
360 |
+
"\n",
|
361 |
+
"Downloading model.safetensors: 71% 2.77G/3.90G [00:24<00:15, 72.2MB/s]\u001b[A\u001b[A\n",
|
362 |
+
"\n",
|
363 |
+
"Downloading model.safetensors: 71% 2.78G/3.90G [00:24<00:14, 74.9MB/s]\u001b[A\u001b[A\n",
|
364 |
+
"\n",
|
365 |
+
"Downloading model.safetensors: 72% 2.79G/3.90G [00:24<00:14, 74.7MB/s]\u001b[A\u001b[A\n",
|
366 |
+
"\n",
|
367 |
+
"Downloading model.safetensors: 72% 2.80G/3.90G [00:25<00:15, 69.4MB/s]\u001b[A\u001b[A\n",
|
368 |
+
"\n",
|
369 |
+
"Downloading model.safetensors: 72% 2.81G/3.90G [00:25<00:15, 71.3MB/s]\u001b[A\u001b[A\n",
|
370 |
+
"\n",
|
371 |
+
"Downloading model.safetensors: 72% 2.82G/3.90G [00:25<00:13, 77.5MB/s]\u001b[A\u001b[A\n",
|
372 |
+
"\n",
|
373 |
+
"Downloading model.safetensors: 73% 2.84G/3.90G [00:25<00:12, 84.6MB/s]\u001b[A\u001b[A\n",
|
374 |
+
"\n",
|
375 |
+
"Downloading model.safetensors: 73% 2.85G/3.90G [00:25<00:12, 83.8MB/s]\u001b[A\u001b[A\n",
|
376 |
+
"\n",
|
377 |
+
"Downloading model.safetensors: 73% 2.86G/3.90G [00:25<00:12, 81.6MB/s]\u001b[A\u001b[A\n",
|
378 |
+
"\n",
|
379 |
+
"Downloading model.safetensors: 74% 2.88G/3.90G [00:25<00:10, 97.2MB/s]\u001b[A\u001b[A\n",
|
380 |
+
"\n",
|
381 |
+
"Downloading model.safetensors: 75% 2.90G/3.90G [00:26<00:08, 118MB/s] \u001b[A\u001b[A\n",
|
382 |
+
"\n",
|
383 |
+
"Downloading model.safetensors: 75% 2.93G/3.90G [00:26<00:07, 134MB/s]\u001b[A\u001b[A\n",
|
384 |
+
"\n",
|
385 |
+
"Downloading model.safetensors: 76% 2.95G/3.90G [00:26<00:06, 149MB/s]\u001b[A\u001b[A\n",
|
386 |
+
"\n",
|
387 |
+
"Downloading model.safetensors: 76% 2.97G/3.90G [00:26<00:05, 159MB/s]\u001b[A\u001b[A\n",
|
388 |
+
"\n",
|
389 |
+
"Downloading model.safetensors: 77% 2.99G/3.90G [00:27<00:23, 37.9MB/s]\u001b[A\u001b[A\n",
|
390 |
+
"\n",
|
391 |
+
"Downloading model.safetensors: 77% 3.02G/3.90G [00:27<00:15, 57.4MB/s]\u001b[A\u001b[A\n",
|
392 |
+
"\n",
|
393 |
+
"Downloading model.safetensors: 78% 3.04G/3.90G [00:28<00:12, 67.9MB/s]\u001b[A\u001b[A\n",
|
394 |
+
"\n",
|
395 |
+
"Downloading model.safetensors: 79% 3.06G/3.90G [00:28<00:10, 78.8MB/s]\u001b[A\u001b[A\n",
|
396 |
+
"\n",
|
397 |
+
"Downloading model.safetensors: 79% 3.08G/3.90G [00:28<00:08, 92.9MB/s]\u001b[A\u001b[A\n",
|
398 |
+
"\n",
|
399 |
+
"Downloading model.safetensors: 80% 3.10G/3.90G [00:28<00:07, 109MB/s] \u001b[A\u001b[A\n",
|
400 |
+
"\n",
|
401 |
+
"Downloading model.safetensors: 80% 3.14G/3.90G [00:28<00:05, 138MB/s]\u001b[A\u001b[A\n",
|
402 |
+
"\n",
|
403 |
+
"Downloading model.safetensors: 81% 3.16G/3.90G [00:28<00:05, 146MB/s]\u001b[A\u001b[A\n",
|
404 |
+
"\n",
|
405 |
+
"Downloading model.safetensors: 82% 3.18G/3.90G [00:28<00:04, 152MB/s]\u001b[A\u001b[A\n",
|
406 |
+
"\n",
|
407 |
+
"Downloading model.safetensors: 82% 3.20G/3.90G [00:29<00:04, 161MB/s]\u001b[A\u001b[A\n",
|
408 |
+
"\n",
|
409 |
+
"Downloading model.safetensors: 83% 3.22G/3.90G [00:29<00:03, 170MB/s]\u001b[A\u001b[A\n",
|
410 |
+
"\n",
|
411 |
+
"Downloading model.safetensors: 83% 3.24G/3.90G [00:29<00:04, 158MB/s]\u001b[A\u001b[A\n",
|
412 |
+
"\n",
|
413 |
+
"Downloading model.safetensors: 84% 3.26G/3.90G [00:29<00:04, 156MB/s]\u001b[A\u001b[A\n",
|
414 |
+
"\n",
|
415 |
+
"Downloading model.safetensors: 84% 3.28G/3.90G [00:29<00:03, 160MB/s]\u001b[A\u001b[A\n",
|
416 |
+
"\n",
|
417 |
+
"Downloading model.safetensors: 85% 3.30G/3.90G [00:29<00:03, 162MB/s]\u001b[A\u001b[A\n",
|
418 |
+
"\n",
|
419 |
+
"Downloading model.safetensors: 85% 3.32G/3.90G [00:29<00:03, 160MB/s]\u001b[A\u001b[A\n",
|
420 |
+
"\n",
|
421 |
+
"Downloading model.safetensors: 86% 3.34G/3.90G [00:29<00:03, 171MB/s]\u001b[A\u001b[A\n",
|
422 |
+
"\n",
|
423 |
+
"Downloading model.safetensors: 87% 3.38G/3.90G [00:30<00:02, 191MB/s]\u001b[A\u001b[A\n",
|
424 |
+
"\n",
|
425 |
+
"Downloading model.safetensors: 87% 3.40G/3.90G [00:30<00:02, 188MB/s]\u001b[A\u001b[A\n",
|
426 |
+
"\n",
|
427 |
+
"Downloading model.safetensors: 88% 3.42G/3.90G [00:30<00:02, 187MB/s]\u001b[A\u001b[A\n",
|
428 |
+
"\n",
|
429 |
+
"Downloading model.safetensors: 88% 3.44G/3.90G [00:30<00:02, 182MB/s]\u001b[A\u001b[A\n",
|
430 |
+
"\n",
|
431 |
+
"Downloading model.safetensors: 89% 3.46G/3.90G [00:30<00:02, 183MB/s]\u001b[A\u001b[A\n",
|
432 |
+
"\n",
|
433 |
+
"Downloading model.safetensors: 89% 3.48G/3.90G [00:30<00:02, 183MB/s]\u001b[A\u001b[A\n",
|
434 |
+
"\n",
|
435 |
+
"Downloading model.safetensors: 90% 3.50G/3.90G [00:30<00:02, 184MB/s]\u001b[A\u001b[A\n",
|
436 |
+
"\n",
|
437 |
+
"Downloading model.safetensors: 90% 3.52G/3.90G [00:30<00:02, 185MB/s]\u001b[A\u001b[A\n",
|
438 |
+
"\n",
|
439 |
+
"Downloading model.safetensors: 91% 3.54G/3.90G [00:30<00:01, 183MB/s]\u001b[A\u001b[A\n",
|
440 |
+
"\n",
|
441 |
+
"Downloading model.safetensors: 91% 3.57G/3.90G [00:31<00:05, 55.5MB/s]\u001b[A\u001b[A\n",
|
442 |
+
"\n",
|
443 |
+
"Downloading model.safetensors: 92% 3.59G/3.90G [00:32<00:08, 38.3MB/s]\u001b[A\u001b[A\n",
|
444 |
+
"\n",
|
445 |
+
"Downloading model.safetensors: 93% 3.61G/3.90G [00:32<00:05, 50.7MB/s]\u001b[A\u001b[A\n",
|
446 |
+
"\n",
|
447 |
+
"Downloading model.safetensors: 93% 3.63G/3.90G [00:33<00:04, 65.0MB/s]\u001b[A\u001b[A\n",
|
448 |
+
"\n",
|
449 |
+
"Downloading model.safetensors: 94% 3.65G/3.90G [00:33<00:03, 80.3MB/s]\u001b[A\u001b[A\n",
|
450 |
+
"\n",
|
451 |
+
"Downloading model.safetensors: 94% 3.67G/3.90G [00:33<00:02, 97.3MB/s]\u001b[A\u001b[A\n",
|
452 |
+
"\n",
|
453 |
+
"Downloading model.safetensors: 95% 3.69G/3.90G [00:33<00:01, 113MB/s] \u001b[A\u001b[A\n",
|
454 |
+
"\n",
|
455 |
+
"Downloading model.safetensors: 95% 3.71G/3.90G [00:33<00:01, 128MB/s]\u001b[A\u001b[A\n",
|
456 |
+
"\n",
|
457 |
+
"Downloading model.safetensors: 96% 3.73G/3.90G [00:33<00:01, 139MB/s]\u001b[A\u001b[A\n",
|
458 |
+
"\n",
|
459 |
+
"Downloading model.safetensors: 96% 3.75G/3.90G [00:33<00:00, 153MB/s]\u001b[A\u001b[A\n",
|
460 |
+
"\n",
|
461 |
+
"Downloading model.safetensors: 97% 3.77G/3.90G [00:33<00:00, 158MB/s]\u001b[A\u001b[A\n",
|
462 |
+
"\n",
|
463 |
+
"Downloading model.safetensors: 97% 3.80G/3.90G [00:34<00:00, 165MB/s]\u001b[A\u001b[A\n",
|
464 |
+
"\n",
|
465 |
+
"Downloading model.safetensors: 98% 3.82G/3.90G [00:34<00:00, 167MB/s]\u001b[A\u001b[A\n",
|
466 |
+
"\n",
|
467 |
+
"Downloading model.safetensors: 98% 3.84G/3.90G [00:34<00:00, 169MB/s]\u001b[A\u001b[A\n",
|
468 |
+
"\n",
|
469 |
+
"Downloading model.safetensors: 99% 3.86G/3.90G [00:34<00:00, 174MB/s]\u001b[A\u001b[A\n",
|
470 |
+
"\n",
|
471 |
+
"Downloading model.safetensors: 100% 3.90G/3.90G [00:34<00:00, 113MB/s]\n",
|
472 |
+
"Fetching 15 files: 100% 15/15 [00:36<00:00, 2.41s/it]\n",
|
473 |
+
"/content/llama2-webui\n",
|
474 |
+
"Running on GPU with backend torch transformers.\n",
|
475 |
+
"2023-08-26 07:14:25.222792: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
476 |
+
"skip module injection for FusedLlamaMLPForQuantizedModel not support integrate without triton yet.\n",
|
477 |
+
"Caching examples at: '/content/llama2-webui/gradio_cached_examples/19'\n",
|
478 |
+
"Caching example 1/5\n",
|
479 |
+
"Caching example 2/5\n",
|
480 |
+
"Caching example 3/5\n",
|
481 |
+
"Caching example 4/5\n",
|
482 |
+
"Caching example 5/5\n",
|
483 |
+
"Caching complete\n",
|
484 |
+
"\n",
|
485 |
+
"Running on local URL: http://127.0.0.1:7860\n",
|
486 |
+
"Running on public URL: https://71c3606942c440e7dd.gradio.live\n",
|
487 |
+
"\n",
|
488 |
+
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n",
|
489 |
+
"Keyboard interruption in main thread... closing server.\n",
|
490 |
+
"Traceback (most recent call last):\n",
|
491 |
+
" File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 2130, in block_thread\n",
|
492 |
+
" time.sleep(0.1)\n",
|
493 |
+
"KeyboardInterrupt\n",
|
494 |
+
"\n",
|
495 |
+
"During handling of the above exception, another exception occurred:\n",
|
496 |
+
"\n",
|
497 |
+
"Traceback (most recent call last):\n",
|
498 |
+
" File \"/content/llama2-webui/app.py\", line 322, in <module>\n",
|
499 |
+
" main()\n",
|
500 |
+
" File \"/content/llama2-webui/app.py\", line 318, in main\n",
|
501 |
+
" demo.queue(max_size=20).launch(share=args.share)\n",
|
502 |
+
" File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 2046, in launch\n",
|
503 |
+
" self.block_thread()\n",
|
504 |
+
" File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 2132, in block_thread\n",
|
505 |
+
" print(\"Keyboard interruption in main thread... closing server.\")\n",
|
506 |
+
"KeyboardInterrupt\n",
|
507 |
+
"Killing tunnel 127.0.0.1:7860 <> https://71c3606942c440e7dd.gradio.live\n",
|
508 |
+
"terminate called without an active exception\n"
|
509 |
+
]
|
510 |
+
}
|
511 |
+
]
|
512 |
+
}
|
513 |
+
]
|
514 |
+
}
|
docs/issues.md
ADDED
File without changes
|
docs/news.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# News
|
2 |
+
- [2023/09] The newest `llama2-wrapper>=0.1.14` supports llama.cpp's `gguf` models.
|
3 |
+
|
4 |
+
- [2023/08] 🔥 For developers, we offer a web server that acts as a drop-in replacement for the OpenAI API.
|
5 |
+
|
6 |
+
- Usage:
|
7 |
+
|
8 |
+
```
|
9 |
+
python3 -m llama2_wrapper.server
|
10 |
+
```
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
- [2023/08] 🔥 For developers, we released `llama2-wrapper` as a llama2 backend wrapper in [PYPI](https://pypi.org/project/llama2-wrapper/).
|
15 |
+
|
16 |
+
- Install: `pip install llama2-wrapper`
|
17 |
+
|
18 |
+
- Usage:
|
19 |
+
|
20 |
+
```python
|
21 |
+
from llama2_wrapper import LLAMA2_WRAPPER, get_prompt
|
22 |
+
llama2_wrapper = LLAMA2_WRAPPER(
|
23 |
+
model_path="./models/Llama-2-7B-Chat-GGML/llama-2-7b-chat.ggmlv3.q4_0.bin",
|
24 |
+
backend_type="llama.cpp", #options: llama.cpp, transformers, gptq
|
25 |
+
)
|
26 |
+
prompt = "Do you know Pytorch"
|
27 |
+
llama2_promt = get_prompt(prompt)
|
28 |
+
answer = llama2_wrapper(llama2_promt, temperature=0.9)
|
29 |
+
```
|
30 |
+
|
31 |
+
- [2023/08] 🔥 We added `benchmark.py` for users to benchmark llama2 models on their local devices.
|
32 |
+
|
33 |
+
- Check/contribute the performance of your device in the full [performance doc](https://github.com/liltom-eth/llama2-webui/blob/main/docs/performance.md).
|
34 |
+
|
35 |
+
- [2023/07] We released **[llama2-webui](https://github.com/liltom-eth/llama2-webui)**, a gradio web UI to run Llama 2 on GPU or CPU from anywhere (Linux/Windows/Mac).
|
36 |
+
|
37 |
+
- Supporting models: [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)/[13b](https://huggingface.co/llamaste/Llama-2-13b-chat-hf)/[70b](https://huggingface.co/llamaste/Llama-2-70b-chat-hf), all [Llama-2-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ), all [Llama-2-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML) ...
|
38 |
+
- Supporting model backends: [tranformers](https://github.com/huggingface/transformers), [bitsandbytes(8-bit inference)](https://github.com/TimDettmers/bitsandbytes), [AutoGPTQ(4-bit inference)](https://github.com/PanQiWei/AutoGPTQ), [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
docs/performance.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmark Performance
|
2 |
+
|
3 |
+
## Performance on Nvidia GPU
|
4 |
+
|
5 |
+
| Model | Precision | Device | GPU VRAM | Speed (tokens/sec) | load time (s) |
|
6 |
+
| --------------------------------- | --------- | ---------- | ---------------------- | ---------------- | ---------------- |
|
7 |
+
| Llama-2-7b-chat-hf | 16 bit | | | | |
|
8 |
+
| Llama-2-7b-chat-hf | 8bit | NVIDIA RTX 2080 Ti | 7.7 GB VRAM | 3.76 | 641.36 |
|
9 |
+
| Llama-2-7b-Chat-GPTQ | 4bit | NVIDIA RTX 2080 Ti | 5.8 GB VRAM | 18.85 | 192.91 |
|
10 |
+
| Llama-2-7b-Chat-GPTQ | 4bit | NVIDIA GTX 1660 Super | 4.8 GB VRAM | 8.5 | 262.74 |
|
11 |
+
| Llama-2-7b-Chat-GPTQ | 4 bit | Google Colab T4 | 5.8 GB VRAM | 18.19 | 37.44 |
|
12 |
+
| Llama-2-13b-chat-hf | 16 bit | | | | |
|
13 |
+
| | | | | | |
|
14 |
+
|
15 |
+
## Performance on CPU / OpenBLAS / cuBLAS / CLBlast / Metal
|
16 |
+
|
17 |
+
| Model | Precision | Device | RAM / GPU VRAM | Speed (tokens/sec) | load time (s) |
|
18 |
+
| --------------------------------- | --------- | ---------- | ---------------------- | ---------------- | ---------------- |
|
19 |
+
| llama-2-7b-chat.ggmlv3.q2_K | 2 bit | Intel i7-8700 | 4.5 GB RAM | 7.88 | 31.90 |
|
20 |
+
| llama-2-7b-chat.ggmlv3.q2_K | 2 bit | Apple M2 CPU | 4.5 GB RAM | 11.10 | 0.10 |
|
21 |
+
| llama-2-7b-chat.ggmlv3.q2_K | 2 bit | Apple M2 Metal | 4.5 GB RAM | 12.10 | 0.12 |
|
22 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Intel i7-8700 | 5.4 GB RAM | 6.27 | 173.15 |
|
23 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Intel i7-9700 | 4.8 GB RAM | 4.2 | 87.9 |
|
24 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Apple M1 Pro CPU | 5.4 GB RAM | 17.90 | 0.18 |
|
25 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Apple M2 CPU | 5.4 GB RAM | 13.70 | 0.13 |
|
26 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Apple M2 Metal | 5.4 GB RAM | 12.60 | 0.10 |
|
27 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | AMD Ryzen 9 5900HS | 4.1 GB RAM | 6.01 | 0.15 |
|
28 |
+
| llama-2-7b-chat.ggmlv3.q4_0 | 4 bit | Intel vServer 4 threads, eth services | 8 GB RAM | 1.31 | 0.5|
|
29 |
+
| llama-2-7b-chat.ggmlv3.q8_0 | 8 bit | Intel i7-8700 | 8.6 GB RAM | 2.63 | 336.57 |
|
30 |
+
| llama-2-7b-chat.ggmlv3.q8_0 | 8 bit | Intel i7-9700 | 7.6 GB RAM | 2.05 | 302.9 |
|
31 |
+
| | | | | | |
|
32 |
+
|
docs/pypi.md
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# llama2-wrapper
|
2 |
+
|
3 |
+
- Use [llama2-wrapper](https://pypi.org/project/llama2-wrapper/) as your local llama2 backend for Generative Agents/Apps, [colab example](https://github.com/liltom-eth/llama2-webui/blob/main/colab/Llama_2_7b_Chat_GPTQ.ipynb).
|
4 |
+
|
5 |
+
- [Run OpenAI Compatible API](https://github.com/liltom-eth/llama2-webui#start-openai-compatible-api) on Llama2 models.
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- Supporting models: [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)/[13b](https://huggingface.co/llamaste/Llama-2-13b-chat-hf)/[70b](https://huggingface.co/llamaste/Llama-2-70b-chat-hf), [Llama-2-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ), [Llama-2-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML), [CodeLlama](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ)...
|
10 |
+
- Supporting model backends: [tranformers](https://github.com/huggingface/transformers), [bitsandbytes(8-bit inference)](https://github.com/TimDettmers/bitsandbytes), [AutoGPTQ(4-bit inference)](https://github.com/PanQiWei/AutoGPTQ), [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
11 |
+
- Demos: [Run Llama2 on MacBook Air](https://twitter.com/liltom_eth/status/1682791729207070720?s=20); [Run Llama2 on Colab T4 GPU](https://github.com/liltom-eth/llama2-webui/blob/main/colab/Llama_2_7b_Chat_GPTQ.ipynb)
|
12 |
+
- Use [llama2-wrapper](https://pypi.org/project/llama2-wrapper/) as your local llama2 backend for Generative Agents/Apps; [colab example](./colab/Llama_2_7b_Chat_GPTQ.ipynb).
|
13 |
+
- [Run OpenAI Compatible API](https://github.com/liltom-eth/llama2-webui#start-openai-compatible-api) on Llama2 models.
|
14 |
+
- [News](https://github.com/liltom-eth/llama2-webui/blob/main/docs/news.md), [Benchmark](https://github.com/liltom-eth/llama2-webui/blob/main/docs/performance.md), [Issue Solutions](https://github.com/liltom-eth/llama2-webui/blob/main/docs/issues.md)
|
15 |
+
|
16 |
+
[llama2-wrapper](https://pypi.org/project/llama2-wrapper/) is the backend and part of [llama2-webui](https://github.com/liltom-eth/llama2-webui), which can run any Llama 2 locally with gradio UI on GPU or CPU from anywhere (Linux/Windows/Mac).
|
17 |
+
|
18 |
+
## Install
|
19 |
+
|
20 |
+
```bash
|
21 |
+
pip install llama2-wrapper
|
22 |
+
```
|
23 |
+
|
24 |
+
## Start OpenAI Compatible API
|
25 |
+
|
26 |
+
```
|
27 |
+
python -m llama2_wrapper.server
|
28 |
+
```
|
29 |
+
|
30 |
+
it will use `llama.cpp` as the backend by default to run `llama-2-7b-chat.ggmlv3.q4_0.bin` model.
|
31 |
+
|
32 |
+
Start Fast API for `gptq` backend:
|
33 |
+
|
34 |
+
```
|
35 |
+
python -m llama2_wrapper.server --backend_type gptq
|
36 |
+
```
|
37 |
+
|
38 |
+
Navigate to http://localhost:8000/docs to see the OpenAPI documentation.
|
39 |
+
|
40 |
+
## API Usage
|
41 |
+
|
42 |
+
### `__call__`
|
43 |
+
|
44 |
+
`__call__()` is the function to generate text from a prompt.
|
45 |
+
|
46 |
+
For example, run ggml llama2 model on CPU, [colab example](https://github.com/liltom-eth/llama2-webui/blob/main/colab/ggmlv3_q4_0.ipynb):
|
47 |
+
|
48 |
+
```python
|
49 |
+
from llama2_wrapper import LLAMA2_WRAPPER, get_prompt
|
50 |
+
llama2_wrapper = LLAMA2_WRAPPER()
|
51 |
+
# Default running on backend llama.cpp.
|
52 |
+
# Automatically downloading model to: ./models/llama-2-7b-chat.ggmlv3.q4_0.bin
|
53 |
+
prompt = "Do you know Pytorch"
|
54 |
+
# llama2_wrapper() will run __call__()
|
55 |
+
answer = llama2_wrapper(get_prompt(prompt), temperature=0.9)
|
56 |
+
```
|
57 |
+
|
58 |
+
Run gptq llama2 model on Nvidia GPU, [colab example](https://github.com/liltom-eth/llama2-webui/blob/main/colab/Llama_2_7b_Chat_GPTQ.ipynb):
|
59 |
+
|
60 |
+
```python
|
61 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
62 |
+
llama2_wrapper = LLAMA2_WRAPPER(backend_type="gptq")
|
63 |
+
# Automatically downloading model to: ./models/Llama-2-7b-Chat-GPTQ
|
64 |
+
```
|
65 |
+
|
66 |
+
Run llama2 7b with bitsandbytes 8 bit with a `model_path`:
|
67 |
+
|
68 |
+
```python
|
69 |
+
from llama2_wrapper import LLAMA2_WRAPPER
|
70 |
+
llama2_wrapper = LLAMA2_WRAPPER(
|
71 |
+
model_path = "./models/Llama-2-7b-chat-hf",
|
72 |
+
backend_type = "transformers",
|
73 |
+
load_in_8bit = True
|
74 |
+
)
|
75 |
+
```
|
76 |
+
|
77 |
+
### completion
|
78 |
+
|
79 |
+
`completion()` is the function to generate text from a prompt for OpenAI compatible API `/v1/completions`.
|
80 |
+
|
81 |
+
```python
|
82 |
+
llama2_wrapper = LLAMA2_WRAPPER()
|
83 |
+
prompt = get_prompt("Hi do you know Pytorch?")
|
84 |
+
print(llm.completion(prompt))
|
85 |
+
```
|
86 |
+
|
87 |
+
### chat_completion
|
88 |
+
|
89 |
+
`chat_completion()` is the function to generate text from a dialog (chat history) for OpenAI compatible API `/v1/chat/completions`.
|
90 |
+
|
91 |
+
```python
|
92 |
+
llama2_wrapper = LLAMA2_WRAPPER()
|
93 |
+
dialog = [
|
94 |
+
{
|
95 |
+
"role":"system",
|
96 |
+
"content":"You are a helpful, respectful and honest assistant. "
|
97 |
+
},{
|
98 |
+
"role":"user",
|
99 |
+
"content":"Hi do you know Pytorch?",
|
100 |
+
},
|
101 |
+
]
|
102 |
+
print(llm.chat_completion(dialog))
|
103 |
+
```
|
104 |
+
|
105 |
+
### generate
|
106 |
+
|
107 |
+
`generate()` is the function to create a generator of response from a prompt.
|
108 |
+
|
109 |
+
This is useful when you want to stream the output like typing in the chatbot.
|
110 |
+
|
111 |
+
```python
|
112 |
+
llama2_wrapper = LLAMA2_WRAPPER()
|
113 |
+
prompt = get_prompt("Hi do you know Pytorch?")
|
114 |
+
for response in llama2_wrapper.generate(prompt):
|
115 |
+
print(response)
|
116 |
+
|
117 |
+
```
|
118 |
+
|
119 |
+
The response will be like:
|
120 |
+
|
121 |
+
```
|
122 |
+
Yes,
|
123 |
+
Yes, I'm
|
124 |
+
Yes, I'm familiar
|
125 |
+
Yes, I'm familiar with
|
126 |
+
Yes, I'm familiar with PyTorch!
|
127 |
+
...
|
128 |
+
```
|
129 |
+
|
130 |
+
### run
|
131 |
+
|
132 |
+
`run()` is similar to `generate()`, but `run()`can also accept `chat_history`and `system_prompt` from the users.
|
133 |
+
|
134 |
+
It will process the input message to llama2 prompt template with `chat_history` and `system_prompt` for a chatbot-like app.
|
135 |
+
|
136 |
+
### get_prompt
|
137 |
+
|
138 |
+
`get_prompt()` will process the input message to llama2 prompt with `chat_history` and `system_prompt`for chatbot.
|
139 |
+
|
140 |
+
By default, `chat_history` and `system_prompt` are empty and `get_prompt()` will add llama2 prompt template to your message:
|
141 |
+
|
142 |
+
```python
|
143 |
+
prompt = get_prompt("Hi do you know Pytorch?")
|
144 |
+
```
|
145 |
+
|
146 |
+
prompt will be:
|
147 |
+
|
148 |
+
```
|
149 |
+
[INST] <<SYS>>
|
150 |
+
|
151 |
+
<</SYS>>
|
152 |
+
|
153 |
+
Hi do you know Pytorch? [/INST]
|
154 |
+
```
|
155 |
+
|
156 |
+
If use `get_prompt("Hi do you know Pytorch?", system_prompt="You are a helpful...")`:
|
157 |
+
|
158 |
+
```
|
159 |
+
[INST] <<SYS>>
|
160 |
+
You are a helpful, respectful and honest assistant.
|
161 |
+
<</SYS>>
|
162 |
+
|
163 |
+
Hi do you know Pytorch? [/INST]
|
164 |
+
```
|
165 |
+
|
166 |
+
### get_prompt_for_dialog
|
167 |
+
|
168 |
+
`get_prompt_for_dialog()` will process dialog (chat history) to llama2 prompt for OpenAI compatible API `/v1/chat/completions`.
|
169 |
+
|
170 |
+
```python
|
171 |
+
dialog = [
|
172 |
+
{
|
173 |
+
"role":"system",
|
174 |
+
"content":"You are a helpful, respectful and honest assistant. "
|
175 |
+
},{
|
176 |
+
"role":"user",
|
177 |
+
"content":"Hi do you know Pytorch?",
|
178 |
+
},
|
179 |
+
]
|
180 |
+
prompt = get_prompt_for_dialog("Hi do you know Pytorch?")
|
181 |
+
# [INST] <<SYS>>
|
182 |
+
# You are a helpful, respectful and honest assistant.
|
183 |
+
# <</SYS>>
|
184 |
+
#
|
185 |
+
# Hi do you know Pytorch? [/INST]
|
186 |
+
```
|
187 |
+
|
env_examples/.env.13b_example
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_PATH = "./models/Llama-2-13b-chat-hf"
|
2 |
+
|
3 |
+
# options: llama.cpp, gptq, transformers
|
4 |
+
BACKEND_TYPE = "transformers"
|
5 |
+
|
6 |
+
# only for transformers bitsandbytes 8 bit
|
7 |
+
LOAD_IN_8BIT = True
|
8 |
+
|
9 |
+
MAX_MAX_NEW_TOKENS = 2048
|
10 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
11 |
+
MAX_INPUT_TOKEN_LENGTH = 4000
|
12 |
+
|
13 |
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
env_examples/.env.7b_8bit_example
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_PATH = "./models/Llama-2-7b-chat-hf"
|
2 |
+
|
3 |
+
# options: llama.cpp, gptq, transformers
|
4 |
+
BACKEND_TYPE = "transformers"
|
5 |
+
|
6 |
+
# only for transformers bitsandbytes 8 bit
|
7 |
+
LOAD_IN_8BIT = True
|
8 |
+
|
9 |
+
MAX_MAX_NEW_TOKENS = 2048
|
10 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
11 |
+
MAX_INPUT_TOKEN_LENGTH = 4000
|
12 |
+
|
13 |
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
env_examples/.env.7b_ggmlv3_q4_0_example
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_PATH = ""
|
2 |
+
# if MODEL_PATH is "", default llama.cpp/gptq models
|
3 |
+
# will be downloaded to: ./models
|
4 |
+
|
5 |
+
# Example ggml path:
|
6 |
+
# MODEL_PATH = "./models/llama-2-7b-chat.ggmlv3.q4_0.bin"
|
7 |
+
|
8 |
+
# options: llama.cpp, gptq, transformers
|
9 |
+
BACKEND_TYPE = "llama.cpp"
|
10 |
+
|
11 |
+
# only for transformers bitsandbytes 8 bit
|
12 |
+
LOAD_IN_8BIT = False
|
13 |
+
|
14 |
+
MAX_MAX_NEW_TOKENS = 2048
|
15 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
+
MAX_INPUT_TOKEN_LENGTH = 4000
|
17 |
+
|
18 |
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
env_examples/.env.7b_gptq_example
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_PATH = "./models/Llama-2-7b-Chat-GPTQ"
|
2 |
+
# if MODEL_PATH is "", default llama.cpp/gptq models
|
3 |
+
# will be downloaded to: ./models
|
4 |
+
|
5 |
+
# Example gptq path:
|
6 |
+
# MODEL_PATH = "./models/Llama-2-7b-Chat-GPTQ"
|
7 |
+
|
8 |
+
# options: llama.cpp, gptq, transformers
|
9 |
+
BACKEND_TYPE = "gptq"
|
10 |
+
|
11 |
+
# only for transformers bitsandbytes 8 bit
|
12 |
+
LOAD_IN_8BIT = False
|
13 |
+
|
14 |
+
MAX_MAX_NEW_TOKENS = 2048
|
15 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
+
MAX_INPUT_TOKEN_LENGTH = 4000
|
17 |
+
|
18 |
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
llama2_cu_python/Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NVCC = nvcc
|
2 |
+
|
3 |
+
.PHONY: libllama2
|
4 |
+
libllama2: llama2.cu
|
5 |
+
$(NVCC) -DUSE_CUDA --shared -O3 -lcublas -lm -o libllama2.so llama2.cu --compiler-options '-fPIC'
|
6 |
+
|
7 |
+
.PHONY: clean
|
8 |
+
clean:
|
9 |
+
rm -f libllama2.so
|
llama2_cu_python/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .llama2_cu import *
|
2 |
+
|
3 |
+
__version__ = "0.1"
|
llama2_cu_python/libllama2.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:722e0a53e9d9afbd37491c3eff84e7fcd1c1e2331a575c83d283ba4ff62e269f
|
3 |
+
size 1038952
|
llama2_cu_python/llama2.cu
ADDED
@@ -0,0 +1,1394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Inference for Llama-2 Transformer model in pure C
|
2 |
+
* With added CUDA support initially drawing from
|
3 |
+
* https://github.com/ankan-ban/llama2.cu/blob/master/llama2.cu
|
4 |
+
* and structured in a way that hopefully makes keeping it
|
5 |
+
* up-to-date straightforward.
|
6 |
+
*/
|
7 |
+
|
8 |
+
#include <stdio.h>
|
9 |
+
#include <stdlib.h>
|
10 |
+
#include <ctype.h>
|
11 |
+
#include <time.h>
|
12 |
+
#include <math.h>
|
13 |
+
#include <string.h>
|
14 |
+
#include <fcntl.h>
|
15 |
+
#include <assert.h>
|
16 |
+
#include <future>
|
17 |
+
#if defined _WIN32
|
18 |
+
#include "win.h"
|
19 |
+
#else
|
20 |
+
#include <unistd.h>
|
21 |
+
#include <sys/mman.h>
|
22 |
+
#endif
|
23 |
+
#include "llama2.h"
|
24 |
+
|
25 |
+
#ifdef USE_CUDA
|
26 |
+
#include <cuda_runtime.h>
|
27 |
+
#include <cub/cub.cuh>
|
28 |
+
#include <cublas_v2.h>
|
29 |
+
|
30 |
+
// Each CUDA function call should be checked for errors.
|
31 |
+
#define CUCHK(err) cuda_check((err), __FILE__, __LINE__)
|
32 |
+
inline void cuda_check(cudaError_t error_code, const char *file, int line)
|
33 |
+
{
|
34 |
+
if (error_code != cudaSuccess)
|
35 |
+
{
|
36 |
+
fprintf(stderr, "CUDA Error %d: %s. In file '%s' on line %d\n", error_code, cudaGetErrorString(error_code), file, line);
|
37 |
+
fflush(stderr);
|
38 |
+
exit(error_code);
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
// cublasHandle_t g_cublas_handle = nullptr;
|
43 |
+
|
44 |
+
// void create_cublas_handle() {
|
45 |
+
// cublasStatus_t stat = cublasCreate(&g_cublas_handle); // FIXME cublasDestroy
|
46 |
+
// if (stat != CUBLAS_STATUS_SUCCESS) {
|
47 |
+
// printf ("CUBLAS initialization failed\n");
|
48 |
+
// exit(EXIT_FAILURE);
|
49 |
+
// }
|
50 |
+
// }
|
51 |
+
// void destroy_cublas_handle() {
|
52 |
+
// cublasStatus_t stat = cublasDestroy(g_cublas_handle);
|
53 |
+
// if (stat != CUBLAS_STATUS_SUCCESS) {
|
54 |
+
// printf ("CUBLAS initialization failed\n");
|
55 |
+
// exit(EXIT_FAILURE);
|
56 |
+
// }
|
57 |
+
// }
|
58 |
+
#endif
|
59 |
+
|
60 |
+
// ----------------------------------------------------------------------------
|
61 |
+
// Transformer model
|
62 |
+
|
63 |
+
typedef struct {
|
64 |
+
int dim; // transformer dimension
|
65 |
+
int hidden_dim; // for ffn layers
|
66 |
+
int n_layers; // number of layers
|
67 |
+
int n_heads; // number of query heads
|
68 |
+
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
|
69 |
+
int vocab_size; // vocabulary size, usually 256 (byte-level)
|
70 |
+
int seq_len; // max sequence length
|
71 |
+
} Config;
|
72 |
+
|
73 |
+
// CUDA NOTE: The TransformerWeights structure will be stored on the host,
|
74 |
+
// but all of the pointers in the structure will point to data on the GPU.
|
75 |
+
// The checkpoint file is mmap-ed to the host and the weights portion
|
76 |
+
// is allocated on and copied to the GPU. Then, memory_map_weights() updates
|
77 |
+
// these structure pointers to point to the proper location. Happily, this
|
78 |
+
// function is the same for both C and CUDA.
|
79 |
+
typedef struct {
|
80 |
+
// token embedding table
|
81 |
+
float* token_embedding_table; // (vocab_size, dim)
|
82 |
+
// weights for rmsnorms
|
83 |
+
float* rms_att_weight; // (layer, dim) rmsnorm weights
|
84 |
+
float* rms_ffn_weight; // (layer, dim)
|
85 |
+
// weights for matmuls. note dim == n_heads * head_size
|
86 |
+
float* wq; // (layer, dim, n_heads * head_size)
|
87 |
+
float* wk; // (layer, dim, n_kv_heads * head_size)
|
88 |
+
float* wv; // (layer, dim, n_kv_heads * head_size)
|
89 |
+
float* wo; // (layer, n_heads * head_size, dim)
|
90 |
+
// weights for ffn
|
91 |
+
float* w1; // (layer, hidden_dim, dim)
|
92 |
+
float* w2; // (layer, dim, hidden_dim)
|
93 |
+
float* w3; // (layer, hidden_dim, dim)
|
94 |
+
// final rmsnorm
|
95 |
+
float* rms_final_weight; // (dim,)
|
96 |
+
// (optional) classifier weights for the logits, on the last layer
|
97 |
+
float* wcls;
|
98 |
+
} TransformerWeights;
|
99 |
+
|
100 |
+
// CUDA NOTE: The RunState structure will be stored on the host, but all of the
|
101 |
+
// pointers in the structure will point to data on the GPU, created via
|
102 |
+
// cudaMalloc. The exception is logits which is the final result of the
|
103 |
+
// transformer & is copied from the GPU as the last step in the transformer
|
104 |
+
// and is used by the host.
|
105 |
+
typedef struct {
|
106 |
+
// current wave of activations
|
107 |
+
float *x; // activation at current time stamp (dim,)
|
108 |
+
float *xb; // same, but inside a residual branch (dim,)
|
109 |
+
float *xb2; // an additional buffer just for convenience (dim,)
|
110 |
+
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
|
111 |
+
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
|
112 |
+
float *q; // query (dim,)
|
113 |
+
float *k; // key (dim,)
|
114 |
+
float *v; // value (dim,)
|
115 |
+
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
116 |
+
#ifdef USE_CUDA
|
117 |
+
float *logits_gpu; // output logits in GPU
|
118 |
+
#endif
|
119 |
+
float *logits; // output logits in CPU
|
120 |
+
// kv cache
|
121 |
+
float* key_cache; // (layer, seq_len, dim)
|
122 |
+
float* value_cache; // (layer, seq_len, dim)
|
123 |
+
} RunState;
|
124 |
+
|
125 |
+
typedef struct {
|
126 |
+
Config config; // the hyperparameters of the architecture (the blueprint)
|
127 |
+
TransformerWeights weights; // the weights of the model
|
128 |
+
RunState state; // buffers for the "wave" of activations in the forward pass
|
129 |
+
// some more state needed to properly clean up the memory mapping (sigh)
|
130 |
+
int fd; // file descriptor for memory mapping
|
131 |
+
float* data; // memory mapped data pointer
|
132 |
+
ssize_t file_size; // size of the checkpoint file in bytes
|
133 |
+
} Transformer;
|
134 |
+
|
135 |
+
#ifdef USE_CUDA
|
136 |
+
void malloc_run_state(RunState* s, Config* p) {
|
137 |
+
// we calloc instead of malloc to keep valgrind happy
|
138 |
+
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
139 |
+
CUCHK(cudaMalloc((void**)&s->x, p->dim * sizeof(float)));
|
140 |
+
CUCHK(cudaMalloc((void**)&s->xb, p->dim * sizeof(float)));
|
141 |
+
CUCHK(cudaMalloc((void**)&s->xb2, p->dim * sizeof(float)));
|
142 |
+
CUCHK(cudaMalloc((void**)&s->hb, p->hidden_dim * sizeof(float)));
|
143 |
+
CUCHK(cudaMalloc((void**)&s->hb2, p->hidden_dim * sizeof(float)));
|
144 |
+
CUCHK(cudaMalloc((void**)&s->q, p->dim * sizeof(float)));
|
145 |
+
CUCHK(cudaMalloc((void**)&s->key_cache, p->n_layers * p->seq_len * kv_dim * sizeof(float)));
|
146 |
+
CUCHK(cudaMalloc((void**)&s->value_cache, p->n_layers * p->seq_len * kv_dim * sizeof(float)));
|
147 |
+
CUCHK(cudaMalloc((void**)&s->att, p->n_heads * p->seq_len * sizeof(float)));
|
148 |
+
CUCHK(cudaMalloc((void**)&s->logits_gpu, p->vocab_size * sizeof(float)));
|
149 |
+
s->logits = (float *)calloc(p->vocab_size, sizeof(float));
|
150 |
+
// ensure all mallocs went fine
|
151 |
+
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
152 |
+
|| !s->key_cache || !s->value_cache || !s->att || !s->logits_gpu || !s->logits) {
|
153 |
+
fprintf(stderr, "malloc failed!\n");
|
154 |
+
exit(EXIT_FAILURE);
|
155 |
+
}
|
156 |
+
}
|
157 |
+
#else
|
158 |
+
void malloc_run_state(RunState* s, Config* p) {
|
159 |
+
// we calloc instead of malloc to keep valgrind happy
|
160 |
+
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
161 |
+
s->x = (float *)calloc(p->dim, sizeof(float));
|
162 |
+
s->xb = (float *)calloc(p->dim, sizeof(float));
|
163 |
+
s->xb2 = (float *)calloc(p->dim, sizeof(float));
|
164 |
+
s->hb = (float *)calloc(p->hidden_dim, sizeof(float));
|
165 |
+
s->hb2 = (float *)calloc(p->hidden_dim, sizeof(float));
|
166 |
+
s->q = (float *)calloc(p->dim, sizeof(float));
|
167 |
+
s->key_cache = (float *)calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
168 |
+
s->value_cache = (float *)calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
169 |
+
s->att = (float *)calloc(p->n_heads * p->seq_len, sizeof(float));
|
170 |
+
s->logits = (float *)calloc(p->vocab_size, sizeof(float));
|
171 |
+
// ensure all mallocs went fine
|
172 |
+
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
173 |
+
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
|
174 |
+
fprintf(stderr, "malloc failed!\n");
|
175 |
+
exit(EXIT_FAILURE);
|
176 |
+
}
|
177 |
+
}
|
178 |
+
#endif
|
179 |
+
|
180 |
+
#ifdef USE_CUDA
|
181 |
+
void free_run_state(RunState* s) {
|
182 |
+
CUCHK(cudaFree(s->x));
|
183 |
+
CUCHK(cudaFree(s->xb));
|
184 |
+
CUCHK(cudaFree(s->xb2));
|
185 |
+
CUCHK(cudaFree(s->hb));
|
186 |
+
CUCHK(cudaFree(s->hb2));
|
187 |
+
CUCHK(cudaFree(s->q));
|
188 |
+
CUCHK(cudaFree(s->att));
|
189 |
+
CUCHK(cudaFree(s->logits_gpu));
|
190 |
+
free(s->logits);
|
191 |
+
CUCHK(cudaFree(s->key_cache));
|
192 |
+
CUCHK(cudaFree(s->value_cache));
|
193 |
+
}
|
194 |
+
#else
|
195 |
+
void free_run_state(RunState* s) {
|
196 |
+
free(s->x);
|
197 |
+
free(s->xb);
|
198 |
+
free(s->xb2);
|
199 |
+
free(s->hb);
|
200 |
+
free(s->hb2);
|
201 |
+
free(s->q);
|
202 |
+
free(s->att);
|
203 |
+
free(s->logits);
|
204 |
+
free(s->key_cache);
|
205 |
+
free(s->value_cache);
|
206 |
+
}
|
207 |
+
#endif
|
208 |
+
|
209 |
+
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
210 |
+
int head_size = p->dim / p->n_heads;
|
211 |
+
// make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
|
212 |
+
unsigned long long n_layers = p->n_layers;
|
213 |
+
w->token_embedding_table = ptr;
|
214 |
+
ptr += p->vocab_size * p->dim;
|
215 |
+
w->rms_att_weight = ptr;
|
216 |
+
ptr += n_layers * p->dim;
|
217 |
+
w->wq = ptr;
|
218 |
+
ptr += n_layers * p->dim * (p->n_heads * head_size);
|
219 |
+
w->wk = ptr;
|
220 |
+
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
|
221 |
+
w->wv = ptr;
|
222 |
+
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
|
223 |
+
w->wo = ptr;
|
224 |
+
ptr += n_layers * (p->n_heads * head_size) * p->dim;
|
225 |
+
w->rms_ffn_weight = ptr;
|
226 |
+
ptr += n_layers * p->dim;
|
227 |
+
w->w1 = ptr;
|
228 |
+
ptr += n_layers * p->dim * p->hidden_dim;
|
229 |
+
w->w2 = ptr;
|
230 |
+
ptr += n_layers * p->hidden_dim * p->dim;
|
231 |
+
w->w3 = ptr;
|
232 |
+
ptr += n_layers * p->dim * p->hidden_dim;
|
233 |
+
w->rms_final_weight = ptr;
|
234 |
+
ptr += p->dim;
|
235 |
+
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
|
236 |
+
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
|
237 |
+
w->wcls = shared_weights ? w->token_embedding_table : ptr;
|
238 |
+
}
|
239 |
+
|
240 |
+
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
|
241 |
+
int* fd, float** data, ssize_t* file_size) {
|
242 |
+
FILE *file = fopen(checkpoint, "rb");
|
243 |
+
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
|
244 |
+
// read in the config header
|
245 |
+
if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
|
246 |
+
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
247 |
+
int shared_weights = config->vocab_size > 0 ? 1 : 0;
|
248 |
+
config->vocab_size = abs(config->vocab_size);
|
249 |
+
// figure out the file size
|
250 |
+
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
251 |
+
*file_size = ftell(file); // get the file size, in bytes
|
252 |
+
fclose(file);
|
253 |
+
// memory map the Transformer weights into the data pointer
|
254 |
+
*fd = open(checkpoint, O_RDONLY); // open in read only mode
|
255 |
+
if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
|
256 |
+
*data = (float *)mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
|
257 |
+
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
|
258 |
+
#ifdef USE_CUDA
|
259 |
+
// allocate & copy mmap data to the gpu first
|
260 |
+
// TODO: allocate & copy just a portion to the GPU if the weights are too big
|
261 |
+
// to fit in the GPU, then copy the data only as needed while running.
|
262 |
+
float* weights_ptr;
|
263 |
+
size_t weights_size = *file_size - sizeof(Config);
|
264 |
+
CUCHK(cudaMalloc((void**)&weights_ptr, weights_size));
|
265 |
+
CUCHK(cudaMemcpy(weights_ptr, *data + sizeof(Config)/sizeof(float), weights_size, cudaMemcpyHostToDevice));
|
266 |
+
#else
|
267 |
+
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
|
268 |
+
#endif
|
269 |
+
memory_map_weights(weights, config, weights_ptr, shared_weights);
|
270 |
+
}
|
271 |
+
|
272 |
+
void build_transformer(Transformer *t, char* checkpoint_path) {
|
273 |
+
// read in the Config and the Weights from the checkpoint
|
274 |
+
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
|
275 |
+
// allocate the RunState buffers
|
276 |
+
malloc_run_state(&t->state, &t->config);
|
277 |
+
}
|
278 |
+
|
279 |
+
void free_transformer(Transformer* t) {
|
280 |
+
// close the memory mapping
|
281 |
+
if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
|
282 |
+
if (t->fd != -1) { close(t->fd); }
|
283 |
+
#ifdef USE_CUDA
|
284 |
+
// we cudaMalloc a region of memory, then hand the address to
|
285 |
+
// the token_embedding_table field. Free it here.
|
286 |
+
CUCHK(cudaFree(t->weights.token_embedding_table));
|
287 |
+
#endif
|
288 |
+
// free the RunState buffers
|
289 |
+
free_run_state(&t->state);
|
290 |
+
}
|
291 |
+
|
292 |
+
// ----------------------------------------------------------------------------
|
293 |
+
// neural net blocks; the dynamics of the Transformer
|
294 |
+
|
295 |
+
#ifdef USE_CUDA
|
296 |
+
// Utility routine to divide a into ceiling of b parts
|
297 |
+
int divUp(int a, int b) {
|
298 |
+
return (a - 1) / b + 1;
|
299 |
+
}
|
300 |
+
|
301 |
+
const int num_threads_lrg = 1024;
|
302 |
+
const int num_threads_med = 256;
|
303 |
+
|
304 |
+
__global__ void rmsnorm_kernel(float* o, float* x, float* weight, int size, int elementsPerThread) {
|
305 |
+
// parallel reduction of sum of squares via CUB
|
306 |
+
float ss = 0.0f;
|
307 |
+
for (int i = 0; i < elementsPerThread; i++) {
|
308 |
+
int j = threadIdx.x + i * num_threads_lrg;
|
309 |
+
if (j < size)
|
310 |
+
ss += x[j] * x[j];
|
311 |
+
}
|
312 |
+
using BlockReduce = cub::BlockReduce<float, num_threads_lrg>;
|
313 |
+
__shared__ typename BlockReduce::TempStorage temp;
|
314 |
+
ss = BlockReduce(temp).Sum(ss);
|
315 |
+
|
316 |
+
// serialization point to calculate normalization factor
|
317 |
+
__shared__ float shared_ss;
|
318 |
+
if (threadIdx.x == 0) {
|
319 |
+
ss /= size;
|
320 |
+
ss += 1e-5f;
|
321 |
+
ss = 1.0f / sqrtf(ss);
|
322 |
+
shared_ss = ss;
|
323 |
+
}
|
324 |
+
__syncthreads();
|
325 |
+
ss = shared_ss;
|
326 |
+
|
327 |
+
// normalize and scale
|
328 |
+
for (int i = 0; i < elementsPerThread; i++) {
|
329 |
+
int j = threadIdx.x + i * num_threads_lrg;
|
330 |
+
if (j < size) {
|
331 |
+
o[j] = weight[j] * (ss * x[j]);
|
332 |
+
}
|
333 |
+
}
|
334 |
+
}
|
335 |
+
void rmsnorm(float* o, float* x, float* weight, int size) {
|
336 |
+
int elementsPerThread = divUp(size, num_threads_lrg);
|
337 |
+
rmsnorm_kernel <<<1, num_threads_lrg >>> (o, x, weight, size, elementsPerThread);
|
338 |
+
}
|
339 |
+
#else
|
340 |
+
void rmsnorm(float* o, float* x, float* weight, int size) {
|
341 |
+
// calculate sum of squares
|
342 |
+
float ss = 0.0f;
|
343 |
+
for (int j = 0; j < size; j++) {
|
344 |
+
ss += x[j] * x[j];
|
345 |
+
}
|
346 |
+
ss /= size;
|
347 |
+
ss += 1e-5f;
|
348 |
+
ss = 1.0f / sqrtf(ss);
|
349 |
+
// normalize and scale
|
350 |
+
for (int j = 0; j < size; j++) {
|
351 |
+
o[j] = weight[j] * (ss * x[j]);
|
352 |
+
}
|
353 |
+
}
|
354 |
+
#endif
|
355 |
+
|
356 |
+
#ifdef USE_CUDA
|
357 |
+
__device__ void softmax_gpu(float* __restrict__ x, int size) {
|
358 |
+
int tid = threadIdx.x;
|
359 |
+
int step = blockDim.x;
|
360 |
+
|
361 |
+
// find max value (for numerical stability)
|
362 |
+
float max_val = tid < size ? x[tid] : 0;
|
363 |
+
for (int i = tid + step; i < size; i += step) {
|
364 |
+
if (x[i] > max_val) {
|
365 |
+
max_val = x[i];
|
366 |
+
}
|
367 |
+
}
|
368 |
+
using BlockReduce = cub::BlockReduce<float, num_threads_lrg>;
|
369 |
+
__shared__ typename BlockReduce::TempStorage temp;
|
370 |
+
__shared__ float shared_val;
|
371 |
+
max_val = BlockReduce(temp).Reduce(max_val, cub::Max());
|
372 |
+
if (threadIdx.x == 0) {
|
373 |
+
shared_val = max_val;
|
374 |
+
}
|
375 |
+
__syncthreads();
|
376 |
+
max_val = shared_val;
|
377 |
+
|
378 |
+
// exp and sum
|
379 |
+
float sum = 0.0f;
|
380 |
+
for (int i = tid; i < size; i += step) {
|
381 |
+
x[i] = expf(x[i] - max_val);
|
382 |
+
sum += x[i];
|
383 |
+
}
|
384 |
+
sum = BlockReduce(temp).Sum(sum);
|
385 |
+
if (threadIdx.x == 0) {
|
386 |
+
shared_val = sum;
|
387 |
+
}
|
388 |
+
__syncthreads();
|
389 |
+
sum = shared_val;
|
390 |
+
|
391 |
+
// normalize
|
392 |
+
for (int i = tid; i < size; i += step) {
|
393 |
+
x[i] /= sum;
|
394 |
+
}
|
395 |
+
}
|
396 |
+
#endif
|
397 |
+
void softmax(float* x, int size) {
|
398 |
+
// find max value (for numerical stability)
|
399 |
+
float max_val = x[0];
|
400 |
+
for (int i = 1; i < size; i++) {
|
401 |
+
if (x[i] > max_val) {
|
402 |
+
max_val = x[i];
|
403 |
+
}
|
404 |
+
}
|
405 |
+
// exp and sum
|
406 |
+
float sum = 0.0f;
|
407 |
+
for (int i = 0; i < size; i++) {
|
408 |
+
x[i] = expf(x[i] - max_val);
|
409 |
+
sum += x[i];
|
410 |
+
}
|
411 |
+
// normalize
|
412 |
+
for (int i = 0; i < size; i++) {
|
413 |
+
x[i] /= sum;
|
414 |
+
}
|
415 |
+
}
|
416 |
+
|
417 |
+
#ifdef USE_CUDA
|
418 |
+
// Use cuBLAS for matmul to leverage this included, high-performance library.
|
419 |
+
void matmul(cublasHandle_t handle, float* xout, float* x, float* w, int n, int d) {
|
420 |
+
// W (d,n) @ x (n,) -> xout (d,)
|
421 |
+
// W is stored in this order: (n=0,d=0), (n=1,d=0), (n=2,d=0), ...
|
422 |
+
// so W is n x d in cublas terms & we'll need to transpose.
|
423 |
+
// Sgemv does y = alpha * op(A) * x + beta * y (modifying y)
|
424 |
+
// where op can transpose the matrix A
|
425 |
+
// Translating to our local vars, that is
|
426 |
+
// xout = 1.0*op(w)*x + 0.0*xout
|
427 |
+
float alpha = 1.0f;
|
428 |
+
float beta = 0.0f; // when this is 0, xout will not be used for input
|
429 |
+
cublasSgemv(handle, CUBLAS_OP_T, n, d, &alpha, w, n, x, 1, &beta, xout, 1);
|
430 |
+
}
|
431 |
+
#else
|
432 |
+
void matmul(float* xout, float* x, float* w, int n, int d) {
|
433 |
+
// W (d,n) @ x (n,) -> xout (d,)
|
434 |
+
// by far the most amount of time is spent inside this little function
|
435 |
+
int i;
|
436 |
+
#pragma omp parallel for private(i)
|
437 |
+
for (i = 0; i < d; i++) {
|
438 |
+
float val = 0.0f;
|
439 |
+
for (int j = 0; j < n; j++) {
|
440 |
+
val += w[i * n + j] * x[j];
|
441 |
+
}
|
442 |
+
xout[i] = val;
|
443 |
+
}
|
444 |
+
}
|
445 |
+
#endif
|
446 |
+
|
447 |
+
// Additional neural net blocks (brought out from transformer function)
|
448 |
+
#ifdef USE_CUDA
|
449 |
+
__global__ void RoPe_rotation_kernel(int pos, float *sq, float *sk, int kv_dim, int head_size) {
|
450 |
+
int i = threadIdx.x * 2 + blockIdx.x * head_size;
|
451 |
+
int head_dim = i % head_size;
|
452 |
+
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
|
453 |
+
float val = pos * freq;
|
454 |
+
float fcr = cosf(val);
|
455 |
+
float fci = sinf(val);
|
456 |
+
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
|
457 |
+
for (int v = 0; v < rotn; v++) {
|
458 |
+
float* vec = v == 0 ? sq : sk; // the vector to rotate (query or key)
|
459 |
+
float v0 = vec[i];
|
460 |
+
float v1 = vec[i+1];
|
461 |
+
vec[i] = v0 * fcr - v1 * fci;
|
462 |
+
vec[i+1] = v0 * fci + v1 * fcr;
|
463 |
+
}
|
464 |
+
}
|
465 |
+
void RoPe_rotation(int pos, RunState* s, int dim, int kv_dim, int head_size) {
|
466 |
+
RoPe_rotation_kernel <<<dim/head_size, head_size/2 >>> (pos, s->q, s->k, kv_dim, head_size);
|
467 |
+
}
|
468 |
+
#else
|
469 |
+
void RoPe_rotation(int pos, RunState* s, int dim, int kv_dim, int head_size) { //s->q, s->k, freq_cis_real_row, freq_cis_imag_row, p->n_heads, head_size) {
|
470 |
+
for (int i = 0; i < dim; i+=2) {
|
471 |
+
int head_dim = i % head_size;
|
472 |
+
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
|
473 |
+
float val = pos * freq;
|
474 |
+
float fcr = cosf(val);
|
475 |
+
float fci = sinf(val);
|
476 |
+
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
|
477 |
+
for (int v = 0; v < rotn; v++) {
|
478 |
+
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
|
479 |
+
float v0 = vec[i];
|
480 |
+
float v1 = vec[i+1];
|
481 |
+
vec[i] = v0 * fcr - v1 * fci;
|
482 |
+
vec[i+1] = v0 * fci + v1 * fcr;
|
483 |
+
}
|
484 |
+
}
|
485 |
+
}
|
486 |
+
#endif
|
487 |
+
|
488 |
+
#ifdef USE_CUDA
|
489 |
+
// TODO refactor vs C code
|
490 |
+
__global__ void multi_head_attention_kernel(int pos, int seq_len, float *sq, float *satt, float *sxb, float *key_cache, float *value_cache, int kv_dim, int kv_mul, int head_size, int loff) {
|
491 |
+
int h = blockIdx.x;
|
492 |
+
// get the query vector for this head
|
493 |
+
float* q = sq + h * head_size;
|
494 |
+
// attention scores for this head
|
495 |
+
float* att = satt + h * seq_len;
|
496 |
+
// iterate over all timesteps, including the current one
|
497 |
+
// In CUDA, each thread does a small portion of the calc
|
498 |
+
for (int t = threadIdx.x; t <= pos; t += blockDim.x) {
|
499 |
+
// get the key vector for this head and at this timestep
|
500 |
+
float* k = key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
501 |
+
// calculate the attention score as the dot product of q and k
|
502 |
+
float score = 0.0f;
|
503 |
+
for (int i = 0; i < head_size; i++) {
|
504 |
+
score += q[i] * k[i];
|
505 |
+
}
|
506 |
+
score /= sqrtf(head_size);
|
507 |
+
// save the score to the attention buffer
|
508 |
+
att[t] = score;
|
509 |
+
}
|
510 |
+
// above was this threads portion of the iteration. wait for all threads to finish
|
511 |
+
__syncthreads();
|
512 |
+
|
513 |
+
// softmax the scores to get attention weights, from 0..pos inclusively
|
514 |
+
softmax_gpu(att, pos + 1);
|
515 |
+
__syncthreads();
|
516 |
+
|
517 |
+
// weighted sum of the values, store back into xb
|
518 |
+
// NOTE: by swapping the order of the for loops (vs. C) a simpler
|
519 |
+
// version of the code accomplishes the same task and fits more
|
520 |
+
// naturally with the CUDA way of subdividing the problem.
|
521 |
+
float* xb = sxb + h * head_size;
|
522 |
+
for (int i = threadIdx.x; i < head_size; i += blockDim.x) {
|
523 |
+
float val = 0.0f;
|
524 |
+
for (int t = 0; t <= pos; t++) {
|
525 |
+
// get the value vector for this head and at this timestep
|
526 |
+
float* v = value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
527 |
+
// get the attention weight for this timestep
|
528 |
+
float a = att[t];
|
529 |
+
val += a * v[i];
|
530 |
+
}
|
531 |
+
xb[i] = val;
|
532 |
+
}
|
533 |
+
}
|
534 |
+
void multi_head_attention(int pos, Config* p, RunState* s, int kv_dim, int kv_mul, int head_size, int loff) {
|
535 |
+
multi_head_attention_kernel <<<p->n_heads, num_threads_lrg>>> (pos, p->seq_len, s->q, s->att, s->xb, s->key_cache, s->value_cache, kv_dim, kv_mul, head_size, loff);
|
536 |
+
}
|
537 |
+
#else
|
538 |
+
void multi_head_attention(int pos, Config* p, RunState* s, int kv_dim, int kv_mul, int head_size, int loff) {
|
539 |
+
int h;
|
540 |
+
#pragma omp parallel for private(h)
|
541 |
+
for (h = 0; h < p->n_heads; h++) {
|
542 |
+
// get the query vector for this head
|
543 |
+
float* q = s->q + h * head_size;
|
544 |
+
// attention scores for this head
|
545 |
+
float* att = s->att + h * p->seq_len;
|
546 |
+
// iterate over all timesteps, including the current one
|
547 |
+
for (int t = 0; t <= pos; t++) {
|
548 |
+
// get the key vector for this head and at this timestep
|
549 |
+
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
550 |
+
// calculate the attention score as the dot product of q and k
|
551 |
+
float score = 0.0f;
|
552 |
+
for (int i = 0; i < head_size; i++) {
|
553 |
+
score += q[i] * k[i];
|
554 |
+
}
|
555 |
+
score /= sqrtf(head_size);
|
556 |
+
// save the score to the attention buffer
|
557 |
+
att[t] = score;
|
558 |
+
}
|
559 |
+
|
560 |
+
// softmax the scores to get attention weights, from 0..pos inclusively
|
561 |
+
softmax(att, pos + 1);
|
562 |
+
|
563 |
+
// weighted sum of the values, store back into xb
|
564 |
+
float* xb = s->xb + h * head_size;
|
565 |
+
memset(xb, 0, head_size * sizeof(float));
|
566 |
+
for (int t = 0; t <= pos; t++) {
|
567 |
+
// get the value vector for this head and at this timestep
|
568 |
+
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
569 |
+
// get the attention weight for this timestep
|
570 |
+
float a = att[t];
|
571 |
+
// accumulate the weighted value into xb
|
572 |
+
for (int i = 0; i < head_size; i++) {
|
573 |
+
xb[i] += a * v[i];
|
574 |
+
}
|
575 |
+
}
|
576 |
+
}
|
577 |
+
}
|
578 |
+
#endif
|
579 |
+
|
580 |
+
#ifdef USE_CUDA
|
581 |
+
__global__ void f_silu_elementwise_mul_w3_kernel(float *shb, float *shb2, int hidden_dim) {
|
582 |
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
583 |
+
if (i < hidden_dim) {
|
584 |
+
float val = shb[i];
|
585 |
+
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
|
586 |
+
val *= (1.0f / (1.0f + expf(-val)));
|
587 |
+
// elementwise multiply with w3(x)
|
588 |
+
val *= shb2[i];
|
589 |
+
shb[i] = val;
|
590 |
+
}
|
591 |
+
}
|
592 |
+
void f_silu_elementwise_mul_w3(RunState *s, int hidden_dim) {
|
593 |
+
f_silu_elementwise_mul_w3_kernel<<<divUp(hidden_dim, num_threads_med), num_threads_med>>>(s->hb, s->hb2, hidden_dim);
|
594 |
+
}
|
595 |
+
#else
|
596 |
+
void f_silu_elementwise_mul_w3(RunState *s, int hidden_dim) {
|
597 |
+
for (int i = 0; i < hidden_dim; i++) {
|
598 |
+
float val = s->hb[i];
|
599 |
+
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
|
600 |
+
val *= (1.0f / (1.0f + expf(-val)));
|
601 |
+
// elementwise multiply with w3(x)
|
602 |
+
val *= s->hb2[i];
|
603 |
+
s->hb[i] = val;
|
604 |
+
}
|
605 |
+
}
|
606 |
+
#endif
|
607 |
+
|
608 |
+
#ifdef USE_CUDA
|
609 |
+
__global__ void accum_kernel(float* a, float* b, int size) {
|
610 |
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
611 |
+
if (i < size) {
|
612 |
+
a[i] += b[i];
|
613 |
+
}
|
614 |
+
}
|
615 |
+
void accum(float *a, float *b, int size) {
|
616 |
+
accum_kernel<<<divUp(size, num_threads_med), num_threads_med>>>(a,b,size);
|
617 |
+
}
|
618 |
+
#else
|
619 |
+
void accum(float *a, float *b, int size) {
|
620 |
+
for (int i = 0; i < size; i++) {
|
621 |
+
a[i] += b[i];
|
622 |
+
}
|
623 |
+
}
|
624 |
+
#endif
|
625 |
+
|
626 |
+
#ifdef USE_CUDA
|
627 |
+
float* forward(Transformer* transformer, int token, int pos, cublasHandle_t handle) {
|
628 |
+
#else
|
629 |
+
float* forward(Transformer* transformer, int token, int pos) {
|
630 |
+
#endif
|
631 |
+
// a few convenience variables
|
632 |
+
Config* p = &transformer->config;
|
633 |
+
TransformerWeights* w = &transformer->weights;
|
634 |
+
RunState* s = &transformer->state;
|
635 |
+
float *x = s->x;
|
636 |
+
int dim = p->dim;
|
637 |
+
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
638 |
+
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
|
639 |
+
int hidden_dim = p->hidden_dim;
|
640 |
+
int head_size = dim / p->n_heads;
|
641 |
+
|
642 |
+
// copy the token embedding into x
|
643 |
+
float* content_row = w->token_embedding_table + token * dim;
|
644 |
+
#ifdef USE_CUDA
|
645 |
+
CUCHK(cudaMemcpy(x, content_row, dim*sizeof(*x), cudaMemcpyDeviceToDevice));
|
646 |
+
#else
|
647 |
+
memcpy(x, content_row, dim*sizeof(*x));
|
648 |
+
#endif
|
649 |
+
|
650 |
+
// forward all the layers
|
651 |
+
for(unsigned long long l = 0; l < p->n_layers; l++) {
|
652 |
+
|
653 |
+
// attention rmsnorm
|
654 |
+
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
|
655 |
+
|
656 |
+
// key and value point to the kv cache
|
657 |
+
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
658 |
+
s->k = s->key_cache + loff + pos * kv_dim;
|
659 |
+
s->v = s->value_cache + loff + pos * kv_dim;
|
660 |
+
|
661 |
+
// qkv matmuls for this position
|
662 |
+
#ifdef USE_CUDA
|
663 |
+
matmul(handle, s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
664 |
+
matmul(handle, s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
665 |
+
matmul(handle, s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
666 |
+
#else
|
667 |
+
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
668 |
+
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
669 |
+
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
670 |
+
#endif
|
671 |
+
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
672 |
+
RoPe_rotation(pos, s, dim, kv_dim, head_size);
|
673 |
+
|
674 |
+
// multihead attention. iterate over all heads
|
675 |
+
multi_head_attention(pos, p, s, kv_dim, kv_mul, head_size, loff);
|
676 |
+
|
677 |
+
// final matmul to get the output of the attention
|
678 |
+
#ifdef USE_CUDA
|
679 |
+
matmul(handle, s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
|
680 |
+
#else
|
681 |
+
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
|
682 |
+
#endif
|
683 |
+
|
684 |
+
// residual connection back into x
|
685 |
+
accum(x, s->xb2, dim);
|
686 |
+
|
687 |
+
// ffn rmsnorm
|
688 |
+
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
|
689 |
+
|
690 |
+
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
|
691 |
+
// first calculate self.w1(x) and self.w3(x)
|
692 |
+
#ifdef USE_CUDA
|
693 |
+
matmul(handle, s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
|
694 |
+
matmul(handle, s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
|
695 |
+
#else
|
696 |
+
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
|
697 |
+
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
|
698 |
+
#endif
|
699 |
+
|
700 |
+
// SwiGLU non-linearity
|
701 |
+
f_silu_elementwise_mul_w3(s, hidden_dim);
|
702 |
+
|
703 |
+
// final matmul to get the output of the ffn
|
704 |
+
#ifdef USE_CUDA
|
705 |
+
matmul(handle, s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
|
706 |
+
#else
|
707 |
+
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
|
708 |
+
#endif
|
709 |
+
|
710 |
+
// residual connection
|
711 |
+
accum(x, s->xb, dim);
|
712 |
+
}
|
713 |
+
|
714 |
+
// final rmsnorm
|
715 |
+
rmsnorm(x, x, w->rms_final_weight, dim);
|
716 |
+
|
717 |
+
// classifier into logits
|
718 |
+
#ifdef USE_CUDA
|
719 |
+
matmul(handle, s->logits_gpu, x, w->wcls, p->dim, p->vocab_size);
|
720 |
+
CUCHK(cudaMemcpy(s->logits, s->logits_gpu, p->vocab_size * sizeof(float), cudaMemcpyDeviceToHost));
|
721 |
+
#else
|
722 |
+
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
723 |
+
#endif
|
724 |
+
return s->logits;
|
725 |
+
}
|
726 |
+
|
727 |
+
// ----------------------------------------------------------------------------
|
728 |
+
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
|
729 |
+
|
730 |
+
typedef struct {
|
731 |
+
char *str;
|
732 |
+
int id;
|
733 |
+
} TokenIndex;
|
734 |
+
|
735 |
+
typedef struct {
|
736 |
+
char** vocab;
|
737 |
+
float* vocab_scores;
|
738 |
+
TokenIndex *sorted_vocab;
|
739 |
+
int vocab_size;
|
740 |
+
unsigned int max_token_length;
|
741 |
+
unsigned char byte_pieces[512]; // stores all single-byte strings
|
742 |
+
} Tokenizer;
|
743 |
+
|
744 |
+
int compare_tokens(const void *a, const void *b) {
|
745 |
+
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
746 |
+
}
|
747 |
+
|
748 |
+
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
749 |
+
// i should have written the vocab_size into the tokenizer file... sigh
|
750 |
+
t->vocab_size = vocab_size;
|
751 |
+
// malloc space to hold the scores and the strings
|
752 |
+
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
753 |
+
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
754 |
+
t->sorted_vocab = NULL; // initialized lazily
|
755 |
+
for (int i = 0; i < 256; i++) {
|
756 |
+
t->byte_pieces[i * 2] = (unsigned char)i;
|
757 |
+
t->byte_pieces[i * 2 + 1] = '\0';
|
758 |
+
}
|
759 |
+
// read in the file
|
760 |
+
FILE *file = fopen(tokenizer_path, "rb");
|
761 |
+
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
762 |
+
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
763 |
+
int len;
|
764 |
+
for (int i = 0; i < vocab_size; i++) {
|
765 |
+
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
|
766 |
+
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
767 |
+
t->vocab[i] = (char *)malloc(len + 1);
|
768 |
+
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
769 |
+
t->vocab[i][len] = '\0'; // add the string terminating token
|
770 |
+
}
|
771 |
+
fclose(file);
|
772 |
+
}
|
773 |
+
|
774 |
+
void free_tokenizer(Tokenizer* t) {
|
775 |
+
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
|
776 |
+
free(t->vocab);
|
777 |
+
free(t->vocab_scores);
|
778 |
+
free(t->sorted_vocab);
|
779 |
+
}
|
780 |
+
|
781 |
+
char* decode(Tokenizer* t, int prev_token, int token) {
|
782 |
+
char *piece = t->vocab[token];
|
783 |
+
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
784 |
+
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
785 |
+
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
786 |
+
// parse this and convert and return the actual byte
|
787 |
+
unsigned char byte_val;
|
788 |
+
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
789 |
+
piece = (char*)t->byte_pieces + byte_val * 2;
|
790 |
+
}
|
791 |
+
return piece;
|
792 |
+
}
|
793 |
+
|
794 |
+
void safe_printf(char *piece) {
|
795 |
+
// piece might be a raw byte token, and we only want to print printable chars or whitespace
|
796 |
+
// because some of the other bytes can be various control codes, backspace, etc.
|
797 |
+
if (piece == NULL) { return; }
|
798 |
+
if (piece[0] == '\0') { return; }
|
799 |
+
if (piece[1] == '\0') {
|
800 |
+
unsigned char byte_val = piece[0];
|
801 |
+
if (!(isprint(byte_val) || isspace(byte_val))) {
|
802 |
+
return; // bad byte, don't print it
|
803 |
+
}
|
804 |
+
}
|
805 |
+
printf("%s", piece);
|
806 |
+
}
|
807 |
+
|
808 |
+
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
809 |
+
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
810 |
+
#if defined USE_CUDA && defined _WIN32
|
811 |
+
// CUDA on Windows was not capable of handling the syntax below
|
812 |
+
TokenIndex tok;
|
813 |
+
tok.str = str;
|
814 |
+
#else
|
815 |
+
TokenIndex tok = { .str = str }; // acts as the key to search for
|
816 |
+
#endif
|
817 |
+
TokenIndex *res = (TokenIndex *)bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
818 |
+
return res != NULL ? res->id : -1;
|
819 |
+
}
|
820 |
+
|
821 |
+
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
|
822 |
+
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
823 |
+
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
|
824 |
+
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
|
825 |
+
|
826 |
+
if (t->sorted_vocab == NULL) {
|
827 |
+
// lazily malloc and sort the vocabulary
|
828 |
+
t->sorted_vocab = (TokenIndex *)malloc(t->vocab_size * sizeof(TokenIndex));
|
829 |
+
for (int i = 0; i < t->vocab_size; i++) {
|
830 |
+
t->sorted_vocab[i].str = t->vocab[i];
|
831 |
+
t->sorted_vocab[i].id = i;
|
832 |
+
}
|
833 |
+
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
|
834 |
+
}
|
835 |
+
|
836 |
+
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
837 |
+
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
|
838 |
+
char* str_buffer = (char *)malloc((t->max_token_length*2 +1 +2) * sizeof(char));
|
839 |
+
size_t str_len = 0;
|
840 |
+
|
841 |
+
// start at 0 tokens
|
842 |
+
*n_tokens = 0;
|
843 |
+
|
844 |
+
// add optional BOS (=1) token, if desired
|
845 |
+
if (bos) tokens[(*n_tokens)++] = 1;
|
846 |
+
|
847 |
+
// add_dummy_prefix is true by default
|
848 |
+
// so prepend a dummy prefix token to the input string, but only if text != ""
|
849 |
+
// TODO: pretty sure this isn't correct in the general case but I don't have the
|
850 |
+
// energy to read more of the sentencepiece code to figure out what it's doing
|
851 |
+
if (text[0] != '\0') {
|
852 |
+
int dummy_prefix = str_lookup((char *)" ", t->sorted_vocab, t->vocab_size);
|
853 |
+
tokens[(*n_tokens)++] = dummy_prefix;
|
854 |
+
}
|
855 |
+
|
856 |
+
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
857 |
+
// Code point ↔ UTF-8 conversion
|
858 |
+
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
|
859 |
+
// U+0000 U+007F 0xxxxxxx
|
860 |
+
// U+0080 U+07FF 110xxxxx 10xxxxxx
|
861 |
+
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
|
862 |
+
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
863 |
+
|
864 |
+
// process the raw (UTF-8) byte sequence of the input string
|
865 |
+
for (char *c = text; *c != '\0'; c++) {
|
866 |
+
|
867 |
+
// reset buffer if the current byte is ASCII or a leading byte
|
868 |
+
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
|
869 |
+
// 0x80 is 10000000
|
870 |
+
// in UTF-8, all continuation bytes start with "10" in first two bits
|
871 |
+
// so in English this is: "if this byte is not a continuation byte"
|
872 |
+
if ((*c & 0xC0) != 0x80) {
|
873 |
+
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
|
874 |
+
// => reset our location, as we're starting a new UTF-8 codepoint
|
875 |
+
str_len = 0;
|
876 |
+
}
|
877 |
+
|
878 |
+
// append the current byte to the buffer
|
879 |
+
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
|
880 |
+
str_buffer[str_len] = '\0';
|
881 |
+
|
882 |
+
// while the next character is a continuation byte, continue appending
|
883 |
+
// but if there are too many of them, just stop to avoid overruning str_buffer size.
|
884 |
+
if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
|
885 |
+
continue;
|
886 |
+
}
|
887 |
+
|
888 |
+
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
889 |
+
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
890 |
+
|
891 |
+
if (id != -1) {
|
892 |
+
// we found this codepoint in vocab, add it as a token
|
893 |
+
tokens[(*n_tokens)++] = id;
|
894 |
+
} else {
|
895 |
+
// byte_fallback encoding: just encode each byte as a token
|
896 |
+
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
|
897 |
+
// so the individual bytes only start at index 3
|
898 |
+
for (int i=0; i < str_len; i++) {
|
899 |
+
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
|
900 |
+
}
|
901 |
+
}
|
902 |
+
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
|
903 |
+
}
|
904 |
+
|
905 |
+
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
906 |
+
while (1) {
|
907 |
+
float best_score = -1e10;
|
908 |
+
int best_id = -1;
|
909 |
+
int best_idx = -1;
|
910 |
+
|
911 |
+
for (int i=0; i < (*n_tokens-1); i++) {
|
912 |
+
// check if we can merge the pair (tokens[i], tokens[i+1])
|
913 |
+
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
|
914 |
+
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
915 |
+
if (id != -1 && t->vocab_scores[id] > best_score) {
|
916 |
+
// this merge pair exists in vocab! record its score and position
|
917 |
+
best_score = t->vocab_scores[id];
|
918 |
+
best_id = id;
|
919 |
+
best_idx = i;
|
920 |
+
}
|
921 |
+
}
|
922 |
+
|
923 |
+
if (best_idx == -1) {
|
924 |
+
break; // we couldn't find any more pairs to merge, so we're done
|
925 |
+
}
|
926 |
+
|
927 |
+
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
|
928 |
+
tokens[best_idx] = best_id;
|
929 |
+
// delete token at position best_idx+1, shift the entire sequence back 1
|
930 |
+
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
|
931 |
+
tokens[i] = tokens[i+1];
|
932 |
+
}
|
933 |
+
(*n_tokens)--; // token length decreased
|
934 |
+
}
|
935 |
+
|
936 |
+
// add optional EOS (=2) token, if desired
|
937 |
+
if (eos) tokens[(*n_tokens)++] = 2;
|
938 |
+
|
939 |
+
free(str_buffer);
|
940 |
+
}
|
941 |
+
|
942 |
+
// ----------------------------------------------------------------------------
|
943 |
+
// The Sampler, which takes logits and returns a sampled token
|
944 |
+
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
945 |
+
|
946 |
+
typedef struct {
|
947 |
+
float prob;
|
948 |
+
int index;
|
949 |
+
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
950 |
+
|
951 |
+
typedef struct {
|
952 |
+
int vocab_size;
|
953 |
+
ProbIndex* probindex; // buffer used in top-p sampling
|
954 |
+
float temperature;
|
955 |
+
float topp;
|
956 |
+
unsigned long long rng_state;
|
957 |
+
} Sampler;
|
958 |
+
|
959 |
+
int sample_argmax(float* probabilities, int n) {
|
960 |
+
// return the index that has the highest probability
|
961 |
+
int max_i = 0;
|
962 |
+
float max_p = probabilities[0];
|
963 |
+
for (int i = 1; i < n; i++) {
|
964 |
+
if (probabilities[i] > max_p) {
|
965 |
+
max_i = i;
|
966 |
+
max_p = probabilities[i];
|
967 |
+
}
|
968 |
+
}
|
969 |
+
return max_i;
|
970 |
+
}
|
971 |
+
|
972 |
+
int sample_mult(float* probabilities, int n, float coin) {
|
973 |
+
// sample index from probabilities (they must sum to 1!)
|
974 |
+
// coin is a random number in [0, 1), usually from random_f32()
|
975 |
+
float cdf = 0.0f;
|
976 |
+
for (int i = 0; i < n; i++) {
|
977 |
+
cdf += probabilities[i];
|
978 |
+
if (coin < cdf) {
|
979 |
+
return i;
|
980 |
+
}
|
981 |
+
}
|
982 |
+
return n - 1; // in case of rounding errors
|
983 |
+
}
|
984 |
+
|
985 |
+
int compare(const void* a, const void* b) {
|
986 |
+
ProbIndex* a_ = (ProbIndex*) a;
|
987 |
+
ProbIndex* b_ = (ProbIndex*) b;
|
988 |
+
if (a_->prob > b_->prob) return -1;
|
989 |
+
if (a_->prob < b_->prob) return 1;
|
990 |
+
return 0;
|
991 |
+
}
|
992 |
+
|
993 |
+
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
|
994 |
+
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
995 |
+
// tokens that exceed probability topp. This way we never sample tokens that
|
996 |
+
// have very low probabilities and are less likely to go "off the rails".
|
997 |
+
// coin is a random number in [0, 1), usually from random_f32()
|
998 |
+
|
999 |
+
int n0 = 0;
|
1000 |
+
// quicksort indices in descending order of probabilities
|
1001 |
+
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
|
1002 |
+
// so for efficiency we crop these out as candidates before sorting
|
1003 |
+
const float cutoff = (1.0f - topp) / (n - 1);
|
1004 |
+
for (int i = 0; i < n; i++) {
|
1005 |
+
if (probabilities[i] >= cutoff) {
|
1006 |
+
probindex[n0].index = i;
|
1007 |
+
probindex[n0].prob = probabilities[i];
|
1008 |
+
n0++;
|
1009 |
+
}
|
1010 |
+
}
|
1011 |
+
qsort(probindex, n0, sizeof(ProbIndex), compare);
|
1012 |
+
|
1013 |
+
// truncate the list where cumulative probability exceeds topp
|
1014 |
+
float cumulative_prob = 0.0f;
|
1015 |
+
int last_idx = n0 - 1; // in case of rounding errors consider all elements
|
1016 |
+
for (int i = 0; i < n0; i++) {
|
1017 |
+
cumulative_prob += probindex[i].prob;
|
1018 |
+
if (cumulative_prob > topp) {
|
1019 |
+
last_idx = i;
|
1020 |
+
break; // we've exceeded topp by including last_idx
|
1021 |
+
}
|
1022 |
+
}
|
1023 |
+
|
1024 |
+
// sample from the truncated list
|
1025 |
+
float r = coin * cumulative_prob;
|
1026 |
+
float cdf = 0.0f;
|
1027 |
+
for (int i = 0; i <= last_idx; i++) {
|
1028 |
+
cdf += probindex[i].prob;
|
1029 |
+
if (r < cdf) {
|
1030 |
+
return probindex[i].index;
|
1031 |
+
}
|
1032 |
+
}
|
1033 |
+
return probindex[last_idx].index; // in case of rounding errors
|
1034 |
+
}
|
1035 |
+
|
1036 |
+
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
|
1037 |
+
sampler->vocab_size = vocab_size;
|
1038 |
+
sampler->temperature = temperature;
|
1039 |
+
sampler->topp = topp;
|
1040 |
+
sampler->rng_state = rng_seed;
|
1041 |
+
// buffer only used with nucleus sampling; may not need but it's ~small
|
1042 |
+
sampler->probindex = (ProbIndex *)malloc(sampler->vocab_size * sizeof(ProbIndex));
|
1043 |
+
}
|
1044 |
+
|
1045 |
+
void free_sampler(Sampler* sampler) {
|
1046 |
+
free(sampler->probindex);
|
1047 |
+
sampler->probindex = NULL;
|
1048 |
+
}
|
1049 |
+
|
1050 |
+
unsigned int random_u32(unsigned long long *state) {
|
1051 |
+
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
1052 |
+
*state ^= *state >> 12;
|
1053 |
+
*state ^= *state << 25;
|
1054 |
+
*state ^= *state >> 27;
|
1055 |
+
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
|
1056 |
+
}
|
1057 |
+
float random_f32(unsigned long long *state) { // random float32 in [0,1)
|
1058 |
+
return (random_u32(state) >> 8) / 16777216.0f;
|
1059 |
+
}
|
1060 |
+
|
1061 |
+
int sample(Sampler* sampler, float* logits) {
|
1062 |
+
// sample the token given the logits and some hyperparameters
|
1063 |
+
int next;
|
1064 |
+
if (sampler->temperature == 0.0f) {
|
1065 |
+
// greedy argmax sampling: take the token with the highest probability
|
1066 |
+
next = sample_argmax(logits, sampler->vocab_size);
|
1067 |
+
} else {
|
1068 |
+
// apply the temperature to the logits
|
1069 |
+
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
1070 |
+
// apply softmax to the logits to get the probabilities for next token
|
1071 |
+
softmax(logits, sampler->vocab_size);
|
1072 |
+
// flip a (float) coin (this is our source of entropy for sampling)
|
1073 |
+
float coin = random_f32(&sampler->rng_state);
|
1074 |
+
// we sample from this distribution to get the next token
|
1075 |
+
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
1076 |
+
// simply sample from the predicted probability distribution
|
1077 |
+
next = sample_mult(logits, sampler->vocab_size, coin);
|
1078 |
+
} else {
|
1079 |
+
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
1080 |
+
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
|
1081 |
+
}
|
1082 |
+
}
|
1083 |
+
return next;
|
1084 |
+
}
|
1085 |
+
|
1086 |
+
// ----------------------------------------------------------------------------
|
1087 |
+
// utilities: time
|
1088 |
+
|
1089 |
+
long time_in_ms() {
|
1090 |
+
// return time in milliseconds, for benchmarking the model speed
|
1091 |
+
struct timespec time;
|
1092 |
+
clock_gettime(CLOCK_REALTIME, &time);
|
1093 |
+
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
1094 |
+
}
|
1095 |
+
|
1096 |
+
// ----------------------------------------------------------------------------
|
1097 |
+
// generation loop
|
1098 |
+
|
1099 |
+
// void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
1100 |
+
// char *empty_prompt = (char *)"";
|
1101 |
+
// if (prompt == NULL) { prompt = empty_prompt; }
|
1102 |
+
|
1103 |
+
// // encode the (string) prompt into tokens sequence
|
1104 |
+
// int num_prompt_tokens = 0;
|
1105 |
+
// int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
|
1106 |
+
// encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
1107 |
+
// if (num_prompt_tokens < 1) {
|
1108 |
+
// fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
|
1109 |
+
// exit(EXIT_FAILURE);
|
1110 |
+
// }
|
1111 |
+
|
1112 |
+
// // start the main loop
|
1113 |
+
// long start = 0; // used to time our code, only initialized after first iteration
|
1114 |
+
// int next; // will store the next token in the sequence
|
1115 |
+
// int token = prompt_tokens[0]; // kick off with the first token in the prompt
|
1116 |
+
// int pos = 0; // position in the sequence
|
1117 |
+
// while (pos < steps) {
|
1118 |
+
|
1119 |
+
// // forward the transformer to get logits for the next token
|
1120 |
+
// float* logits = forward(transformer, token, pos);
|
1121 |
+
|
1122 |
+
// // advance the state machine
|
1123 |
+
// if (pos < num_prompt_tokens - 1) {
|
1124 |
+
// // if we are still processing the input prompt, force the next prompt token
|
1125 |
+
// next = prompt_tokens[pos + 1];
|
1126 |
+
// } else {
|
1127 |
+
// // otherwise sample the next token from the logits
|
1128 |
+
// next = sample(sampler, logits);
|
1129 |
+
// }
|
1130 |
+
// pos++;
|
1131 |
+
|
1132 |
+
// // data-dependent terminating condition: the BOS (=1) token delimits sequences
|
1133 |
+
// if (next == 1) { break; }
|
1134 |
+
|
1135 |
+
// // print the token as string, decode it with the Tokenizer object
|
1136 |
+
// char* piece = decode(tokenizer, token, next);
|
1137 |
+
// safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
1138 |
+
// fflush(stdout);
|
1139 |
+
// token = next;
|
1140 |
+
|
1141 |
+
// // init the timer here because the first iteration can be slower
|
1142 |
+
// if (start == 0) { start = time_in_ms(); }
|
1143 |
+
// }
|
1144 |
+
// printf("\n");
|
1145 |
+
|
1146 |
+
// // report achieved tok/s (pos-1 because the timer starts after first iteration)
|
1147 |
+
// if (pos > 1) {
|
1148 |
+
// long end = time_in_ms();
|
1149 |
+
// fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
1150 |
+
// }
|
1151 |
+
|
1152 |
+
// free(prompt_tokens);
|
1153 |
+
// }
|
1154 |
+
|
1155 |
+
// void read_stdin(const char* guide, char* buffer, size_t bufsize) {
|
1156 |
+
// // read a line from stdin, up to but not including \n
|
1157 |
+
// printf("%s", guide);
|
1158 |
+
// if (fgets(buffer, bufsize, stdin) != NULL) {
|
1159 |
+
// size_t len = strlen(buffer);
|
1160 |
+
// if (len > 0 && buffer[len - 1] == '\n') {
|
1161 |
+
// buffer[len - 1] = '\0'; // strip newline
|
1162 |
+
// }
|
1163 |
+
// }
|
1164 |
+
// }
|
1165 |
+
|
1166 |
+
// // ----------------------------------------------------------------------------
|
1167 |
+
// // chat loop
|
1168 |
+
// // I manually inspected the tokens for a few chat conversations compared to
|
1169 |
+
// // python reference and that seemed ok, but this was not thoroughly tested and
|
1170 |
+
// // is not safely implemented, it's more a proof of concept atm.
|
1171 |
+
|
1172 |
+
// void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
1173 |
+
// char *cli_user_prompt, char *cli_system_prompt, int steps) {
|
1174 |
+
|
1175 |
+
// // buffers for reading the system prompt and user prompt from stdin
|
1176 |
+
// // you'll notice they are soomewhat haphazardly and unsafely set atm
|
1177 |
+
// char system_prompt[512];
|
1178 |
+
// char user_prompt[512];
|
1179 |
+
// char rendered_prompt[1152];
|
1180 |
+
// int num_prompt_tokens = 0;
|
1181 |
+
// int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
|
1182 |
+
// int user_idx;
|
1183 |
+
|
1184 |
+
// // start the main loop
|
1185 |
+
// int8_t user_turn = 1; // user starts
|
1186 |
+
// int next; // will store the next token in the sequence
|
1187 |
+
// int token; // stores the current token to feed into the transformer
|
1188 |
+
// int prev_token;
|
1189 |
+
// int pos = 0; // position in the sequence
|
1190 |
+
// while (pos < steps) {
|
1191 |
+
|
1192 |
+
// // when it is the user's turn to contribute tokens to the dialog...
|
1193 |
+
// if (user_turn) {
|
1194 |
+
// // get the (optional) system prompt at position 0
|
1195 |
+
// if (pos == 0) {
|
1196 |
+
// // at position 0, the user can also contribute a system prompt
|
1197 |
+
// if (cli_system_prompt == NULL) {
|
1198 |
+
// // system prompt was not passed in, attempt to get it from stdin
|
1199 |
+
// read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
|
1200 |
+
// } else {
|
1201 |
+
// // system prompt was passed in, use it
|
1202 |
+
// strcpy(system_prompt, cli_system_prompt);
|
1203 |
+
// }
|
1204 |
+
// }
|
1205 |
+
// // get the user prompt
|
1206 |
+
// if (pos == 0 && cli_user_prompt != NULL) {
|
1207 |
+
// // user prompt for position 0 was passed in, use it
|
1208 |
+
// strcpy(user_prompt, cli_user_prompt);
|
1209 |
+
// } else {
|
1210 |
+
// // otherwise get user prompt from stdin
|
1211 |
+
// read_stdin("User: ", user_prompt, sizeof(user_prompt));
|
1212 |
+
// }
|
1213 |
+
// // render user/system prompts into the Llama 2 Chat schema
|
1214 |
+
// if (pos == 0 && system_prompt[0] != '\0') {
|
1215 |
+
// char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
|
1216 |
+
// sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
|
1217 |
+
// } else {
|
1218 |
+
// char user_template[] = "[INST] %s [/INST]";
|
1219 |
+
// sprintf(rendered_prompt, user_template, user_prompt);
|
1220 |
+
// }
|
1221 |
+
// // encode the rendered prompt into tokens
|
1222 |
+
// encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
1223 |
+
// user_idx = 0; // reset the user index
|
1224 |
+
// user_turn = 0;
|
1225 |
+
// printf("Assistant: ");
|
1226 |
+
// }
|
1227 |
+
|
1228 |
+
// // determine the token to pass into the transformer next
|
1229 |
+
// if (user_idx < num_prompt_tokens) {
|
1230 |
+
// // if we are still processing the input prompt, force the next prompt token
|
1231 |
+
// token = prompt_tokens[user_idx++];
|
1232 |
+
// } else {
|
1233 |
+
// // otherwise use the next token sampled from previous turn
|
1234 |
+
// token = next;
|
1235 |
+
// }
|
1236 |
+
// // EOS (=2) token ends the Assistant turn
|
1237 |
+
// if (token == 2) { user_turn = 1; }
|
1238 |
+
|
1239 |
+
// // forward the transformer to get logits for the next token
|
1240 |
+
// float* logits = forward(transformer, token, pos);
|
1241 |
+
// next = sample(sampler, logits);
|
1242 |
+
// pos++;
|
1243 |
+
|
1244 |
+
// if (user_idx >= num_prompt_tokens && next != 2) {
|
1245 |
+
// // the Assistant is responding, so print its output
|
1246 |
+
// char* piece = decode(tokenizer, token, next);
|
1247 |
+
// safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
1248 |
+
// fflush(stdout);
|
1249 |
+
// }
|
1250 |
+
// if (next == 2) { printf("\n"); }
|
1251 |
+
// }
|
1252 |
+
// printf("\n");
|
1253 |
+
// free(prompt_tokens);
|
1254 |
+
// }
|
1255 |
+
|
1256 |
+
typedef struct {
|
1257 |
+
Transformer transformer;
|
1258 |
+
Tokenizer tokenizer;
|
1259 |
+
Sampler sampler;
|
1260 |
+
int *output; // buffer to store the output tokens(max_tokens + 1)
|
1261 |
+
int output_idx; // current index in the output buffer(0 ... max_tokens - 1)
|
1262 |
+
int gen_idx; // generated tokens(0 ... max_tokens)
|
1263 |
+
int finished;
|
1264 |
+
#ifdef USE_CUDA
|
1265 |
+
cublasHandle_t g_cublas_handle;
|
1266 |
+
#endif
|
1267 |
+
} llama2_ctx;
|
1268 |
+
|
1269 |
+
void *llama2_init(char *model_path, char *tokenizer_path) {
|
1270 |
+
llama2_ctx *ctx = (llama2_ctx *)malloc(sizeof(llama2_ctx));
|
1271 |
+
build_transformer(&ctx->transformer, model_path);
|
1272 |
+
build_tokenizer(&ctx->tokenizer, tokenizer_path, ctx->transformer.config.vocab_size);
|
1273 |
+
ctx->output = NULL;
|
1274 |
+
#ifdef USE_CUDA
|
1275 |
+
cublasStatus_t stat = cublasCreate(&ctx->g_cublas_handle); // FIXME cublasDestroy
|
1276 |
+
if (stat != CUBLAS_STATUS_SUCCESS) {
|
1277 |
+
printf ("CUBLAS initialization failed\n");
|
1278 |
+
exit(EXIT_FAILURE);
|
1279 |
+
}
|
1280 |
+
#endif
|
1281 |
+
return ctx;
|
1282 |
+
}
|
1283 |
+
|
1284 |
+
void llama2_free(void *ctx) {
|
1285 |
+
llama2_ctx *c = (llama2_ctx *)ctx;
|
1286 |
+
free_transformer(&c->transformer);
|
1287 |
+
free_tokenizer(&c->tokenizer);
|
1288 |
+
if (c->sampler.probindex != NULL)
|
1289 |
+
free_sampler(&c->sampler);
|
1290 |
+
#ifdef USE_CUDA
|
1291 |
+
cublasStatus_t stat = cublasDestroy(c->g_cublas_handle);
|
1292 |
+
if (stat != CUBLAS_STATUS_SUCCESS) {
|
1293 |
+
printf ("CUBLAS destroy failed\n");
|
1294 |
+
exit(EXIT_FAILURE);
|
1295 |
+
}
|
1296 |
+
#endif
|
1297 |
+
if (c->output != NULL)
|
1298 |
+
free(c->output);
|
1299 |
+
}
|
1300 |
+
|
1301 |
+
void llama2_generate_loop(llama2_ctx *ctx, int *prompt_tokens, int num_prompt_tokens, int steps, int *output_tokens) {
|
1302 |
+
// printf("generate loop started\n");
|
1303 |
+
// start the main loop
|
1304 |
+
// long start = 0; // used to time our code, only initialized after first iteration
|
1305 |
+
int next; // will store the next token in the sequence
|
1306 |
+
int token = prompt_tokens[0]; // kick off with the first token in the prompt
|
1307 |
+
int pos = 0; // position in the sequence
|
1308 |
+
while (pos < steps) {
|
1309 |
+
|
1310 |
+
// forward the transformer to get logits for the next token
|
1311 |
+
#ifdef USE_CUDA
|
1312 |
+
float* logits = forward(&ctx->transformer, token, pos, ctx->g_cublas_handle);
|
1313 |
+
#else
|
1314 |
+
float* logits = forward(&ctx->transformer, token, pos);
|
1315 |
+
#endif
|
1316 |
+
// advance the state machine
|
1317 |
+
if (pos < num_prompt_tokens - 1) {
|
1318 |
+
// if we are still processing the input prompt, force the next prompt token
|
1319 |
+
next = prompt_tokens[pos + 1];
|
1320 |
+
} else {
|
1321 |
+
// otherwise sample the next token from the logits
|
1322 |
+
next = sample(&ctx->sampler, logits);
|
1323 |
+
}
|
1324 |
+
// printf("current gen idx: %d, %d\n", ctx->gen_idx, next);
|
1325 |
+
if (pos == num_prompt_tokens - 1)
|
1326 |
+
output_tokens[ctx->gen_idx] = token;
|
1327 |
+
if (pos >= num_prompt_tokens - 1)
|
1328 |
+
output_tokens[ctx->gen_idx++ + 1] = next;
|
1329 |
+
pos++;
|
1330 |
+
token = next;
|
1331 |
+
|
1332 |
+
// EOS (=2) token ends the Assistant turn
|
1333 |
+
if (next == 2)
|
1334 |
+
break;
|
1335 |
+
}
|
1336 |
+
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
1337 |
+
// if (pos > 1) {
|
1338 |
+
// long end = time_in_ms();
|
1339 |
+
// fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
1340 |
+
// }
|
1341 |
+
ctx->finished = 1;
|
1342 |
+
free(prompt_tokens);
|
1343 |
+
free_sampler(&ctx->sampler);
|
1344 |
+
// printf("generate loop finished\n");
|
1345 |
+
}
|
1346 |
+
|
1347 |
+
int llama2_generate(void *ctx, char *prompt, int steps, float temperature, float topp, int seed) {
|
1348 |
+
llama2_ctx *c = (llama2_ctx *)ctx;
|
1349 |
+
build_sampler(&c->sampler, c->transformer.config.vocab_size, temperature, topp, seed);
|
1350 |
+
char *empty_prompt = (char *)"";
|
1351 |
+
if (prompt == NULL) { prompt = empty_prompt; }
|
1352 |
+
// encode the (string) prompt into tokens sequence
|
1353 |
+
int num_prompt_tokens = 0;
|
1354 |
+
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
|
1355 |
+
encode(&c->tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
1356 |
+
if (num_prompt_tokens < 1) {
|
1357 |
+
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
|
1358 |
+
return 1;
|
1359 |
+
}
|
1360 |
+
if (num_prompt_tokens >= steps) {
|
1361 |
+
fprintf(stderr, "prompt tokens exceeds max token length\n");
|
1362 |
+
return 1;
|
1363 |
+
}
|
1364 |
+
c->output = (int *)malloc((steps + 1) * sizeof(int));
|
1365 |
+
c->gen_idx = 0;
|
1366 |
+
c->output_idx = 0;
|
1367 |
+
c->finished = 0;
|
1368 |
+
std::thread t(llama2_generate_loop, c, prompt_tokens, num_prompt_tokens, steps, c->output);
|
1369 |
+
t.detach();
|
1370 |
+
return 0;
|
1371 |
+
}
|
1372 |
+
|
1373 |
+
char *llama2_get_last(void *ctx) {
|
1374 |
+
llama2_ctx *c = (llama2_ctx *)ctx;
|
1375 |
+
assert(c->output != NULL); // shouldn't be called again after finished
|
1376 |
+
while(!c->finished && c->output_idx >= c->gen_idx) {
|
1377 |
+
// printf("current idx: %d, %d\n", c->output_idx, c->gen_idx);
|
1378 |
+
usleep(100000);
|
1379 |
+
} // wait for next token
|
1380 |
+
if (c->finished && c->output_idx >= c->gen_idx) {
|
1381 |
+
free(c->output);
|
1382 |
+
c->output = NULL;
|
1383 |
+
return NULL;
|
1384 |
+
}
|
1385 |
+
// printf("current idx: %d, %d, finished:%d\n", c->output_idx, c->gen_idx, c->finished);
|
1386 |
+
char *piece = decode(&c->tokenizer, c->output[c->output_idx], c->output[c->output_idx + 1]);
|
1387 |
+
c->output_idx++;
|
1388 |
+
return piece;
|
1389 |
+
}
|
1390 |
+
|
1391 |
+
void llama2_tokenize(void *ctx, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
|
1392 |
+
llama2_ctx *c = (llama2_ctx *)ctx;
|
1393 |
+
encode(&c->tokenizer, text, bos, eos, tokens, n_tokens);
|
1394 |
+
}
|
llama2_cu_python/llama2.h
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef __LLAMA2_H__
|
2 |
+
#define __LLAMA2_H__
|
3 |
+
#include <stdint.h>
|
4 |
+
|
5 |
+
#ifdef __cplusplus
|
6 |
+
extern "C" {
|
7 |
+
#endif
|
8 |
+
|
9 |
+
void *llama2_init(char *model_path, char *tokenizer_path);
|
10 |
+
|
11 |
+
void llama2_free(void *ctx);
|
12 |
+
|
13 |
+
int llama2_generate(void *ctx, char *prompt, int steps, float temperature, float topp, int seed);
|
14 |
+
|
15 |
+
char *llama2_get_last(void *ctx);
|
16 |
+
|
17 |
+
void llama2_tokenize(void *ctx, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens);
|
18 |
+
|
19 |
+
#ifdef __cplusplus
|
20 |
+
}
|
21 |
+
#endif // __cplusplus
|
22 |
+
|
23 |
+
#endif // __LLAMA2_H__
|
llama2_cu_python/llama2_cu.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Generator, Iterator, List, Optional, Union
|
2 |
+
import ctypes
|
3 |
+
from ctypes import (
|
4 |
+
c_bool,
|
5 |
+
c_char_p,
|
6 |
+
c_int,
|
7 |
+
c_int8,
|
8 |
+
c_int32,
|
9 |
+
c_uint8,
|
10 |
+
c_uint32,
|
11 |
+
c_size_t,
|
12 |
+
c_float,
|
13 |
+
c_double,
|
14 |
+
c_void_p,
|
15 |
+
POINTER,
|
16 |
+
_Pointer, # type: ignore
|
17 |
+
Structure,
|
18 |
+
Array,
|
19 |
+
)
|
20 |
+
import pathlib
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
# Load the library
|
25 |
+
def _load_shared_library(lib_base_name: str):
|
26 |
+
# Construct the paths to the possible shared library names
|
27 |
+
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
|
28 |
+
# Searching for the library in the current directory under the name "libllama2" (default name
|
29 |
+
# for llama2.cu) and "llama" (default name for this repo)
|
30 |
+
_lib_paths: List[pathlib.Path] = []
|
31 |
+
# Determine the file extension based on the platform
|
32 |
+
if sys.platform.startswith("linux"):
|
33 |
+
_lib_paths += [
|
34 |
+
_base_path / f"lib{lib_base_name}.so",
|
35 |
+
]
|
36 |
+
else:
|
37 |
+
raise RuntimeError("Unsupported platform")
|
38 |
+
|
39 |
+
if "LLAMA2_CU_LIB" in os.environ:
|
40 |
+
lib_base_name = os.environ["LLAMA2_CU_LIB"]
|
41 |
+
_lib = pathlib.Path(lib_base_name)
|
42 |
+
_base_path = _lib.parent.resolve()
|
43 |
+
_lib_paths = [_lib.resolve()]
|
44 |
+
|
45 |
+
cdll_args = dict() # type: ignore
|
46 |
+
# Add the library directory to the DLL search path on Windows (if needed)
|
47 |
+
|
48 |
+
# Try to load the shared library, handling potential errors
|
49 |
+
for _lib_path in _lib_paths:
|
50 |
+
if _lib_path.exists():
|
51 |
+
try:
|
52 |
+
return ctypes.CDLL(str(_lib_path), **cdll_args)
|
53 |
+
except Exception as e:
|
54 |
+
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
55 |
+
|
56 |
+
raise FileNotFoundError(
|
57 |
+
f"Shared library with base name '{lib_base_name}' not found"
|
58 |
+
)
|
59 |
+
|
60 |
+
# Specify the base name of the shared library to load
|
61 |
+
_lib_base_name = "llama2"
|
62 |
+
|
63 |
+
# Load the library
|
64 |
+
_lib = _load_shared_library(_lib_base_name)
|
65 |
+
|
66 |
+
|
67 |
+
def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p:
|
68 |
+
return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8'))
|
69 |
+
|
70 |
+
_lib.llama2_init.argtypes = [c_char_p, c_char_p]
|
71 |
+
_lib.llama2_init.restype = c_void_p
|
72 |
+
|
73 |
+
def llama2_free(ctx: c_void_p) -> None:
|
74 |
+
_lib.llama2_free(ctx)
|
75 |
+
|
76 |
+
_lib.llama2_free.argtypes = [c_void_p]
|
77 |
+
_lib.llama2_free.restype = None
|
78 |
+
|
79 |
+
def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int:
|
80 |
+
return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed)
|
81 |
+
|
82 |
+
_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int]
|
83 |
+
_lib.llama2_generate.restype = c_int
|
84 |
+
|
85 |
+
def llama2_get_last(ctx: c_void_p) -> bytes:
|
86 |
+
return _lib.llama2_get_last(ctx) # bytes or None
|
87 |
+
|
88 |
+
_lib.llama2_get_last.argtypes = [c_void_p]
|
89 |
+
_lib.llama2_get_last.restype = c_char_p
|
90 |
+
|
91 |
+
def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]:
|
92 |
+
tokens = (c_int * (len(text) + 3))()
|
93 |
+
n_tokens = (c_int * 1)()
|
94 |
+
_lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens)
|
95 |
+
return tokens[:n_tokens[0]]
|
96 |
+
|
97 |
+
_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)]
|
98 |
+
_lib.llama2_tokenize.restype = None
|
99 |
+
|
100 |
+
class Llama2:
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
model_path: str,
|
104 |
+
tokenizer_path: str='tokenizer.bin',
|
105 |
+
n_ctx: int = 0,
|
106 |
+
n_batch: int = 0) -> None:
|
107 |
+
self.n_ctx = n_ctx
|
108 |
+
self.n_batch = n_batch
|
109 |
+
self.llama2_ctx = llama2_init(model_path, tokenizer_path)
|
110 |
+
|
111 |
+
def tokenize(
|
112 |
+
self, text: str, add_bos: bool = True, add_eos: bool = False
|
113 |
+
) -> List[int]:
|
114 |
+
return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos)
|
115 |
+
|
116 |
+
def __call__(
|
117 |
+
self,
|
118 |
+
prompt: str,
|
119 |
+
max_tokens: int = 128,
|
120 |
+
temperature: float = 0.8,
|
121 |
+
top_p: float = 0.95,
|
122 |
+
min_p: float = 0.05,
|
123 |
+
typical_p: float = 1.0,
|
124 |
+
logprobs: Optional[int] = None,
|
125 |
+
frequency_penalty: float = 0.0,
|
126 |
+
presence_penalty: float = 0.0,
|
127 |
+
repeat_penalty: float = 1.1,
|
128 |
+
top_k: int = 40,
|
129 |
+
stream: bool = False,
|
130 |
+
seed: Optional[int] = None,
|
131 |
+
) -> Iterator[str]:
|
132 |
+
if seed is None:
|
133 |
+
seed = 42
|
134 |
+
ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed)
|
135 |
+
if ret != 0:
|
136 |
+
raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'")
|
137 |
+
bytes_buffer = b'' # store generated bytes until decoded (in case of multi-byte characters)
|
138 |
+
while True:
|
139 |
+
result = llama2_get_last(self.llama2_ctx)
|
140 |
+
if result is None:
|
141 |
+
break
|
142 |
+
bytes_buffer += result
|
143 |
+
try:
|
144 |
+
string = bytes_buffer.decode('utf-8')
|
145 |
+
except UnicodeDecodeError:
|
146 |
+
pass
|
147 |
+
else:
|
148 |
+
bytes_buffer = b''
|
149 |
+
yield string
|
150 |
+
|
151 |
+
|
llama2_wrapper/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import LLAMA2_WRAPPER, get_prompt, get_prompt_for_dialog
|
llama2_wrapper/download/__init__.py
ADDED
File without changes
|
llama2_wrapper/download/__main__.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
parser.add_argument(
|
8 |
+
"--repo_id",
|
9 |
+
type=str,
|
10 |
+
default="",
|
11 |
+
required=True,
|
12 |
+
help="Repo ID like 'TheBloke/Llama-2-7B-Chat-GGML' ",
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--filename",
|
16 |
+
type=str,
|
17 |
+
default=None,
|
18 |
+
help="Filename like llama-2-7b-chat.ggmlv3.q4_0.bin",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--save_dir", type=str, default="./models", help="Directory to save models"
|
22 |
+
)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
repo_id = args.repo_id
|
27 |
+
save_dir = args.save_dir
|
28 |
+
|
29 |
+
if not os.path.exists(save_dir):
|
30 |
+
os.makedirs(save_dir)
|
31 |
+
|
32 |
+
if args.filename:
|
33 |
+
filename = args.filename
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
|
36 |
+
print(f"Start downloading model {repo_id} {filename} to: {save_dir}")
|
37 |
+
|
38 |
+
hf_hub_download(
|
39 |
+
repo_id=repo_id,
|
40 |
+
filename=filename,
|
41 |
+
local_dir=save_dir,
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
repo_name = repo_id.split("/")[1]
|
45 |
+
save_path = os.path.join(save_dir, repo_name)
|
46 |
+
if not os.path.exists(save_path):
|
47 |
+
os.makedirs(save_path)
|
48 |
+
print(f"Start downloading model {repo_id} to: {save_path}")
|
49 |
+
|
50 |
+
from huggingface_hub import snapshot_download
|
51 |
+
|
52 |
+
snapshot_download(
|
53 |
+
repo_id=repo_id,
|
54 |
+
local_dir=save_path,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
main()
|
llama2_wrapper/model.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import uuid
|
4 |
+
from enum import Enum
|
5 |
+
from threading import Thread
|
6 |
+
from typing import Any, Iterator, Union, List
|
7 |
+
from llama2_wrapper.types import (
|
8 |
+
Completion,
|
9 |
+
CompletionChunk,
|
10 |
+
ChatCompletion,
|
11 |
+
ChatCompletionChunk,
|
12 |
+
# ChatCompletionMessage,
|
13 |
+
Message,
|
14 |
+
B_INST,
|
15 |
+
E_INST,
|
16 |
+
B_SYS,
|
17 |
+
E_SYS,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class LLAMA2_WRAPPER:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model_path: str = "",
|
25 |
+
tokenizer_path: str = "",
|
26 |
+
backend_type: str = "llama.cpp",
|
27 |
+
max_tokens: int = 4000,
|
28 |
+
load_in_8bit: bool = True,
|
29 |
+
verbose: bool = False,
|
30 |
+
):
|
31 |
+
"""Load a llama2 model from `model_path`.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model_path: Path to the model.
|
35 |
+
backend_type: Backend for llama2, options: llama.cpp, gptq, transformers
|
36 |
+
max_tokens: Maximum context size.
|
37 |
+
load_in_8bit: Use bitsandbytes to run model in 8 bit mode (only for transformers models).
|
38 |
+
verbose: Print verbose output to stderr.
|
39 |
+
|
40 |
+
Raises:
|
41 |
+
ValueError: If the model path does not exist.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
A LLAMA2_WRAPPER instance.
|
45 |
+
"""
|
46 |
+
self.model_path = model_path
|
47 |
+
self.tokenizer_path = tokenizer_path
|
48 |
+
self.backend_type = BackendType.get_type(backend_type)
|
49 |
+
self.max_tokens = max_tokens
|
50 |
+
self.load_in_8bit = load_in_8bit
|
51 |
+
|
52 |
+
self.model = None
|
53 |
+
self.tokenizer = None
|
54 |
+
|
55 |
+
self.verbose = verbose
|
56 |
+
|
57 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
58 |
+
print("Running on backend llama.cpp.")
|
59 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
60 |
+
print("Running on backend llama2.cu.")
|
61 |
+
else:
|
62 |
+
import torch
|
63 |
+
|
64 |
+
if torch.cuda.is_available():
|
65 |
+
print("Running on GPU with backend torch transformers.")
|
66 |
+
else:
|
67 |
+
print("GPU CUDA not found.")
|
68 |
+
|
69 |
+
self.default_llamacpp_path = "./models/llama-2-7b-chat.Q4_0.gguf"
|
70 |
+
self.default_gptq_path = "./models/Llama-2-7b-Chat-GPTQ"
|
71 |
+
self.default_llama2cu_path = "./models/llama2_7b.bin"
|
72 |
+
# Download default ggml/gptq model
|
73 |
+
if self.model_path == "":
|
74 |
+
print("Model path is empty.")
|
75 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
76 |
+
print("Use default llama.cpp model path: " + self.default_llamacpp_path)
|
77 |
+
if not os.path.exists(self.default_llamacpp_path):
|
78 |
+
print("Start downloading model to: " + self.default_llamacpp_path)
|
79 |
+
from huggingface_hub import hf_hub_download
|
80 |
+
|
81 |
+
hf_hub_download(
|
82 |
+
repo_id="TheBloke/Llama-2-7b-Chat-GGUF",
|
83 |
+
filename="llama-2-7b-chat.Q4_0.gguf",
|
84 |
+
local_dir="./models/",
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
print("Model exists in ./models/llama-2-7b-chat.Q4_0.gguf.")
|
88 |
+
self.model_path = self.default_llamacpp_path
|
89 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
90 |
+
if not os.path.exists(self.default_llama2cu_path):
|
91 |
+
print("Default model not found in " + self.default_llama2cu_path)
|
92 |
+
exit(1)
|
93 |
+
else:
|
94 |
+
print("Model exists in " + self.default_llama2cu_path)
|
95 |
+
self.model_path = self.default_llama2cu_path
|
96 |
+
elif self.backend_type is BackendType.GPTQ:
|
97 |
+
print("Use default gptq model path: " + self.default_gptq_path)
|
98 |
+
if not os.path.exists(self.default_gptq_path):
|
99 |
+
print("Start downloading model to: " + self.default_gptq_path)
|
100 |
+
from huggingface_hub import snapshot_download
|
101 |
+
|
102 |
+
snapshot_download(
|
103 |
+
"TheBloke/Llama-2-7b-Chat-GPTQ",
|
104 |
+
local_dir=self.default_gptq_path,
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
print("Model exists in " + self.default_gptq_path)
|
108 |
+
self.model_path = self.default_gptq_path
|
109 |
+
|
110 |
+
self.init_tokenizer()
|
111 |
+
self.init_model()
|
112 |
+
|
113 |
+
def init_model(self):
|
114 |
+
if self.model is None:
|
115 |
+
self.model = LLAMA2_WRAPPER.create_llama2_model(
|
116 |
+
self.model_path,
|
117 |
+
self.backend_type,
|
118 |
+
self.max_tokens,
|
119 |
+
self.load_in_8bit,
|
120 |
+
self.verbose,
|
121 |
+
self.tokenizer_path,
|
122 |
+
)
|
123 |
+
if self.backend_type not in [BackendType.LLAMA_CPP, BackendType.LLAMA2_CU]:
|
124 |
+
self.model.eval()
|
125 |
+
|
126 |
+
def init_tokenizer(self):
|
127 |
+
if self.backend_type not in [BackendType.LLAMA_CPP, BackendType.LLAMA2_CU]:
|
128 |
+
if self.tokenizer is None:
|
129 |
+
self.tokenizer = LLAMA2_WRAPPER.create_llama2_tokenizer(self.model_path)
|
130 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
131 |
+
self.default_llama2cu_tokenizer = "./models/tokenizer.bin"
|
132 |
+
if not os.path.exists(self.default_llama2cu_tokenizer):
|
133 |
+
print("Default tokenizer not found in " + self.default_llama2cu_tokenizer)
|
134 |
+
exit(1)
|
135 |
+
else:
|
136 |
+
print("Tokenizer exists in " + self.default_llama2cu_tokenizer)
|
137 |
+
self.tokenizer_path = self.default_llama2cu_tokenizer
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def create_llama2_model(
|
141 |
+
cls, model_path, backend_type, max_tokens, load_in_8bit, verbose, tokenizer_path
|
142 |
+
):
|
143 |
+
if backend_type is BackendType.LLAMA_CPP:
|
144 |
+
from llama_cpp import Llama
|
145 |
+
|
146 |
+
model = Llama(
|
147 |
+
model_path=model_path,
|
148 |
+
n_ctx=max_tokens,
|
149 |
+
n_batch=max_tokens,
|
150 |
+
verbose=verbose,
|
151 |
+
)
|
152 |
+
elif backend_type is BackendType.LLAMA2_CU:
|
153 |
+
from llama2_cu_python import Llama2
|
154 |
+
|
155 |
+
model = Llama2(model_path=model_path, tokenizer_path=tokenizer_path, n_ctx=max_tokens, n_batch=max_tokens)
|
156 |
+
elif backend_type is BackendType.GPTQ:
|
157 |
+
from auto_gptq import AutoGPTQForCausalLM
|
158 |
+
|
159 |
+
model = AutoGPTQForCausalLM.from_quantized(
|
160 |
+
model_path,
|
161 |
+
use_safetensors=True,
|
162 |
+
trust_remote_code=True,
|
163 |
+
device="cuda:0",
|
164 |
+
use_triton=False,
|
165 |
+
quantize_config=None,
|
166 |
+
)
|
167 |
+
elif backend_type is BackendType.TRANSFORMERS:
|
168 |
+
import torch
|
169 |
+
from transformers import AutoModelForCausalLM
|
170 |
+
|
171 |
+
model = AutoModelForCausalLM.from_pretrained(
|
172 |
+
model_path,
|
173 |
+
device_map="auto",
|
174 |
+
torch_dtype=torch.float16,
|
175 |
+
load_in_8bit=load_in_8bit,
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
print(backend_type + "not implemented.")
|
179 |
+
return model
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def create_llama2_tokenizer(cls, model_path):
|
183 |
+
from transformers import AutoTokenizer
|
184 |
+
|
185 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
186 |
+
return tokenizer
|
187 |
+
|
188 |
+
def get_token_length(
|
189 |
+
self,
|
190 |
+
prompt: str,
|
191 |
+
) -> int:
|
192 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
193 |
+
input_ids = self.model.tokenize(bytes(prompt, "utf-8"))
|
194 |
+
return len(input_ids)
|
195 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
196 |
+
input_ids = self.model.tokenize(prompt)
|
197 |
+
return len(input_ids)
|
198 |
+
else:
|
199 |
+
input_ids = self.tokenizer([prompt], return_tensors="np")["input_ids"]
|
200 |
+
return input_ids.shape[-1]
|
201 |
+
|
202 |
+
def get_input_token_length(
|
203 |
+
self,
|
204 |
+
message: str,
|
205 |
+
chat_history: list[tuple[str, str]] = [],
|
206 |
+
system_prompt: str = "",
|
207 |
+
) -> int:
|
208 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
209 |
+
|
210 |
+
return self.get_token_length(prompt)
|
211 |
+
|
212 |
+
def generate(
|
213 |
+
self,
|
214 |
+
prompt: str,
|
215 |
+
max_new_tokens: int = 1000,
|
216 |
+
temperature: float = 0.9,
|
217 |
+
top_p: float = 1.0,
|
218 |
+
top_k: int = 40,
|
219 |
+
repetition_penalty: float = 1.0,
|
220 |
+
**kwargs: Any,
|
221 |
+
) -> Iterator[str]:
|
222 |
+
"""Create a generator of response from a prompt.
|
223 |
+
|
224 |
+
Examples:
|
225 |
+
>>> llama2_wrapper = LLAMA2_WRAPPER()
|
226 |
+
>>> prompt = get_prompt("Hi do you know Pytorch?")
|
227 |
+
>>> for response in llama2_wrapper.generate(prompt):
|
228 |
+
... print(response)
|
229 |
+
|
230 |
+
Args:
|
231 |
+
prompt: The prompt to generate text from.
|
232 |
+
max_new_tokens: The maximum number of tokens to generate.
|
233 |
+
temperature: The temperature to use for sampling.
|
234 |
+
top_p: The top-p value to use for sampling.
|
235 |
+
top_k: The top-k value to use for sampling.
|
236 |
+
repetition_penalty: The penalty to apply to repeated tokens.
|
237 |
+
kwargs: all other arguments.
|
238 |
+
|
239 |
+
Yields:
|
240 |
+
The generated text.
|
241 |
+
"""
|
242 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
243 |
+
result = self.model(
|
244 |
+
prompt=prompt,
|
245 |
+
stream=True,
|
246 |
+
max_tokens=max_new_tokens,
|
247 |
+
top_k=top_k,
|
248 |
+
top_p=top_p,
|
249 |
+
temperature=temperature,
|
250 |
+
repeat_penalty=repetition_penalty,
|
251 |
+
**kwargs,
|
252 |
+
)
|
253 |
+
outputs = []
|
254 |
+
for part in result:
|
255 |
+
text = part["choices"][0]["text"]
|
256 |
+
outputs.append(text)
|
257 |
+
yield "".join(outputs)
|
258 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
259 |
+
result = self.model(
|
260 |
+
prompt=prompt,
|
261 |
+
stream=True,
|
262 |
+
max_tokens=max_new_tokens,
|
263 |
+
top_k=top_k,
|
264 |
+
top_p=top_p,
|
265 |
+
temperature=temperature,
|
266 |
+
repeat_penalty=repetition_penalty,
|
267 |
+
**kwargs,
|
268 |
+
)
|
269 |
+
outputs = []
|
270 |
+
for part in result:
|
271 |
+
outputs.append(part)
|
272 |
+
yield "".join(outputs)
|
273 |
+
else:
|
274 |
+
from transformers import TextIteratorStreamer
|
275 |
+
|
276 |
+
inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
|
277 |
+
|
278 |
+
streamer = TextIteratorStreamer(
|
279 |
+
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
280 |
+
)
|
281 |
+
generate_kwargs = dict(
|
282 |
+
inputs,
|
283 |
+
streamer=streamer,
|
284 |
+
max_new_tokens=max_new_tokens,
|
285 |
+
temperature=temperature,
|
286 |
+
top_p=top_p,
|
287 |
+
top_k=top_k,
|
288 |
+
repetition_penalty=repetition_penalty,
|
289 |
+
# num_beams=1,
|
290 |
+
)
|
291 |
+
generate_kwargs = (
|
292 |
+
generate_kwargs if kwargs is None else {**generate_kwargs, **kwargs}
|
293 |
+
)
|
294 |
+
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
|
295 |
+
t.start()
|
296 |
+
|
297 |
+
outputs = []
|
298 |
+
for text in streamer:
|
299 |
+
outputs.append(text)
|
300 |
+
yield "".join(outputs)
|
301 |
+
|
302 |
+
def run(
|
303 |
+
self,
|
304 |
+
message: str,
|
305 |
+
chat_history: list[tuple[str, str]] = [],
|
306 |
+
system_prompt: str = "",
|
307 |
+
max_new_tokens: int = 1000,
|
308 |
+
temperature: float = 0.9,
|
309 |
+
top_p: float = 1.0,
|
310 |
+
top_k: int = 40,
|
311 |
+
repetition_penalty: float = 1.0,
|
312 |
+
) -> Iterator[str]:
|
313 |
+
"""Create a generator of response from a chat message.
|
314 |
+
Process message to llama2 prompt with chat history
|
315 |
+
and system_prompt for chatbot.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
message: The origianl chat message to generate text from.
|
319 |
+
chat_history: Chat history list from chatbot.
|
320 |
+
system_prompt: System prompt for chatbot.
|
321 |
+
max_new_tokens: The maximum number of tokens to generate.
|
322 |
+
temperature: The temperature to use for sampling.
|
323 |
+
top_p: The top-p value to use for sampling.
|
324 |
+
top_k: The top-k value to use for sampling.
|
325 |
+
repetition_penalty: The penalty to apply to repeated tokens.
|
326 |
+
kwargs: all other arguments.
|
327 |
+
|
328 |
+
Yields:
|
329 |
+
The generated text.
|
330 |
+
"""
|
331 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
332 |
+
return self.generate(
|
333 |
+
prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty
|
334 |
+
)
|
335 |
+
|
336 |
+
def __call__(
|
337 |
+
self,
|
338 |
+
prompt: str,
|
339 |
+
stream: bool = False,
|
340 |
+
max_new_tokens: int = 1000,
|
341 |
+
temperature: float = 0.9,
|
342 |
+
top_p: float = 1.0,
|
343 |
+
top_k: int = 40,
|
344 |
+
repetition_penalty: float = 1.0,
|
345 |
+
**kwargs: Any,
|
346 |
+
) -> Union[str, Iterator[str]]:
|
347 |
+
"""Generate text from a prompt.
|
348 |
+
|
349 |
+
Examples:
|
350 |
+
>>> llama2_wrapper = LLAMA2_WRAPPER()
|
351 |
+
>>> prompt = get_prompt("Hi do you know Pytorch?")
|
352 |
+
>>> print(llama2_wrapper(prompt))
|
353 |
+
|
354 |
+
Args:
|
355 |
+
prompt: The prompt to generate text from.
|
356 |
+
stream: Whether to stream the results.
|
357 |
+
max_new_tokens: The maximum number of tokens to generate.
|
358 |
+
temperature: The temperature to use for sampling.
|
359 |
+
top_p: The top-p value to use for sampling.
|
360 |
+
top_k: The top-k value to use for sampling.
|
361 |
+
repetition_penalty: The penalty to apply to repeated tokens.
|
362 |
+
kwargs: all other arguments.
|
363 |
+
|
364 |
+
Raises:
|
365 |
+
ValueError: If the requested tokens exceed the context window.
|
366 |
+
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
Generated text.
|
370 |
+
"""
|
371 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
372 |
+
completion_or_chunks = self.model.__call__(
|
373 |
+
prompt,
|
374 |
+
stream=stream,
|
375 |
+
max_tokens=max_new_tokens,
|
376 |
+
temperature=temperature,
|
377 |
+
top_p=top_p,
|
378 |
+
top_k=top_k,
|
379 |
+
repeat_penalty=repetition_penalty,
|
380 |
+
**kwargs,
|
381 |
+
)
|
382 |
+
if stream:
|
383 |
+
|
384 |
+
def chunk_generator(chunks):
|
385 |
+
for part in chunks:
|
386 |
+
chunk = part["choices"][0]["text"]
|
387 |
+
yield chunk
|
388 |
+
|
389 |
+
chunks: Iterator[str] = chunk_generator(completion_or_chunks)
|
390 |
+
return chunks
|
391 |
+
return completion_or_chunks["choices"][0]["text"]
|
392 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
393 |
+
pass # TODO
|
394 |
+
else:
|
395 |
+
inputs = self.tokenizer([prompt], return_tensors="pt").input_ids
|
396 |
+
prompt_tokens_len = len(inputs[0])
|
397 |
+
inputs = inputs.to("cuda")
|
398 |
+
generate_kwargs = dict(
|
399 |
+
inputs=inputs,
|
400 |
+
max_new_tokens=max_new_tokens,
|
401 |
+
temperature=temperature,
|
402 |
+
top_p=top_p,
|
403 |
+
top_k=top_k,
|
404 |
+
repetition_penalty=repetition_penalty,
|
405 |
+
# num_beams=1,
|
406 |
+
)
|
407 |
+
generate_kwargs = (
|
408 |
+
generate_kwargs if kwargs is None else {**generate_kwargs, **kwargs}
|
409 |
+
)
|
410 |
+
if stream:
|
411 |
+
from transformers import TextIteratorStreamer
|
412 |
+
|
413 |
+
streamer = TextIteratorStreamer(
|
414 |
+
self.tokenizer,
|
415 |
+
timeout=10.0,
|
416 |
+
skip_prompt=True,
|
417 |
+
skip_special_tokens=True,
|
418 |
+
)
|
419 |
+
generate_kwargs["streamer"] = streamer
|
420 |
+
|
421 |
+
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
|
422 |
+
t.start()
|
423 |
+
return streamer
|
424 |
+
else:
|
425 |
+
output_ids = self.model.generate(
|
426 |
+
**generate_kwargs,
|
427 |
+
)
|
428 |
+
# skip prompt, skip special tokens
|
429 |
+
output = self.tokenizer.decode(
|
430 |
+
output_ids[0][prompt_tokens_len:], skip_special_tokens=True
|
431 |
+
)
|
432 |
+
return output
|
433 |
+
|
434 |
+
def completion(
|
435 |
+
self,
|
436 |
+
prompt: str,
|
437 |
+
stream: bool = False,
|
438 |
+
max_new_tokens: int = 1000,
|
439 |
+
temperature: float = 0.9,
|
440 |
+
top_p: float = 1.0,
|
441 |
+
top_k: int = 40,
|
442 |
+
repetition_penalty: float = 1.0,
|
443 |
+
**kwargs: Any,
|
444 |
+
) -> Union[Completion, Iterator[CompletionChunk]]:
|
445 |
+
"""For OpenAI compatible API /v1/completions
|
446 |
+
Generate text from a prompt.
|
447 |
+
|
448 |
+
Examples:
|
449 |
+
>>> llama2_wrapper = LLAMA2_WRAPPER()
|
450 |
+
>>> prompt = get_prompt("Hi do you know Pytorch?")
|
451 |
+
>>> print(llm.completion(prompt))
|
452 |
+
|
453 |
+
Args:
|
454 |
+
prompt: The prompt to generate text from.
|
455 |
+
stream: Whether to stream the results.
|
456 |
+
max_new_tokens: The maximum number of tokens to generate.
|
457 |
+
temperature: The temperature to use for sampling.
|
458 |
+
top_p: The top-p value to use for sampling.
|
459 |
+
top_k: The top-k value to use for sampling.
|
460 |
+
repetition_penalty: The penalty to apply to repeated tokens.
|
461 |
+
kwargs: all other arguments.
|
462 |
+
|
463 |
+
Raises:
|
464 |
+
ValueError: If the requested tokens exceed the context window.
|
465 |
+
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
|
466 |
+
|
467 |
+
Returns:
|
468 |
+
Response object containing the generated text.
|
469 |
+
"""
|
470 |
+
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
471 |
+
created: int = int(time.time())
|
472 |
+
model_name: str = (
|
473 |
+
self.backend_type + " default model"
|
474 |
+
if self.model_path == ""
|
475 |
+
else self.model_path
|
476 |
+
)
|
477 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
478 |
+
completion_or_chunks = self.model.__call__(
|
479 |
+
prompt,
|
480 |
+
stream=stream,
|
481 |
+
max_tokens=max_new_tokens,
|
482 |
+
temperature=temperature,
|
483 |
+
top_p=top_p,
|
484 |
+
top_k=top_k,
|
485 |
+
repeat_penalty=repetition_penalty,
|
486 |
+
**kwargs,
|
487 |
+
)
|
488 |
+
if stream:
|
489 |
+
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
490 |
+
return chunks
|
491 |
+
return completion_or_chunks
|
492 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
493 |
+
pass # TODO
|
494 |
+
else:
|
495 |
+
inputs = self.tokenizer([prompt], return_tensors="pt").input_ids
|
496 |
+
prompt_tokens_len = len(inputs[0])
|
497 |
+
inputs = inputs.to("cuda")
|
498 |
+
generate_kwargs = dict(
|
499 |
+
inputs=inputs,
|
500 |
+
max_new_tokens=max_new_tokens,
|
501 |
+
temperature=temperature,
|
502 |
+
top_p=top_p,
|
503 |
+
top_k=top_k,
|
504 |
+
repetition_penalty=repetition_penalty,
|
505 |
+
# num_beams=1,
|
506 |
+
)
|
507 |
+
generate_kwargs = (
|
508 |
+
generate_kwargs if kwargs is None else {**generate_kwargs, **kwargs}
|
509 |
+
)
|
510 |
+
if stream:
|
511 |
+
from transformers import TextIteratorStreamer
|
512 |
+
|
513 |
+
streamer = TextIteratorStreamer(
|
514 |
+
self.tokenizer,
|
515 |
+
timeout=10.0,
|
516 |
+
skip_prompt=True,
|
517 |
+
skip_special_tokens=True,
|
518 |
+
)
|
519 |
+
generate_kwargs["streamer"] = streamer
|
520 |
+
|
521 |
+
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
|
522 |
+
t.start()
|
523 |
+
|
524 |
+
def chunk_generator(chunks):
|
525 |
+
for part in chunks:
|
526 |
+
yield {
|
527 |
+
"id": completion_id,
|
528 |
+
"object": "text_completion",
|
529 |
+
"created": created,
|
530 |
+
"model": model_name,
|
531 |
+
"choices": [
|
532 |
+
{
|
533 |
+
"text": part,
|
534 |
+
"index": 0,
|
535 |
+
"logprobs": None,
|
536 |
+
"finish_reason": None,
|
537 |
+
}
|
538 |
+
],
|
539 |
+
}
|
540 |
+
|
541 |
+
chunks: Iterator[CompletionChunk] = chunk_generator(streamer)
|
542 |
+
return chunks
|
543 |
+
|
544 |
+
else:
|
545 |
+
output_ids = self.model.generate(
|
546 |
+
**generate_kwargs,
|
547 |
+
)
|
548 |
+
total_tokens_len = len(output_ids[0])
|
549 |
+
output = self.tokenizer.decode(
|
550 |
+
output_ids[0][prompt_tokens_len:], skip_special_tokens=True
|
551 |
+
)
|
552 |
+
completion: Completion = {
|
553 |
+
"id": completion_id,
|
554 |
+
"object": "text_completion",
|
555 |
+
"created": created,
|
556 |
+
"model": model_name,
|
557 |
+
"choices": [
|
558 |
+
{
|
559 |
+
"text": output,
|
560 |
+
"index": 0,
|
561 |
+
"logprobs": None,
|
562 |
+
"finish_reason": None,
|
563 |
+
}
|
564 |
+
],
|
565 |
+
"usage": {
|
566 |
+
"prompt_tokens": prompt_tokens_len,
|
567 |
+
"completion_tokens": total_tokens_len - prompt_tokens_len,
|
568 |
+
"total_tokens": total_tokens_len,
|
569 |
+
},
|
570 |
+
}
|
571 |
+
return completion
|
572 |
+
|
573 |
+
def chat_completion(
|
574 |
+
self,
|
575 |
+
messages: List[Message],
|
576 |
+
stream: bool = False,
|
577 |
+
max_new_tokens: int = 1000,
|
578 |
+
temperature: float = 0.9,
|
579 |
+
top_p: float = 1.0,
|
580 |
+
top_k: int = 40,
|
581 |
+
repetition_penalty: float = 1.0,
|
582 |
+
**kwargs: Any,
|
583 |
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
584 |
+
"""For OpenAI compatible API /v1/chat/completions
|
585 |
+
Generate text from a dialog (chat history).
|
586 |
+
|
587 |
+
Examples:
|
588 |
+
>>> llama2_wrapper = LLAMA2_WRAPPER()
|
589 |
+
>>> dialog = [
|
590 |
+
{
|
591 |
+
"role":"system",
|
592 |
+
"content":"You are a helpful, respectful and honest assistant. "
|
593 |
+
},{
|
594 |
+
"role":"user",
|
595 |
+
"content":"Hi do you know Pytorch?",
|
596 |
+
},
|
597 |
+
]
|
598 |
+
>>> print(llm.chat_completion(dialog))
|
599 |
+
|
600 |
+
Args:
|
601 |
+
dialog: The dialog (chat history) to generate text from.
|
602 |
+
stream: Whether to stream the results.
|
603 |
+
max_new_tokens: The maximum number of tokens to generate.
|
604 |
+
temperature: The temperature to use for sampling.
|
605 |
+
top_p: The top-p value to use for sampling.
|
606 |
+
top_k: The top-k value to use for sampling.
|
607 |
+
repetition_penalty: The penalty to apply to repeated tokens.
|
608 |
+
kwargs: all other arguments.
|
609 |
+
|
610 |
+
Raises:
|
611 |
+
ValueError: If the requested tokens exceed the context window.
|
612 |
+
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
|
613 |
+
|
614 |
+
Returns:
|
615 |
+
Response object containing the generated text.
|
616 |
+
"""
|
617 |
+
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
618 |
+
created: int = int(time.time())
|
619 |
+
model_name: str = (
|
620 |
+
self.backend_type + " default model"
|
621 |
+
if self.model_path == ""
|
622 |
+
else self.model_path
|
623 |
+
)
|
624 |
+
if self.backend_type is BackendType.LLAMA_CPP:
|
625 |
+
completion_or_chunks = self.model.create_chat_completion(
|
626 |
+
messages,
|
627 |
+
stream=stream,
|
628 |
+
max_tokens=max_new_tokens,
|
629 |
+
temperature=temperature,
|
630 |
+
top_p=top_p,
|
631 |
+
top_k=top_k,
|
632 |
+
repeat_penalty=repetition_penalty,
|
633 |
+
**kwargs,
|
634 |
+
)
|
635 |
+
if stream:
|
636 |
+
chunks: Iterator[ChatCompletionChunk] = completion_or_chunks
|
637 |
+
return chunks
|
638 |
+
return completion_or_chunks
|
639 |
+
elif self.backend_type is BackendType.LLAMA2_CU:
|
640 |
+
pass # TODO
|
641 |
+
else:
|
642 |
+
prompt = get_prompt_for_dialog(messages)
|
643 |
+
inputs = self.tokenizer([prompt], return_tensors="pt").input_ids
|
644 |
+
prompt_tokens_len = len(inputs[0])
|
645 |
+
inputs = inputs.to("cuda")
|
646 |
+
generate_kwargs = dict(
|
647 |
+
inputs=inputs,
|
648 |
+
max_new_tokens=max_new_tokens,
|
649 |
+
temperature=temperature,
|
650 |
+
top_p=top_p,
|
651 |
+
top_k=top_k,
|
652 |
+
repetition_penalty=repetition_penalty,
|
653 |
+
# num_beams=1,
|
654 |
+
)
|
655 |
+
generate_kwargs = (
|
656 |
+
generate_kwargs if kwargs is None else {**generate_kwargs, **kwargs}
|
657 |
+
)
|
658 |
+
if stream:
|
659 |
+
from transformers import TextIteratorStreamer
|
660 |
+
|
661 |
+
streamer = TextIteratorStreamer(
|
662 |
+
self.tokenizer,
|
663 |
+
timeout=10.0,
|
664 |
+
skip_prompt=True,
|
665 |
+
skip_special_tokens=True,
|
666 |
+
)
|
667 |
+
generate_kwargs["streamer"] = streamer
|
668 |
+
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
|
669 |
+
t.start()
|
670 |
+
|
671 |
+
def chunk_generator(chunks):
|
672 |
+
yield {
|
673 |
+
"id": "chat" + completion_id,
|
674 |
+
"model": model_name,
|
675 |
+
"created": created,
|
676 |
+
"object": "chat.completion.chunk",
|
677 |
+
"choices": [
|
678 |
+
{
|
679 |
+
"index": 0,
|
680 |
+
"delta": {
|
681 |
+
"role": "assistant",
|
682 |
+
},
|
683 |
+
"finish_reason": None,
|
684 |
+
}
|
685 |
+
],
|
686 |
+
}
|
687 |
+
for part in enumerate(chunks):
|
688 |
+
yield {
|
689 |
+
"id": "chat" + completion_id,
|
690 |
+
"model": model_name,
|
691 |
+
"created": created,
|
692 |
+
"object": "chat.completion.chunk",
|
693 |
+
"choices": [
|
694 |
+
{
|
695 |
+
"index": 0,
|
696 |
+
"delta": {
|
697 |
+
"content": part,
|
698 |
+
},
|
699 |
+
"finish_reason": None,
|
700 |
+
}
|
701 |
+
],
|
702 |
+
}
|
703 |
+
|
704 |
+
chunks: Iterator[ChatCompletionChunk] = chunk_generator(streamer)
|
705 |
+
return chunks
|
706 |
+
|
707 |
+
else:
|
708 |
+
output_ids = self.model.generate(
|
709 |
+
**generate_kwargs,
|
710 |
+
)
|
711 |
+
total_tokens_len = len(output_ids[0])
|
712 |
+
output = self.tokenizer.decode(
|
713 |
+
output_ids[0][prompt_tokens_len:], skip_special_tokens=True
|
714 |
+
)
|
715 |
+
chatcompletion: ChatCompletion = {
|
716 |
+
"id": "chat" + completion_id,
|
717 |
+
"object": "chat.completion",
|
718 |
+
"created": created,
|
719 |
+
"model": model_name,
|
720 |
+
"choices": [
|
721 |
+
{
|
722 |
+
"index": 0,
|
723 |
+
"message": {
|
724 |
+
"role": "assistant",
|
725 |
+
"content": output,
|
726 |
+
},
|
727 |
+
"finish_reason": None,
|
728 |
+
}
|
729 |
+
],
|
730 |
+
"usage": {
|
731 |
+
"prompt_tokens": prompt_tokens_len,
|
732 |
+
"completion_tokens": total_tokens_len - prompt_tokens_len,
|
733 |
+
"total_tokens": total_tokens_len,
|
734 |
+
},
|
735 |
+
}
|
736 |
+
return chatcompletion
|
737 |
+
|
738 |
+
|
739 |
+
def get_prompt_for_dialog(dialog: List[Message]) -> str:
|
740 |
+
"""Process dialog (chat history) to llama2 prompt for
|
741 |
+
OpenAI compatible API /v1/chat/completions.
|
742 |
+
|
743 |
+
Examples:
|
744 |
+
>>> dialog = [
|
745 |
+
{
|
746 |
+
"role":"system",
|
747 |
+
"content":"You are a helpful, respectful and honest assistant. "
|
748 |
+
},{
|
749 |
+
"role":"user",
|
750 |
+
"content":"Hi do you know Pytorch?",
|
751 |
+
},
|
752 |
+
]
|
753 |
+
>>> prompt = get_prompt_for_dialog("Hi do you know Pytorch?")
|
754 |
+
|
755 |
+
Args:
|
756 |
+
dialog: The dialog (chat history) to generate text from.
|
757 |
+
|
758 |
+
Yields:
|
759 |
+
prompt string.
|
760 |
+
"""
|
761 |
+
# add "<<SYS>>\n{system_prompt}\n<</SYS>>\n\n" in first dialog
|
762 |
+
if dialog[0]["role"] == "system":
|
763 |
+
dialog = [
|
764 |
+
{
|
765 |
+
"role": dialog[1]["role"],
|
766 |
+
"content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"],
|
767 |
+
}
|
768 |
+
] + dialog[2:]
|
769 |
+
# check roles
|
770 |
+
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
|
771 |
+
[msg["role"] == "assistant" for msg in dialog[1::2]]
|
772 |
+
), (
|
773 |
+
"model only supports 'system', 'user' and 'assistant' roles, "
|
774 |
+
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
|
775 |
+
)
|
776 |
+
# add chat history
|
777 |
+
texts = []
|
778 |
+
for prompt, answer in zip(
|
779 |
+
dialog[::2],
|
780 |
+
dialog[1::2],
|
781 |
+
):
|
782 |
+
texts.append(
|
783 |
+
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} "
|
784 |
+
)
|
785 |
+
# check last message if role is user, then add it to prompt text
|
786 |
+
assert (
|
787 |
+
dialog[-1]["role"] == "user"
|
788 |
+
), f"Last message must be from user, got {dialog[-1]['role']}"
|
789 |
+
texts.append(f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}")
|
790 |
+
return "".join(texts)
|
791 |
+
|
792 |
+
|
793 |
+
def get_prompt(
|
794 |
+
message: str, chat_history: list[tuple[str, str]] = [], system_prompt: str = ""
|
795 |
+
) -> str:
|
796 |
+
"""Process message to llama2 prompt with chat history
|
797 |
+
and system_prompt for chatbot.
|
798 |
+
|
799 |
+
Examples:
|
800 |
+
>>> prompt = get_prompt("Hi do you know Pytorch?")
|
801 |
+
|
802 |
+
Args:
|
803 |
+
message: The origianl chat message to generate text from.
|
804 |
+
chat_history: Chat history list from chatbot.
|
805 |
+
system_prompt: System prompt for chatbot.
|
806 |
+
|
807 |
+
Yields:
|
808 |
+
prompt string.
|
809 |
+
"""
|
810 |
+
texts = [f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
|
811 |
+
for user_input, response in chat_history:
|
812 |
+
texts.append(f"{user_input.strip()} [/INST] {response.strip()} </s><s> [INST] ")
|
813 |
+
texts.append(f"{message.strip()} [/INST]")
|
814 |
+
return "".join(texts)
|
815 |
+
|
816 |
+
|
817 |
+
class BackendType(Enum):
|
818 |
+
UNKNOWN = 0
|
819 |
+
TRANSFORMERS = 1
|
820 |
+
GPTQ = 2
|
821 |
+
LLAMA_CPP = 3
|
822 |
+
LLAMA2_CU = 4
|
823 |
+
|
824 |
+
@classmethod
|
825 |
+
def get_type(cls, backend_name: str):
|
826 |
+
backend_type = None
|
827 |
+
backend_name_lower = backend_name.lower()
|
828 |
+
if "transformers" in backend_name_lower:
|
829 |
+
backend_type = BackendType.TRANSFORMERS
|
830 |
+
elif "gptq" in backend_name_lower:
|
831 |
+
backend_type = BackendType.GPTQ
|
832 |
+
elif "cpp" in backend_name_lower:
|
833 |
+
backend_type = BackendType.LLAMA_CPP
|
834 |
+
elif "cu" in backend_name_lower:
|
835 |
+
backend_type = BackendType.LLAMA2_CU
|
836 |
+
else:
|
837 |
+
raise Exception("Unknown backend: " + backend_name)
|
838 |
+
# backend_type = BackendType.UNKNOWN
|
839 |
+
return backend_type
|
llama2_wrapper/server/__init__.py
ADDED
File without changes
|
llama2_wrapper/server/__main__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Example FastAPI server for llama2_wrapper.
|
2 |
+
|
3 |
+
To run this example:
|
4 |
+
|
5 |
+
```
|
6 |
+
python3 -m llama2_wrapper.server
|
7 |
+
```
|
8 |
+
|
9 |
+
or
|
10 |
+
|
11 |
+
```
|
12 |
+
uvicorn llama2_wrapper.server.app:app --reload
|
13 |
+
```
|
14 |
+
|
15 |
+
Then visit http://localhost:8000/docs to see the interactive API docs.
|
16 |
+
|
17 |
+
"""
|
18 |
+
import os
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
import uvicorn
|
22 |
+
|
23 |
+
from llama2_wrapper.server.app import create_app, Settings
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
for name, field in Settings.model_fields.items():
|
28 |
+
description = field.description
|
29 |
+
if field.default is not None and description is not None:
|
30 |
+
description += f" (default: {field.default})"
|
31 |
+
parser.add_argument(
|
32 |
+
f"--{name}",
|
33 |
+
dest=name,
|
34 |
+
type=field.annotation if field.annotation is not None else str,
|
35 |
+
help=description,
|
36 |
+
)
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
settings = Settings(**{k: v for k, v in vars(args).items() if v is not None})
|
40 |
+
app = create_app(settings=settings)
|
41 |
+
|
42 |
+
uvicorn.run(
|
43 |
+
app,
|
44 |
+
host=os.getenv("HOST", settings.host),
|
45 |
+
port=int(os.getenv("PORT", settings.port)),
|
46 |
+
)
|
llama2_wrapper/server/app.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import multiprocessing
|
3 |
+
from re import compile, Match, Pattern
|
4 |
+
from threading import Lock
|
5 |
+
from functools import partial
|
6 |
+
from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict
|
7 |
+
from typing_extensions import TypedDict, Literal
|
8 |
+
|
9 |
+
import anyio
|
10 |
+
from anyio.streams.memory import MemoryObjectSendStream
|
11 |
+
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
12 |
+
from fastapi import Depends, FastAPI, APIRouter, Request, Response
|
13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
from fastapi.responses import JSONResponse
|
15 |
+
from fastapi.routing import APIRoute
|
16 |
+
from pydantic import BaseModel, Field
|
17 |
+
from pydantic_settings import BaseSettings
|
18 |
+
from sse_starlette.sse import EventSourceResponse
|
19 |
+
|
20 |
+
from llama2_wrapper.model import LLAMA2_WRAPPER
|
21 |
+
from llama2_wrapper.types import (
|
22 |
+
Completion,
|
23 |
+
CompletionChunk,
|
24 |
+
ChatCompletion,
|
25 |
+
ChatCompletionChunk,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class Settings(BaseSettings):
|
30 |
+
model_path: str = Field(
|
31 |
+
default="",
|
32 |
+
description="The path to the model to use for generating completions.",
|
33 |
+
)
|
34 |
+
backend_type: str = Field(
|
35 |
+
default="llama.cpp",
|
36 |
+
description="Backend for llama2, options: llama.cpp, gptq, transformers",
|
37 |
+
)
|
38 |
+
max_tokens: int = Field(default=4000, ge=1, description="Maximum context size.")
|
39 |
+
load_in_8bit: bool = Field(
|
40 |
+
default=False,
|
41 |
+
description="`Whether to use bitsandbytes to run model in 8 bit mode (only for transformers models).",
|
42 |
+
)
|
43 |
+
verbose: bool = Field(
|
44 |
+
default=False,
|
45 |
+
description="Whether to print verbose output to stderr.",
|
46 |
+
)
|
47 |
+
host: str = Field(default="localhost", description="API address")
|
48 |
+
port: int = Field(default=8000, description="API port")
|
49 |
+
interrupt_requests: bool = Field(
|
50 |
+
default=True,
|
51 |
+
description="Whether to interrupt requests when a new request is received.",
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class ErrorResponse(TypedDict):
|
56 |
+
"""OpenAI style error response"""
|
57 |
+
|
58 |
+
message: str
|
59 |
+
type: str
|
60 |
+
param: Optional[str]
|
61 |
+
code: Optional[str]
|
62 |
+
|
63 |
+
|
64 |
+
class ErrorResponseFormatters:
|
65 |
+
"""Collection of formatters for error responses.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
69 |
+
Request body
|
70 |
+
match (Match[str]): Match object from regex pattern
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Tuple[int, ErrorResponse]: Status code and error response
|
74 |
+
"""
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def context_length_exceeded(
|
78 |
+
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
79 |
+
match, # type: Match[str] # type: ignore
|
80 |
+
) -> Tuple[int, ErrorResponse]:
|
81 |
+
"""Formatter for context length exceeded error"""
|
82 |
+
|
83 |
+
context_window = int(match.group(2))
|
84 |
+
prompt_tokens = int(match.group(1))
|
85 |
+
completion_tokens = request.max_new_tokens
|
86 |
+
if hasattr(request, "messages"):
|
87 |
+
# Chat completion
|
88 |
+
message = (
|
89 |
+
"This model's maximum context length is {} tokens. "
|
90 |
+
"However, you requested {} tokens "
|
91 |
+
"({} in the messages, {} in the completion). "
|
92 |
+
"Please reduce the length of the messages or completion."
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
# Text completion
|
96 |
+
message = (
|
97 |
+
"This model's maximum context length is {} tokens, "
|
98 |
+
"however you requested {} tokens "
|
99 |
+
"({} in your prompt; {} for the completion). "
|
100 |
+
"Please reduce your prompt; or completion length."
|
101 |
+
)
|
102 |
+
return 400, ErrorResponse(
|
103 |
+
message=message.format(
|
104 |
+
context_window,
|
105 |
+
completion_tokens + prompt_tokens,
|
106 |
+
prompt_tokens,
|
107 |
+
completion_tokens,
|
108 |
+
),
|
109 |
+
type="invalid_request_error",
|
110 |
+
param="messages",
|
111 |
+
code="context_length_exceeded",
|
112 |
+
)
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def model_not_found(
|
116 |
+
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
117 |
+
match, # type: Match[str] # type: ignore
|
118 |
+
) -> Tuple[int, ErrorResponse]:
|
119 |
+
"""Formatter for model_not_found error"""
|
120 |
+
|
121 |
+
model_path = str(match.group(1))
|
122 |
+
message = f"The model `{model_path}` does not exist"
|
123 |
+
return 400, ErrorResponse(
|
124 |
+
message=message,
|
125 |
+
type="invalid_request_error",
|
126 |
+
param=None,
|
127 |
+
code="model_not_found",
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class RouteErrorHandler(APIRoute):
|
132 |
+
"""Custom APIRoute that handles application errors and exceptions"""
|
133 |
+
|
134 |
+
# key: regex pattern for original error message from llama_cpp
|
135 |
+
# value: formatter function
|
136 |
+
pattern_and_formatters: Dict[
|
137 |
+
"Pattern",
|
138 |
+
Callable[
|
139 |
+
[
|
140 |
+
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
141 |
+
"Match[str]",
|
142 |
+
],
|
143 |
+
Tuple[int, ErrorResponse],
|
144 |
+
],
|
145 |
+
] = {
|
146 |
+
compile(
|
147 |
+
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
148 |
+
): ErrorResponseFormatters.context_length_exceeded,
|
149 |
+
compile(
|
150 |
+
r"Model path does not exist: (.+)"
|
151 |
+
): ErrorResponseFormatters.model_not_found,
|
152 |
+
}
|
153 |
+
|
154 |
+
def error_message_wrapper(
|
155 |
+
self,
|
156 |
+
error: Exception,
|
157 |
+
body: Optional[
|
158 |
+
Union[
|
159 |
+
"CreateChatCompletionRequest",
|
160 |
+
"CreateCompletionRequest",
|
161 |
+
]
|
162 |
+
] = None,
|
163 |
+
) -> Tuple[int, ErrorResponse]:
|
164 |
+
"""Wraps error message in OpenAI style error response"""
|
165 |
+
|
166 |
+
if body is not None and isinstance(
|
167 |
+
body,
|
168 |
+
(
|
169 |
+
CreateCompletionRequest,
|
170 |
+
CreateChatCompletionRequest,
|
171 |
+
),
|
172 |
+
):
|
173 |
+
# When text completion or chat completion
|
174 |
+
for pattern, callback in self.pattern_and_formatters.items():
|
175 |
+
match = pattern.search(str(error))
|
176 |
+
if match is not None:
|
177 |
+
return callback(body, match)
|
178 |
+
|
179 |
+
# Wrap other errors as internal server error
|
180 |
+
return 500, ErrorResponse(
|
181 |
+
message=str(error),
|
182 |
+
type="internal_server_error",
|
183 |
+
param=None,
|
184 |
+
code=None,
|
185 |
+
)
|
186 |
+
|
187 |
+
def get_route_handler(
|
188 |
+
self,
|
189 |
+
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
190 |
+
"""Defines custom route handler that catches exceptions and formats
|
191 |
+
in OpenAI style error response"""
|
192 |
+
|
193 |
+
original_route_handler = super().get_route_handler()
|
194 |
+
|
195 |
+
async def custom_route_handler(request: Request) -> Response:
|
196 |
+
try:
|
197 |
+
return await original_route_handler(request)
|
198 |
+
except Exception as exc:
|
199 |
+
json_body = await request.json()
|
200 |
+
try:
|
201 |
+
if "messages" in json_body:
|
202 |
+
# Chat completion
|
203 |
+
body: Optional[
|
204 |
+
Union[
|
205 |
+
CreateChatCompletionRequest,
|
206 |
+
CreateCompletionRequest,
|
207 |
+
]
|
208 |
+
] = CreateChatCompletionRequest(**json_body)
|
209 |
+
elif "prompt" in json_body:
|
210 |
+
# Text completion
|
211 |
+
body = CreateCompletionRequest(**json_body)
|
212 |
+
# else:
|
213 |
+
# # Embedding
|
214 |
+
# body = CreateEmbeddingRequest(**json_body)
|
215 |
+
except Exception:
|
216 |
+
# Invalid request body
|
217 |
+
body = None
|
218 |
+
|
219 |
+
# Get proper error message from the exception
|
220 |
+
(
|
221 |
+
status_code,
|
222 |
+
error_message,
|
223 |
+
) = self.error_message_wrapper(error=exc, body=body)
|
224 |
+
return JSONResponse(
|
225 |
+
{"error": error_message},
|
226 |
+
status_code=status_code,
|
227 |
+
)
|
228 |
+
|
229 |
+
return custom_route_handler
|
230 |
+
|
231 |
+
|
232 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
233 |
+
|
234 |
+
settings: Optional[Settings] = None
|
235 |
+
llama2: Optional[LLAMA2_WRAPPER] = None
|
236 |
+
|
237 |
+
|
238 |
+
def create_app(settings: Optional[Settings] = None):
|
239 |
+
if settings is None:
|
240 |
+
settings = Settings()
|
241 |
+
app = FastAPI(
|
242 |
+
title="llama2-wrapper Fast API",
|
243 |
+
version="0.0.1",
|
244 |
+
)
|
245 |
+
app.add_middleware(
|
246 |
+
CORSMiddleware,
|
247 |
+
allow_origins=["*"],
|
248 |
+
allow_credentials=True,
|
249 |
+
allow_methods=["*"],
|
250 |
+
allow_headers=["*"],
|
251 |
+
)
|
252 |
+
app.include_router(router)
|
253 |
+
global llama2
|
254 |
+
llama2 = LLAMA2_WRAPPER(
|
255 |
+
model_path=settings.model_path,
|
256 |
+
backend_type=settings.backend_type,
|
257 |
+
max_tokens=settings.max_tokens,
|
258 |
+
load_in_8bit=settings.load_in_8bit,
|
259 |
+
verbose=settings.load_in_8bit,
|
260 |
+
)
|
261 |
+
|
262 |
+
def set_settings(_settings: Settings):
|
263 |
+
global settings
|
264 |
+
settings = _settings
|
265 |
+
|
266 |
+
set_settings(settings)
|
267 |
+
return app
|
268 |
+
|
269 |
+
|
270 |
+
llama_outer_lock = Lock()
|
271 |
+
llama_inner_lock = Lock()
|
272 |
+
|
273 |
+
|
274 |
+
def get_llama():
|
275 |
+
# NOTE: This double lock allows the currently streaming llama model to
|
276 |
+
# check if any other requests are pending in the same thread and cancel
|
277 |
+
# the stream if so.
|
278 |
+
llama_outer_lock.acquire()
|
279 |
+
release_outer_lock = True
|
280 |
+
try:
|
281 |
+
llama_inner_lock.acquire()
|
282 |
+
try:
|
283 |
+
llama_outer_lock.release()
|
284 |
+
release_outer_lock = False
|
285 |
+
yield llama2
|
286 |
+
finally:
|
287 |
+
llama_inner_lock.release()
|
288 |
+
finally:
|
289 |
+
if release_outer_lock:
|
290 |
+
llama_outer_lock.release()
|
291 |
+
|
292 |
+
|
293 |
+
def get_settings():
|
294 |
+
yield settings
|
295 |
+
|
296 |
+
|
297 |
+
async def get_event_publisher(
|
298 |
+
request: Request,
|
299 |
+
inner_send_chan: MemoryObjectSendStream,
|
300 |
+
iterator: Iterator,
|
301 |
+
):
|
302 |
+
async with inner_send_chan:
|
303 |
+
try:
|
304 |
+
async for chunk in iterate_in_threadpool(iterator):
|
305 |
+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
306 |
+
if await request.is_disconnected():
|
307 |
+
raise anyio.get_cancelled_exc_class()()
|
308 |
+
if settings.interrupt_requests and llama_outer_lock.locked():
|
309 |
+
await inner_send_chan.send(dict(data="[DONE]"))
|
310 |
+
raise anyio.get_cancelled_exc_class()()
|
311 |
+
await inner_send_chan.send(dict(data="[DONE]"))
|
312 |
+
except anyio.get_cancelled_exc_class() as e:
|
313 |
+
print("disconnected")
|
314 |
+
with anyio.move_on_after(1, shield=True):
|
315 |
+
print(f"Disconnected from client (via refresh/close) {request.client}")
|
316 |
+
raise e
|
317 |
+
|
318 |
+
|
319 |
+
stream_field = Field(
|
320 |
+
default=False,
|
321 |
+
description="Whether to stream the results as they are generated. Useful for chatbots.",
|
322 |
+
)
|
323 |
+
max_new_tokens_field = Field(
|
324 |
+
default=1000, ge=1, description="The maximum number of tokens to generate."
|
325 |
+
)
|
326 |
+
|
327 |
+
temperature_field = Field(
|
328 |
+
default=0.9,
|
329 |
+
ge=0.0,
|
330 |
+
le=2.0,
|
331 |
+
description="The temperature to use for sampling.",
|
332 |
+
)
|
333 |
+
|
334 |
+
top_p_field = Field(
|
335 |
+
default=1.0,
|
336 |
+
ge=0.0,
|
337 |
+
le=1.0,
|
338 |
+
description="The top-p value to use for sampling.",
|
339 |
+
)
|
340 |
+
top_k_field = Field(
|
341 |
+
default=40,
|
342 |
+
ge=0,
|
343 |
+
description="The top-k value to use for sampling.",
|
344 |
+
)
|
345 |
+
repetition_penalty_field = Field(
|
346 |
+
default=1.0,
|
347 |
+
ge=0.0,
|
348 |
+
description="The penalty to apply to repeated tokens.",
|
349 |
+
)
|
350 |
+
# stop_field = Field(
|
351 |
+
# default=None,
|
352 |
+
# description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
353 |
+
# )
|
354 |
+
|
355 |
+
|
356 |
+
class CreateCompletionRequest(BaseModel):
|
357 |
+
prompt: Union[str, List[str]] = Field(
|
358 |
+
default="", description="The prompt to generate text from."
|
359 |
+
)
|
360 |
+
stream: bool = stream_field
|
361 |
+
max_new_tokens: int = max_new_tokens_field
|
362 |
+
temperature: float = temperature_field
|
363 |
+
top_p: float = top_p_field
|
364 |
+
top_k: int = top_k_field
|
365 |
+
repetition_penalty: float = repetition_penalty_field
|
366 |
+
# stop: Optional[Union[str, List[str]]] = stop_field
|
367 |
+
|
368 |
+
model_config = {
|
369 |
+
"json_schema_extra": {
|
370 |
+
"examples": [
|
371 |
+
{
|
372 |
+
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
|
373 |
+
# "stop": ["\n", "###"],
|
374 |
+
}
|
375 |
+
]
|
376 |
+
}
|
377 |
+
}
|
378 |
+
|
379 |
+
|
380 |
+
@router.post(
|
381 |
+
"/v1/completions",
|
382 |
+
)
|
383 |
+
async def create_completion(
|
384 |
+
request: Request,
|
385 |
+
body: CreateCompletionRequest,
|
386 |
+
llama2: LLAMA2_WRAPPER = Depends(get_llama),
|
387 |
+
) -> Completion:
|
388 |
+
if isinstance(body.prompt, list):
|
389 |
+
assert len(body.prompt) <= 1
|
390 |
+
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
|
391 |
+
|
392 |
+
kwargs = body.model_dump()
|
393 |
+
|
394 |
+
iterator_or_completion: Union[
|
395 |
+
Completion, Iterator[CompletionChunk]
|
396 |
+
] = await run_in_threadpool(llama2.completion, **kwargs)
|
397 |
+
|
398 |
+
if isinstance(iterator_or_completion, Iterator):
|
399 |
+
first_response = await run_in_threadpool(next, iterator_or_completion)
|
400 |
+
|
401 |
+
# If no exception was raised from first_response, we can assume that
|
402 |
+
# the iterator is valid and we can use it to stream the response.
|
403 |
+
def iterator() -> Iterator[CompletionChunk]:
|
404 |
+
yield first_response
|
405 |
+
yield from iterator_or_completion
|
406 |
+
|
407 |
+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
408 |
+
return EventSourceResponse(
|
409 |
+
recv_chan,
|
410 |
+
data_sender_callable=partial( # type: ignore
|
411 |
+
get_event_publisher,
|
412 |
+
request=request,
|
413 |
+
inner_send_chan=send_chan,
|
414 |
+
iterator=iterator(),
|
415 |
+
),
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
return iterator_or_completion
|
419 |
+
|
420 |
+
|
421 |
+
class ChatCompletionRequestMessage(BaseModel):
|
422 |
+
role: Literal["system", "user", "assistant"] = Field(
|
423 |
+
default="user", description="The role of the message."
|
424 |
+
)
|
425 |
+
content: str = Field(default="", description="The content of the message.")
|
426 |
+
|
427 |
+
|
428 |
+
class CreateChatCompletionRequest(BaseModel):
|
429 |
+
messages: List[ChatCompletionRequestMessage] = Field(
|
430 |
+
default=[], description="A list of messages to generate completions for."
|
431 |
+
)
|
432 |
+
stream: bool = stream_field
|
433 |
+
max_new_tokens: int = max_new_tokens_field
|
434 |
+
temperature: float = temperature_field
|
435 |
+
top_p: float = top_p_field
|
436 |
+
top_k: int = top_k_field
|
437 |
+
repetition_penalty: float = repetition_penalty_field
|
438 |
+
# stop: Optional[List[str]] = stop_field
|
439 |
+
|
440 |
+
model_config = {
|
441 |
+
"json_schema_extra": {
|
442 |
+
"examples": [
|
443 |
+
{
|
444 |
+
"messages": [
|
445 |
+
ChatCompletionRequestMessage(
|
446 |
+
role="system", content="You are a helpful assistant."
|
447 |
+
).model_dump(),
|
448 |
+
ChatCompletionRequestMessage(
|
449 |
+
role="user", content="What is the capital of France?"
|
450 |
+
).model_dump(),
|
451 |
+
]
|
452 |
+
}
|
453 |
+
]
|
454 |
+
}
|
455 |
+
}
|
456 |
+
|
457 |
+
|
458 |
+
@router.post(
|
459 |
+
"/v1/chat/completions",
|
460 |
+
)
|
461 |
+
async def create_chat_completion(
|
462 |
+
request: Request,
|
463 |
+
body: CreateChatCompletionRequest,
|
464 |
+
llama2: LLAMA2_WRAPPER = Depends(get_llama),
|
465 |
+
settings: Settings = Depends(get_settings),
|
466 |
+
) -> ChatCompletion:
|
467 |
+
kwargs = body.model_dump()
|
468 |
+
|
469 |
+
iterator_or_completion: Union[
|
470 |
+
ChatCompletion, Iterator[ChatCompletionChunk]
|
471 |
+
] = await run_in_threadpool(llama2.chat_completion, **kwargs)
|
472 |
+
|
473 |
+
if isinstance(iterator_or_completion, Iterator):
|
474 |
+
first_response = await run_in_threadpool(next, iterator_or_completion)
|
475 |
+
|
476 |
+
# If no exception was raised from first_response, we can assume that
|
477 |
+
# the iterator is valid and we can use it to stream the response.
|
478 |
+
def iterator() -> Iterator[ChatCompletionChunk]:
|
479 |
+
yield first_response
|
480 |
+
yield from iterator_or_completion
|
481 |
+
|
482 |
+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
483 |
+
return EventSourceResponse(
|
484 |
+
recv_chan,
|
485 |
+
data_sender_callable=partial( # type: ignore
|
486 |
+
get_event_publisher,
|
487 |
+
request=request,
|
488 |
+
inner_send_chan=send_chan,
|
489 |
+
iterator=iterator(),
|
490 |
+
),
|
491 |
+
)
|
492 |
+
else:
|
493 |
+
return iterator_or_completion
|
494 |
+
|
495 |
+
|
496 |
+
class ModelData(TypedDict):
|
497 |
+
id: str
|
498 |
+
object: Literal["model"]
|
499 |
+
owned_by: str
|
500 |
+
permissions: List[str]
|
501 |
+
|
502 |
+
|
503 |
+
class ModelList(TypedDict):
|
504 |
+
object: Literal["list"]
|
505 |
+
data: List[ModelData]
|
506 |
+
|
507 |
+
|
508 |
+
@router.get("/v1/models")
|
509 |
+
async def get_models(
|
510 |
+
settings: Settings = Depends(get_settings),
|
511 |
+
) -> ModelList:
|
512 |
+
assert llama2 is not None
|
513 |
+
|
514 |
+
return {
|
515 |
+
"object": "list",
|
516 |
+
"data": [
|
517 |
+
{
|
518 |
+
"id": settings.backend_type + " default model"
|
519 |
+
if settings.model_path == ""
|
520 |
+
else settings.model_path,
|
521 |
+
"object": "model",
|
522 |
+
"owned_by": "me",
|
523 |
+
"permissions": [],
|
524 |
+
}
|
525 |
+
],
|
526 |
+
}
|
llama2_wrapper/types.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Optional, Dict, Union
|
2 |
+
from typing_extensions import TypedDict, NotRequired, Literal
|
3 |
+
|
4 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
5 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
6 |
+
|
7 |
+
|
8 |
+
# Role = Literal["system", "user", "assistant"]
|
9 |
+
# class Message(TypedDict):
|
10 |
+
# role: Role
|
11 |
+
# content: str
|
12 |
+
|
13 |
+
|
14 |
+
class ChatCompletionMessage(TypedDict):
|
15 |
+
role: Literal["assistant", "user", "system"]
|
16 |
+
content: str
|
17 |
+
user: NotRequired[str]
|
18 |
+
|
19 |
+
|
20 |
+
# transformers: Message; llama.cpp: ChatCompletionMessage
|
21 |
+
Message = ChatCompletionMessage
|
22 |
+
Dialog = List[Message]
|
23 |
+
|
24 |
+
|
25 |
+
class EmbeddingUsage(TypedDict):
|
26 |
+
prompt_tokens: int
|
27 |
+
total_tokens: int
|
28 |
+
|
29 |
+
|
30 |
+
class EmbeddingData(TypedDict):
|
31 |
+
index: int
|
32 |
+
object: str
|
33 |
+
embedding: List[float]
|
34 |
+
|
35 |
+
|
36 |
+
class Embedding(TypedDict):
|
37 |
+
object: Literal["list"]
|
38 |
+
model: str
|
39 |
+
data: List[EmbeddingData]
|
40 |
+
usage: EmbeddingUsage
|
41 |
+
|
42 |
+
|
43 |
+
class CompletionLogprobs(TypedDict):
|
44 |
+
text_offset: List[int]
|
45 |
+
token_logprobs: List[Optional[float]]
|
46 |
+
tokens: List[str]
|
47 |
+
top_logprobs: List[Optional[Dict[str, float]]]
|
48 |
+
|
49 |
+
|
50 |
+
class CompletionChoice(TypedDict):
|
51 |
+
text: str
|
52 |
+
index: int
|
53 |
+
logprobs: Optional[CompletionLogprobs]
|
54 |
+
finish_reason: Optional[str]
|
55 |
+
|
56 |
+
|
57 |
+
class CompletionUsage(TypedDict):
|
58 |
+
prompt_tokens: int
|
59 |
+
completion_tokens: int
|
60 |
+
total_tokens: int
|
61 |
+
|
62 |
+
|
63 |
+
class CompletionChunk(TypedDict):
|
64 |
+
id: str
|
65 |
+
object: Literal["text_completion"]
|
66 |
+
created: int
|
67 |
+
model: str
|
68 |
+
choices: List[CompletionChoice]
|
69 |
+
|
70 |
+
|
71 |
+
class Completion(TypedDict):
|
72 |
+
id: str
|
73 |
+
object: Literal["text_completion"]
|
74 |
+
created: int
|
75 |
+
model: str
|
76 |
+
choices: List[CompletionChoice]
|
77 |
+
usage: CompletionUsage
|
78 |
+
|
79 |
+
|
80 |
+
class ChatCompletionChoice(TypedDict):
|
81 |
+
index: int
|
82 |
+
message: ChatCompletionMessage
|
83 |
+
finish_reason: Optional[str]
|
84 |
+
|
85 |
+
|
86 |
+
class ChatCompletion(TypedDict):
|
87 |
+
id: str
|
88 |
+
object: Literal["chat.completion"]
|
89 |
+
created: int
|
90 |
+
model: str
|
91 |
+
choices: List[ChatCompletionChoice]
|
92 |
+
usage: CompletionUsage
|
93 |
+
|
94 |
+
|
95 |
+
class ChatCompletionChunkDeltaEmpty(TypedDict):
|
96 |
+
pass
|
97 |
+
|
98 |
+
|
99 |
+
class ChatCompletionChunkDelta(TypedDict):
|
100 |
+
role: NotRequired[Literal["assistant"]]
|
101 |
+
content: NotRequired[str]
|
102 |
+
|
103 |
+
|
104 |
+
class ChatCompletionChunkChoice(TypedDict):
|
105 |
+
index: int
|
106 |
+
delta: Union[ChatCompletionChunkDelta, ChatCompletionChunkDeltaEmpty]
|
107 |
+
finish_reason: Optional[str]
|
108 |
+
|
109 |
+
|
110 |
+
class ChatCompletionChunk(TypedDict):
|
111 |
+
id: str
|
112 |
+
model: str
|
113 |
+
object: Literal["chat.completion.chunk"]
|
114 |
+
created: int
|
115 |
+
choices: List[ChatCompletionChunkChoice]
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompts/prompts_en.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompts/utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
from hashlib import md5
|
4 |
+
|
5 |
+
|
6 |
+
def read_csv_to_dict_list(file_path):
|
7 |
+
with open(file_path, mode="r", encoding="utf-8") as file:
|
8 |
+
reader = csv.DictReader(file)
|
9 |
+
list_of_dicts = [row for row in reader]
|
10 |
+
return list_of_dicts
|
11 |
+
|
12 |
+
|
13 |
+
def split_list_with_key(lst, dict_key):
|
14 |
+
result = {}
|
15 |
+
for row in lst:
|
16 |
+
if row.get(dict_key) not in result:
|
17 |
+
result[row.get(dict_key)] = []
|
18 |
+
result[row.get(dict_key)].append(row)
|
19 |
+
return result
|
20 |
+
|
21 |
+
|
22 |
+
def read_csv_to_type_dict(file_path, type_key):
|
23 |
+
lst = read_csv_to_dict_list(file_path=file_path)
|
24 |
+
return split_list_with_key(lst=lst, dict_key=type_key)
|
25 |
+
|
26 |
+
|
27 |
+
def md5_str(str):
|
28 |
+
return md5(str.encode("utf8")).hexdigest()
|
29 |
+
|
30 |
+
|
31 |
+
current_dir = os.path.dirname(__file__)
|
32 |
+
|
33 |
+
|
34 |
+
class PromtsContainer(object):
|
35 |
+
def __init__(self) -> None:
|
36 |
+
prompts_path = os.path.join(current_dir, "prompts_en.csv")
|
37 |
+
self.data = read_csv_to_type_dict(prompts_path, "type")
|
38 |
+
self.summary_dict = {
|
39 |
+
md5_str(row.get("summary")): row.get("prompt")
|
40 |
+
for chunk in self.data.values()
|
41 |
+
for row in chunk
|
42 |
+
}
|
43 |
+
|
44 |
+
def get_prompts_tab_dict(self):
|
45 |
+
return self.data
|
46 |
+
|
47 |
+
def get_prompt_by_summary(self, summary):
|
48 |
+
return self.summary_dict.get(md5_str(summary), summary)
|
pyproject.toml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "llama2-wrapper"
|
3 |
+
version = "0.1.14"
|
4 |
+
description = "Use llama2-wrapper as your local llama2 backend for Generative Agents / Apps"
|
5 |
+
authors = ["liltom-eth <liltom.eth@gmail.com>"]
|
6 |
+
license = "MIT"
|
7 |
+
homepage = "https://github.com/liltom-eth/llama2-webui"
|
8 |
+
repository = "https://github.com/liltom-eth/llama2-webui"
|
9 |
+
readme = "./docs/pypi.md"
|
10 |
+
|
11 |
+
packages = [{include = "llama2_wrapper"}]
|
12 |
+
|
13 |
+
[tool.poetry.dependencies]
|
14 |
+
python = ">=3.10,<3.13"
|
15 |
+
accelerate = "^0.21.0"
|
16 |
+
auto-gptq = "0.3.0"
|
17 |
+
gradio = "3.37.0"
|
18 |
+
protobuf = "3.20.3"
|
19 |
+
scipy = "1.11.1"
|
20 |
+
sentencepiece = "0.1.99"
|
21 |
+
torch = "2.0.1"
|
22 |
+
transformers = "4.31.0"
|
23 |
+
tqdm = "4.65.0"
|
24 |
+
python-dotenv = "1.0.0"
|
25 |
+
llama-cpp-python = "0.2.11"
|
26 |
+
bitsandbytes = [
|
27 |
+
{platform = 'linux', version = "0.40.2"},
|
28 |
+
{platform = 'darwin', version = "0.40.2"},
|
29 |
+
]
|
30 |
+
memory-profiler = "0.61.0"
|
31 |
+
huggingface-hub = "0.16.4"
|
32 |
+
fastapi = "0.100.0"
|
33 |
+
uvicorn = "0.23.1"
|
34 |
+
sse-starlette = "1.6.5"
|
35 |
+
pydantic = "2.2.1"
|
36 |
+
pydantic-settings = "2.0.3"
|
37 |
+
pytest = "7.4.0"
|
38 |
+
black = "23.7.0"
|
39 |
+
|
40 |
+
|
41 |
+
[build-system]
|
42 |
+
requires = ["poetry-core"]
|
43 |
+
build-backend = "poetry.core.masonry.api"
|
44 |
+
|
45 |
+
[virtualenvs]
|
46 |
+
create = true
|
47 |
+
in-project = true
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.21.0
|
2 |
+
auto-gptq==0.3.0
|
3 |
+
bitsandbytes==0.40.2
|
4 |
+
gradio==3.37.0
|
5 |
+
protobuf==3.20.3
|
6 |
+
scipy==1.11.1
|
7 |
+
sentencepiece==0.1.99
|
8 |
+
torch==2.0.1
|
9 |
+
transformers==4.31.0
|
10 |
+
tqdm==4.65.0
|
11 |
+
python-dotenv==1.0.0
|
12 |
+
llama-cpp-python==0.2.11
|
13 |
+
memory-profiler==0.61.0
|
14 |
+
huggingface-hub==0.16.4
|
15 |
+
fastapi==0.100.0
|
16 |
+
uvicorn==0.23.1
|
17 |
+
sse-starlette==1.6.5
|
18 |
+
pydantic==2.2.1
|
19 |
+
pydantic-settings==2.0.3
|
20 |
+
pytest==7.4.0
|
21 |
+
black==23.7.0
|
static/screenshot.png
ADDED
tests/__init__.py
ADDED
File without changes
|
tests/test_get_prompt.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from llama2_wrapper.model import get_prompt_for_dialog
|
3 |
+
|
4 |
+
|
5 |
+
class TestClassGetPromptForDialog:
|
6 |
+
from llama2_wrapper.types import Message
|
7 |
+
|
8 |
+
dialog = []
|
9 |
+
message1 = Message(
|
10 |
+
role="system",
|
11 |
+
content="You are a helpful, respectful and honest assistant. ",
|
12 |
+
)
|
13 |
+
message2 = Message(
|
14 |
+
role="user",
|
15 |
+
content="Hi do you know Pytorch?",
|
16 |
+
)
|
17 |
+
dialog.append(message1)
|
18 |
+
dialog.append(message2)
|
19 |
+
|
20 |
+
dialog2 = []
|
21 |
+
dialog2.append(message1)
|
22 |
+
dialog2.append(message2)
|
23 |
+
message3 = Message(
|
24 |
+
role="assistant",
|
25 |
+
content="Yes I know Pytorch. ",
|
26 |
+
)
|
27 |
+
message4 = Message(
|
28 |
+
role="user",
|
29 |
+
content="Can you write a CNN in Pytorch?",
|
30 |
+
)
|
31 |
+
dialog2.append(message3)
|
32 |
+
dialog2.append(message4)
|
33 |
+
|
34 |
+
dialog3 = []
|
35 |
+
dialog3.append(message3)
|
36 |
+
dialog3.append(message4)
|
37 |
+
dialog3.append(message3)
|
38 |
+
dialog3.append(message4)
|
39 |
+
message5 = Message(
|
40 |
+
role="assistant",
|
41 |
+
content="Yes I can write a CNN in Pytorch.",
|
42 |
+
)
|
43 |
+
dialog3.append(message5)
|
44 |
+
|
45 |
+
def test_dialog1(self):
|
46 |
+
prompt = get_prompt_for_dialog(self.dialog)
|
47 |
+
# print(prompt)
|
48 |
+
result = """[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. \n<</SYS>>\n\nHi do you know Pytorch? [/INST]"""
|
49 |
+
assert prompt == result
|
50 |
+
|
51 |
+
def test_dialog2(self):
|
52 |
+
prompt = get_prompt_for_dialog(self.dialog2)
|
53 |
+
# print(prompt)
|
54 |
+
result = """[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. \n<</SYS>>\n\nHi do you know Pytorch? [/INST] Yes I know Pytorch. [INST] Can you write a CNN in Pytorch? [/INST]"""
|
55 |
+
assert prompt == result
|
56 |
+
|
57 |
+
def test_dialog3(self):
|
58 |
+
with pytest.raises(AssertionError):
|
59 |
+
prompt = get_prompt_for_dialog(self.dialog3)
|