Spaces:
Paused
Paused
lengyue233
commited on
Commit
•
75e9ff1
1
Parent(s):
1caffd8
optimize compile by removing if branch
Browse files- app.py +16 -6
- tools/llama/generate.py +42 -45
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from huggingface_hub import snapshot_download
|
3 |
import hydra
|
4 |
|
@@ -125,17 +126,26 @@ def inference(
|
|
125 |
)
|
126 |
|
127 |
payload = dict(
|
128 |
-
|
129 |
request=request,
|
130 |
)
|
131 |
llama_queue.put(payload)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
137 |
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
# VQGAN Inference
|
141 |
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
|
1 |
import os
|
2 |
+
import queue
|
3 |
from huggingface_hub import snapshot_download
|
4 |
import hydra
|
5 |
|
|
|
126 |
)
|
127 |
|
128 |
payload = dict(
|
129 |
+
response_queue=queue.Queue(),
|
130 |
request=request,
|
131 |
)
|
132 |
llama_queue.put(payload)
|
133 |
|
134 |
+
codes = []
|
135 |
+
while True:
|
136 |
+
result = payload["response_queue"].get()
|
137 |
+
if result == "next":
|
138 |
+
# TODO: handle next sentence
|
139 |
+
continue
|
140 |
|
141 |
+
if result == "done":
|
142 |
+
if payload["success"] is False:
|
143 |
+
raise payload["response"]
|
144 |
+
break
|
145 |
+
|
146 |
+
codes.append(result)
|
147 |
+
|
148 |
+
codes = torch.cat(codes, dim=1)
|
149 |
|
150 |
# VQGAN Inference
|
151 |
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
tools/llama/generate.py
CHANGED
@@ -47,32 +47,32 @@ def logits_to_probs(
|
|
47 |
top_p: Optional[int] = None,
|
48 |
repetition_penalty: float = 1.0,
|
49 |
):
|
50 |
-
if previous_tokens is not None and repetition_penalty != 1.0:
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
if top_p is not None and top_p < 1.0:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
|
70 |
logits = logits / max(temperature, 1e-5)
|
71 |
|
72 |
-
if top_k is not None:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
|
77 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
78 |
return probs
|
@@ -470,16 +470,14 @@ def generate_long(
|
|
470 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
471 |
|
472 |
if use_prompt:
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
num_codebooks=model.config.num_codebooks,
|
482 |
-
)
|
483 |
)
|
484 |
|
485 |
for idx, text in enumerate(texts):
|
@@ -501,10 +499,6 @@ def generate_long(
|
|
501 |
all_codes = []
|
502 |
seg_idx = 0
|
503 |
|
504 |
-
if use_prompt:
|
505 |
-
seg_idx = 1
|
506 |
-
global_encoded.append(encoded[0])
|
507 |
-
|
508 |
while seg_idx < len(encoded):
|
509 |
logger.info(
|
510 |
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
@@ -531,6 +525,9 @@ def generate_long(
|
|
531 |
else:
|
532 |
partial_encoded = global_encoded
|
533 |
|
|
|
|
|
|
|
534 |
cat_encoded = torch.cat(partial_encoded, dim=1)
|
535 |
prompt_length = cat_encoded.size(1)
|
536 |
|
@@ -593,14 +590,13 @@ def generate_long(
|
|
593 |
|
594 |
if is_streaming:
|
595 |
# This indicates the end of the current sample
|
596 |
-
yield
|
597 |
else:
|
598 |
all_codes = torch.cat(all_codes, dim=1)
|
599 |
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
600 |
yield all_codes
|
601 |
|
602 |
|
603 |
-
|
604 |
def launch_thread_safe_queue(
|
605 |
config_name,
|
606 |
checkpoint_path,
|
@@ -624,20 +620,21 @@ def launch_thread_safe_queue(
|
|
624 |
break
|
625 |
|
626 |
kwargs = item["request"]
|
627 |
-
|
628 |
|
629 |
try:
|
630 |
item["success"] = True
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
)
|
635 |
-
|
|
|
636 |
except Exception as e:
|
637 |
item["success"] = False
|
638 |
item["response"] = e
|
639 |
|
640 |
-
|
641 |
|
642 |
threading.Thread(target=worker, daemon=True).start()
|
643 |
init_event.wait()
|
|
|
47 |
top_p: Optional[int] = None,
|
48 |
repetition_penalty: float = 1.0,
|
49 |
):
|
50 |
+
# if previous_tokens is not None and repetition_penalty != 1.0:
|
51 |
+
previous_tokens = previous_tokens.long()
|
52 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
53 |
+
score = torch.where(
|
54 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
55 |
+
)
|
56 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
57 |
|
58 |
+
# if top_p is not None and top_p < 1.0:
|
59 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
60 |
+
cum_probs = torch.cumsum(
|
61 |
+
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
62 |
+
)
|
63 |
+
sorted_indices_to_remove = cum_probs > top_p
|
64 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
65 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
66 |
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
67 |
+
)
|
68 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
69 |
|
70 |
logits = logits / max(temperature, 1e-5)
|
71 |
|
72 |
+
# if top_k is not None:
|
73 |
+
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
74 |
+
# pivot = v.select(-1, -1).unsqueeze(-1)
|
75 |
+
# logits = torch.where(logits < pivot, -float("Inf"), logits)
|
76 |
|
77 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
78 |
return probs
|
|
|
470 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
471 |
|
472 |
if use_prompt:
|
473 |
+
encoded_prompts = encode_tokens(
|
474 |
+
tokenizer,
|
475 |
+
prompt_text,
|
476 |
+
prompt_tokens=prompt_tokens,
|
477 |
+
bos=True,
|
478 |
+
device=device,
|
479 |
+
speaker=speaker,
|
480 |
+
num_codebooks=model.config.num_codebooks,
|
|
|
|
|
481 |
)
|
482 |
|
483 |
for idx, text in enumerate(texts):
|
|
|
499 |
all_codes = []
|
500 |
seg_idx = 0
|
501 |
|
|
|
|
|
|
|
|
|
502 |
while seg_idx < len(encoded):
|
503 |
logger.info(
|
504 |
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
|
525 |
else:
|
526 |
partial_encoded = global_encoded
|
527 |
|
528 |
+
if use_prompt:
|
529 |
+
partial_encoded = [encoded_prompts] + partial_encoded
|
530 |
+
|
531 |
cat_encoded = torch.cat(partial_encoded, dim=1)
|
532 |
prompt_length = cat_encoded.size(1)
|
533 |
|
|
|
590 |
|
591 |
if is_streaming:
|
592 |
# This indicates the end of the current sample
|
593 |
+
yield "next"
|
594 |
else:
|
595 |
all_codes = torch.cat(all_codes, dim=1)
|
596 |
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
597 |
yield all_codes
|
598 |
|
599 |
|
|
|
600 |
def launch_thread_safe_queue(
|
601 |
config_name,
|
602 |
checkpoint_path,
|
|
|
620 |
break
|
621 |
|
622 |
kwargs = item["request"]
|
623 |
+
response_queue = item["response_queue"]
|
624 |
|
625 |
try:
|
626 |
item["success"] = True
|
627 |
+
for chunk in generate_long(
|
628 |
+
model=model, decode_one_token=decode_one_token, **kwargs
|
629 |
+
):
|
630 |
+
response_queue.put(chunk)
|
631 |
+
|
632 |
+
response_queue.put("done")
|
633 |
except Exception as e:
|
634 |
item["success"] = False
|
635 |
item["response"] = e
|
636 |
|
637 |
+
response_queue.put("done")
|
638 |
|
639 |
threading.Thread(target=worker, daemon=True).start()
|
640 |
init_event.wait()
|