Spaces:
Running
Running
TGI/vLLM benchmarking (#34)
Browse files- scripts/inference-server/.gitignore +2 -0
- scripts/inference-server/Dockerfile +16 -0
- scripts/inference-server/README.md +5 -0
- scripts/inference-server/benchmark.py +323 -0
- scripts/inference-server/local-tokenizers/README.md +79 -0
- scripts/inference-server/local-tokenizers/meta-llama/Llama-2-70b-chat-hf/tokenizer_config.json +36 -0
- scripts/inference-server/local-tokenizers/meta-llama/Llama-2-7b-chat-hf/tokenizer_config.json +22 -0
- scripts/inference-server/local-tokenizers/mistralai/Mistral-7B-Instruct-v0.2/chat_template.jinja +1 -0
- scripts/inference-server/local-tokenizers/mistralai/Mistral-7B-Instruct-v0.2/tokenizer_config.json +5 -0
- scripts/inference-server/requirements.txt +7 -0
- sharegpt/README.md +31 -6
- sharegpt/ShareGPT_V3_filtered_500.json +0 -0
- sharegpt/compare_distributions.py +62 -0
- sharegpt/filter_dataset.py +107 -0
scripts/inference-server/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
test*
|
2 |
+
temp*
|
scripts/inference-server/Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# docker build -t benchmark:latest .
|
2 |
+
|
3 |
+
# Use an official Python runtime as a parent image
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
# Set the working directory in the container
|
7 |
+
WORKDIR /benchmark
|
8 |
+
|
9 |
+
# Copy the current directory contents into the container at /benchmark
|
10 |
+
COPY . .
|
11 |
+
|
12 |
+
# Install any needed packages specified in requirements.txt
|
13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
14 |
+
|
15 |
+
# Run script.py when the container launches
|
16 |
+
ENTRYPOINT ["python", "benchmark.py"]
|
scripts/inference-server/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# About
|
2 |
+
|
3 |
+
This directory contains a script for running benchmarks (including energy comsumption) on models that are hosted on a dedicated inference server. The script is taken and modified from [vllm](https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/benchmarks/benchmark_serving.py)
|
4 |
+
|
5 |
+
The current script supports TGI and vLLM. Before running the benchmark script, the inference server hosting the relevant model should be hosted.
|
scripts/inference-server/benchmark.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Taken and modified from vllm: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/benchmarks/benchmark_serving.py
|
2 |
+
"""
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
import torch
|
10 |
+
from typing import AsyncGenerator, List, Tuple
|
11 |
+
|
12 |
+
import aiohttp
|
13 |
+
import numpy as np
|
14 |
+
from dataclasses import asdict, dataclass, field
|
15 |
+
from tqdm.asyncio import tqdm
|
16 |
+
from zeus.monitor import ZeusMonitor
|
17 |
+
|
18 |
+
|
19 |
+
SYSTEM_PROMPT = "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Results:
|
24 |
+
model: str
|
25 |
+
backend: str
|
26 |
+
request_rate: float
|
27 |
+
num_failures: int = 0
|
28 |
+
system_prompt: str = SYSTEM_PROMPT
|
29 |
+
total_time: float = 0.0
|
30 |
+
throughput: float = 0.0
|
31 |
+
total_prompt_tokens: int = 0.0
|
32 |
+
total_completion_tokens: int = 0.0
|
33 |
+
avg_latency: float = 0.0
|
34 |
+
avg_latency_per_token: float = 0.0
|
35 |
+
avg_latency_per_output_token: float = 0.0
|
36 |
+
server_total_energy: float = 0.0
|
37 |
+
server_energy_per_request: float = 0.0
|
38 |
+
server_energy_per_output_token: float = 0.0
|
39 |
+
local_zeus_total_energy: float = 0.0
|
40 |
+
local_zeus_energy_per_request: float = 0.0
|
41 |
+
local_zeus_energy_per_output_token: float = 0.0
|
42 |
+
results: list["Result"] = field(default_factory=list)
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class Result:
|
47 |
+
success: bool = True
|
48 |
+
latency: float = 0.0
|
49 |
+
prompt: str = ""
|
50 |
+
response: str = ""
|
51 |
+
num_prompt_tokens: int = 0
|
52 |
+
num_completion_tokens: int = 0
|
53 |
+
energy: float = 0.0
|
54 |
+
|
55 |
+
|
56 |
+
def get_requests(
|
57 |
+
dataset_path: str,
|
58 |
+
) -> List[str]:
|
59 |
+
# Load the dataset.
|
60 |
+
with open(dataset_path) as f:
|
61 |
+
dataset = json.load(f)
|
62 |
+
# Only keep the first turn of each conversation.
|
63 |
+
dataset = [data["conversations"][0]["value"] for data in dataset]
|
64 |
+
|
65 |
+
return dataset
|
66 |
+
|
67 |
+
|
68 |
+
async def get_request(
|
69 |
+
input_requests: List[str],
|
70 |
+
request_rate: float,
|
71 |
+
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
72 |
+
input_requests = iter(input_requests)
|
73 |
+
for i, request in enumerate(input_requests):
|
74 |
+
yield i, request
|
75 |
+
|
76 |
+
if request_rate == float("inf"):
|
77 |
+
# If the request rate is infinity, then we don't need to wait.
|
78 |
+
continue
|
79 |
+
# Sample the request interval from the exponential distribution.
|
80 |
+
interval = np.random.exponential(1.0 / request_rate)
|
81 |
+
# The next request will be sent after the interval.
|
82 |
+
await asyncio.sleep(interval)
|
83 |
+
|
84 |
+
|
85 |
+
async def send_request(
|
86 |
+
result: Result,
|
87 |
+
backend: str,
|
88 |
+
model: str,
|
89 |
+
api_url: str,
|
90 |
+
prompt: str,
|
91 |
+
pbar: tqdm,
|
92 |
+
) -> None:
|
93 |
+
request_start_time = time.perf_counter()
|
94 |
+
|
95 |
+
headers = {"Content-Type": "application/json"}
|
96 |
+
# OpenAI Chat Completions API request format
|
97 |
+
pload = {
|
98 |
+
"model": model,
|
99 |
+
"messages": [
|
100 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
101 |
+
{"role": "user", "content": prompt},
|
102 |
+
],
|
103 |
+
"stream": False,
|
104 |
+
"max_tokens": 1000,
|
105 |
+
}
|
106 |
+
|
107 |
+
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
108 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
109 |
+
async with session.post(api_url, headers=headers, json=pload) as response:
|
110 |
+
# Request failed
|
111 |
+
if response.status // 100 != 2:
|
112 |
+
print('request failed')
|
113 |
+
print(f"response.status {response.status}")
|
114 |
+
result.prompt = prompt
|
115 |
+
result.success = False
|
116 |
+
return
|
117 |
+
chunks = []
|
118 |
+
async for chunk, _ in response.content.iter_chunks():
|
119 |
+
chunks.append(chunk)
|
120 |
+
request_end_time = time.perf_counter()
|
121 |
+
output = b"".join(chunks).decode("utf-8")
|
122 |
+
output = json.loads(output)
|
123 |
+
|
124 |
+
result.latency = request_end_time - request_start_time
|
125 |
+
result.prompt = prompt
|
126 |
+
result.response = output["choices"][0]["message"]["content"]
|
127 |
+
result.num_prompt_tokens = output["usage"]["prompt_tokens"]
|
128 |
+
result.num_completion_tokens = output["usage"]["completion_tokens"]
|
129 |
+
result.energy = output["usage"]["energy"]
|
130 |
+
|
131 |
+
pbar.update(1)
|
132 |
+
|
133 |
+
|
134 |
+
async def benchmark(
|
135 |
+
results: Results,
|
136 |
+
backend: str,
|
137 |
+
model: str,
|
138 |
+
api_url: str,
|
139 |
+
input_requests: List[str],
|
140 |
+
request_rate: float,
|
141 |
+
) -> None:
|
142 |
+
tasks: List[asyncio.Task] = []
|
143 |
+
pbar = tqdm(total=len(input_requests))
|
144 |
+
async for i, request in get_request(input_requests, request_rate):
|
145 |
+
prompt = request
|
146 |
+
task = asyncio.create_task(
|
147 |
+
# Ensures results has same ordering as the input dataset
|
148 |
+
send_request(
|
149 |
+
results.results[i],
|
150 |
+
backend,
|
151 |
+
model,
|
152 |
+
api_url,
|
153 |
+
prompt,
|
154 |
+
pbar,
|
155 |
+
)
|
156 |
+
)
|
157 |
+
tasks.append(task)
|
158 |
+
await asyncio.gather(*tasks)
|
159 |
+
pbar.close()
|
160 |
+
|
161 |
+
|
162 |
+
def run_benchmark(
|
163 |
+
args: argparse.Namespace, api_url: str, input_requests: List[str], out_filename: str
|
164 |
+
):
|
165 |
+
results = Results(
|
166 |
+
model=args.model,
|
167 |
+
backend=args.backend,
|
168 |
+
request_rate=args.request_rate,
|
169 |
+
results=[Result() for _ in input_requests],
|
170 |
+
)
|
171 |
+
|
172 |
+
zeus_monitor = ZeusMonitor()
|
173 |
+
zeus_monitor.begin_window(out_filename)
|
174 |
+
benchmark_start_time = time.perf_counter()
|
175 |
+
asyncio.run(
|
176 |
+
benchmark(
|
177 |
+
results,
|
178 |
+
args.backend,
|
179 |
+
args.model,
|
180 |
+
api_url,
|
181 |
+
input_requests,
|
182 |
+
args.request_rate,
|
183 |
+
)
|
184 |
+
)
|
185 |
+
benchmark_end_time = time.perf_counter()
|
186 |
+
measurements = zeus_monitor.end_window(out_filename)
|
187 |
+
zeus_total_energy = measurements.total_energy
|
188 |
+
|
189 |
+
# Store aggregated results
|
190 |
+
total_prompt_tokens = 0
|
191 |
+
total_completion_tokens = 0
|
192 |
+
total_latency = 0
|
193 |
+
total_latency_per_token = 0
|
194 |
+
total_latency_per_output_token = 0
|
195 |
+
server_total_energy = 0
|
196 |
+
for result in results.results:
|
197 |
+
if not result.success:
|
198 |
+
results.num_failures += 1
|
199 |
+
continue
|
200 |
+
total_prompt_tokens += result.num_prompt_tokens
|
201 |
+
total_completion_tokens += result.num_completion_tokens
|
202 |
+
total_latency += result.latency
|
203 |
+
total_latency_per_token += result.latency / (
|
204 |
+
result.num_prompt_tokens + result.num_completion_tokens
|
205 |
+
)
|
206 |
+
total_latency_per_output_token += result.latency / result.num_completion_tokens
|
207 |
+
server_total_energy += result.energy
|
208 |
+
|
209 |
+
num_results = len(results.results) - results.num_failures
|
210 |
+
if num_results == 0:
|
211 |
+
print(f"{out_filename} not generated. All requests in this run failed.")
|
212 |
+
return
|
213 |
+
|
214 |
+
results.total_time = benchmark_end_time - benchmark_start_time
|
215 |
+
results.throughput = num_results / results.total_time
|
216 |
+
results.total_prompt_tokens = total_prompt_tokens
|
217 |
+
results.total_completion_tokens = total_completion_tokens
|
218 |
+
results.avg_latency = total_latency / num_results
|
219 |
+
results.avg_latency_per_token = total_latency_per_token / num_results
|
220 |
+
results.avg_latency_per_output_token = total_latency_per_output_token / num_results
|
221 |
+
results.server_total_energy = server_total_energy
|
222 |
+
results.server_energy_per_request = results.server_total_energy / num_results
|
223 |
+
results.server_energy_per_output_token = (
|
224 |
+
results.server_total_energy / results.total_completion_tokens
|
225 |
+
)
|
226 |
+
results.local_zeus_total_energy = zeus_total_energy
|
227 |
+
results.local_zeus_energy_per_request = zeus_total_energy / num_results
|
228 |
+
results.local_zeus_energy_per_output_token = (
|
229 |
+
zeus_total_energy / results.total_completion_tokens
|
230 |
+
)
|
231 |
+
|
232 |
+
with open(out_filename, "w") as f:
|
233 |
+
f.write(json.dumps(asdict(results), indent=2))
|
234 |
+
|
235 |
+
if args.verbose:
|
236 |
+
print("Benchmark results:")
|
237 |
+
print(f"Model: {results.model}")
|
238 |
+
print(f"Backend: {results.backend}")
|
239 |
+
print(f"Request rate: {results.request_rate} requests/s")
|
240 |
+
print()
|
241 |
+
print(f"Total time: {results.total_time:.2f} s")
|
242 |
+
print(f"Throughput: {results.throughput:.2f} requests/s")
|
243 |
+
print(f"Average latency: {results.avg_latency:.2f} s")
|
244 |
+
print(f"Average latency per token: {results.avg_latency_per_token:.2f} s")
|
245 |
+
print(f"Average latency per output token: {results.avg_latency_per_output_token:.2f} s")
|
246 |
+
print(f"(Zeus) Total energy: {results.local_zeus_total_energy:.2f} J")
|
247 |
+
print(f"(Zeus) Energy per request: {results.local_zeus_energy_per_request:.2f} J")
|
248 |
+
print(f"(Zeus) Energy per token: {results.local_zeus_energy_per_output_token:.2f} J")
|
249 |
+
print(f"(Server) Total energy: {results.server_total_energy:.2f} J")
|
250 |
+
print(f"(Server) Energy per request: {results.server_energy_per_request:.2f} J")
|
251 |
+
print(f"(Server) Energy per token: {results.server_energy_per_output_token:.2f} J")
|
252 |
+
|
253 |
+
print("Benchmark results written to", out_filename)
|
254 |
+
|
255 |
+
|
256 |
+
def main(args: argparse.Namespace):
|
257 |
+
if args.backend not in ["tgi", "vllm"]:
|
258 |
+
raise ValueError(f"Unknown backend: {args.backend}")
|
259 |
+
|
260 |
+
arg_out_filename = f"{args.out_name}-args.json"
|
261 |
+
with open(arg_out_filename, "w") as f:
|
262 |
+
f.write(json.dumps(vars(args), indent=2))
|
263 |
+
if args.verbose:
|
264 |
+
print(args)
|
265 |
+
print("Benchmark args written to", arg_out_filename)
|
266 |
+
|
267 |
+
random.seed(args.seed)
|
268 |
+
np.random.seed(args.seed)
|
269 |
+
|
270 |
+
out_name = args.out_name
|
271 |
+
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
|
272 |
+
input_requests = get_requests(args.dataset)
|
273 |
+
|
274 |
+
# Note: output filenames are 1-indexed
|
275 |
+
for i in range(1, args.num_runs + 1):
|
276 |
+
run_benchmark(args, api_url, input_requests, out_name + f"-run{i}.json")
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == "__main__":
|
280 |
+
parser = argparse.ArgumentParser(
|
281 |
+
description="Benchmark the online serving throughput."
|
282 |
+
)
|
283 |
+
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "tgi"])
|
284 |
+
parser.add_argument(
|
285 |
+
"--protocol", type=str, default="http", choices=["http", "https"]
|
286 |
+
)
|
287 |
+
parser.add_argument("--host", type=str, default="localhost")
|
288 |
+
parser.add_argument("--port", type=int, default=8000)
|
289 |
+
parser.add_argument("--endpoint", type=str, default="/v1/chat/completions")
|
290 |
+
parser.add_argument("--model", type=str, default=None)
|
291 |
+
parser.add_argument(
|
292 |
+
"--dataset", type=str, required=True, help="Path to the dataset."
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"--num-runs",
|
296 |
+
type=int,
|
297 |
+
default=3,
|
298 |
+
help="Runs the benchmark num-runs times, writing results to 3 separate files.",
|
299 |
+
)
|
300 |
+
parser.add_argument(
|
301 |
+
"--request-rate",
|
302 |
+
type=float,
|
303 |
+
default=float("inf"),
|
304 |
+
help="Number of requests per second. If this is inf, "
|
305 |
+
"then all the requests are sent at time 0. "
|
306 |
+
"Otherwise, we use Poisson process to synthesize "
|
307 |
+
"the request arrival times.",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--out-name",
|
311 |
+
type=str,
|
312 |
+
default="benchmark_result",
|
313 |
+
help="Name of file to write benchmark results. Note: '-run{i}.json' will be appended for actual outputted files.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--verbose",
|
317 |
+
type=bool,
|
318 |
+
default=True,
|
319 |
+
help="Set to true to print out benchmark results. Otherwise, only write to file.",
|
320 |
+
)
|
321 |
+
parser.add_argument("--seed", type=int, default=0)
|
322 |
+
args = parser.parse_args()
|
323 |
+
main(args)
|
scripts/inference-server/local-tokenizers/README.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TGI
|
2 |
+
The local tokenizer config can be supplied to TGI through the flag `--tokenizer-config-path`, documented [here](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#tokenizerconfigpath).
|
3 |
+
|
4 |
+
# vLLM
|
5 |
+
A local chat template can be supplied to vLLM through the flag `--chat-template`. It is not explicitly documented, but can be found mentioned in GitHub Issues relating to the topic.
|
6 |
+
|
7 |
+
# Llama-2 models on TGI
|
8 |
+
There is a [known bug with TGI](https://github.com/huggingface/text-generation-inference/issues/1534) in which the default `tokenizer_config.json` is not handled properly by TGI by applying chat templating. While this is resolved, we are using a modified `tokenizer_config.json` that is compatible with TGI. Note that the chat templating jinja itself the same, with the exception of removing 2 calls to `.strip()`, which TGI reports errors on.
|
9 |
+
|
10 |
+
For reference, here is the original unmodified chat template:
|
11 |
+
```
|
12 |
+
{% if messages[0]['role'] == 'system' %}
|
13 |
+
{% set loop_messages = messages[1:] %}
|
14 |
+
{% set system_message = messages[0]['content'] %}
|
15 |
+
{% else %}
|
16 |
+
{% set loop_messages = messages %}
|
17 |
+
{% set system_message = false %}
|
18 |
+
{% endif %}
|
19 |
+
{% for message in loop_messages %}
|
20 |
+
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
21 |
+
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
22 |
+
{% endif %}
|
23 |
+
{% if loop.index0 == 0 and system_message != false %}
|
24 |
+
{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}
|
25 |
+
{% else %}
|
26 |
+
{% set content = message['content'] %}
|
27 |
+
{% endif %}
|
28 |
+
{% if message['role'] == 'user' %}
|
29 |
+
{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}
|
30 |
+
{% elif message['role'] == 'assistant' %}
|
31 |
+
{{ ' ' + content.strip() + ' ' + eos_token }}
|
32 |
+
{% endif %}
|
33 |
+
{% endfor %}
|
34 |
+
```
|
35 |
+
|
36 |
+
We also note that the `eos_token` and `bos_token` are originally provided as maps, but the TGI implementation only accepts a string. So we also modify them to only contain the `content` string.
|
37 |
+
|
38 |
+
For reference, here is the original unmodified `tokenizer_config.json`:
|
39 |
+
```
|
40 |
+
{
|
41 |
+
"add_bos_token": true,
|
42 |
+
"add_eos_token": false,
|
43 |
+
"bos_token": {
|
44 |
+
"__type": "AddedToken",
|
45 |
+
"content": "<s>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
},
|
51 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}",
|
52 |
+
"clean_up_tokenization_spaces": false,
|
53 |
+
"eos_token": {
|
54 |
+
"__type": "AddedToken",
|
55 |
+
"content": "</s>",
|
56 |
+
"lstrip": false,
|
57 |
+
"normalized": false,
|
58 |
+
"rstrip": false,
|
59 |
+
"single_word": false
|
60 |
+
},
|
61 |
+
"legacy": false,
|
62 |
+
"model_max_length": 1000000000000000019884624838656,
|
63 |
+
"pad_token": null,
|
64 |
+
"padding_side": "right",
|
65 |
+
"sp_model_kwargs": {},
|
66 |
+
"tokenizer_class": "LlamaTokenizer",
|
67 |
+
"unk_token": {
|
68 |
+
"__type": "AddedToken",
|
69 |
+
"content": "<unk>",
|
70 |
+
"lstrip": false,
|
71 |
+
"normalized": false,
|
72 |
+
"rstrip": false,
|
73 |
+
"single_word": false
|
74 |
+
}
|
75 |
+
}
|
76 |
+
```
|
77 |
+
|
78 |
+
# Mistral with chat templating
|
79 |
+
Mistral for chatting has not been explicitly trained using a distinct system prompt. Therefore, the default Mistral `tokenizer_config.json` explicitly assumes that the system role does not exist. To keep our benchmarks consistent across models, we reenginered the original Mistral chat template to account for a system prompt. We simply preppend the system prompt to the first user prompt in a given conversation.
|
scripts/inference-server/local-tokenizers/meta-llama/Llama-2-70b-chat-hf/tokenizer_config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"bos_token": {
|
5 |
+
"__type": "AddedToken",
|
6 |
+
"content": "<s>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}",
|
13 |
+
"clean_up_tokenization_spaces": false,
|
14 |
+
"eos_token": {
|
15 |
+
"__type": "AddedToken",
|
16 |
+
"content": "</s>",
|
17 |
+
"lstrip": false,
|
18 |
+
"normalized": false,
|
19 |
+
"rstrip": false,
|
20 |
+
"single_word": false
|
21 |
+
},
|
22 |
+
"legacy": false,
|
23 |
+
"model_max_length": 1000000000000000019884624838656,
|
24 |
+
"pad_token": null,
|
25 |
+
"padding_side": "right",
|
26 |
+
"sp_model_kwargs": {},
|
27 |
+
"tokenizer_class": "LlamaTokenizer",
|
28 |
+
"unk_token": {
|
29 |
+
"__type": "AddedToken",
|
30 |
+
"content": "<unk>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false
|
35 |
+
}
|
36 |
+
}
|
scripts/inference-server/local-tokenizers/meta-llama/Llama-2-7b-chat-hf/tokenizer_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"bos_token": "</s>",
|
5 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}",
|
6 |
+
"clean_up_tokenization_spaces": false,
|
7 |
+
"eos_token": "</s>",
|
8 |
+
"legacy": false,
|
9 |
+
"model_max_length": 1000000000000000019884624838656,
|
10 |
+
"pad_token": null,
|
11 |
+
"padding_side": "right",
|
12 |
+
"sp_model_kwargs": {},
|
13 |
+
"tokenizer_class": "LlamaTokenizer",
|
14 |
+
"unk_token": {
|
15 |
+
"__type": "AddedToken",
|
16 |
+
"content": "<unk>",
|
17 |
+
"lstrip": false,
|
18 |
+
"normalized": false,
|
19 |
+
"rstrip": false,
|
20 |
+
"single_word": false
|
21 |
+
}
|
22 |
+
}
|
scripts/inference-server/local-tokenizers/mistralai/Mistral-7B-Instruct-v0.2/chat_template.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{% if (messages[0]['role'] != 'system') %}{{ raise_exception('First role should be system!') }}{% elif (messages[1]['role'] != 'user') %}{{ raise_exception('Second role should be user!') }}{% endif %}{{ bos_token }}{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}{% for message in messages[2:] %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 1) %}{{ raise_exception('Conversation roles must alternate system/user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only system, user and assistant roles are supported!') }}{% endif %}{% endfor %}
|
scripts/inference-server/local-tokenizers/mistralai/Mistral-7B-Instruct-v0.2/tokenizer_config.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"chat_template": "{% if (messages[0]['role'] != 'system') %}{{ raise_exception('First role should be system!') }}{% elif (messages[1]['role'] != 'user') %}{{ raise_exception('Second role should be user!') }}{% endif %}{{ bos_token }}{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}{% for message in messages[2:] %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 1) %}{{ raise_exception('Conversation roles must alternate system/user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only system, user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
4 |
+
"eos_token": "</s>"
|
5 |
+
}
|
scripts/inference-server/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
argparse
|
2 |
+
asyncio
|
3 |
+
aiohttp
|
4 |
+
numpy
|
5 |
+
torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118
|
6 |
+
tqdm
|
7 |
+
zeus-ml
|
sharegpt/README.md
CHANGED
@@ -1,33 +1,58 @@
|
|
1 |
# How we used ShareGPT to create our benchmark dataset
|
2 |
|
3 |
-
##
|
|
|
|
|
4 |
```
|
5 |
https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/HTML_cleaned_raw_dataset/sg_90k_part1_html_cleaned.json
|
6 |
```
|
7 |
|
8 |
-
|
9 |
```
|
10 |
pip install fschat
|
11 |
```
|
12 |
|
13 |
-
|
14 |
```
|
15 |
pip install polyglot pyicu pycld2
|
16 |
python -m fastchat.data.optional_clean --in sg_90k_part1_html_cleaned.json --out sg_90k_part1_html_cleaned_lang.json --keep-lang en
|
17 |
```
|
18 |
|
19 |
-
|
20 |
```
|
21 |
python extract_first.py --in-file sg_90k_part1_html_cleaned_lang.json --out-file sg_90k_part1_html_cleaned_lang_first.json
|
22 |
```
|
23 |
|
24 |
-
|
25 |
```
|
26 |
python -m fastchat.data.sample --in sg_90k_part1_html_cleaned_lang_first.json --out sg_90k_part1_html_cleaned_lang_first_sampled.json --end 10000 --max-length 10000
|
27 |
```
|
28 |
|
29 |
-
|
30 |
We sort the requests by sequence length, placing the longest sequences first. This approach minimizes the amount of padding required and allows for early detection of out-of-memory.
|
31 |
```
|
32 |
python sort.py --data-dir sg_90k_part1_html_cleaned_lang_first_sampled.json --out-file sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
|
33 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# How we used ShareGPT to create our benchmark dataset
|
2 |
|
3 |
+
## sg_90k_part1_html_cleaned.json
|
4 |
+
|
5 |
+
### Download ShareGPT dataset
|
6 |
```
|
7 |
https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/HTML_cleaned_raw_dataset/sg_90k_part1_html_cleaned.json
|
8 |
```
|
9 |
|
10 |
+
### Install Fastchat
|
11 |
```
|
12 |
pip install fschat
|
13 |
```
|
14 |
|
15 |
+
### Clean data:
|
16 |
```
|
17 |
pip install polyglot pyicu pycld2
|
18 |
python -m fastchat.data.optional_clean --in sg_90k_part1_html_cleaned.json --out sg_90k_part1_html_cleaned_lang.json --keep-lang en
|
19 |
```
|
20 |
|
21 |
+
### Extract first prompt
|
22 |
```
|
23 |
python extract_first.py --in-file sg_90k_part1_html_cleaned_lang.json --out-file sg_90k_part1_html_cleaned_lang_first.json
|
24 |
```
|
25 |
|
26 |
+
### Sample data
|
27 |
```
|
28 |
python -m fastchat.data.sample --in sg_90k_part1_html_cleaned_lang_first.json --out sg_90k_part1_html_cleaned_lang_first_sampled.json --end 10000 --max-length 10000
|
29 |
```
|
30 |
|
31 |
+
### Sorted data
|
32 |
We sort the requests by sequence length, placing the longest sequences first. This approach minimizes the amount of padding required and allows for early detection of out-of-memory.
|
33 |
```
|
34 |
python sort.py --data-dir sg_90k_part1_html_cleaned_lang_first_sampled.json --out-file sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
|
35 |
```
|
36 |
+
|
37 |
+
## ShareGPT_V3_filtered.json
|
38 |
+
|
39 |
+
### Download ShareGPT dataset
|
40 |
+
```
|
41 |
+
https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
42 |
+
```
|
43 |
+
|
44 |
+
### Install Transformers
|
45 |
+
```
|
46 |
+
pip install transformers
|
47 |
+
```
|
48 |
+
|
49 |
+
### Filter conversations with too long prompts/responses, extract first turn, and randomly sample 500 prompts
|
50 |
+
```
|
51 |
+
python filter_dataset.py
|
52 |
+
```
|
53 |
+
|
54 |
+
### Compare the response length distribution of sampled dataset with respect to initial dataset
|
55 |
+
```
|
56 |
+
pip install matplotlib numpy
|
57 |
+
python compare_distributions.py
|
58 |
+
```
|
sharegpt/ShareGPT_V3_filtered_500.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sharegpt/compare_distributions.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
from transformers import (
|
5 |
+
AutoTokenizer,
|
6 |
+
PreTrainedTokenizer,
|
7 |
+
PreTrainedTokenizerBase,
|
8 |
+
PreTrainedTokenizerFast,
|
9 |
+
)
|
10 |
+
|
11 |
+
# Open datasets
|
12 |
+
file_paths = ["ShareGPT_V3_filtered.json", "ShareGPT_V3_filtered_500.json"]
|
13 |
+
|
14 |
+
names = [file_path[:-5] for file_path in file_paths]
|
15 |
+
|
16 |
+
data_lists = []
|
17 |
+
for file_path in file_paths:
|
18 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
19 |
+
data_list = json.load(file)
|
20 |
+
data_lists.append(data_list)
|
21 |
+
|
22 |
+
for name, data_list in zip(names, data_lists):
|
23 |
+
print(f"{name}: {len(data_list)}")
|
24 |
+
|
25 |
+
# Get prompt lengths using tokenizer
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
27 |
+
all_prompts = [
|
28 |
+
[data["conversations"][0]["value"] for data in data_lists]
|
29 |
+
for data_lists in data_lists
|
30 |
+
]
|
31 |
+
all_token_ids_per_prompts = [tokenizer(prompts).input_ids for prompts in all_prompts]
|
32 |
+
all_prompt_lens = [
|
33 |
+
[len(token_ids) for token_ids in token_ids_per_prompt]
|
34 |
+
for token_ids_per_prompt in all_token_ids_per_prompts
|
35 |
+
]
|
36 |
+
|
37 |
+
# Plotting the histograms
|
38 |
+
for name, prompt_lens in zip(names, all_prompt_lens):
|
39 |
+
plt.hist(
|
40 |
+
prompt_lens,
|
41 |
+
bins=range(min(prompt_lens), max(prompt_lens) + 1),
|
42 |
+
edgecolor="black",
|
43 |
+
)
|
44 |
+
plt.xlabel("Prompt Length (number of tokens)")
|
45 |
+
plt.ylabel("Frequency")
|
46 |
+
plt.title(f"Histogram of {name}")
|
47 |
+
plt.savefig(f"{name}_distribution.png")
|
48 |
+
plt.close()
|
49 |
+
|
50 |
+
# Plotting the CDF
|
51 |
+
for name, prompt_lens in zip(names, all_prompt_lens):
|
52 |
+
values, counts = np.unique(prompt_lens, return_counts=True)
|
53 |
+
relative_frequencies = counts / len(prompt_lens)
|
54 |
+
sorted_data = np.sort(values)
|
55 |
+
cumulative_frequencies = np.cumsum(relative_frequencies)
|
56 |
+
plt.step(sorted_data, cumulative_frequencies, where="post", label=name)
|
57 |
+
|
58 |
+
plt.title(f"Cumulative Distribution Function (CDF) Overlayed")
|
59 |
+
plt.xlabel("Prompt Length (number of tokens)")
|
60 |
+
plt.ylabel("Cumulative Probability")
|
61 |
+
plt.savefig(f"{name}_cdf.png")
|
62 |
+
plt.close()
|
sharegpt/filter_dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Taken and modified from vllm: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/benchmarks/benchmark_serving.py
|
2 |
+
Filter dataset to:
|
3 |
+
1. Remove entries that have too long prompts or completions
|
4 |
+
2. Only keep first human prompt for each conversation
|
5 |
+
"""
|
6 |
+
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
from typing import AsyncGenerator, List, Tuple
|
10 |
+
|
11 |
+
from transformers import (
|
12 |
+
AutoTokenizer,
|
13 |
+
PreTrainedTokenizer,
|
14 |
+
PreTrainedTokenizerBase,
|
15 |
+
PreTrainedTokenizerFast,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def filter_dataset_to_size(
|
20 |
+
dataset_path: str,
|
21 |
+
size: int,
|
22 |
+
) -> List[Tuple[str, int, int]]:
|
23 |
+
# Load the dataset.
|
24 |
+
with open(dataset_path) as f:
|
25 |
+
dataset = json.load(f)
|
26 |
+
|
27 |
+
# randomly sample dataset
|
28 |
+
return random.sample(dataset, size)
|
29 |
+
|
30 |
+
|
31 |
+
def filter_dataset(
|
32 |
+
dataset_path: str,
|
33 |
+
tokenizer: PreTrainedTokenizerBase,
|
34 |
+
) -> List[Tuple[str, int, int]]:
|
35 |
+
# Load the dataset.
|
36 |
+
with open(dataset_path) as f:
|
37 |
+
dataset = json.load(f)
|
38 |
+
# Filter out the conversations with less than 2 turns.
|
39 |
+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
40 |
+
# Only keep the first two turns of each conversation.
|
41 |
+
dataset = [
|
42 |
+
(
|
43 |
+
data["id"],
|
44 |
+
data["conversations"][0]["value"],
|
45 |
+
data["conversations"][1]["value"],
|
46 |
+
)
|
47 |
+
for data in dataset
|
48 |
+
]
|
49 |
+
|
50 |
+
# Tokenize the prompts and completions.
|
51 |
+
conversation_ids = [conv_id for conv_id, _, _ in dataset]
|
52 |
+
prompts = [prompt for _, prompt, _ in dataset]
|
53 |
+
prompt_token_ids = tokenizer(prompts).input_ids
|
54 |
+
completions = [completion for _, _, completion in dataset]
|
55 |
+
completion_token_ids = tokenizer(completions).input_ids
|
56 |
+
tokenized_dataset = []
|
57 |
+
for i in range(len(dataset)):
|
58 |
+
output_len = len(completion_token_ids[i])
|
59 |
+
tokenized_dataset.append(
|
60 |
+
(conversation_ids[i], prompts[i], prompt_token_ids[i], output_len)
|
61 |
+
)
|
62 |
+
|
63 |
+
# Filter out too long sequences.
|
64 |
+
filtered_dataset_json = []
|
65 |
+
for conv_id, prompt, prompt_token_ids, output_len in tokenized_dataset:
|
66 |
+
prompt_len = len(prompt_token_ids)
|
67 |
+
if prompt_len < 4 or output_len < 4:
|
68 |
+
# Prune too short sequences.
|
69 |
+
# This is because TGI causes errors when the input or output length
|
70 |
+
# is too short.
|
71 |
+
continue
|
72 |
+
# making even shorter than 1024 to account for additional tokens introduced by chat completion wrapper
|
73 |
+
if prompt_len > 800 or output_len > 800:
|
74 |
+
# if prompt_len > 1024 or output_len > 1024:
|
75 |
+
# Prune too long sequences.
|
76 |
+
continue
|
77 |
+
filtered_dataset_json.append(
|
78 |
+
{
|
79 |
+
"id": conv_id,
|
80 |
+
"conversations": [
|
81 |
+
{
|
82 |
+
"from": "human",
|
83 |
+
"value": prompt,
|
84 |
+
}
|
85 |
+
],
|
86 |
+
}
|
87 |
+
)
|
88 |
+
|
89 |
+
return filtered_dataset_json
|
90 |
+
|
91 |
+
|
92 |
+
def main():
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
94 |
+
# download: https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
95 |
+
filtered_dataset = filter_dataset(
|
96 |
+
"ShareGPT_V3_unfiltered_cleaned_split.json", tokenizer
|
97 |
+
)
|
98 |
+
with open("ShareGPT_V3_filtered.json", "w") as f:
|
99 |
+
json.dump(filtered_dataset, f)
|
100 |
+
|
101 |
+
sampled_dataset = filter_dataset_to_size("ShareGPT_V3_filtered.json", 500)
|
102 |
+
with open("ShareGPT_V3_filtered_500.json", "w") as f:
|
103 |
+
json.dump(sampled_dataset, f)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
main()
|