lengyue233 commited on
Commit
a4dfb48
1 Parent(s): f7a538e

wait for init

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. tools/llama/generate.py +4 -0
app.py CHANGED
@@ -306,7 +306,7 @@ if __name__ == "__main__":
306
  args.vqgan_config_name = "vqgan_pretrain"
307
 
308
  logger.info("Loading Llama model...")
309
- hydra.core.global_hydra.GlobalHydra.instance().clear()
310
  llama_queue = launch_thread_safe_queue(
311
  config_name=args.llama_config_name,
312
  checkpoint_path=args.llama_checkpoint_path,
@@ -314,11 +314,12 @@ if __name__ == "__main__":
314
  precision=args.precision,
315
  max_length=args.max_length,
316
  compile=args.compile,
 
317
  )
318
  llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
 
319
  logger.info("Llama model loaded, loading VQ-GAN model...")
320
 
321
- hydra.core.global_hydra.GlobalHydra.instance().clear()
322
  vqgan_model = load_vqgan_model(
323
  config_name=args.vqgan_config_name,
324
  checkpoint_path=args.vqgan_checkpoint_path,
 
306
  args.vqgan_config_name = "vqgan_pretrain"
307
 
308
  logger.info("Loading Llama model...")
309
+ init_event = threading.Event()
310
  llama_queue = launch_thread_safe_queue(
311
  config_name=args.llama_config_name,
312
  checkpoint_path=args.llama_checkpoint_path,
 
314
  precision=args.precision,
315
  max_length=args.max_length,
316
  compile=args.compile,
317
+ init_event=init_event,
318
  )
319
  llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
320
+ init_event.wait()
321
  logger.info("Llama model loaded, loading VQ-GAN model...")
322
 
 
323
  vqgan_model = load_vqgan_model(
324
  config_name=args.vqgan_config_name,
325
  checkpoint_path=args.vqgan_checkpoint_path,
tools/llama/generate.py CHANGED
@@ -607,6 +607,7 @@ def launch_thread_safe_queue(
607
  precision,
608
  max_length,
609
  compile=False,
 
610
  ):
611
  input_queue = queue.Queue()
612
 
@@ -615,6 +616,9 @@ def launch_thread_safe_queue(
615
  config_name, checkpoint_path, device, precision, max_length, compile=compile
616
  )
617
 
 
 
 
618
  while True:
619
  item = input_queue.get()
620
  if item is None:
 
607
  precision,
608
  max_length,
609
  compile=False,
610
+ init_event=None,
611
  ):
612
  input_queue = queue.Queue()
613
 
 
616
  config_name, checkpoint_path, device, precision, max_length, compile=compile
617
  )
618
 
619
+ if init_event is not None:
620
+ init_event.set()
621
+
622
  while True:
623
  item = input_queue.get()
624
  if item is None: