Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,265 Bytes
84a6c36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from huggingface_hub import snapshot_download
use_auth_token = os.getenv("YOUR_AUTH_TOKEN")
repo_id = "m-a-p/qwen2.5-7b-ins-v3"
local_dir = repo_id.rsplit("/")[-1]
snapshot_download(repo_id=repo_id, local_dir=local_dir, use_auth_token=use_auth_token, resume_download=True)
model_path = "qwen2.5-7b-ins-v3/checkpoint-1000"
tokenizer = AutoTokenizer.from_pretrained(model_path)
sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=8192)
llm = LLM(model=model_path)
def api_call_batch(batch_messages):
text_list = [
tokenizer.apply_chat_template(conversation=messages, tokenize=False, add_generation_prompt=True, return_tensors='pt')
for messages in batch_messages
]
outputs = llm.generate(text_list, sampling_params)
result = [output.outputs[0].text for output in outputs]
return result
def api_call(messages):
return api_call_batch([messages])[0]
def call_gpt(history, prompt):
return api_call(history+[{"role":"user", "content":prompt}])
if __name__ == "__main__":
messages = [{"role":"user", "content":"你是谁?"}]
breakpoint()
print(api_call_batch([messages]*4))
|