|
''' |
|
|
|
Contributed by SagsMug. Thank you SagsMug. |
|
https://github.com/oobabooga/text-generation-webui/pull/175 |
|
|
|
''' |
|
|
|
import asyncio |
|
import json |
|
import random |
|
import string |
|
|
|
import websockets |
|
|
|
|
|
def random_hash(): |
|
letters = string.ascii_lowercase + string.digits |
|
return ''.join(random.choice(letters) for i in range(9)) |
|
|
|
async def run(context): |
|
server = "127.0.0.1" |
|
params = { |
|
'max_new_tokens': 200, |
|
'do_sample': True, |
|
'temperature': 0.5, |
|
'top_p': 0.9, |
|
'typical_p': 1, |
|
'repetition_penalty': 1.05, |
|
'top_k': 0, |
|
'min_length': 0, |
|
'no_repeat_ngram_size': 0, |
|
'num_beams': 1, |
|
'penalty_alpha': 0, |
|
'length_penalty': 1, |
|
'early_stopping': False, |
|
} |
|
session = random_hash() |
|
|
|
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: |
|
while content := json.loads(await websocket.recv()): |
|
|
|
match content["msg"]: |
|
case "send_hash": |
|
await websocket.send(json.dumps({ |
|
"session_hash": session, |
|
"fn_index": 7 |
|
})) |
|
case "estimation": |
|
pass |
|
case "send_data": |
|
await websocket.send(json.dumps({ |
|
"session_hash": session, |
|
"fn_index": 7, |
|
"data": [ |
|
context, |
|
params['max_new_tokens'], |
|
params['do_sample'], |
|
params['temperature'], |
|
params['top_p'], |
|
params['typical_p'], |
|
params['repetition_penalty'], |
|
params['top_k'], |
|
params['min_length'], |
|
params['no_repeat_ngram_size'], |
|
params['num_beams'], |
|
params['penalty_alpha'], |
|
params['length_penalty'], |
|
params['early_stopping'], |
|
] |
|
})) |
|
case "process_starts": |
|
pass |
|
case "process_generating" | "process_completed": |
|
yield content["output"]["data"][0] |
|
|
|
|
|
if (content["msg"] == "process_completed"): |
|
break |
|
|
|
prompt = "What I would like to say is the following: " |
|
|
|
async def get_result(): |
|
async for response in run(prompt): |
|
|
|
print(response) |
|
|
|
|
|
print(response) |
|
|
|
asyncio.run(get_result()) |
|
|