|
import { HF_ACCESS_TOKEN, MESSAGES_BEFORE_LOGIN, RATE_LIMIT } from "$env/static/private"; |
|
import { buildPrompt } from "$lib/buildPrompt"; |
|
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken"; |
|
import { authCondition, requiresUser } from "$lib/server/auth"; |
|
import { collections } from "$lib/server/database"; |
|
import { modelEndpoint } from "$lib/server/modelEndpoint"; |
|
import { models } from "$lib/server/models"; |
|
import { ERROR_MESSAGES } from "$lib/stores/errors"; |
|
import type { Message } from "$lib/types/Message"; |
|
import { trimPrefix } from "$lib/utils/trimPrefix"; |
|
import { trimSuffix } from "$lib/utils/trimSuffix"; |
|
import { textGenerationStream } from "@huggingface/inference"; |
|
import { error } from "@sveltejs/kit"; |
|
import { ObjectId } from "mongodb"; |
|
import { z } from "zod"; |
|
import { AwsClient } from "aws4fetch"; |
|
import type { MessageUpdate } from "$lib/types/MessageUpdate"; |
|
import { runWebSearch } from "$lib/server/websearch/runWebSearch"; |
|
import type { WebSearch } from "$lib/types/WebSearch"; |
|
import { abortedGenerations } from "$lib/server/abortedGenerations"; |
|
import { summarize } from "$lib/server/summarize"; |
|
|
|
export async function POST({ request, fetch, locals, params, getClientAddress }) { |
|
const id = z.string().parse(params.id); |
|
const convId = new ObjectId(id); |
|
const promptedAt = new Date(); |
|
|
|
const userId = locals.user?._id ?? locals.sessionId; |
|
|
|
|
|
if (!userId) { |
|
throw error(401, "Unauthorized"); |
|
} |
|
|
|
|
|
const conv = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
|
|
await collections.messageEvents.insertOne({ |
|
userId: userId, |
|
createdAt: new Date(), |
|
ip: getClientAddress(), |
|
}); |
|
|
|
|
|
if ( |
|
!locals.user?._id && |
|
requiresUser && |
|
conv.messages.length > (MESSAGES_BEFORE_LOGIN ? parseInt(MESSAGES_BEFORE_LOGIN) : 0) |
|
) { |
|
throw error(429, "Exceeded number of messages before login"); |
|
} |
|
|
|
|
|
const nEvents = Math.max( |
|
await collections.messageEvents.countDocuments({ userId }), |
|
await collections.messageEvents.countDocuments({ ip: getClientAddress() }) |
|
); |
|
|
|
if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) { |
|
throw error(429, ERROR_MESSAGES.rateLimited); |
|
} |
|
|
|
|
|
const model = models.find((m) => m.id === conv.model); |
|
const settings = await collections.settings.findOne(authCondition(locals)); |
|
|
|
if (!model) { |
|
throw error(410, "Model not available anymore"); |
|
} |
|
|
|
|
|
const json = await request.json(); |
|
|
|
const { |
|
inputs: newPrompt, |
|
response_id: responseId, |
|
id: messageId, |
|
is_retry, |
|
web_search: webSearch, |
|
} = z |
|
.object({ |
|
inputs: z.string().trim().min(1), |
|
id: z.optional(z.string().uuid()), |
|
response_id: z.optional(z.string().uuid()), |
|
is_retry: z.optional(z.boolean()), |
|
web_search: z.optional(z.boolean()), |
|
}) |
|
.parse(json); |
|
|
|
|
|
|
|
let messages = (() => { |
|
if (is_retry && messageId) { |
|
|
|
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId); |
|
if (retryMessageIdx === -1) { |
|
retryMessageIdx = conv.messages.length; |
|
} |
|
return [ |
|
...conv.messages.slice(0, retryMessageIdx), |
|
{ content: newPrompt, from: "user", id: messageId as Message["id"], updatedAt: new Date() }, |
|
]; |
|
} |
|
|
|
return [ |
|
...conv.messages, |
|
{ |
|
content: newPrompt, |
|
from: "user", |
|
id: (messageId as Message["id"]) || crypto.randomUUID(), |
|
createdAt: new Date(), |
|
updatedAt: new Date(), |
|
}, |
|
]; |
|
})() satisfies Message[]; |
|
|
|
if (conv.title.startsWith("Untitled")) { |
|
try { |
|
conv.title = (await summarize(newPrompt)) ?? conv.title; |
|
} catch (e) { |
|
console.error(e); |
|
} |
|
} |
|
|
|
|
|
const stream = new ReadableStream({ |
|
async start(controller) { |
|
const updates: MessageUpdate[] = []; |
|
|
|
function update(newUpdate: MessageUpdate) { |
|
if (newUpdate.type !== "stream") { |
|
updates.push(newUpdate); |
|
} |
|
controller.enqueue(JSON.stringify(newUpdate) + "\n"); |
|
} |
|
|
|
update({ type: "status", status: "started" }); |
|
|
|
let webSearchResults: WebSearch | undefined; |
|
|
|
if (webSearch) { |
|
webSearchResults = await runWebSearch(conv, newPrompt, update); |
|
} |
|
|
|
|
|
const prompt = await buildPrompt({ |
|
messages, |
|
model, |
|
webSearch: webSearchResults, |
|
preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, |
|
locals: locals, |
|
}); |
|
|
|
|
|
const randomEndpoint = modelEndpoint(model); |
|
|
|
let usedFetch = fetch; |
|
|
|
if (randomEndpoint.host === "sagemaker") { |
|
const aws = new AwsClient({ |
|
accessKeyId: randomEndpoint.accessKey, |
|
secretAccessKey: randomEndpoint.secretKey, |
|
sessionToken: randomEndpoint.sessionToken, |
|
service: "sagemaker", |
|
}); |
|
|
|
usedFetch = aws.fetch.bind(aws) as typeof fetch; |
|
} |
|
|
|
async function saveLast(generated_text: string) { |
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
const lastMessage = messages[messages.length - 1]; |
|
|
|
if (lastMessage) { |
|
|
|
if (generated_text.startsWith(prompt)) { |
|
generated_text = generated_text.slice(prompt.length); |
|
} |
|
|
|
generated_text = trimSuffix( |
|
trimPrefix(generated_text, "<|startoftext|>"), |
|
PUBLIC_SEP_TOKEN |
|
).trimEnd(); |
|
|
|
|
|
for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) { |
|
if (generated_text.endsWith(stop)) { |
|
generated_text = generated_text.slice(0, -stop.length).trimEnd(); |
|
} |
|
} |
|
lastMessage.content = generated_text; |
|
|
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
messages, |
|
title: conv.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
|
|
update({ |
|
type: "finalAnswer", |
|
text: generated_text, |
|
}); |
|
} |
|
} |
|
|
|
const tokenStream = textGenerationStream( |
|
{ |
|
parameters: { |
|
...models.find((m) => m.id === conv.model)?.parameters, |
|
return_full_text: false, |
|
}, |
|
model: randomEndpoint.url, |
|
inputs: prompt, |
|
accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN, |
|
}, |
|
{ |
|
use_cache: false, |
|
fetch: usedFetch, |
|
} |
|
); |
|
|
|
for await (const output of tokenStream) { |
|
|
|
if (!output.generated_text) { |
|
|
|
if (!output.token.special) { |
|
const lastMessage = messages[messages.length - 1]; |
|
update({ |
|
type: "stream", |
|
token: output.token.text, |
|
}); |
|
|
|
|
|
if (lastMessage?.from !== "assistant") { |
|
|
|
messages = [ |
|
...messages, |
|
|
|
|
|
{ |
|
from: "assistant", |
|
content: output.token.text.trimStart(), |
|
webSearch: webSearchResults, |
|
updates: updates, |
|
id: (responseId as Message["id"]) || crypto.randomUUID(), |
|
createdAt: new Date(), |
|
updatedAt: new Date(), |
|
}, |
|
]; |
|
} else { |
|
const date = abortedGenerations.get(convId.toString()); |
|
if (date && date > promptedAt) { |
|
saveLast(lastMessage.content); |
|
} |
|
if (!output) { |
|
break; |
|
} |
|
|
|
|
|
lastMessage.content += output.token.text; |
|
} |
|
} |
|
} else { |
|
saveLast(output.generated_text); |
|
} |
|
} |
|
}, |
|
async cancel() { |
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
messages, |
|
title: conv.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
}, |
|
}); |
|
|
|
|
|
return new Response(stream); |
|
} |
|
|
|
export async function DELETE({ locals, params }) { |
|
const convId = new ObjectId(params.id); |
|
|
|
const conv = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
await collections.conversations.deleteOne({ _id: conv._id }); |
|
|
|
return new Response(); |
|
} |
|
|
|
export async function PATCH({ request, locals, params }) { |
|
const { title } = z |
|
.object({ title: z.string().trim().min(1).max(100) }) |
|
.parse(await request.json()); |
|
|
|
const convId = new ObjectId(params.id); |
|
|
|
const conv = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
title, |
|
}, |
|
} |
|
); |
|
|
|
return new Response(); |
|
} |
|
|