|
"""Call API providers.""" |
|
|
|
from json import loads |
|
import os |
|
import random |
|
import time |
|
|
|
from fastchat.utils import build_logger |
|
from fastchat.constants import WORKER_API_TIMEOUT |
|
|
|
|
|
logger = build_logger("gradio_web_server", "gradio_web_server.log") |
|
|
|
|
|
def openai_api_stream_iter( |
|
model_name, |
|
messages, |
|
temperature, |
|
top_p, |
|
max_new_tokens, |
|
api_base=None, |
|
api_key=None, |
|
): |
|
import openai |
|
|
|
is_azure = False |
|
if "azure" in model_name: |
|
is_azure = True |
|
openai.api_type = "azure" |
|
openai.api_version = "2023-07-01-preview" |
|
else: |
|
openai.api_type = "open_ai" |
|
openai.api_version = None |
|
|
|
openai.api_base = api_base or "https://api.openai.com/v1" |
|
openai.api_key = api_key or os.environ["OPENAI_API_KEY"] |
|
if model_name == "gpt-4-turbo": |
|
model_name = "gpt-4-1106-preview" |
|
|
|
|
|
gen_params = { |
|
"model": model_name, |
|
"prompt": messages, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
} |
|
logger.info(f"==== request ====\n{gen_params}") |
|
|
|
if is_azure: |
|
res = openai.ChatCompletion.create( |
|
engine=model_name, |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens=max_new_tokens, |
|
stream=True, |
|
) |
|
else: |
|
res = openai.ChatCompletion.create( |
|
model=model_name, |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens=max_new_tokens, |
|
stream=True, |
|
) |
|
text = "" |
|
for chunk in res: |
|
if len(chunk["choices"]) > 0: |
|
text += chunk["choices"][0]["delta"].get("content", "") |
|
data = { |
|
"text": text, |
|
"error_code": 0, |
|
} |
|
yield data |
|
|
|
|
|
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): |
|
import anthropic |
|
|
|
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) |
|
|
|
|
|
gen_params = { |
|
"model": model_name, |
|
"prompt": prompt, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
} |
|
logger.info(f"==== request ====\n{gen_params}") |
|
|
|
res = c.completions.create( |
|
prompt=prompt, |
|
stop_sequences=[anthropic.HUMAN_PROMPT], |
|
max_tokens_to_sample=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
model=model_name, |
|
stream=True, |
|
) |
|
text = "" |
|
for chunk in res: |
|
text += chunk.completion |
|
data = { |
|
"text": text, |
|
"error_code": 0, |
|
} |
|
yield data |
|
|
|
|
|
def init_palm_chat(model_name): |
|
import vertexai |
|
from vertexai.preview.language_models import ChatModel |
|
from vertexai.preview.generative_models import GenerativeModel |
|
|
|
project_id = os.environ["GCP_PROJECT_ID"] |
|
location = "us-central1" |
|
vertexai.init(project=project_id, location=location) |
|
|
|
if model_name in ["palm-2"]: |
|
|
|
|
|
model_name = "chat-bison@001" |
|
chat_model = ChatModel.from_pretrained(model_name) |
|
chat = chat_model.start_chat(examples=[]) |
|
elif model_name in ["gemini-pro"]: |
|
model = GenerativeModel(model_name) |
|
chat = model.start_chat() |
|
return chat |
|
|
|
|
|
def palm_api_stream_iter(model_name, chat, message, temperature, top_p, max_new_tokens): |
|
if model_name in ["gemini-pro"]: |
|
max_new_tokens = max_new_tokens * 2 |
|
parameters = { |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_output_tokens": max_new_tokens, |
|
} |
|
gen_params = { |
|
"model": model_name, |
|
"prompt": message, |
|
} |
|
gen_params.update(parameters) |
|
if model_name == "palm-2": |
|
response = chat.send_message(message, **parameters) |
|
else: |
|
response = chat.send_message(message, generation_config=parameters, stream=True) |
|
|
|
logger.info(f"==== request ====\n{gen_params}") |
|
|
|
try: |
|
text = "" |
|
for chunk in response: |
|
text += chunk.text |
|
data = { |
|
"text": text, |
|
"error_code": 0, |
|
} |
|
yield data |
|
except Exception as e: |
|
logger.error(f"==== error ====\n{e}") |
|
yield { |
|
"text": f"**API REQUEST ERROR** Reason: {e}\nPlease try again or increase the number of max tokens.", |
|
"error_code": 1, |
|
} |
|
yield data |
|
|
|
|
|
def ai2_api_stream_iter( |
|
model_name, |
|
messages, |
|
temperature, |
|
top_p, |
|
max_new_tokens, |
|
api_key=None, |
|
api_base=None, |
|
): |
|
from requests import post |
|
|
|
|
|
ai2_key = api_key or os.environ.get("AI2_API_KEY") |
|
api_base = api_base or "https://inferd.allen.ai/api/v1/infer" |
|
model_id = "mod_01hhgcga70c91402r9ssyxekan" |
|
|
|
|
|
gen_params = { |
|
"model": model_name, |
|
"prompt": messages, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
} |
|
logger.info(f"==== request ====\n{gen_params}") |
|
|
|
|
|
|
|
if temperature == 0.0 and top_p < 1.0: |
|
raise ValueError("top_p must be 1 when temperature is 0.0") |
|
|
|
res = post( |
|
api_base, |
|
stream=True, |
|
headers={"Authorization": f"Bearer {ai2_key}"}, |
|
json={ |
|
"model_id": model_id, |
|
|
|
|
|
|
|
"input": { |
|
"messages": messages, |
|
"opts": { |
|
"max_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"logprobs": 1, |
|
}, |
|
}, |
|
}, |
|
) |
|
|
|
if res.status_code != 200: |
|
logger.error(f"unexpected response ({res.status_code}): {res.text}") |
|
raise ValueError("unexpected response from InferD", res) |
|
|
|
text = "" |
|
for line in res.iter_lines(): |
|
if line: |
|
part = loads(line) |
|
if "result" in part and "output" in part["result"]: |
|
for t in part["result"]["output"]["text"]: |
|
text += t |
|
else: |
|
logger.error(f"unexpected part: {part}") |
|
raise ValueError("empty result in InferD response") |
|
|
|
data = { |
|
"text": text, |
|
"error_code": 0, |
|
} |
|
yield data |
|
|