File size: 3,864 Bytes
624088c
 
3ca0269
624088c
287a603
 
 
 
3ca0269
 
 
 
 
 
 
287a603
3ca0269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287a603
 
 
 
 
 
 
 
 
 
3ca0269
 
 
 
 
 
624088c
 
 
 
 
287a603
 
 
 
3ca0269
 
624088c
 
3ca0269
 
624088c
 
 
3ca0269
a3bf1f1
 
624088c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287a603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624088c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"use server"

import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"

import type { ChatCompletionMessage } from "openai/resources/chat"
import { LLMEngine } from "@/types"
import OpenAI from "openai"

const hf = new HfInference(process.env.HF_API_TOKEN)


// note: we always try "inference endpoint" first
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
const inferenceEndpoint = `${process.env.HF_INFERENCE_ENDPOINT_URL || ""}`
const inferenceModel = `${process.env.HF_INFERENCE_API_MODEL || ""}`
const openaiApiKey = `${process.env.OPENAI_API_KEY || ""}`

let hfie: HfInferenceEndpoint

switch (llmEngine) {
  case "INFERENCE_ENDPOINT":
    if (inferenceEndpoint) {
      console.log("Using a custom HF Inference Endpoint")
      hfie = hf.endpoint(inferenceEndpoint)
    } else {
      const error = "No Inference Endpoint URL defined"
      console.error(error)
      throw new Error(error)
    }
    break;
  
  case "INFERENCE_API":
    if (inferenceModel) {
      console.log("Using an HF Inference API Model")
    } else {
      const error = "No Inference API model defined"
      console.error(error)
      throw new Error(error)
    }
    break;

  case "OPENAI":
    if (openaiApiKey) {
      console.log("Using an OpenAI API Key")
    } else {
      const error = "No OpenAI API key defined"
      console.error(error)
      throw new Error(error)
    }
    break;
  
  default:
    const error = "No Inference Endpoint URL or Inference API Model defined"
    console.error(error)
    throw new Error(error)
}

export async function predict(inputs: string) {

  console.log(`predict: `, inputs)
  
  if (llmEngine==="OPENAI") {
    return predictWithOpenAI(inputs)
  } 
  
  const api = llmEngine ==="INFERENCE_ENDPOINT" ? hfie : hf

  let instructions = ""
  try {
    for await (const output of api.textGenerationStream({
      model: llmEngine ==="INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
      inputs,
      parameters: {
        do_sample: true,
        // we don't require a lot of token for our task
        // but to be safe, let's count ~110 tokens per panel
        max_new_tokens: 450, // 1150,
        return_full_text: false,
      }
    })) {
      instructions += output.token.text
      process.stdout.write(output.token.text)
      if (
        instructions.includes("</s>") || 
        instructions.includes("<s>") ||
        instructions.includes("[INST]") ||
        instructions.includes("[/INST]") ||
        instructions.includes("<SYS>") ||
        instructions.includes("</SYS>") ||
        instructions.includes("<|end|>") ||
        instructions.includes("<|assistant|>")
      ) {
        break
      }
    }
  } catch (err) {
    console.error(`error during generation: ${err}`)
  }

  // need to do some cleanup of the garbage the LLM might have gave us
  return (
    instructions
    .replaceAll("<|end|>", "")
    .replaceAll("<s>", "")
    .replaceAll("</s>", "")
    .replaceAll("[INST]", "")
    .replaceAll("[/INST]", "") 
    .replaceAll("<SYS>", "")
    .replaceAll("</SYS>", "")
    .replaceAll("<|assistant|>", "")
    .replaceAll('""', '"')
  )
}

async function predictWithOpenAI(inputs: string) {
  const openaiApiBaseUrl = `${process.env.OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
  const openaiApiModel = `${process.env.OPENAI_API_MODEL || "gpt-3.5-turbo"}`
  
  const openai = new OpenAI({
    apiKey: openaiApiKey,
    baseURL: openaiApiBaseUrl,
  })

  const messages: ChatCompletionMessage[] = [
    { role: "system", content: inputs },
  ]

  try {
    const res = await openai.chat.completions.create({
      messages: messages,
      stream: false,
      model: openaiApiModel,
      temperature: 0.8
    })

    return res.choices[0].message.content
  } catch (err) {
    console.error(`error during generation: ${err}`)
  }
}