|
import { MESSAGES_BEFORE_LOGIN, RATE_LIMIT } from "$env/static/private"; |
|
import { authCondition, requiresUser } from "$lib/server/auth"; |
|
import { collections } from "$lib/server/database"; |
|
import { models } from "$lib/server/models"; |
|
import { ERROR_MESSAGES } from "$lib/stores/errors"; |
|
import type { Message } from "$lib/types/Message"; |
|
import { error } from "@sveltejs/kit"; |
|
import { ObjectId } from "mongodb"; |
|
import { z } from "zod"; |
|
import type { MessageUpdate } from "$lib/types/MessageUpdate"; |
|
import { runWebSearch } from "$lib/server/websearch/runWebSearch"; |
|
import { abortedGenerations } from "$lib/server/abortedGenerations"; |
|
import { summarize } from "$lib/server/summarize"; |
|
import { uploadFile } from "$lib/server/files/uploadFile"; |
|
import sizeof from "image-size"; |
|
import { convertLegacyConversation } from "$lib/utils/tree/convertLegacyConversation"; |
|
import { isMessageId } from "$lib/utils/tree/isMessageId"; |
|
import { buildSubtree } from "$lib/utils/tree/buildSubtree.js"; |
|
import { addChildren } from "$lib/utils/tree/addChildren.js"; |
|
import { addSibling } from "$lib/utils/tree/addSibling.js"; |
|
import { preprocessMessages } from "$lib/server/preprocessMessages.js"; |
|
|
|
export async function POST({ request, locals, params, getClientAddress }) { |
|
const id = z.string().parse(params.id); |
|
const convId = new ObjectId(id); |
|
const promptedAt = new Date(); |
|
|
|
const userId = locals.user?._id ?? locals.sessionId; |
|
|
|
|
|
if (!userId) { |
|
throw error(401, "Unauthorized"); |
|
} |
|
|
|
|
|
const convBeforeCheck = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (convBeforeCheck && !convBeforeCheck.rootMessageId) { |
|
const res = await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
...convBeforeCheck, |
|
...convertLegacyConversation(convBeforeCheck), |
|
}, |
|
} |
|
); |
|
|
|
if (!res.acknowledged) { |
|
throw error(500, "Failed to convert conversation"); |
|
} |
|
} |
|
|
|
const conv = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
|
|
await collections.messageEvents.insertOne({ |
|
userId, |
|
createdAt: new Date(), |
|
ip: getClientAddress(), |
|
}); |
|
|
|
|
|
if ( |
|
!locals.user?._id && |
|
requiresUser && |
|
(MESSAGES_BEFORE_LOGIN ? parseInt(MESSAGES_BEFORE_LOGIN) : 0) > 0 |
|
) { |
|
const totalMessages = |
|
( |
|
await collections.conversations |
|
.aggregate([ |
|
{ $match: authCondition(locals) }, |
|
{ $project: { messages: 1 } }, |
|
{ $unwind: "$messages" }, |
|
{ $match: { "messages.from": "assistant" } }, |
|
{ $count: "messages" }, |
|
]) |
|
.toArray() |
|
)[0]?.messages ?? 0; |
|
|
|
if (totalMessages > parseInt(MESSAGES_BEFORE_LOGIN)) { |
|
throw error(429, "Exceeded number of messages before login"); |
|
} |
|
} |
|
|
|
|
|
const nEvents = Math.max( |
|
await collections.messageEvents.countDocuments({ userId }), |
|
await collections.messageEvents.countDocuments({ ip: getClientAddress() }) |
|
); |
|
|
|
if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) { |
|
throw error(429, ERROR_MESSAGES.rateLimited); |
|
} |
|
|
|
|
|
const model = models.find((m) => m.id === conv.model); |
|
|
|
if (!model) { |
|
throw error(410, "Model not available anymore"); |
|
} |
|
|
|
|
|
const json = await request.json(); |
|
|
|
const { |
|
inputs: newPrompt, |
|
id: messageId, |
|
is_retry: isRetry, |
|
is_continue: isContinue, |
|
web_search: webSearch, |
|
files: b64files, |
|
} = z |
|
.object({ |
|
id: z.string().uuid().refine(isMessageId).optional(), |
|
inputs: z.optional(z.string().trim().min(1)), |
|
is_retry: z.optional(z.boolean()), |
|
is_continue: z.optional(z.boolean()), |
|
web_search: z.optional(z.boolean()), |
|
files: z.optional(z.array(z.string())), |
|
}) |
|
.parse(json); |
|
|
|
|
|
|
|
|
|
const files = b64files?.map((file) => { |
|
const blob = Buffer.from(file, "base64"); |
|
return new File([blob], "image.png"); |
|
}); |
|
|
|
|
|
if (files) { |
|
const filechecks = await Promise.all( |
|
files.map(async (file) => { |
|
const dimensions = sizeof(Buffer.from(await file.arrayBuffer())); |
|
return ( |
|
file.size > 2 * 1024 * 1024 || |
|
(dimensions.width ?? 0) > 224 || |
|
(dimensions.height ?? 0) > 224 |
|
); |
|
}) |
|
); |
|
|
|
if (filechecks.some((check) => check)) { |
|
throw error(413, "File too large, should be <2MB and 224x224 max."); |
|
} |
|
} |
|
|
|
let hashes: undefined | string[]; |
|
|
|
if (files) { |
|
hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv))); |
|
} |
|
|
|
|
|
let messageToWriteToId: Message["id"] | undefined = undefined; |
|
|
|
let messagesForPrompt: Message[] = []; |
|
|
|
if (isContinue && messageId) { |
|
|
|
|
|
if ((conv.messages.find((msg) => msg.id === messageId)?.children?.length ?? 0) > 0) { |
|
throw error(400, "Can only continue the last message"); |
|
} |
|
messageToWriteToId = messageId; |
|
messagesForPrompt = buildSubtree(conv, messageId); |
|
} else if (isRetry && messageId) { |
|
|
|
|
|
|
|
|
|
|
|
const messageToRetry = conv.messages.find((message) => message.id === messageId); |
|
|
|
if (!messageToRetry) { |
|
throw error(404, "Message not found"); |
|
} |
|
|
|
if (messageToRetry.from === "user" && newPrompt) { |
|
|
|
|
|
const newUserMessageId = addSibling(conv, { from: "user", content: newPrompt }, messageId); |
|
messageToWriteToId = addChildren( |
|
conv, |
|
{ from: "assistant", content: "", files: hashes }, |
|
newUserMessageId |
|
); |
|
messagesForPrompt = buildSubtree(conv, newUserMessageId); |
|
} else if (messageToRetry.from === "assistant") { |
|
|
|
|
|
messageToWriteToId = addSibling(conv, { from: "assistant", content: "" }, messageId); |
|
messagesForPrompt = buildSubtree(conv, messageId); |
|
messagesForPrompt.pop(); |
|
} |
|
} else { |
|
|
|
|
|
const newUserMessageId = addChildren( |
|
conv, |
|
{ |
|
from: "user", |
|
content: newPrompt ?? "", |
|
files: hashes, |
|
createdAt: new Date(), |
|
updatedAt: new Date(), |
|
}, |
|
messageId |
|
); |
|
|
|
messageToWriteToId = addChildren( |
|
conv, |
|
{ |
|
from: "assistant", |
|
content: "", |
|
createdAt: new Date(), |
|
updatedAt: new Date(), |
|
}, |
|
newUserMessageId |
|
); |
|
|
|
messagesForPrompt = buildSubtree(conv, newUserMessageId); |
|
} |
|
|
|
const messageToWriteTo = conv.messages.find((message) => message.id === messageToWriteToId); |
|
if (!messageToWriteTo) { |
|
throw error(500, "Failed to create message"); |
|
} |
|
if (messagesForPrompt.length === 0) { |
|
throw error(500, "Failed to create prompt"); |
|
} |
|
|
|
|
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
messages: conv.messages, |
|
title: conv.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
|
|
let doneStreaming = false; |
|
|
|
|
|
const stream = new ReadableStream({ |
|
async start(controller) { |
|
messageToWriteTo.updates ??= []; |
|
function update(newUpdate: MessageUpdate) { |
|
if (newUpdate.type !== "stream") { |
|
messageToWriteTo?.updates?.push(newUpdate); |
|
} |
|
|
|
if (newUpdate.type === "stream" && newUpdate.token === "") { |
|
return; |
|
} |
|
controller.enqueue(JSON.stringify(newUpdate) + "\n"); |
|
|
|
if (newUpdate.type === "finalAnswer") { |
|
|
|
controller.enqueue(" ".repeat(4096)); |
|
} |
|
} |
|
|
|
update({ type: "status", status: "started" }); |
|
|
|
const summarizeIfNeeded = (async () => { |
|
if (conv.title === "New Chat" && conv.messages.length === 3) { |
|
try { |
|
conv.title = (await summarize(conv.messages[1].content)) ?? conv.title; |
|
update({ type: "status", status: "title", message: conv.title }); |
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
title: conv?.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
} catch (e) { |
|
console.error(e); |
|
} |
|
} |
|
})(); |
|
|
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
title: conv.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
|
|
|
|
if (webSearch && !isContinue && !conv.assistantId) { |
|
messageToWriteTo.webSearch = await runWebSearch(conv, messagesForPrompt, update); |
|
} |
|
|
|
|
|
const processedMessages = await preprocessMessages( |
|
messagesForPrompt, |
|
model.multimodal, |
|
convId |
|
); |
|
|
|
const previousText = messageToWriteTo.content; |
|
|
|
try { |
|
const endpoint = await model.getEndpoint(); |
|
for await (const output of await endpoint({ |
|
messages: processedMessages, |
|
preprompt: conv.preprompt, |
|
continueMessage: isContinue, |
|
})) { |
|
|
|
if (!output.generated_text) { |
|
|
|
if (!output.token.special) { |
|
update({ |
|
type: "stream", |
|
token: output.token.text, |
|
}); |
|
|
|
const date = abortedGenerations.get(convId.toString()); |
|
if (date && date > promptedAt) { |
|
break; |
|
} |
|
|
|
if (!output) { |
|
break; |
|
} |
|
|
|
|
|
messageToWriteTo.content += output.token.text; |
|
} |
|
} else { |
|
messageToWriteTo.interrupted = !output.token.special; |
|
|
|
|
|
const text = (model.parameters.stop ?? []).reduce((acc: string, curr: string) => { |
|
if (acc.endsWith(curr)) { |
|
messageToWriteTo.interrupted = false; |
|
return acc.slice(0, acc.length - curr.length); |
|
} |
|
return acc; |
|
}, output.generated_text.trimEnd()); |
|
|
|
messageToWriteTo.content = previousText + text; |
|
messageToWriteTo.updatedAt = new Date(); |
|
} |
|
} |
|
} catch (e) { |
|
update({ type: "status", status: "error", message: (e as Error).message }); |
|
} |
|
|
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
messages: conv.messages, |
|
title: conv?.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
|
|
|
|
doneStreaming = true; |
|
|
|
update({ |
|
type: "finalAnswer", |
|
text: messageToWriteTo.content, |
|
}); |
|
|
|
await summarizeIfNeeded; |
|
controller.close(); |
|
return; |
|
}, |
|
async cancel() { |
|
if (!doneStreaming) { |
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
messages: conv.messages, |
|
title: conv.title, |
|
updatedAt: new Date(), |
|
}, |
|
} |
|
); |
|
} |
|
}, |
|
}); |
|
|
|
|
|
return new Response(stream, { |
|
headers: { |
|
"Content-Type": "text/event-stream", |
|
}, |
|
}); |
|
} |
|
|
|
export async function DELETE({ locals, params }) { |
|
const convId = new ObjectId(params.id); |
|
|
|
const conv = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
await collections.conversations.deleteOne({ _id: conv._id }); |
|
|
|
return new Response(); |
|
} |
|
|
|
export async function PATCH({ request, locals, params }) { |
|
const { title } = z |
|
.object({ title: z.string().trim().min(1).max(100) }) |
|
.parse(await request.json()); |
|
|
|
const convId = new ObjectId(params.id); |
|
|
|
const conv = await collections.conversations.findOne({ |
|
_id: convId, |
|
...authCondition(locals), |
|
}); |
|
|
|
if (!conv) { |
|
throw error(404, "Conversation not found"); |
|
} |
|
|
|
await collections.conversations.updateOne( |
|
{ |
|
_id: convId, |
|
}, |
|
{ |
|
$set: { |
|
title, |
|
}, |
|
} |
|
); |
|
|
|
return new Response(); |
|
} |
|
|