nsarrazin's picture
nsarrazin HF staff
Save user message on request (#337) (#492)
ac04347 unverified
raw
history blame
9.46 kB
import { HF_ACCESS_TOKEN, MESSAGES_BEFORE_LOGIN, RATE_LIMIT } from "$env/static/private";
import { buildPrompt } from "$lib/buildPrompt";
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
import { authCondition, requiresUser } from "$lib/server/auth";
import { collections } from "$lib/server/database";
import { modelEndpoint } from "$lib/server/modelEndpoint";
import { models } from "$lib/server/models";
import { ERROR_MESSAGES } from "$lib/stores/errors";
import type { Message } from "$lib/types/Message";
import { trimPrefix } from "$lib/utils/trimPrefix";
import { trimSuffix } from "$lib/utils/trimSuffix";
import { textGenerationStream } from "@huggingface/inference";
import { error } from "@sveltejs/kit";
import { ObjectId } from "mongodb";
import { z } from "zod";
import { AwsClient } from "aws4fetch";
import type { MessageUpdate } from "$lib/types/MessageUpdate";
import { runWebSearch } from "$lib/server/websearch/runWebSearch";
import type { WebSearch } from "$lib/types/WebSearch";
import { abortedGenerations } from "$lib/server/abortedGenerations";
import { summarize } from "$lib/server/summarize";
export async function POST({ request, fetch, locals, params, getClientAddress }) {
const id = z.string().parse(params.id);
const convId = new ObjectId(id);
const promptedAt = new Date();
const userId = locals.user?._id ?? locals.sessionId;
// check user
if (!userId) {
throw error(401, "Unauthorized");
}
// check if the user has access to the conversation
const conv = await collections.conversations.findOne({
_id: convId,
...authCondition(locals),
});
if (!conv) {
throw error(404, "Conversation not found");
}
// register the event for ratelimiting
await collections.messageEvents.insertOne({
userId: userId,
createdAt: new Date(),
ip: getClientAddress(),
});
// make sure an anonymous user can't post more than one message
if (
!locals.user?._id &&
requiresUser &&
conv.messages.length > (MESSAGES_BEFORE_LOGIN ? parseInt(MESSAGES_BEFORE_LOGIN) : 0)
) {
throw error(429, "Exceeded number of messages before login");
}
// check if the user is rate limited
const nEvents = Math.max(
await collections.messageEvents.countDocuments({ userId }),
await collections.messageEvents.countDocuments({ ip: getClientAddress() })
);
if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) {
throw error(429, ERROR_MESSAGES.rateLimited);
}
// fetch the model
const model = models.find((m) => m.id === conv.model);
if (!model) {
throw error(410, "Model not available anymore");
}
// finally parse the content of the request
const json = await request.json();
const {
inputs: newPrompt,
response_id: responseId,
id: messageId,
is_retry,
web_search: webSearch,
} = z
.object({
inputs: z.string().trim().min(1),
id: z.optional(z.string().uuid()),
response_id: z.optional(z.string().uuid()),
is_retry: z.optional(z.boolean()),
web_search: z.optional(z.boolean()),
})
.parse(json);
// get the list of messages
// while checking for retries
let messages = (() => {
if (is_retry && messageId) {
// if the message is a retry, replace the message and remove the messages after it
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
if (retryMessageIdx === -1) {
retryMessageIdx = conv.messages.length;
}
return [
...conv.messages.slice(0, retryMessageIdx),
{ content: newPrompt, from: "user", id: messageId as Message["id"], updatedAt: new Date() },
];
} // else append the message at the bottom
return [
...conv.messages,
{
content: newPrompt,
from: "user",
id: (messageId as Message["id"]) || crypto.randomUUID(),
createdAt: new Date(),
updatedAt: new Date(),
},
];
})() satisfies Message[];
if (conv.title.startsWith("Untitled")) {
try {
conv.title = (await summarize(newPrompt)) ?? conv.title;
} catch (e) {
console.error(e);
}
}
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
messages,
title: conv.title,
updatedAt: new Date(),
},
}
);
// we now build the stream
const stream = new ReadableStream({
async start(controller) {
const updates: MessageUpdate[] = [];
function update(newUpdate: MessageUpdate) {
if (newUpdate.type !== "stream") {
updates.push(newUpdate);
}
controller.enqueue(JSON.stringify(newUpdate) + "\n");
}
update({ type: "status", status: "started" });
let webSearchResults: WebSearch | undefined;
if (webSearch) {
webSearchResults = await runWebSearch(conv, newPrompt, update);
}
// we can now build the prompt using the messages
const prompt = await buildPrompt({
messages,
model,
webSearch: webSearchResults,
preprompt: conv.preprompt ?? model.preprompt,
locals: locals,
});
// fetch the endpoint
const randomEndpoint = modelEndpoint(model);
let usedFetch = fetch;
if (randomEndpoint.host === "sagemaker") {
const aws = new AwsClient({
accessKeyId: randomEndpoint.accessKey,
secretAccessKey: randomEndpoint.secretKey,
sessionToken: randomEndpoint.sessionToken,
service: "sagemaker",
});
usedFetch = aws.fetch.bind(aws) as typeof fetch;
}
async function saveLast(generated_text: string) {
if (!conv) {
throw error(404, "Conversation not found");
}
const lastMessage = messages[messages.length - 1];
if (lastMessage) {
// We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
if (generated_text.startsWith(prompt)) {
generated_text = generated_text.slice(prompt.length);
}
generated_text = trimSuffix(
trimPrefix(generated_text, "<|startoftext|>"),
PUBLIC_SEP_TOKEN
).trimEnd();
// remove the stop tokens
for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) {
if (generated_text.endsWith(stop)) {
generated_text = generated_text.slice(0, -stop.length).trimEnd();
}
}
lastMessage.content = generated_text;
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
messages,
title: conv.title,
updatedAt: new Date(),
},
}
);
update({
type: "finalAnswer",
text: generated_text,
});
}
}
const tokenStream = textGenerationStream(
{
parameters: {
...models.find((m) => m.id === conv.model)?.parameters,
return_full_text: false,
},
model: randomEndpoint.url,
inputs: prompt,
accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN,
},
{
use_cache: false,
fetch: usedFetch,
}
);
for await (const output of tokenStream) {
// if not generated_text is here it means the generation is not done
if (!output.generated_text) {
// else we get the next token
if (!output.token.special) {
const lastMessage = messages[messages.length - 1];
update({
type: "stream",
token: output.token.text,
});
// if the last message is not from assistant, it means this is the first token
if (lastMessage?.from !== "assistant") {
// so we create a new message
messages = [
...messages,
// id doesn't match the backend id but it's not important for assistant messages
// First token has a space at the beginning, trim it
{
from: "assistant",
content: output.token.text.trimStart(),
webSearch: webSearchResults,
updates: updates,
id: (responseId as Message["id"]) || crypto.randomUUID(),
createdAt: new Date(),
updatedAt: new Date(),
},
];
} else {
const date = abortedGenerations.get(convId.toString());
if (date && date > promptedAt) {
saveLast(lastMessage.content);
}
if (!output) {
break;
}
// otherwise we just concatenate tokens
lastMessage.content += output.token.text;
}
}
} else {
saveLast(output.generated_text);
}
}
},
async cancel() {
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
messages,
title: conv.title,
updatedAt: new Date(),
},
}
);
},
});
// Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
return new Response(stream);
}
export async function DELETE({ locals, params }) {
const convId = new ObjectId(params.id);
const conv = await collections.conversations.findOne({
_id: convId,
...authCondition(locals),
});
if (!conv) {
throw error(404, "Conversation not found");
}
await collections.conversations.deleteOne({ _id: conv._id });
return new Response();
}
export async function PATCH({ request, locals, params }) {
const { title } = z
.object({ title: z.string().trim().min(1).max(100) })
.parse(await request.json());
const convId = new ObjectId(params.id);
const conv = await collections.conversations.findOne({
_id: convId,
...authCondition(locals),
});
if (!conv) {
throw error(404, "Conversation not found");
}
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
title,
},
}
);
return new Response();
}