me@idoubi.cc
support openai api
287a603
raw
history blame
3.86 kB
"use server"
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
import type { ChatCompletionMessage } from "openai/resources/chat"
import { LLMEngine } from "@/types"
import OpenAI from "openai"
const hf = new HfInference(process.env.HF_API_TOKEN)
// note: we always try "inference endpoint" first
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
const inferenceEndpoint = `${process.env.HF_INFERENCE_ENDPOINT_URL || ""}`
const inferenceModel = `${process.env.HF_INFERENCE_API_MODEL || ""}`
const openaiApiKey = `${process.env.OPENAI_API_KEY || ""}`
let hfie: HfInferenceEndpoint
switch (llmEngine) {
case "INFERENCE_ENDPOINT":
if (inferenceEndpoint) {
console.log("Using a custom HF Inference Endpoint")
hfie = hf.endpoint(inferenceEndpoint)
} else {
const error = "No Inference Endpoint URL defined"
console.error(error)
throw new Error(error)
}
break;
case "INFERENCE_API":
if (inferenceModel) {
console.log("Using an HF Inference API Model")
} else {
const error = "No Inference API model defined"
console.error(error)
throw new Error(error)
}
break;
case "OPENAI":
if (openaiApiKey) {
console.log("Using an OpenAI API Key")
} else {
const error = "No OpenAI API key defined"
console.error(error)
throw new Error(error)
}
break;
default:
const error = "No Inference Endpoint URL or Inference API Model defined"
console.error(error)
throw new Error(error)
}
export async function predict(inputs: string) {
console.log(`predict: `, inputs)
if (llmEngine==="OPENAI") {
return predictWithOpenAI(inputs)
}
const api = llmEngine ==="INFERENCE_ENDPOINT" ? hfie : hf
let instructions = ""
try {
for await (const output of api.textGenerationStream({
model: llmEngine ==="INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
inputs,
parameters: {
do_sample: true,
// we don't require a lot of token for our task
// but to be safe, let's count ~110 tokens per panel
max_new_tokens: 450, // 1150,
return_full_text: false,
}
})) {
instructions += output.token.text
process.stdout.write(output.token.text)
if (
instructions.includes("</s>") ||
instructions.includes("<s>") ||
instructions.includes("[INST]") ||
instructions.includes("[/INST]") ||
instructions.includes("<SYS>") ||
instructions.includes("</SYS>") ||
instructions.includes("<|end|>") ||
instructions.includes("<|assistant|>")
) {
break
}
}
} catch (err) {
console.error(`error during generation: ${err}`)
}
// need to do some cleanup of the garbage the LLM might have gave us
return (
instructions
.replaceAll("<|end|>", "")
.replaceAll("<s>", "")
.replaceAll("</s>", "")
.replaceAll("[INST]", "")
.replaceAll("[/INST]", "")
.replaceAll("<SYS>", "")
.replaceAll("</SYS>", "")
.replaceAll("<|assistant|>", "")
.replaceAll('""', '"')
)
}
async function predictWithOpenAI(inputs: string) {
const openaiApiBaseUrl = `${process.env.OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
const openaiApiModel = `${process.env.OPENAI_API_MODEL || "gpt-3.5-turbo"}`
const openai = new OpenAI({
apiKey: openaiApiKey,
baseURL: openaiApiBaseUrl,
})
const messages: ChatCompletionMessage[] = [
{ role: "system", content: inputs },
]
try {
const res = await openai.chat.completions.create({
messages: messages,
stream: false,
model: openaiApiModel,
temperature: 0.8
})
return res.choices[0].message.content
} catch (err) {
console.error(`error during generation: ${err}`)
}
}