Liam Dyer commited on
Commit
8583cf1
1 Parent(s): 999407a

feat: smooth and combine token output (#936)

Browse files

* feat: smooth and combine token output

* fix: stop generating button not triggering message updates abort

src/lib/utils/messageUpdates.ts ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { MessageUpdate, TextStreamUpdate } from "$lib/types/MessageUpdate";
2
+
3
+ type MessageUpdateRequestOptions = {
4
+ base: string;
5
+ inputs?: string;
6
+ messageId?: string;
7
+ isRetry: boolean;
8
+ isContinue: boolean;
9
+ webSearch: boolean;
10
+ files?: string[];
11
+ };
12
+ export async function fetchMessageUpdates(
13
+ conversationId: string,
14
+ opts: MessageUpdateRequestOptions,
15
+ abortSignal: AbortSignal
16
+ ): Promise<AsyncGenerator<MessageUpdate>> {
17
+ const abortController = new AbortController();
18
+ abortSignal.addEventListener("abort", () => abortController.abort());
19
+
20
+ const response = await fetch(`${opts.base}/conversation/${conversationId}`, {
21
+ method: "POST",
22
+ headers: { "Content-Type": "application/json" },
23
+ body: JSON.stringify({
24
+ inputs: opts.inputs,
25
+ id: opts.messageId,
26
+ is_retry: opts.isRetry,
27
+ is_continue: opts.isContinue,
28
+ web_search: opts.webSearch,
29
+ files: opts.files,
30
+ }),
31
+ signal: abortController.signal,
32
+ });
33
+
34
+ if (!response.ok) {
35
+ const errorMessage = await response
36
+ .json()
37
+ .then((obj) => obj.message)
38
+ .catch(() => `Request failed with status code ${response.status}: ${response.statusText}`);
39
+ throw Error(errorMessage);
40
+ }
41
+ if (!response.body) {
42
+ throw Error("Body not defined");
43
+ }
44
+ return smoothAsyncIterator(
45
+ streamMessageUpdatesToFullWords(endpointStreamToIterator(response, abortController))
46
+ );
47
+ }
48
+
49
+ async function* endpointStreamToIterator(
50
+ response: Response,
51
+ abortController: AbortController
52
+ ): AsyncGenerator<MessageUpdate> {
53
+ const reader = response.body?.pipeThrough(new TextDecoderStream()).getReader();
54
+ if (!reader) throw Error("Response for endpoint had no body");
55
+
56
+ // Handle any cases where we must abort
57
+ reader.closed.then(() => abortController.abort());
58
+
59
+ // Handle logic for aborting
60
+ abortController.signal.addEventListener("abort", () => reader.cancel());
61
+
62
+ // ex) If the last response is => {"type": "stream", "token":
63
+ // It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
64
+ let prevChunk = "";
65
+ while (!abortController.signal.aborted) {
66
+ const { done, value } = await reader.read();
67
+ if (done) {
68
+ abortController.abort();
69
+ break;
70
+ }
71
+ if (!value) continue;
72
+
73
+ const { messageUpdates, remainingText } = parseMessageUpdates(prevChunk + value);
74
+ prevChunk = remainingText;
75
+ for (const messageUpdate of messageUpdates) yield messageUpdate;
76
+ }
77
+ }
78
+
79
+ function parseMessageUpdates(value: string): {
80
+ messageUpdates: MessageUpdate[];
81
+ remainingText: string;
82
+ } {
83
+ const inputs = value.split("\n");
84
+ const messageUpdates: MessageUpdate[] = [];
85
+ for (const input of inputs) {
86
+ try {
87
+ messageUpdates.push(JSON.parse(input) as MessageUpdate);
88
+ } catch (error) {
89
+ // in case of parsing error, we return what we were able to parse
90
+ if (error instanceof SyntaxError) {
91
+ return {
92
+ messageUpdates,
93
+ remainingText: inputs.at(-1) ?? "",
94
+ };
95
+ }
96
+ }
97
+ }
98
+ return { messageUpdates, remainingText: "" };
99
+ }
100
+
101
+ /**
102
+ * Emits all the message updates immediately that aren't "stream" type
103
+ * Emits a concatenated "stream" type message update once it detects a full word
104
+ * Example: "what" " don" "'t" => "what" " don't"
105
+ * Only supports latin languages, ignores others
106
+ */
107
+ async function* streamMessageUpdatesToFullWords(
108
+ iterator: AsyncGenerator<MessageUpdate>
109
+ ): AsyncGenerator<MessageUpdate> {
110
+ let bufferedStreamUpdates: TextStreamUpdate[] = [];
111
+
112
+ const endAlphanumeric = /[a-zA-Z0-9À-ž'`]+$/;
113
+ const beginnningAlphanumeric = /^[a-zA-Z0-9À-ž'`]+/;
114
+
115
+ for await (const messageUpdate of iterator) {
116
+ if (messageUpdate.type !== "stream") {
117
+ yield messageUpdate;
118
+ continue;
119
+ }
120
+ bufferedStreamUpdates.push(messageUpdate);
121
+
122
+ let lastIndexEmitted = 0;
123
+ for (let i = 1; i < bufferedStreamUpdates.length; i++) {
124
+ const prevEndsAlphanumeric = endAlphanumeric.test(bufferedStreamUpdates[i - 1].token);
125
+ const currBeginsAlphanumeric = beginnningAlphanumeric.test(bufferedStreamUpdates[i].token);
126
+ const shouldCombine = prevEndsAlphanumeric && currBeginsAlphanumeric;
127
+ const combinedTooMany = i - lastIndexEmitted >= 5;
128
+ if (shouldCombine && !combinedTooMany) continue;
129
+
130
+ // Combine tokens together and emit
131
+ yield {
132
+ type: "stream",
133
+ token: bufferedStreamUpdates
134
+ .slice(lastIndexEmitted, i)
135
+ .map((_) => _.token)
136
+ .join(""),
137
+ };
138
+ lastIndexEmitted = i;
139
+ }
140
+ bufferedStreamUpdates = bufferedStreamUpdates.slice(lastIndexEmitted);
141
+ }
142
+ for (const messageUpdate of bufferedStreamUpdates) yield messageUpdate;
143
+ }
144
+
145
+ /**
146
+ * Attempts to smooth out the time between values emitted by an async iterator
147
+ * by waiting for the average time between values to emit the next value
148
+ */
149
+ async function* smoothAsyncIterator<T>(iterator: AsyncGenerator<T>): AsyncGenerator<T> {
150
+ const eventTarget = new EventTarget();
151
+ let done = false;
152
+ const valuesBuffer: T[] = [];
153
+ const valueTimesMS: number[] = [];
154
+
155
+ const next = async () => {
156
+ const obj = await iterator.next();
157
+ if (obj.done) {
158
+ done = true;
159
+ } else {
160
+ valuesBuffer.push(obj.value);
161
+ valueTimesMS.push(performance.now());
162
+ next();
163
+ }
164
+ eventTarget.dispatchEvent(new Event("next"));
165
+ };
166
+ next();
167
+
168
+ let timeOfLastEmitMS = performance.now();
169
+ while (!done || valuesBuffer.length > 0) {
170
+ // Only consider the last X times between tokens
171
+ const sampledTimesMS = valueTimesMS.slice(-30);
172
+
173
+ // Get the total time spent in abnormal periods
174
+ const anomalyThresholdMS = 2000;
175
+ const anomalyDurationMS = sampledTimesMS
176
+ .map((time, i, times) => time - times[i - 1])
177
+ .slice(1)
178
+ .filter((time) => time > anomalyThresholdMS)
179
+ .reduce((a, b) => a + b, 0);
180
+
181
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
182
+ const totalTimeMSBetweenValues = sampledTimesMS.at(-1)! - sampledTimesMS[0];
183
+ const timeMSBetweenValues = totalTimeMSBetweenValues - anomalyDurationMS;
184
+
185
+ const averageTimeMSBetweenValues = Math.min(
186
+ 200,
187
+ timeMSBetweenValues / (sampledTimesMS.length - 1)
188
+ );
189
+ const timeSinceLastEmitMS = performance.now() - timeOfLastEmitMS;
190
+
191
+ // Emit after waiting duration or cancel if "next" event is emitted
192
+ const gotNext = await Promise.race([
193
+ sleep(Math.max(5, averageTimeMSBetweenValues - timeSinceLastEmitMS)),
194
+ waitForEvent(eventTarget, "next"),
195
+ ]);
196
+
197
+ // Go to next iteration so we can re-calculate when to emit
198
+ if (gotNext) continue;
199
+
200
+ // Nothing in buffer to emit
201
+ if (valuesBuffer.length === 0) continue;
202
+
203
+ // Emit
204
+ timeOfLastEmitMS = performance.now();
205
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
206
+ yield valuesBuffer.shift()!;
207
+ }
208
+ }
209
+
210
+ const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
211
+ const waitForEvent = (eventTarget: EventTarget, eventName: string) =>
212
+ new Promise<boolean>((resolve) =>
213
+ eventTarget.addEventListener(eventName, () => resolve(true), { once: true })
214
+ );
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -16,6 +16,7 @@
16
  import file2base64 from "$lib/utils/file2base64";
17
  import { addChildren } from "$lib/utils/tree/addChildren";
18
  import { addSibling } from "$lib/utils/tree/addSibling";
 
19
  import { createConvTreeStore } from "$lib/stores/convTree";
20
  import type { v4 } from "uuid";
21
 
@@ -181,125 +182,71 @@
181
 
182
  messages = [...messages];
183
  const messageToWriteTo = messages.find((message) => message.id === messageToWriteToId);
184
-
185
  if (!messageToWriteTo) {
186
  throw new Error("Message to write to not found");
187
  }
 
188
  // disable websearch if assistant is present
189
  const hasAssistant = !!$page.data.assistant;
190
-
191
- const response = await fetch(`${base}/conversation/${$page.params.id}`, {
192
- method: "POST",
193
- headers: { "Content-Type": "application/json" },
194
- body: JSON.stringify({
195
  inputs: prompt,
196
- id: messageId,
197
- is_retry: isRetry,
198
- is_continue: isContinue,
199
- web_search: !hasAssistant && $webSearchParameters.useSearch,
200
  files: isRetry ? undefined : resizedImages,
201
- }),
 
 
 
202
  });
 
203
 
204
  files = [];
205
- if (!response.body) {
206
- throw new Error("Body not defined");
207
- }
208
-
209
- if (!response.ok) {
210
- error.set((await response.json())?.message);
211
- return;
212
- }
213
 
214
- // eslint-disable-next-line no-undef
215
- const encoder = new TextDecoderStream();
216
- const reader = response?.body?.pipeThrough(encoder).getReader();
217
- let finalAnswer = "";
218
  const messageUpdates: MessageUpdate[] = [];
219
-
220
- // set str queue
221
- // ex) if the last response is => {"type": "stream", "token":
222
- // It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
223
- let prev_input_chunk = [""];
224
-
225
- // this is a bit ugly
226
- // we read the stream until we get the final answer
227
-
228
- let readerClosed = false;
229
-
230
- reader.closed.then(() => {
231
- readerClosed = true;
232
- });
233
-
234
- while (finalAnswer === "") {
235
- // check for abort
236
- if ($isAborted || $error || readerClosed) {
237
- reader?.cancel();
238
  break;
239
  }
240
 
241
- // if there is something to read
242
- await reader?.read().then(async ({ done, value }) => {
243
- // we read, if it's done we cancel
244
- if (done) {
245
- reader.cancel();
246
- }
247
-
248
- if (!value) {
249
- return;
250
- }
251
-
252
- value = prev_input_chunk.pop() + value;
253
-
254
- // if it's not done we parse the value, which contains all messages
255
- const inputs = value.split("\n");
256
- inputs.forEach(async (el: string) => {
257
- try {
258
- const update = JSON.parse(el) as MessageUpdate;
259
-
260
- if (update.type !== "stream") {
261
- messageUpdates.push(update);
262
- }
263
-
264
- if (update.type === "finalAnswer") {
265
- finalAnswer = update.text;
266
- loading = false;
267
- pending = false;
268
- } else if (update.type === "stream") {
269
- pending = false;
270
- messageToWriteTo.content += update.token;
271
- messages = [...messages];
272
- } else if (update.type === "webSearch") {
273
- messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
274
- messages = [...messages];
275
- } else if (update.type === "status") {
276
- if (update.status === "title" && update.message) {
277
- const convInData = data.conversations.find(({ id }) => id === $page.params.id);
278
- if (convInData) {
279
- convInData.title = update.message;
280
-
281
- $titleUpdate = {
282
- title: update.message,
283
- convId: $page.params.id,
284
- };
285
- }
286
- } else if (update.status === "error") {
287
- $error = update.message ?? "An error has occurred";
288
- }
289
- } else if (update.type === "error") {
290
- error.set(update.message);
291
- reader.cancel();
292
- }
293
- } catch (parseError) {
294
- // in case of parsing error we wait for the next message
295
-
296
- if (el === inputs[inputs.length - 1]) {
297
- prev_input_chunk.push(el);
298
- }
299
- return;
300
  }
301
- });
302
- });
 
 
 
 
 
303
  }
304
 
305
  messageToWriteTo.updates = messageUpdates;
 
16
  import file2base64 from "$lib/utils/file2base64";
17
  import { addChildren } from "$lib/utils/tree/addChildren";
18
  import { addSibling } from "$lib/utils/tree/addSibling";
19
+ import { fetchMessageUpdates } from "$lib/utils/messageUpdates";
20
  import { createConvTreeStore } from "$lib/stores/convTree";
21
  import type { v4 } from "uuid";
22
 
 
182
 
183
  messages = [...messages];
184
  const messageToWriteTo = messages.find((message) => message.id === messageToWriteToId);
 
185
  if (!messageToWriteTo) {
186
  throw new Error("Message to write to not found");
187
  }
188
+
189
  // disable websearch if assistant is present
190
  const hasAssistant = !!$page.data.assistant;
191
+ const messageUpdatesAbortController = new AbortController();
192
+ const messageUpdatesIterator = await fetchMessageUpdates(
193
+ $page.params.id,
194
+ {
195
+ base,
196
  inputs: prompt,
197
+ messageId,
198
+ isRetry,
199
+ isContinue,
200
+ webSearch: !hasAssistant && $webSearchParameters.useSearch,
201
  files: isRetry ? undefined : resizedImages,
202
+ },
203
+ messageUpdatesAbortController.signal
204
+ ).catch((err) => {
205
+ error.set(err.message);
206
  });
207
+ if (messageUpdatesIterator === undefined) return;
208
 
209
  files = [];
 
 
 
 
 
 
 
 
210
 
 
 
 
 
211
  const messageUpdates: MessageUpdate[] = [];
212
+ for await (const update of messageUpdatesIterator) {
213
+ if ($isAborted) {
214
+ messageUpdatesAbortController.abort();
215
+ return;
216
+ }
217
+ if (update.type === "finalAnswer") {
218
+ loading = false;
219
+ pending = false;
 
 
 
 
 
 
 
 
 
 
 
220
  break;
221
  }
222
 
223
+ messageUpdates.push(update);
224
+
225
+ if (update.type === "stream") {
226
+ pending = false;
227
+ messageToWriteTo.content += update.token;
228
+ messages = [...messages];
229
+ } else if (update.type === "webSearch") {
230
+ messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
231
+ messages = [...messages];
232
+ } else if (update.type === "status") {
233
+ if (update.status === "title" && update.message) {
234
+ const convInData = data.conversations.find(({ id }) => id === $page.params.id);
235
+ if (convInData) {
236
+ convInData.title = update.message;
237
+
238
+ $titleUpdate = {
239
+ title: update.message,
240
+ convId: $page.params.id,
241
+ };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  }
243
+ } else if (update.status === "error") {
244
+ $error = update.message ?? "An error has occurred";
245
+ }
246
+ } else if (update.type === "error") {
247
+ error.set(update.message);
248
+ messageUpdatesAbortController.abort();
249
+ }
250
  }
251
 
252
  messageToWriteTo.updates = messageUpdates;