zyliu commited on
Commit
8b33d6d
1 Parent(s): 966d74c

update app.py

Browse files
Files changed (2) hide show
  1. app.py +8 -2
  2. model_worker.py +1 -2
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import fire
2
  import subprocess
3
  import os
@@ -9,7 +10,12 @@ import atexit
9
  try:
10
  import flash_attn
11
  except ImportError:
12
- os.system("pip install flash-attn==2.5.9.post1")
 
 
 
 
 
13
  import flash_attn
14
 
15
 
@@ -34,7 +40,7 @@ def main(
34
  run_worker=True,
35
  run_gradio=True,
36
  controller_port=10086,
37
- gradio_port=10087,
38
  worker_names=[
39
  "OpenGVLab/InternVL2-8B",
40
  ],
 
1
+ import spaces
2
  import fire
3
  import subprocess
4
  import os
 
10
  try:
11
  import flash_attn
12
  except ImportError:
13
+
14
+ @spaces.GPU
15
+ def install_flash_attn():
16
+ os.system("pip install flash-attn==2.5.9.post1")
17
+
18
+ install_flash_attn()
19
  import flash_attn
20
 
21
 
 
40
  run_worker=True,
41
  run_gradio=True,
42
  controller_port=10086,
43
+ gradio_port=7860,
44
  worker_names=[
45
  "OpenGVLab/InternVL2-8B",
46
  ],
model_worker.py CHANGED
@@ -7,6 +7,7 @@
7
  """
8
  A model worker executes the model.
9
  """
 
10
  import argparse
11
  import asyncio
12
 
@@ -17,7 +18,6 @@ import time
17
  import uuid
18
  import traceback
19
  from functools import partial
20
-
21
  from threading import Thread
22
 
23
  import requests
@@ -36,7 +36,6 @@ from utils import (
36
  server_error_msg,
37
  load_image_from_base64,
38
  )
39
- import spaces
40
 
41
  worker_id = str(uuid.uuid4())[:6]
42
  logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
 
7
  """
8
  A model worker executes the model.
9
  """
10
+ import spaces
11
  import argparse
12
  import asyncio
13
 
 
18
  import uuid
19
  import traceback
20
  from functools import partial
 
21
  from threading import Thread
22
 
23
  import requests
 
36
  server_error_msg,
37
  load_image_from_base64,
38
  )
 
39
 
40
  worker_id = str(uuid.uuid4())[:6]
41
  logger = build_logger("model_worker", f"model_worker_{worker_id}.log")