File size: 3,864 Bytes
624088c 3ca0269 624088c 287a603 3ca0269 287a603 3ca0269 287a603 3ca0269 624088c 287a603 3ca0269 624088c 3ca0269 624088c 3ca0269 a3bf1f1 624088c 287a603 624088c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"use server"
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
import type { ChatCompletionMessage } from "openai/resources/chat"
import { LLMEngine } from "@/types"
import OpenAI from "openai"
const hf = new HfInference(process.env.HF_API_TOKEN)
// note: we always try "inference endpoint" first
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
const inferenceEndpoint = `${process.env.HF_INFERENCE_ENDPOINT_URL || ""}`
const inferenceModel = `${process.env.HF_INFERENCE_API_MODEL || ""}`
const openaiApiKey = `${process.env.OPENAI_API_KEY || ""}`
let hfie: HfInferenceEndpoint
switch (llmEngine) {
case "INFERENCE_ENDPOINT":
if (inferenceEndpoint) {
console.log("Using a custom HF Inference Endpoint")
hfie = hf.endpoint(inferenceEndpoint)
} else {
const error = "No Inference Endpoint URL defined"
console.error(error)
throw new Error(error)
}
break;
case "INFERENCE_API":
if (inferenceModel) {
console.log("Using an HF Inference API Model")
} else {
const error = "No Inference API model defined"
console.error(error)
throw new Error(error)
}
break;
case "OPENAI":
if (openaiApiKey) {
console.log("Using an OpenAI API Key")
} else {
const error = "No OpenAI API key defined"
console.error(error)
throw new Error(error)
}
break;
default:
const error = "No Inference Endpoint URL or Inference API Model defined"
console.error(error)
throw new Error(error)
}
export async function predict(inputs: string) {
console.log(`predict: `, inputs)
if (llmEngine==="OPENAI") {
return predictWithOpenAI(inputs)
}
const api = llmEngine ==="INFERENCE_ENDPOINT" ? hfie : hf
let instructions = ""
try {
for await (const output of api.textGenerationStream({
model: llmEngine ==="INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
inputs,
parameters: {
do_sample: true,
// we don't require a lot of token for our task
// but to be safe, let's count ~110 tokens per panel
max_new_tokens: 450, // 1150,
return_full_text: false,
}
})) {
instructions += output.token.text
process.stdout.write(output.token.text)
if (
instructions.includes("</s>") ||
instructions.includes("<s>") ||
instructions.includes("[INST]") ||
instructions.includes("[/INST]") ||
instructions.includes("<SYS>") ||
instructions.includes("</SYS>") ||
instructions.includes("<|end|>") ||
instructions.includes("<|assistant|>")
) {
break
}
}
} catch (err) {
console.error(`error during generation: ${err}`)
}
// need to do some cleanup of the garbage the LLM might have gave us
return (
instructions
.replaceAll("<|end|>", "")
.replaceAll("<s>", "")
.replaceAll("</s>", "")
.replaceAll("[INST]", "")
.replaceAll("[/INST]", "")
.replaceAll("<SYS>", "")
.replaceAll("</SYS>", "")
.replaceAll("<|assistant|>", "")
.replaceAll('""', '"')
)
}
async function predictWithOpenAI(inputs: string) {
const openaiApiBaseUrl = `${process.env.OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
const openaiApiModel = `${process.env.OPENAI_API_MODEL || "gpt-3.5-turbo"}`
const openai = new OpenAI({
apiKey: openaiApiKey,
baseURL: openaiApiBaseUrl,
})
const messages: ChatCompletionMessage[] = [
{ role: "system", content: inputs },
]
try {
const res = await openai.chat.completions.create({
messages: messages,
stream: false,
model: openaiApiModel,
temperature: 0.8
})
return res.choices[0].message.content
} catch (err) {
console.error(`error during generation: ${err}`)
}
} |