Spaces:
Paused
Paused
lengyue233
commited on
Commit
•
1caffd8
1
Parent(s):
a4dfb48
Better init event waiting
Browse files- app.py +0 -3
- tools/llama/generate.py +4 -4
app.py
CHANGED
@@ -306,7 +306,6 @@ if __name__ == "__main__":
|
|
306 |
args.vqgan_config_name = "vqgan_pretrain"
|
307 |
|
308 |
logger.info("Loading Llama model...")
|
309 |
-
init_event = threading.Event()
|
310 |
llama_queue = launch_thread_safe_queue(
|
311 |
config_name=args.llama_config_name,
|
312 |
checkpoint_path=args.llama_checkpoint_path,
|
@@ -314,10 +313,8 @@ if __name__ == "__main__":
|
|
314 |
precision=args.precision,
|
315 |
max_length=args.max_length,
|
316 |
compile=args.compile,
|
317 |
-
init_event=init_event,
|
318 |
)
|
319 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
320 |
-
init_event.wait()
|
321 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
322 |
|
323 |
vqgan_model = load_vqgan_model(
|
|
|
306 |
args.vqgan_config_name = "vqgan_pretrain"
|
307 |
|
308 |
logger.info("Loading Llama model...")
|
|
|
309 |
llama_queue = launch_thread_safe_queue(
|
310 |
config_name=args.llama_config_name,
|
311 |
checkpoint_path=args.llama_checkpoint_path,
|
|
|
313 |
precision=args.precision,
|
314 |
max_length=args.max_length,
|
315 |
compile=args.compile,
|
|
|
316 |
)
|
317 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
|
318 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
319 |
|
320 |
vqgan_model = load_vqgan_model(
|
tools/llama/generate.py
CHANGED
@@ -600,6 +600,7 @@ def generate_long(
|
|
600 |
yield all_codes
|
601 |
|
602 |
|
|
|
603 |
def launch_thread_safe_queue(
|
604 |
config_name,
|
605 |
checkpoint_path,
|
@@ -607,17 +608,15 @@ def launch_thread_safe_queue(
|
|
607 |
precision,
|
608 |
max_length,
|
609 |
compile=False,
|
610 |
-
init_event=None,
|
611 |
):
|
612 |
input_queue = queue.Queue()
|
|
|
613 |
|
614 |
def worker():
|
615 |
model, decode_one_token = load_model(
|
616 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
617 |
)
|
618 |
-
|
619 |
-
if init_event is not None:
|
620 |
-
init_event.set()
|
621 |
|
622 |
while True:
|
623 |
item = input_queue.get()
|
@@ -641,6 +640,7 @@ def launch_thread_safe_queue(
|
|
641 |
event.set()
|
642 |
|
643 |
threading.Thread(target=worker, daemon=True).start()
|
|
|
644 |
|
645 |
return input_queue
|
646 |
|
|
|
600 |
yield all_codes
|
601 |
|
602 |
|
603 |
+
|
604 |
def launch_thread_safe_queue(
|
605 |
config_name,
|
606 |
checkpoint_path,
|
|
|
608 |
precision,
|
609 |
max_length,
|
610 |
compile=False,
|
|
|
611 |
):
|
612 |
input_queue = queue.Queue()
|
613 |
+
init_event = threading.Event()
|
614 |
|
615 |
def worker():
|
616 |
model, decode_one_token = load_model(
|
617 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
618 |
)
|
619 |
+
init_event.set()
|
|
|
|
|
620 |
|
621 |
while True:
|
622 |
item = input_queue.get()
|
|
|
640 |
event.set()
|
641 |
|
642 |
threading.Thread(target=worker, daemon=True).start()
|
643 |
+
init_event.wait()
|
644 |
|
645 |
return input_queue
|
646 |
|