Spaces:
Running
Running
Zhiyu Wu
commited on
Commit
•
01bc423
1
Parent(s):
aaadf66
add attention mask; fix stop_str length (#26)
Browse files- pegasus/benchmark.yaml +7 -1
- scripts/benchmark.py +25 -6
- scripts/sort.py +15 -0
- sharegpt/README.md +1 -0
- sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json +0 -0
pegasus/benchmark.yaml
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
# {{ gpu }} is defined in `hosts.yaml`, and will be filled in when Pegasus
|
4 |
# determines the specific node and gpu the generated job command will run on.
|
5 |
- command:
|
6 |
-
- docker exec leaderboard{{ gpu }} python scripts/benchmark.py --input-file sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json --model-path {{ model }} --task {{ task }}
|
7 |
model:
|
8 |
- /data/leaderboard/weights/metaai/llama-7B
|
9 |
- /data/leaderboard/weights/metaai/llama-13B
|
@@ -31,3 +31,9 @@
|
|
31 |
- chat-concise
|
32 |
- instruct
|
33 |
- instruct-concise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
# {{ gpu }} is defined in `hosts.yaml`, and will be filled in when Pegasus
|
4 |
# determines the specific node and gpu the generated job command will run on.
|
5 |
- command:
|
6 |
+
- docker exec leaderboard{{ gpu }} python scripts/benchmark.py --input-file sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json --model-path {{ model }} --task {{ task }} --batch-size {{ batch_size }}
|
7 |
model:
|
8 |
- /data/leaderboard/weights/metaai/llama-7B
|
9 |
- /data/leaderboard/weights/metaai/llama-13B
|
|
|
31 |
- chat-concise
|
32 |
- instruct
|
33 |
- instruct-concise
|
34 |
+
batch_size:
|
35 |
+
- 1
|
36 |
+
- 2
|
37 |
+
- 4
|
38 |
+
- 8
|
39 |
+
- 16
|
scripts/benchmark.py
CHANGED
@@ -104,7 +104,10 @@ def run_inference(
|
|
104 |
temperature, repetition_penalty, top_p, top_k
|
105 |
)
|
106 |
|
107 |
-
|
|
|
|
|
|
|
108 |
output_ids = [[] for _ in range(batch_size)]
|
109 |
|
110 |
if model.config.is_encoder_decoder:
|
@@ -113,10 +116,12 @@ def run_inference(
|
|
113 |
max_src_len = context_len - max_new_tokens - 1
|
114 |
|
115 |
input_ids = [input_id[-max_src_len:] for input_id in input_ids]
|
|
|
116 |
|
117 |
if model.config.is_encoder_decoder:
|
118 |
encoder_output = model.encoder(
|
119 |
-
input_ids=torch.as_tensor(input_ids, device=device)
|
|
|
120 |
)[0]
|
121 |
start_ids = torch.as_tensor(
|
122 |
[[model.generation_config.decoder_start_token_id] for _ in range(batch_size)],
|
@@ -126,6 +131,12 @@ def run_inference(
|
|
126 |
|
127 |
past_key_values = out = None
|
128 |
stopped = np.array(batch_size*[False])
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
for i in range(max_new_tokens):
|
130 |
if i == 0: # prefill
|
131 |
if model.config.is_encoder_decoder:
|
@@ -136,7 +147,7 @@ def run_inference(
|
|
136 |
)
|
137 |
logits = model.lm_head(out[0])
|
138 |
else:
|
139 |
-
out = model(torch.as_tensor(input_ids, device=device), use_cache=True)
|
140 |
logits = out.logits
|
141 |
past_key_values = out.past_key_values
|
142 |
else: # decoding
|
@@ -157,10 +168,17 @@ def run_inference(
|
|
157 |
),
|
158 |
use_cache=True,
|
159 |
past_key_values=past_key_values,
|
|
|
160 |
)
|
161 |
logits = out.logits
|
162 |
past_key_values = out.past_key_values
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
if logits_processor:
|
165 |
if repetition_penalty > 1.0:
|
166 |
tmp_output_ids = torch.as_tensor(output_ids, device=logits.device)
|
@@ -213,14 +231,15 @@ def run_inference(
|
|
213 |
for each_stop in stop_str:
|
214 |
pos_array = np.char.rfind(output_np, each_stop, rfind_start)
|
215 |
find_stop = pos_array != -1
|
|
|
|
|
216 |
else:
|
217 |
raise ValueError("Invalid stop field type.")
|
218 |
|
219 |
stop_str_indices = np.where(find_stop & ~stopped)[0]
|
220 |
if stop_str_indices.size > 0:
|
221 |
for j in stop_str_indices:
|
222 |
-
|
223 |
-
result[j].response_length = i
|
224 |
result[j].output = output[j][:pos_array[j]]
|
225 |
stopped[find_stop] = True
|
226 |
|
@@ -378,7 +397,7 @@ def main(
|
|
378 |
|
379 |
for is_warmup, input_prompts in data_iter:
|
380 |
# Construct the input prompt.
|
381 |
-
for i in range(
|
382 |
conv = copy.deepcopy(conv_base)
|
383 |
conv.append_message(conv.roles[0], input_prompts[i])
|
384 |
conv.append_message(conv.roles[1], "")
|
|
|
104 |
temperature, repetition_penalty, top_p, top_k
|
105 |
)
|
106 |
|
107 |
+
prompts_encode = tokenizer(prompts, padding=True)
|
108 |
+
input_ids = prompts_encode.input_ids
|
109 |
+
attention_masks = prompts_encode.attention_mask
|
110 |
+
|
111 |
output_ids = [[] for _ in range(batch_size)]
|
112 |
|
113 |
if model.config.is_encoder_decoder:
|
|
|
116 |
max_src_len = context_len - max_new_tokens - 1
|
117 |
|
118 |
input_ids = [input_id[-max_src_len:] for input_id in input_ids]
|
119 |
+
attention_masks = torch.as_tensor([attention_mask[-max_src_len:] for attention_mask in attention_masks], device=device)
|
120 |
|
121 |
if model.config.is_encoder_decoder:
|
122 |
encoder_output = model.encoder(
|
123 |
+
input_ids=torch.as_tensor(input_ids, device=device),
|
124 |
+
attention_mask=attention_masks
|
125 |
)[0]
|
126 |
start_ids = torch.as_tensor(
|
127 |
[[model.generation_config.decoder_start_token_id] for _ in range(batch_size)],
|
|
|
131 |
|
132 |
past_key_values = out = None
|
133 |
stopped = np.array(batch_size*[False])
|
134 |
+
|
135 |
+
# stop string length
|
136 |
+
stop_str_length = np.zeros(batch_size, dtype=int)
|
137 |
+
if stop_str and isinstance(stop_str, str):
|
138 |
+
stop_str_length[:] = len(tokenizer(stop_str).input_ids)
|
139 |
+
|
140 |
for i in range(max_new_tokens):
|
141 |
if i == 0: # prefill
|
142 |
if model.config.is_encoder_decoder:
|
|
|
147 |
)
|
148 |
logits = model.lm_head(out[0])
|
149 |
else:
|
150 |
+
out = model(torch.as_tensor(input_ids, device=device), use_cache=True, attention_mask=attention_masks)
|
151 |
logits = out.logits
|
152 |
past_key_values = out.past_key_values
|
153 |
else: # decoding
|
|
|
168 |
),
|
169 |
use_cache=True,
|
170 |
past_key_values=past_key_values,
|
171 |
+
attention_mask=attention_masks,
|
172 |
)
|
173 |
logits = out.logits
|
174 |
past_key_values = out.past_key_values
|
175 |
|
176 |
+
# update attention mask
|
177 |
+
attention_masks = torch.cat(
|
178 |
+
[attention_masks, torch.ones((batch_size, 1), device=device)],
|
179 |
+
dim=1
|
180 |
+
)
|
181 |
+
|
182 |
if logits_processor:
|
183 |
if repetition_penalty > 1.0:
|
184 |
tmp_output_ids = torch.as_tensor(output_ids, device=logits.device)
|
|
|
231 |
for each_stop in stop_str:
|
232 |
pos_array = np.char.rfind(output_np, each_stop, rfind_start)
|
233 |
find_stop = pos_array != -1
|
234 |
+
# update stop_str_length with each stop_str_length for each request
|
235 |
+
stop_str_length[find_stop] = len(tokenizer(each_stop).input_ids)
|
236 |
else:
|
237 |
raise ValueError("Invalid stop field type.")
|
238 |
|
239 |
stop_str_indices = np.where(find_stop & ~stopped)[0]
|
240 |
if stop_str_indices.size > 0:
|
241 |
for j in stop_str_indices:
|
242 |
+
result[j].response_length = i+1-stop_str_length[j]
|
|
|
243 |
result[j].output = output[j][:pos_array[j]]
|
244 |
stopped[find_stop] = True
|
245 |
|
|
|
397 |
|
398 |
for is_warmup, input_prompts in data_iter:
|
399 |
# Construct the input prompt.
|
400 |
+
for i in range(len(input_prompts)):
|
401 |
conv = copy.deepcopy(conv_base)
|
402 |
conv.append_message(conv.roles[0], input_prompts[i])
|
403 |
conv.append_message(conv.roles[1], "")
|
scripts/sort.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import tyro
|
3 |
+
|
4 |
+
def main(data_dir:str, out_file:str) -> None:
|
5 |
+
|
6 |
+
with open(data_dir, "r") as f:
|
7 |
+
data = json.load(f)
|
8 |
+
|
9 |
+
sorted_data = sorted(data, key=lambda x: len(x['conversations'][0]['value']), reverse=True)
|
10 |
+
|
11 |
+
with open(out_file, "w") as f:
|
12 |
+
json.dump(sorted_data, f, indent=4)
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
tyro.cli(main)
|
sharegpt/README.md
CHANGED
@@ -27,6 +27,7 @@ python -m fastchat.data.sample --in sg_90k_part1_html_cleaned_lang_first.json --
|
|
27 |
```
|
28 |
|
29 |
## Sorted data
|
|
|
30 |
```
|
31 |
python sort.py --data-dir sg_90k_part1_html_cleaned_lang_first_sampled.json --out-file sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
|
32 |
```
|
|
|
27 |
```
|
28 |
|
29 |
## Sorted data
|
30 |
+
We sort the requests by sequence length, placing the longest sequences first. This approach minimizes the amount of padding required and allows for early detection of out-of-memory.
|
31 |
```
|
32 |
python sort.py --data-dir sg_90k_part1_html_cleaned_lang_first_sampled.json --out-file sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
|
33 |
```
|
sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|