Continue generation feature (#707)
Browse files* Initial work on continue feature
* Move continue button
* Fix websearch with continue
* Make it work with every model
* Update src/routes/conversation/[id]/+server.ts
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
* fixes
* async all the things
* add reduce comment
* remove log
* Only show loading indicator if not continuing
---------
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
- .env.template +4 -2
- src/lib/buildPrompt.ts +18 -16
- src/lib/components/ContinueBtn.svelte +13 -0
- src/lib/components/chat/ChatMessage.svelte +1 -0
- src/lib/components/chat/ChatMessages.svelte +2 -1
- src/lib/components/chat/ChatWindow.svelte +17 -2
- src/lib/server/endpoints/endpoints.ts +1 -0
- src/lib/server/endpoints/tgi/endpointTgi.ts +13 -2
- src/lib/types/Message.ts +1 -0
- src/routes/conversation/[id]/+page.svelte +76 -25
- src/routes/conversation/[id]/+server.ts +65 -35
.env.template
CHANGED
@@ -57,7 +57,8 @@ MODELS=`[
|
|
57 |
"repetition_penalty": 1.2,
|
58 |
"top_k": 50,
|
59 |
"truncate": 3072,
|
60 |
-
"max_new_tokens": 1024
|
|
|
61 |
}
|
62 |
},
|
63 |
{
|
@@ -116,7 +117,8 @@ MODELS=`[
|
|
116 |
"repetition_penalty": 1.2,
|
117 |
"top_k": 50,
|
118 |
"truncate": 4096,
|
119 |
-
"max_new_tokens": 4096
|
|
|
120 |
}
|
121 |
},
|
122 |
{
|
|
|
57 |
"repetition_penalty": 1.2,
|
58 |
"top_k": 50,
|
59 |
"truncate": 3072,
|
60 |
+
"max_new_tokens": 1024,
|
61 |
+
"stop" : ["</s>", " </s><s>[INST] "]
|
62 |
}
|
63 |
},
|
64 |
{
|
|
|
117 |
"repetition_penalty": 1.2,
|
118 |
"top_k": 50,
|
119 |
"truncate": 4096,
|
120 |
+
"max_new_tokens": 4096,
|
121 |
+
"stop": [" </s><s>[INST] "]
|
122 |
}
|
123 |
},
|
124 |
{
|
src/lib/buildPrompt.ts
CHANGED
@@ -13,6 +13,7 @@ interface buildPromptOptions {
|
|
13 |
webSearch?: WebSearch;
|
14 |
preprompt?: string;
|
15 |
files?: File[];
|
|
|
16 |
}
|
17 |
|
18 |
export async function buildPrompt({
|
@@ -22,37 +23,38 @@ export async function buildPrompt({
|
|
22 |
preprompt,
|
23 |
id,
|
24 |
}: buildPromptOptions): Promise<string> {
|
|
|
|
|
25 |
if (webSearch && webSearch.context) {
|
26 |
-
|
27 |
-
const
|
28 |
-
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
|
29 |
|
|
|
|
|
30 |
const previousQuestions =
|
31 |
previousUserMessages.length > 0
|
32 |
? `Previous questions: \n${previousUserMessages
|
33 |
.map(({ content }) => `- ${content}`)
|
34 |
.join("\n")}`
|
35 |
: "";
|
|
|
36 |
const currentDate = format(new Date(), "MMMM d, yyyy");
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
=====================
|
43 |
${webSearch.context}
|
44 |
=====================
|
45 |
${previousQuestions}
|
46 |
-
Answer the question: ${
|
47 |
-
|
48 |
-
},
|
49 |
-
];
|
50 |
}
|
51 |
-
|
52 |
// section to handle potential files input
|
53 |
if (model.multimodal) {
|
54 |
-
|
55 |
-
|
56 |
let content = el.content;
|
57 |
|
58 |
if (el.from === "user") {
|
@@ -83,7 +85,7 @@ export async function buildPrompt({
|
|
83 |
|
84 |
return (
|
85 |
model
|
86 |
-
.chatPromptRender({ messages, preprompt })
|
87 |
// Not super precise, but it's truncated in the model's backend anyway
|
88 |
.split(" ")
|
89 |
.slice(-(model.parameters?.truncate ?? 0))
|
|
|
13 |
webSearch?: WebSearch;
|
14 |
preprompt?: string;
|
15 |
files?: File[];
|
16 |
+
continue?: boolean;
|
17 |
}
|
18 |
|
19 |
export async function buildPrompt({
|
|
|
23 |
preprompt,
|
24 |
id,
|
25 |
}: buildPromptOptions): Promise<string> {
|
26 |
+
let modifiedMessages = [...messages];
|
27 |
+
|
28 |
if (webSearch && webSearch.context) {
|
29 |
+
// find index of the last user message
|
30 |
+
const lastUsrMsgIndex = modifiedMessages.map((el) => el.from).lastIndexOf("user");
|
|
|
31 |
|
32 |
+
// combine all the other previous questions into one string
|
33 |
+
const previousUserMessages = modifiedMessages.filter((el) => el.from === "user").slice(0, -1);
|
34 |
const previousQuestions =
|
35 |
previousUserMessages.length > 0
|
36 |
? `Previous questions: \n${previousUserMessages
|
37 |
.map(({ content }) => `- ${content}`)
|
38 |
.join("\n")}`
|
39 |
: "";
|
40 |
+
|
41 |
const currentDate = format(new Date(), "MMMM d, yyyy");
|
42 |
+
|
43 |
+
// update the last user message directly (that way if the last message is an assistant partial answer, we keep the beginning of that answer)
|
44 |
+
modifiedMessages[lastUsrMsgIndex] = {
|
45 |
+
from: "user",
|
46 |
+
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
|
47 |
=====================
|
48 |
${webSearch.context}
|
49 |
=====================
|
50 |
${previousQuestions}
|
51 |
+
Answer the question: ${messages[lastUsrMsgIndex].content} `,
|
52 |
+
};
|
|
|
|
|
53 |
}
|
|
|
54 |
// section to handle potential files input
|
55 |
if (model.multimodal) {
|
56 |
+
modifiedMessages = await Promise.all(
|
57 |
+
modifiedMessages.map(async (el) => {
|
58 |
let content = el.content;
|
59 |
|
60 |
if (el.from === "user") {
|
|
|
85 |
|
86 |
return (
|
87 |
model
|
88 |
+
.chatPromptRender({ messages: modifiedMessages, preprompt })
|
89 |
// Not super precise, but it's truncated in the model's backend anyway
|
90 |
.split(" ")
|
91 |
.slice(-(model.parameters?.truncate ?? 0))
|
src/lib/components/ContinueBtn.svelte
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import CarbonContinue from "~icons/carbon/continue";
|
3 |
+
|
4 |
+
export let classNames = "";
|
5 |
+
</script>
|
6 |
+
|
7 |
+
<button
|
8 |
+
type="button"
|
9 |
+
on:click
|
10 |
+
class="btn flex h-8 rounded-lg border bg-white px-3 py-1 text-gray-500 shadow-sm transition-all hover:bg-gray-100 dark:border-gray-600 dark:bg-gray-700 dark:text-gray-300 dark:hover:bg-gray-600 {classNames}"
|
11 |
+
>
|
12 |
+
<CarbonContinue class="mr-2 text-xs " /> Continue
|
13 |
+
</button>
|
src/lib/components/chat/ChatMessage.svelte
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
import CarbonDownload from "~icons/carbon/download";
|
14 |
import CarbonThumbsUp from "~icons/carbon/thumbs-up";
|
15 |
import CarbonThumbsDown from "~icons/carbon/thumbs-down";
|
|
|
16 |
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
|
17 |
import type { Model } from "$lib/types/Model";
|
18 |
|
|
|
13 |
import CarbonDownload from "~icons/carbon/download";
|
14 |
import CarbonThumbsUp from "~icons/carbon/thumbs-up";
|
15 |
import CarbonThumbsDown from "~icons/carbon/thumbs-down";
|
16 |
+
|
17 |
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
|
18 |
import type { Model } from "$lib/types/Model";
|
19 |
|
src/lib/components/chat/ChatMessages.svelte
CHANGED
@@ -54,11 +54,12 @@
|
|
54 |
webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
|
55 |
on:retry
|
56 |
on:vote
|
|
|
57 |
/>
|
58 |
{:else}
|
59 |
<ChatIntroduction {models} {currentModel} on:message />
|
60 |
{/each}
|
61 |
-
{#if pending}
|
62 |
<ChatMessage
|
63 |
message={{ from: "assistant", content: "", id: randomUUID() }}
|
64 |
model={currentModel}
|
|
|
54 |
webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
|
55 |
on:retry
|
56 |
on:vote
|
57 |
+
on:continue
|
58 |
/>
|
59 |
{:else}
|
60 |
<ChatIntroduction {models} {currentModel} on:message />
|
61 |
{/each}
|
62 |
+
{#if pending && messages[messages.length - 1]?.from === "user"}
|
63 |
<ChatMessage
|
64 |
message={{ from: "assistant", content: "", id: randomUUID() }}
|
65 |
model={currentModel}
|
src/lib/components/chat/ChatWindow.svelte
CHANGED
@@ -24,6 +24,7 @@
|
|
24 |
import UploadBtn from "../UploadBtn.svelte";
|
25 |
import file2base64 from "$lib/utils/file2base64";
|
26 |
import { useSettingsStore } from "$lib/stores/settings";
|
|
|
27 |
|
28 |
export let messages: Message[] = [];
|
29 |
export let loading = false;
|
@@ -48,6 +49,7 @@
|
|
48 |
share: void;
|
49 |
stop: void;
|
50 |
retry: { id: Message["id"]; content: string };
|
|
|
51 |
}>();
|
52 |
|
53 |
const handleSubmit = () => {
|
@@ -124,6 +126,7 @@
|
|
124 |
}
|
125 |
}}
|
126 |
on:vote
|
|
|
127 |
on:retry={(ev) => {
|
128 |
if (!loading) dispatch("retry", ev.detail);
|
129 |
}}
|
@@ -173,8 +176,20 @@
|
|
173 |
content: messages[messages.length - 1].content,
|
174 |
})}
|
175 |
/>
|
176 |
-
{:else
|
177 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
{/if}
|
179 |
</div>
|
180 |
<form
|
|
|
24 |
import UploadBtn from "../UploadBtn.svelte";
|
25 |
import file2base64 from "$lib/utils/file2base64";
|
26 |
import { useSettingsStore } from "$lib/stores/settings";
|
27 |
+
import ContinueBtn from "../ContinueBtn.svelte";
|
28 |
|
29 |
export let messages: Message[] = [];
|
30 |
export let loading = false;
|
|
|
49 |
share: void;
|
50 |
stop: void;
|
51 |
retry: { id: Message["id"]; content: string };
|
52 |
+
continue: { id: Message["id"] };
|
53 |
}>();
|
54 |
|
55 |
const handleSubmit = () => {
|
|
|
126 |
}
|
127 |
}}
|
128 |
on:vote
|
129 |
+
on:continue
|
130 |
on:retry={(ev) => {
|
131 |
if (!loading) dispatch("retry", ev.detail);
|
132 |
}}
|
|
|
176 |
content: messages[messages.length - 1].content,
|
177 |
})}
|
178 |
/>
|
179 |
+
{:else}
|
180 |
+
<div class="ml-auto gap-2">
|
181 |
+
{#if currentModel.multimodal}
|
182 |
+
<UploadBtn bind:files classNames="ml-auto" />
|
183 |
+
{/if}
|
184 |
+
{#if messages && messages[messages.length - 1]?.interrupted && !isReadOnly}
|
185 |
+
<ContinueBtn
|
186 |
+
on:click={() =>
|
187 |
+
dispatch("continue", {
|
188 |
+
id: messages[messages.length - 1].id,
|
189 |
+
})}
|
190 |
+
/>
|
191 |
+
{/if}
|
192 |
+
</div>
|
193 |
{/if}
|
194 |
</div>
|
195 |
<form
|
src/lib/server/endpoints/endpoints.ts
CHANGED
@@ -14,6 +14,7 @@ interface EndpointParameters {
|
|
14 |
preprompt?: Conversation["preprompt"];
|
15 |
_id?: Conversation["_id"];
|
16 |
};
|
|
|
17 |
}
|
18 |
|
19 |
interface CommonEndpoint {
|
|
|
14 |
preprompt?: Conversation["preprompt"];
|
15 |
_id?: Conversation["_id"];
|
16 |
};
|
17 |
+
continue?: boolean;
|
18 |
}
|
19 |
|
20 |
interface CommonEndpoint {
|
src/lib/server/endpoints/tgi/endpointTgi.ts
CHANGED
@@ -15,8 +15,9 @@ export const endpointTgiParametersSchema = z.object({
|
|
15 |
|
16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
18 |
-
|
19 |
-
|
|
|
20 |
messages: conversation.messages,
|
21 |
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
|
22 |
preprompt: conversation.preprompt,
|
@@ -24,6 +25,16 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
|
|
24 |
id: conversation._id,
|
25 |
});
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
return textGenerationStream(
|
28 |
{
|
29 |
parameters: { ...model.parameters, return_full_text: false },
|
|
|
15 |
|
16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
18 |
+
|
19 |
+
return async ({ conversation, continue: messageContinue }) => {
|
20 |
+
let prompt = await buildPrompt({
|
21 |
messages: conversation.messages,
|
22 |
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
|
23 |
preprompt: conversation.preprompt,
|
|
|
25 |
id: conversation._id,
|
26 |
});
|
27 |
|
28 |
+
if (messageContinue) {
|
29 |
+
// start with the full prompt, and for each stop token, try to remove it from the end of the prompt
|
30 |
+
prompt = model.parameters.stop.reduce((acc: string, curr: string) => {
|
31 |
+
if (acc.endsWith(curr)) {
|
32 |
+
return acc.slice(0, acc.length - curr.length);
|
33 |
+
}
|
34 |
+
return acc;
|
35 |
+
}, prompt.trimEnd());
|
36 |
+
}
|
37 |
+
|
38 |
return textGenerationStream(
|
39 |
{
|
40 |
parameters: { ...model.parameters, return_full_text: false },
|
src/lib/types/Message.ts
CHANGED
@@ -11,4 +11,5 @@ export type Message = Partial<Timestamps> & {
|
|
11 |
webSearch?: WebSearch;
|
12 |
score?: -1 | 0 | 1;
|
13 |
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
|
|
|
14 |
};
|
|
|
11 |
webSearch?: WebSearch;
|
12 |
score?: -1 | 0 | 1;
|
13 |
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
|
14 |
+
interrupted?: boolean;
|
15 |
};
|
src/routes/conversation/[id]/+page.svelte
CHANGED
@@ -64,9 +64,17 @@
|
|
64 |
}
|
65 |
}
|
66 |
// this function is used to send new message to the backends
|
67 |
-
async function writeMessage(
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
try {
|
71 |
$isAborted = false;
|
72 |
loading = true;
|
@@ -74,13 +82,21 @@
|
|
74 |
|
75 |
// first we check if the messageId already exists, indicating a retry
|
76 |
|
77 |
-
let
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
}
|
83 |
|
|
|
|
|
84 |
const module = await import("browser-image-resizer");
|
85 |
|
86 |
// currently, only IDEFICS is supported by TGI
|
@@ -99,15 +115,31 @@
|
|
99 |
);
|
100 |
|
101 |
// slice up to the point of the retry
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
files = [];
|
113 |
|
@@ -115,9 +147,10 @@
|
|
115 |
method: "POST",
|
116 |
headers: { "Content-Type": "application/json" },
|
117 |
body: JSON.stringify({
|
118 |
-
inputs:
|
119 |
id: messageId,
|
120 |
is_retry: isRetry,
|
|
|
121 |
web_search: $webSearchParameters.useSearch,
|
122 |
files: isRetry ? undefined : resizedImages,
|
123 |
}),
|
@@ -282,37 +315,54 @@
|
|
282 |
// only used in case of creating new conversations (from the parent POST endpoint)
|
283 |
if ($pendingMessage) {
|
284 |
files = $pendingMessage.files;
|
285 |
-
await writeMessage($pendingMessage.content);
|
286 |
$pendingMessage = undefined;
|
287 |
}
|
288 |
});
|
289 |
|
290 |
async function onMessage(event: CustomEvent<string>) {
|
291 |
if (!data.shared) {
|
292 |
-
writeMessage(event.detail);
|
293 |
} else {
|
294 |
-
convFromShared()
|
295 |
.then(async (convId) => {
|
296 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
297 |
})
|
298 |
-
.then(() => writeMessage(event.detail))
|
299 |
.finally(() => (loading = false));
|
300 |
}
|
301 |
}
|
302 |
|
303 |
async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
|
304 |
if (!data.shared) {
|
305 |
-
writeMessage(
|
|
|
|
|
|
|
|
|
306 |
} else {
|
307 |
-
convFromShared()
|
308 |
.then(async (convId) => {
|
309 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
310 |
})
|
311 |
-
.then(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
.finally(() => (loading = false));
|
313 |
}
|
314 |
}
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
$: $page.params.id, (($isAborted = true), (loading = false));
|
317 |
$: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
|
318 |
</script>
|
@@ -337,6 +387,7 @@
|
|
337 |
bind:files
|
338 |
on:message={onMessage}
|
339 |
on:retry={onRetry}
|
|
|
340 |
on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
|
341 |
on:share={() => shareConversation($page.params.id, data.title)}
|
342 |
on:stop={() => (($isAborted = true), (loading = false))}
|
|
|
64 |
}
|
65 |
}
|
66 |
// this function is used to send new message to the backends
|
67 |
+
async function writeMessage({
|
68 |
+
prompt,
|
69 |
+
messageId = randomUUID(),
|
70 |
+
isRetry = false,
|
71 |
+
isContinue = false,
|
72 |
+
}: {
|
73 |
+
prompt?: string;
|
74 |
+
messageId?: ReturnType<typeof randomUUID>;
|
75 |
+
isRetry?: boolean;
|
76 |
+
isContinue?: boolean;
|
77 |
+
}): Promise<void> {
|
78 |
try {
|
79 |
$isAborted = false;
|
80 |
loading = true;
|
|
|
82 |
|
83 |
// first we check if the messageId already exists, indicating a retry
|
84 |
|
85 |
+
let msgIndex = messages.findIndex((msg) => msg.id === messageId);
|
86 |
+
|
87 |
+
if (msgIndex === -1) {
|
88 |
+
msgIndex = messages.length - 1;
|
89 |
+
}
|
90 |
+
if (isRetry && messages[msgIndex].from === "assistant") {
|
91 |
+
throw new Error("Trying to retry a message that is not from user");
|
92 |
+
}
|
93 |
+
|
94 |
+
if (isContinue && messages[msgIndex].from === "user") {
|
95 |
+
throw new Error("Trying to continue a message that is not from assistant");
|
96 |
}
|
97 |
|
98 |
+
// const isNewMessage = !isRetry && !isContinue;
|
99 |
+
|
100 |
const module = await import("browser-image-resizer");
|
101 |
|
102 |
// currently, only IDEFICS is supported by TGI
|
|
|
115 |
);
|
116 |
|
117 |
// slice up to the point of the retry
|
118 |
+
if (isRetry) {
|
119 |
+
messages = [
|
120 |
+
...messages.slice(0, msgIndex),
|
121 |
+
{
|
122 |
+
from: "user",
|
123 |
+
content: messages[msgIndex].content,
|
124 |
+
id: messageId,
|
125 |
+
files: messages[msgIndex].files,
|
126 |
+
},
|
127 |
+
];
|
128 |
+
} else if (!isContinue) {
|
129 |
+
// or add a new message if its not a continue request
|
130 |
+
if (!prompt) {
|
131 |
+
throw new Error("Prompt is undefined");
|
132 |
+
}
|
133 |
+
messages = [
|
134 |
+
...messages,
|
135 |
+
{
|
136 |
+
from: "user",
|
137 |
+
content: prompt ?? "",
|
138 |
+
id: messageId,
|
139 |
+
files: resizedImages,
|
140 |
+
},
|
141 |
+
];
|
142 |
+
}
|
143 |
|
144 |
files = [];
|
145 |
|
|
|
147 |
method: "POST",
|
148 |
headers: { "Content-Type": "application/json" },
|
149 |
body: JSON.stringify({
|
150 |
+
inputs: prompt,
|
151 |
id: messageId,
|
152 |
is_retry: isRetry,
|
153 |
+
is_continue: isContinue,
|
154 |
web_search: $webSearchParameters.useSearch,
|
155 |
files: isRetry ? undefined : resizedImages,
|
156 |
}),
|
|
|
315 |
// only used in case of creating new conversations (from the parent POST endpoint)
|
316 |
if ($pendingMessage) {
|
317 |
files = $pendingMessage.files;
|
318 |
+
await writeMessage({ prompt: $pendingMessage.content });
|
319 |
$pendingMessage = undefined;
|
320 |
}
|
321 |
});
|
322 |
|
323 |
async function onMessage(event: CustomEvent<string>) {
|
324 |
if (!data.shared) {
|
325 |
+
await writeMessage({ prompt: event.detail });
|
326 |
} else {
|
327 |
+
await convFromShared()
|
328 |
.then(async (convId) => {
|
329 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
330 |
})
|
331 |
+
.then(async () => await writeMessage({ prompt: event.detail }))
|
332 |
.finally(() => (loading = false));
|
333 |
}
|
334 |
}
|
335 |
|
336 |
async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
|
337 |
if (!data.shared) {
|
338 |
+
await writeMessage({
|
339 |
+
prompt: event.detail.content,
|
340 |
+
messageId: event.detail.id,
|
341 |
+
isRetry: true,
|
342 |
+
});
|
343 |
} else {
|
344 |
+
await convFromShared()
|
345 |
.then(async (convId) => {
|
346 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
347 |
})
|
348 |
+
.then(
|
349 |
+
async () =>
|
350 |
+
await writeMessage({
|
351 |
+
prompt: event.detail.content,
|
352 |
+
messageId: event.detail.id,
|
353 |
+
isRetry: true,
|
354 |
+
})
|
355 |
+
)
|
356 |
.finally(() => (loading = false));
|
357 |
}
|
358 |
}
|
359 |
|
360 |
+
async function onContinue(event: CustomEvent<{ id: Message["id"] }>) {
|
361 |
+
if (!data.shared) {
|
362 |
+
writeMessage({ messageId: event.detail.id, isContinue: true });
|
363 |
+
}
|
364 |
+
}
|
365 |
+
|
366 |
$: $page.params.id, (($isAborted = true), (loading = false));
|
367 |
$: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
|
368 |
</script>
|
|
|
387 |
bind:files
|
388 |
on:message={onMessage}
|
389 |
on:retry={onRetry}
|
390 |
+
on:continue={onContinue}
|
391 |
on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
|
392 |
on:share={() => shareConversation($page.params.id, data.title)}
|
393 |
on:stop={() => (($isAborted = true), (loading = false))}
|
src/routes/conversation/[id]/+server.ts
CHANGED
@@ -91,14 +91,16 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
91 |
const {
|
92 |
inputs: newPrompt,
|
93 |
id: messageId,
|
94 |
-
is_retry,
|
|
|
95 |
web_search: webSearch,
|
96 |
files: b64files,
|
97 |
} = z
|
98 |
.object({
|
99 |
-
inputs: z.string().trim().min(1),
|
100 |
id: z.optional(z.string().uuid()),
|
101 |
is_retry: z.optional(z.boolean()),
|
|
|
102 |
web_search: z.optional(z.boolean()),
|
103 |
files: z.optional(z.array(z.string())),
|
104 |
})
|
@@ -136,38 +138,50 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
136 |
hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv)));
|
137 |
}
|
138 |
|
|
|
|
|
|
|
|
|
|
|
139 |
// get the list of messages
|
140 |
// while checking for retries
|
141 |
let messages = (() => {
|
142 |
-
|
|
|
143 |
// if the message is a retry, replace the message and remove the messages after it
|
144 |
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
|
|
|
145 |
if (retryMessageIdx === -1) {
|
146 |
retryMessageIdx = conv.messages.length;
|
147 |
}
|
|
|
148 |
return [
|
149 |
...conv.messages.slice(0, retryMessageIdx),
|
150 |
{
|
151 |
-
content:
|
152 |
from: "user",
|
153 |
id: messageId as Message["id"],
|
154 |
updatedAt: new Date(),
|
155 |
files: conv.messages[retryMessageIdx]?.files,
|
156 |
},
|
157 |
];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
} // else append the message at the bottom
|
159 |
-
|
160 |
-
return [
|
161 |
-
...conv.messages,
|
162 |
-
{
|
163 |
-
content: newPrompt,
|
164 |
-
from: "user",
|
165 |
-
id: (messageId as Message["id"]) || crypto.randomUUID(),
|
166 |
-
createdAt: new Date(),
|
167 |
-
updatedAt: new Date(),
|
168 |
-
files: hashes,
|
169 |
-
},
|
170 |
-
];
|
171 |
})() satisfies Message[];
|
172 |
|
173 |
await collections.conversations.updateOne(
|
@@ -183,10 +197,14 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
183 |
}
|
184 |
);
|
185 |
|
|
|
|
|
186 |
// we now build the stream
|
187 |
const stream = new ReadableStream({
|
188 |
async start(controller) {
|
189 |
-
const updates: MessageUpdate[] =
|
|
|
|
|
190 |
|
191 |
function update(newUpdate: MessageUpdate) {
|
192 |
if (newUpdate.type !== "stream") {
|
@@ -209,7 +227,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
209 |
const summarizeIfNeeded = (async () => {
|
210 |
if (conv.title === "New Chat" && messages.length === 1) {
|
211 |
try {
|
212 |
-
conv.title = (await summarize(
|
213 |
update({ type: "status", status: "title", message: conv.title });
|
214 |
} catch (e) {
|
215 |
console.error(e);
|
@@ -232,17 +250,22 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
232 |
|
233 |
let webSearchResults: WebSearch | undefined;
|
234 |
|
235 |
-
if (webSearch) {
|
236 |
-
webSearchResults = await runWebSearch(conv,
|
|
|
|
|
|
|
237 |
}
|
238 |
|
239 |
-
messages[messages.length - 1].webSearch = webSearchResults;
|
240 |
-
|
241 |
conv.messages = messages;
|
242 |
|
|
|
|
|
|
|
|
|
243 |
try {
|
244 |
const endpoint = await model.getEndpoint();
|
245 |
-
for await (const output of await endpoint({ conversation: conv })) {
|
246 |
// if not generated_text is here it means the generation is not done
|
247 |
if (!output.generated_text) {
|
248 |
// else we get the next token
|
@@ -292,7 +315,8 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
292 |
...messages.slice(0, -1),
|
293 |
{
|
294 |
...messages[messages.length - 1],
|
295 |
-
content: output.generated_text,
|
|
|
296 |
updates,
|
297 |
updatedAt: new Date(),
|
298 |
},
|
@@ -302,6 +326,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
302 |
} catch (e) {
|
303 |
update({ type: "status", status: "error", message: (e as Error).message });
|
304 |
}
|
|
|
305 |
await collections.conversations.updateOne(
|
306 |
{
|
307 |
_id: convId,
|
@@ -315,6 +340,9 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
315 |
}
|
316 |
);
|
317 |
|
|
|
|
|
|
|
318 |
update({
|
319 |
type: "finalAnswer",
|
320 |
text: messages[messages.length - 1].content,
|
@@ -324,18 +352,20 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
324 |
return;
|
325 |
},
|
326 |
async cancel() {
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
{
|
332 |
-
$set: {
|
333 |
-
messages,
|
334 |
-
title: conv.title,
|
335 |
-
updatedAt: new Date(),
|
336 |
},
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
},
|
340 |
});
|
341 |
|
|
|
91 |
const {
|
92 |
inputs: newPrompt,
|
93 |
id: messageId,
|
94 |
+
is_retry: isRetry,
|
95 |
+
is_continue: isContinue,
|
96 |
web_search: webSearch,
|
97 |
files: b64files,
|
98 |
} = z
|
99 |
.object({
|
100 |
+
inputs: z.optional(z.string().trim().min(1)),
|
101 |
id: z.optional(z.string().uuid()),
|
102 |
is_retry: z.optional(z.boolean()),
|
103 |
+
is_continue: z.optional(z.boolean()),
|
104 |
web_search: z.optional(z.boolean()),
|
105 |
files: z.optional(z.array(z.string())),
|
106 |
})
|
|
|
138 |
hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv)));
|
139 |
}
|
140 |
|
141 |
+
// can only call isContinue on the last message id
|
142 |
+
if (isContinue && conv.messages[conv.messages.length - 1].id !== messageId) {
|
143 |
+
throw error(400, "Can only continue the last message");
|
144 |
+
}
|
145 |
+
|
146 |
// get the list of messages
|
147 |
// while checking for retries
|
148 |
let messages = (() => {
|
149 |
+
// for retries we slice and rewrite the last user message
|
150 |
+
if (isRetry && messageId) {
|
151 |
// if the message is a retry, replace the message and remove the messages after it
|
152 |
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
|
153 |
+
|
154 |
if (retryMessageIdx === -1) {
|
155 |
retryMessageIdx = conv.messages.length;
|
156 |
}
|
157 |
+
|
158 |
return [
|
159 |
...conv.messages.slice(0, retryMessageIdx),
|
160 |
{
|
161 |
+
content: conv.messages[retryMessageIdx]?.content,
|
162 |
from: "user",
|
163 |
id: messageId as Message["id"],
|
164 |
updatedAt: new Date(),
|
165 |
files: conv.messages[retryMessageIdx]?.files,
|
166 |
},
|
167 |
];
|
168 |
+
} else if (isContinue && messageId) {
|
169 |
+
// for continue we do nothing and expand the last assistant message
|
170 |
+
return conv.messages;
|
171 |
+
} else {
|
172 |
+
// in normal conversation we add an extra user message
|
173 |
+
return [
|
174 |
+
...conv.messages,
|
175 |
+
{
|
176 |
+
content: newPrompt ?? "",
|
177 |
+
from: "user",
|
178 |
+
id: (messageId as Message["id"]) || crypto.randomUUID(),
|
179 |
+
createdAt: new Date(),
|
180 |
+
updatedAt: new Date(),
|
181 |
+
files: hashes,
|
182 |
+
},
|
183 |
+
];
|
184 |
} // else append the message at the bottom
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
})() satisfies Message[];
|
186 |
|
187 |
await collections.conversations.updateOne(
|
|
|
197 |
}
|
198 |
);
|
199 |
|
200 |
+
let doneStreaming = false;
|
201 |
+
|
202 |
// we now build the stream
|
203 |
const stream = new ReadableStream({
|
204 |
async start(controller) {
|
205 |
+
const updates: MessageUpdate[] = isContinue
|
206 |
+
? conv.messages[conv.messages.length - 1].updates ?? []
|
207 |
+
: [];
|
208 |
|
209 |
function update(newUpdate: MessageUpdate) {
|
210 |
if (newUpdate.type !== "stream") {
|
|
|
227 |
const summarizeIfNeeded = (async () => {
|
228 |
if (conv.title === "New Chat" && messages.length === 1) {
|
229 |
try {
|
230 |
+
conv.title = (await summarize(messages[0].content)) ?? conv.title;
|
231 |
update({ type: "status", status: "title", message: conv.title });
|
232 |
} catch (e) {
|
233 |
console.error(e);
|
|
|
250 |
|
251 |
let webSearchResults: WebSearch | undefined;
|
252 |
|
253 |
+
if (webSearch && !isContinue) {
|
254 |
+
webSearchResults = await runWebSearch(conv, messages[messages.length - 1].content, update);
|
255 |
+
messages[messages.length - 1].webSearch = webSearchResults;
|
256 |
+
} else if (isContinue) {
|
257 |
+
webSearchResults = messages[messages.length - 1].webSearch;
|
258 |
}
|
259 |
|
|
|
|
|
260 |
conv.messages = messages;
|
261 |
|
262 |
+
const previousContent = isContinue
|
263 |
+
? conv.messages.find((message) => message.id === messageId)?.content ?? ""
|
264 |
+
: "";
|
265 |
+
|
266 |
try {
|
267 |
const endpoint = await model.getEndpoint();
|
268 |
+
for await (const output of await endpoint({ conversation: conv, continue: isContinue })) {
|
269 |
// if not generated_text is here it means the generation is not done
|
270 |
if (!output.generated_text) {
|
271 |
// else we get the next token
|
|
|
315 |
...messages.slice(0, -1),
|
316 |
{
|
317 |
...messages[messages.length - 1],
|
318 |
+
content: previousContent + output.generated_text,
|
319 |
+
interrupted: !output.token.special, // if its a special token it finished on its own, else it was interrupted
|
320 |
updates,
|
321 |
updatedAt: new Date(),
|
322 |
},
|
|
|
326 |
} catch (e) {
|
327 |
update({ type: "status", status: "error", message: (e as Error).message });
|
328 |
}
|
329 |
+
|
330 |
await collections.conversations.updateOne(
|
331 |
{
|
332 |
_id: convId,
|
|
|
340 |
}
|
341 |
);
|
342 |
|
343 |
+
// used to detect if cancel() is called bc of interrupt or just because the connection closes
|
344 |
+
doneStreaming = true;
|
345 |
+
|
346 |
update({
|
347 |
type: "finalAnswer",
|
348 |
text: messages[messages.length - 1].content,
|
|
|
352 |
return;
|
353 |
},
|
354 |
async cancel() {
|
355 |
+
if (!doneStreaming) {
|
356 |
+
await collections.conversations.updateOne(
|
357 |
+
{
|
358 |
+
_id: convId,
|
|
|
|
|
|
|
|
|
|
|
359 |
},
|
360 |
+
{
|
361 |
+
$set: {
|
362 |
+
messages,
|
363 |
+
title: conv.title,
|
364 |
+
updatedAt: new Date(),
|
365 |
+
},
|
366 |
+
}
|
367 |
+
);
|
368 |
+
}
|
369 |
},
|
370 |
});
|
371 |
|