update model_worker.py
Browse files- model_worker.py +2 -3
model_worker.py
CHANGED
@@ -228,7 +228,7 @@ class ModelWorker:
|
|
228 |
)
|
229 |
self.heart_beat_thread.start()
|
230 |
|
231 |
-
@spaces.GPU
|
232 |
def import_flash_attn(self):
|
233 |
try:
|
234 |
import flash_attn
|
@@ -325,10 +325,8 @@ class ModelWorker:
|
|
325 |
"queue_length": self.get_queue_length(),
|
326 |
}
|
327 |
|
328 |
-
@spaces.GPU
|
329 |
@torch.inference_mode()
|
330 |
def generate_stream(self, params):
|
331 |
-
|
332 |
system_message = params["prompt"][0]["content"]
|
333 |
send_messages = params["prompt"][1:]
|
334 |
max_input_tiles = params["max_input_tiles"]
|
@@ -455,6 +453,7 @@ class ModelWorker:
|
|
455 |
)
|
456 |
self.model.system_message = old_system_message
|
457 |
|
|
|
458 |
def generate_stream_gate(self, params):
|
459 |
try:
|
460 |
for x in self.generate_stream(params):
|
|
|
228 |
)
|
229 |
self.heart_beat_thread.start()
|
230 |
|
231 |
+
@spaces.GPU(duration=120)
|
232 |
def import_flash_attn(self):
|
233 |
try:
|
234 |
import flash_attn
|
|
|
325 |
"queue_length": self.get_queue_length(),
|
326 |
}
|
327 |
|
|
|
328 |
@torch.inference_mode()
|
329 |
def generate_stream(self, params):
|
|
|
330 |
system_message = params["prompt"][0]["content"]
|
331 |
send_messages = params["prompt"][1:]
|
332 |
max_input_tiles = params["max_input_tiles"]
|
|
|
453 |
)
|
454 |
self.model.system_message = old_system_message
|
455 |
|
456 |
+
@spaces.GPU(duration=120)
|
457 |
def generate_stream_gate(self, params):
|
458 |
try:
|
459 |
for x in self.generate_stream(params):
|