|
"use server" |
|
|
|
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference" |
|
import { LLMEngine } from "@/types" |
|
|
|
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> { |
|
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN) |
|
|
|
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine |
|
const inferenceEndpoint = `${process.env.LLM_HF_INFERENCE_ENDPOINT_URL || ""}` |
|
const inferenceModel = `${process.env.LLM_HF_INFERENCE_API_MODEL || ""}` |
|
|
|
let hfie: HfInferenceEndpoint = hf |
|
|
|
switch (llmEngine) { |
|
case "INFERENCE_ENDPOINT": |
|
if (inferenceEndpoint) { |
|
|
|
hfie = hf.endpoint(inferenceEndpoint) |
|
} else { |
|
const error = "No Inference Endpoint URL defined" |
|
console.error(error) |
|
throw new Error(error) |
|
} |
|
break; |
|
|
|
case "INFERENCE_API": |
|
if (inferenceModel) { |
|
|
|
} else { |
|
const error = "No Inference API model defined" |
|
console.error(error) |
|
throw new Error(error) |
|
} |
|
break; |
|
|
|
default: |
|
const error = "Please check your Hugging Face Inference API or Inference Endpoint settings" |
|
console.error(error) |
|
throw new Error(error) |
|
} |
|
|
|
const api = llmEngine === "INFERENCE_ENDPOINT" ? hfie : hf |
|
|
|
let instructions = "" |
|
try { |
|
for await (const output of api.textGenerationStream({ |
|
model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined), |
|
inputs, |
|
parameters: { |
|
do_sample: true, |
|
max_new_tokens: nbMaxNewTokens, |
|
return_full_text: false, |
|
} |
|
})) { |
|
instructions += output.token.text |
|
|
|
if ( |
|
instructions.includes("</s>") || |
|
instructions.includes("<s>") || |
|
instructions.includes("/s>") || |
|
instructions.includes("[INST]") || |
|
instructions.includes("[/INST]") || |
|
instructions.includes("<SYS>") || |
|
instructions.includes("<<SYS>>") || |
|
instructions.includes("</SYS>") || |
|
instructions.includes("<</SYS>>") || |
|
instructions.includes("<|user|>") || |
|
instructions.includes("<|end|>") || |
|
instructions.includes("<|system|>") || |
|
instructions.includes("<|assistant|>") |
|
) { |
|
break |
|
} |
|
} |
|
} catch (err) { |
|
|
|
|
|
|
|
if (`${err}` === "Error: Model is overloaded") { |
|
instructions = `` |
|
} |
|
} |
|
|
|
|
|
return ( |
|
instructions |
|
.replaceAll("<|end|>", "") |
|
.replaceAll("<s>", "") |
|
.replaceAll("</s>", "") |
|
.replaceAll("/s>", "") |
|
.replaceAll("[INST]", "") |
|
.replaceAll("[/INST]", "") |
|
.replaceAll("<SYS>", "") |
|
.replaceAll("<<SYS>>", "") |
|
.replaceAll("</SYS>", "") |
|
.replaceAll("<</SYS>>", "") |
|
.replaceAll("<|system|>", "") |
|
.replaceAll("<|user|>", "") |
|
.replaceAll("<|all|>", "") |
|
.replaceAll("<|assistant|>", "") |
|
.replaceAll('""', '"') |
|
) |
|
} |
|
|