File size: 3,723 Bytes
9db8ced cb29148 9db8ced a1afcb6 9db8ced cb29148 9db8ced cb29148 9db8ced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import { buildPrompt } from "$lib/buildPrompt";
import { OPENAI_API_KEY } from "$env/static/private";
import type { Endpoint } from "../endpoints";
import { format } from "date-fns";
export const endpointOAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("openai"),
baseURL: z.string().url().default("https://api.openai.com/v1"),
apiKey: z.string().default(OPENAI_API_KEY ?? "sk-"),
completion: z
.union([z.literal("completions"), z.literal("chat_completions")])
.default("chat_completions"),
});
export async function endpointOai(
input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input);
let OpenAI;
try {
OpenAI = (await import("openai")).OpenAI;
} catch (e) {
throw new Error("Failed to import OpenAI", { cause: e });
}
const openai = new OpenAI({
apiKey: apiKey ?? "sk-",
baseURL: baseURL,
});
if (completion === "completions") {
return async ({ conversation }) => {
return openAICompletionToTextGenerationStream(
await openai.completions.create({
model: model.id ?? model.name,
prompt: await buildPrompt({
messages: conversation.messages,
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
preprompt: conversation.preprompt,
model,
}),
stream: true,
max_tokens: model.parameters?.max_new_tokens,
stop: model.parameters?.stop,
temperature: model.parameters?.temperature,
top_p: model.parameters?.top_p,
frequency_penalty: model.parameters?.repetition_penalty,
})
);
};
} else if (completion === "chat_completions") {
return async ({ conversation }) => {
let messages = conversation.messages;
const webSearch = conversation.messages[conversation.messages.length - 1].webSearch;
if (webSearch && webSearch.context) {
const lastMsg = messages.slice(-1)[0];
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
const previousQuestions =
previousUserMessages.length > 0
? `Previous questions: \n${previousUserMessages
.map(({ content }) => `- ${content}`)
.join("\n")}`
: "";
const currentDate = format(new Date(), "MMMM d, yyyy");
messages = [
...messagesWithoutLastUsrMsg,
{
from: "user",
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
=====================
${webSearch.context}
=====================
${previousQuestions}
Answer the question: ${lastMsg.content}
`,
},
];
}
const messagesOpenAI = messages.map((message) => ({
role: message.from,
content: message.content,
}));
return openAIChatToTextGenerationStream(
await openai.chat.completions.create({
model: model.id ?? model.name,
messages: conversation.preprompt
? [{ role: "system", content: conversation.preprompt }, ...messagesOpenAI]
: messagesOpenAI,
stream: true,
max_tokens: model.parameters?.max_new_tokens,
stop: model.parameters?.stop,
temperature: model.parameters?.temperature,
top_p: model.parameters?.top_p,
frequency_penalty: model.parameters?.repetition_penalty,
})
);
};
} else {
throw new Error("Invalid completion type");
}
}
|