Add embedding models configurable, from both transformers.js and TEI (#646)
Browse files* Add embedding models configurable, from both Xenova and TEI
* fix lint and format
* Fix bug in sentenceSimilarity
* Batches for TEI using /info route
* Fix web search disapear when finish searching
* Fix lint and format
* Add more options for better embedding model usage
* Fixing CR issues
* Fix websearch disapear in later PR
* Fix lint
* Fix more minor code CR
* Valiadate embeddingModelName field in model config
* Add embeddingModel into shared conversation
* Fix lint and format
* Add default embedding model, and more readme explanation
* Fix minor embedding model readme detailed
* Update settings.json
* Update README.md
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
* Update README.md
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
* Apply suggestions from code review
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
* Resolved more issues
* lint
* Fix more issues
* Fix format
* fix small typo
* lint
* fix default model
* Rn `maxSequenceLength` -> `chunkCharLength`
* format
* add "authorization" example
* format
---------
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>
Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
- .env +12 -0
- .env.template +0 -1
- README.md +84 -4
- src/lib/components/OpenWebSearchResults.svelte +2 -2
- src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +65 -0
- src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts +46 -0
- src/lib/server/embeddingModels.ts +99 -0
- src/lib/server/models.ts +2 -0
- src/lib/server/{websearch/sentenceSimilarity.ts → sentenceSimilarity.ts} +18 -28
- src/lib/server/websearch/runWebSearch.ts +12 -6
- src/lib/types/Conversation.ts +1 -0
- src/lib/types/EmbeddingEndpoints.ts +41 -0
- src/lib/types/SharedConversation.ts +2 -0
- src/routes/conversation/+server.ts +6 -0
- src/routes/conversation/[id]/+page.svelte +2 -1
- src/routes/conversation/[id]/share/+server.ts +1 -0
- src/routes/login/callback/updateUser.spec.ts +2 -0
@@ -46,6 +46,18 @@ CA_PATH=#
|
|
46 |
CLIENT_KEY_PASSWORD=#
|
47 |
REJECT_UNAUTHORIZED=true
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# 'name', 'userMessageToken', 'assistantMessageToken' are required
|
50 |
MODELS=`[
|
51 |
{
|
|
|
46 |
CLIENT_KEY_PASSWORD=#
|
47 |
REJECT_UNAUTHORIZED=true
|
48 |
|
49 |
+
TEXT_EMBEDDING_MODELS = `[
|
50 |
+
{
|
51 |
+
"name": "Xenova/gte-small",
|
52 |
+
"displayName": "Xenova/gte-small",
|
53 |
+
"description": "Local embedding model running on the server.",
|
54 |
+
"chunkCharLength": 512,
|
55 |
+
"endpoints": [
|
56 |
+
{ "type": "transformersjs" }
|
57 |
+
]
|
58 |
+
}
|
59 |
+
]`
|
60 |
+
|
61 |
# 'name', 'userMessageToken', 'assistantMessageToken' are required
|
62 |
MODELS=`[
|
63 |
{
|
@@ -204,7 +204,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2'
|
|
204 |
# "stop": ["</s>"]
|
205 |
# }}`
|
206 |
|
207 |
-
|
208 |
APP_BASE="/chat"
|
209 |
PUBLIC_ORIGIN=https://huggingface.co
|
210 |
PUBLIC_SHARE_PREFIX=https://hf.co/chat
|
|
|
204 |
# "stop": ["</s>"]
|
205 |
# }}`
|
206 |
|
|
|
207 |
APP_BASE="/chat"
|
208 |
PUBLIC_ORIGIN=https://huggingface.co
|
209 |
PUBLIC_SHARE_PREFIX=https://hf.co/chat
|
@@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv
|
|
20 |
1. [Setup](#setup)
|
21 |
2. [Launch](#launch)
|
22 |
3. [Web Search](#web-search)
|
23 |
-
4. [
|
24 |
-
5. [
|
25 |
-
6. [
|
|
|
26 |
|
27 |
## No Setup Deploy
|
28 |
|
@@ -78,10 +79,50 @@ Chat UI features a powerful Web Search feature. It works by:
|
|
78 |
|
79 |
1. Generating an appropriate search query from the user prompt.
|
80 |
2. Performing web search and extracting content from webpages.
|
81 |
-
3. Creating embeddings from texts using
|
82 |
4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
|
83 |
5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
## Extra parameters
|
86 |
|
87 |
### OpenID connect
|
@@ -425,6 +466,45 @@ If you're using a certificate signed by a private CA, you will also need to add
|
|
425 |
|
426 |
If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
## Deploying to a HF Space
|
429 |
|
430 |
Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.
|
|
|
20 |
1. [Setup](#setup)
|
21 |
2. [Launch](#launch)
|
22 |
3. [Web Search](#web-search)
|
23 |
+
4. [Text Embedding Models](#text-embedding-models)
|
24 |
+
5. [Extra parameters](#extra-parameters)
|
25 |
+
6. [Deploying to a HF Space](#deploying-to-a-hf-space)
|
26 |
+
7. [Building](#building)
|
27 |
|
28 |
## No Setup Deploy
|
29 |
|
|
|
79 |
|
80 |
1. Generating an appropriate search query from the user prompt.
|
81 |
2. Performing web search and extracting content from webpages.
|
82 |
+
3. Creating embeddings from texts using a text embedding model.
|
83 |
4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
|
84 |
5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).
|
85 |
|
86 |
+
## Text Embedding Models
|
87 |
+
|
88 |
+
By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
|
89 |
+
|
90 |
+
You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example:
|
91 |
+
|
92 |
+
```env
|
93 |
+
TEXT_EMBEDDING_MODELS = `[
|
94 |
+
{
|
95 |
+
"name": "Xenova/gte-small",
|
96 |
+
"displayName": "Xenova/gte-small",
|
97 |
+
"description": "locally running embedding",
|
98 |
+
"chunkCharLength": 512,
|
99 |
+
"endpoints": [
|
100 |
+
{"type": "transformersjs"}
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"name": "intfloat/e5-base-v2",
|
105 |
+
"displayName": "intfloat/e5-base-v2",
|
106 |
+
"description": "hosted embedding model",
|
107 |
+
"chunkCharLength": 768,
|
108 |
+
"preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
|
109 |
+
"prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
|
110 |
+
"endpoints": [
|
111 |
+
{
|
112 |
+
"type": "tei",
|
113 |
+
"url": "http://127.0.0.1:8080/",
|
114 |
+
"authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT"
|
115 |
+
}
|
116 |
+
]
|
117 |
+
}
|
118 |
+
]`
|
119 |
+
```
|
120 |
+
|
121 |
+
The required fields are `name`, `chunkCharLength` and `endpoints`.
|
122 |
+
Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.
|
123 |
+
|
124 |
+
When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.
|
125 |
+
|
126 |
## Extra parameters
|
127 |
|
128 |
### OpenID connect
|
|
|
466 |
|
467 |
If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.
|
468 |
|
469 |
+
#### Specific Embedding Model
|
470 |
+
|
471 |
+
A model can use any of the embedding models defined in `.env.local`, (currently used when web searching),
|
472 |
+
by default it will use the first embedding model, but it can be changed with the field `embeddingModel`:
|
473 |
+
|
474 |
+
```env
|
475 |
+
TEXT_EMBEDDING_MODELS = `[
|
476 |
+
{
|
477 |
+
"name": "Xenova/gte-small",
|
478 |
+
"chunkCharLength": 512,
|
479 |
+
"endpoints": [
|
480 |
+
{"type": "transformersjs"}
|
481 |
+
]
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"name": "intfloat/e5-base-v2",
|
485 |
+
"chunkCharLength": 768,
|
486 |
+
"endpoints": [
|
487 |
+
{"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"},
|
488 |
+
{"type": "tei", "url": "http://127.0.0.1:8081/"}
|
489 |
+
]
|
490 |
+
}
|
491 |
+
]`
|
492 |
+
|
493 |
+
MODELS=`[
|
494 |
+
{
|
495 |
+
"name": "Ollama Mistral",
|
496 |
+
"chatPromptTemplate": "...",
|
497 |
+
"embeddingModel": "intfloat/e5-base-v2"
|
498 |
+
"parameters": {
|
499 |
+
...
|
500 |
+
},
|
501 |
+
"endpoints": [
|
502 |
+
...
|
503 |
+
]
|
504 |
+
}
|
505 |
+
]`
|
506 |
+
```
|
507 |
+
|
508 |
## Deploying to a HF Space
|
509 |
|
510 |
Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.
|
@@ -30,8 +30,8 @@
|
|
30 |
{:else}
|
31 |
<CarbonCheckmark class="my-auto text-gray-500" />
|
32 |
{/if}
|
33 |
-
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}
|
34 |
-
|
35 |
</span>
|
36 |
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
|
37 |
<CarbonCaretRight />
|
|
|
30 |
{:else}
|
31 |
<CarbonCheckmark class="my-auto text-gray-500" />
|
32 |
{/if}
|
33 |
+
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
|
34 |
+
Web search
|
35 |
</span>
|
36 |
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
|
37 |
<CarbonCaretRight />
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { z } from "zod";
|
2 |
+
import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints";
|
3 |
+
import { chunk } from "$lib/utils/chunk";
|
4 |
+
|
5 |
+
export const embeddingEndpointTeiParametersSchema = z.object({
|
6 |
+
weight: z.number().int().positive().default(1),
|
7 |
+
model: z.any(),
|
8 |
+
type: z.literal("tei"),
|
9 |
+
url: z.string().url(),
|
10 |
+
authorization: z.string().optional(),
|
11 |
+
});
|
12 |
+
|
13 |
+
const getModelInfoByUrl = async (url: string, authorization?: string) => {
|
14 |
+
const { origin } = new URL(url);
|
15 |
+
|
16 |
+
const response = await fetch(`${origin}/info`, {
|
17 |
+
headers: {
|
18 |
+
Accept: "application/json",
|
19 |
+
"Content-Type": "application/json",
|
20 |
+
...(authorization ? { Authorization: authorization } : {}),
|
21 |
+
},
|
22 |
+
});
|
23 |
+
|
24 |
+
const json = await response.json();
|
25 |
+
return json;
|
26 |
+
};
|
27 |
+
|
28 |
+
export async function embeddingEndpointTei(
|
29 |
+
input: z.input<typeof embeddingEndpointTeiParametersSchema>
|
30 |
+
): Promise<EmbeddingEndpoint> {
|
31 |
+
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
|
32 |
+
|
33 |
+
const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
|
34 |
+
const maxBatchSize = Math.min(
|
35 |
+
max_client_batch_size,
|
36 |
+
Math.floor(max_batch_tokens / model.chunkCharLength)
|
37 |
+
);
|
38 |
+
|
39 |
+
return async ({ inputs }) => {
|
40 |
+
const { origin } = new URL(url);
|
41 |
+
|
42 |
+
const batchesInputs = chunk(inputs, maxBatchSize);
|
43 |
+
|
44 |
+
const batchesResults = await Promise.all(
|
45 |
+
batchesInputs.map(async (batchInputs) => {
|
46 |
+
const response = await fetch(`${origin}/embed`, {
|
47 |
+
method: "POST",
|
48 |
+
headers: {
|
49 |
+
Accept: "application/json",
|
50 |
+
"Content-Type": "application/json",
|
51 |
+
...(authorization ? { Authorization: authorization } : {}),
|
52 |
+
},
|
53 |
+
body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
|
54 |
+
});
|
55 |
+
|
56 |
+
const embeddings: Embedding[] = await response.json();
|
57 |
+
return embeddings;
|
58 |
+
})
|
59 |
+
);
|
60 |
+
|
61 |
+
const flatAllEmbeddings = batchesResults.flat();
|
62 |
+
|
63 |
+
return flatAllEmbeddings;
|
64 |
+
};
|
65 |
+
}
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { z } from "zod";
|
2 |
+
import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints";
|
3 |
+
import type { Tensor, Pipeline } from "@xenova/transformers";
|
4 |
+
import { pipeline } from "@xenova/transformers";
|
5 |
+
|
6 |
+
export const embeddingEndpointTransformersJSParametersSchema = z.object({
|
7 |
+
weight: z.number().int().positive().default(1),
|
8 |
+
model: z.any(),
|
9 |
+
type: z.literal("transformersjs"),
|
10 |
+
});
|
11 |
+
|
12 |
+
// Use the Singleton pattern to enable lazy construction of the pipeline.
|
13 |
+
class TransformersJSModelsSingleton {
|
14 |
+
static instances: Array<[string, Promise<Pipeline>]> = [];
|
15 |
+
|
16 |
+
static async getInstance(modelName: string): Promise<Pipeline> {
|
17 |
+
const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
|
18 |
+
|
19 |
+
if (modelPipelineInstance) {
|
20 |
+
const [, modelPipeline] = modelPipelineInstance;
|
21 |
+
return modelPipeline;
|
22 |
+
}
|
23 |
+
|
24 |
+
const newModelPipeline = pipeline("feature-extraction", modelName);
|
25 |
+
this.instances.push([modelName, newModelPipeline]);
|
26 |
+
|
27 |
+
return newModelPipeline;
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
export async function calculateEmbedding(modelName: string, inputs: string[]) {
|
32 |
+
const extractor = await TransformersJSModelsSingleton.getInstance(modelName);
|
33 |
+
const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });
|
34 |
+
|
35 |
+
return output.tolist();
|
36 |
+
}
|
37 |
+
|
38 |
+
export function embeddingEndpointTransformersJS(
|
39 |
+
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
|
40 |
+
): EmbeddingEndpoint {
|
41 |
+
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);
|
42 |
+
|
43 |
+
return async ({ inputs }) => {
|
44 |
+
return calculateEmbedding(model.name, inputs);
|
45 |
+
};
|
46 |
+
}
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";
|
2 |
+
|
3 |
+
import { z } from "zod";
|
4 |
+
import { sum } from "$lib/utils/sum";
|
5 |
+
import {
|
6 |
+
embeddingEndpoints,
|
7 |
+
embeddingEndpointSchema,
|
8 |
+
type EmbeddingEndpoint,
|
9 |
+
} from "$lib/types/EmbeddingEndpoints";
|
10 |
+
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
|
11 |
+
|
12 |
+
const modelConfig = z.object({
|
13 |
+
/** Used as an identifier in DB */
|
14 |
+
id: z.string().optional(),
|
15 |
+
/** Used to link to the model page, and for inference */
|
16 |
+
name: z.string().min(1),
|
17 |
+
displayName: z.string().min(1).optional(),
|
18 |
+
description: z.string().min(1).optional(),
|
19 |
+
websiteUrl: z.string().url().optional(),
|
20 |
+
modelUrl: z.string().url().optional(),
|
21 |
+
endpoints: z.array(embeddingEndpointSchema).nonempty(),
|
22 |
+
chunkCharLength: z.number().positive(),
|
23 |
+
preQuery: z.string().default(""),
|
24 |
+
prePassage: z.string().default(""),
|
25 |
+
});
|
26 |
+
|
27 |
+
// Default embedding model for backward compatibility
|
28 |
+
const rawEmbeddingModelJSON =
|
29 |
+
TEXT_EMBEDDING_MODELS ||
|
30 |
+
`[
|
31 |
+
{
|
32 |
+
"name": "Xenova/gte-small",
|
33 |
+
"chunkCharLength": 512,
|
34 |
+
"endpoints": [
|
35 |
+
{ "type": "transformersjs" }
|
36 |
+
]
|
37 |
+
}
|
38 |
+
]`;
|
39 |
+
|
40 |
+
const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON));
|
41 |
+
|
42 |
+
const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
|
43 |
+
...m,
|
44 |
+
id: m.id || m.name,
|
45 |
+
});
|
46 |
+
|
47 |
+
const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
|
48 |
+
...m,
|
49 |
+
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
|
50 |
+
if (!m.endpoints) {
|
51 |
+
return embeddingEndpointTransformersJS({
|
52 |
+
type: "transformersjs",
|
53 |
+
weight: 1,
|
54 |
+
model: m,
|
55 |
+
});
|
56 |
+
}
|
57 |
+
|
58 |
+
const totalWeight = sum(m.endpoints.map((e) => e.weight));
|
59 |
+
|
60 |
+
let random = Math.random() * totalWeight;
|
61 |
+
|
62 |
+
for (const endpoint of m.endpoints) {
|
63 |
+
if (random < endpoint.weight) {
|
64 |
+
const args = { ...endpoint, model: m };
|
65 |
+
|
66 |
+
switch (args.type) {
|
67 |
+
case "tei":
|
68 |
+
return embeddingEndpoints.tei(args);
|
69 |
+
case "transformersjs":
|
70 |
+
return embeddingEndpoints.transformersjs(args);
|
71 |
+
}
|
72 |
+
}
|
73 |
+
|
74 |
+
random -= endpoint.weight;
|
75 |
+
}
|
76 |
+
|
77 |
+
throw new Error(`Failed to select embedding endpoint`);
|
78 |
+
},
|
79 |
+
});
|
80 |
+
|
81 |
+
export const embeddingModels = await Promise.all(
|
82 |
+
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
|
83 |
+
);
|
84 |
+
|
85 |
+
export const defaultEmbeddingModel = embeddingModels[0];
|
86 |
+
|
87 |
+
const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
|
88 |
+
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
|
89 |
+
};
|
90 |
+
|
91 |
+
export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
|
92 |
+
return validateEmbeddingModel(_models, "id");
|
93 |
+
};
|
94 |
+
|
95 |
+
export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
|
96 |
+
return validateEmbeddingModel(_models, "name");
|
97 |
+
};
|
98 |
+
|
99 |
+
export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
|
@@ -12,6 +12,7 @@ import { z } from "zod";
|
|
12 |
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
|
13 |
import endpointTgi from "./endpoints/tgi/endpointTgi";
|
14 |
import { sum } from "$lib/utils/sum";
|
|
|
15 |
|
16 |
import JSON5 from "json5";
|
17 |
|
@@ -68,6 +69,7 @@ const modelConfig = z.object({
|
|
68 |
.optional(),
|
69 |
multimodal: z.boolean().default(false),
|
70 |
unlisted: z.boolean().default(false),
|
|
|
71 |
});
|
72 |
|
73 |
const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
|
|
|
12 |
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
|
13 |
import endpointTgi from "./endpoints/tgi/endpointTgi";
|
14 |
import { sum } from "$lib/utils/sum";
|
15 |
+
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
|
16 |
|
17 |
import JSON5 from "json5";
|
18 |
|
|
|
69 |
.optional(),
|
70 |
multimodal: z.boolean().default(false),
|
71 |
unlisted: z.boolean().default(false),
|
72 |
+
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
|
73 |
});
|
74 |
|
75 |
const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
|
@@ -1,43 +1,33 @@
|
|
1 |
-
import
|
2 |
-
import {
|
|
|
3 |
|
4 |
// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
|
5 |
-
function innerProduct(
|
6 |
-
return 1.0 - dot(
|
7 |
}
|
8 |
|
9 |
-
// Use the Singleton pattern to enable lazy construction of the pipeline.
|
10 |
-
class PipelineSingleton {
|
11 |
-
static modelId = "Xenova/gte-small";
|
12 |
-
static instance: Promise<Pipeline> | null = null;
|
13 |
-
static async getInstance() {
|
14 |
-
if (this.instance === null) {
|
15 |
-
this.instance = pipeline("feature-extraction", this.modelId);
|
16 |
-
}
|
17 |
-
return this.instance;
|
18 |
-
}
|
19 |
-
}
|
20 |
-
|
21 |
-
// see https://huggingface.co/thenlper/gte-small/blob/d8e2604cadbeeda029847d19759d219e0ce2e6d8/README.md?code=true#L2625
|
22 |
-
export const MAX_SEQ_LEN = 512 as const;
|
23 |
-
|
24 |
export async function findSimilarSentences(
|
|
|
25 |
query: string,
|
26 |
sentences: string[],
|
27 |
{ topK = 5 }: { topK: number }
|
28 |
-
) {
|
29 |
-
const
|
|
|
|
|
|
|
30 |
|
31 |
-
const
|
32 |
-
const output
|
33 |
|
34 |
-
const
|
35 |
-
const
|
36 |
|
37 |
-
const distancesFromQuery: { distance: number; index: number }[] = [...
|
38 |
-
(
|
39 |
return {
|
40 |
-
distance: innerProduct(
|
41 |
index: index,
|
42 |
};
|
43 |
}
|
|
|
1 |
+
import { dot } from "@xenova/transformers";
|
2 |
+
import type { EmbeddingBackendModel } from "$lib/server/embeddingModels";
|
3 |
+
import type { Embedding } from "$lib/types/EmbeddingEndpoints";
|
4 |
|
5 |
// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
|
6 |
+
function innerProduct(embeddingA: Embedding, embeddingB: Embedding) {
|
7 |
+
return 1.0 - dot(embeddingA, embeddingB);
|
8 |
}
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
export async function findSimilarSentences(
|
11 |
+
embeddingModel: EmbeddingBackendModel,
|
12 |
query: string,
|
13 |
sentences: string[],
|
14 |
{ topK = 5 }: { topK: number }
|
15 |
+
): Promise<Embedding> {
|
16 |
+
const inputs = [
|
17 |
+
`${embeddingModel.preQuery}${query}`,
|
18 |
+
...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`),
|
19 |
+
];
|
20 |
|
21 |
+
const embeddingEndpoint = await embeddingModel.getEndpoint();
|
22 |
+
const output = await embeddingEndpoint({ inputs });
|
23 |
|
24 |
+
const queryEmbedding: Embedding = output[0];
|
25 |
+
const sentencesEmbeddings: Embedding[] = output.slice(1, inputs.length - 1);
|
26 |
|
27 |
+
const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map(
|
28 |
+
(sentenceEmbedding: Embedding, index: number) => {
|
29 |
return {
|
30 |
+
distance: innerProduct(queryEmbedding, sentenceEmbedding),
|
31 |
index: index,
|
32 |
};
|
33 |
}
|
@@ -4,13 +4,11 @@ import type { WebSearch, WebSearchSource } from "$lib/types/WebSearch";
|
|
4 |
import { generateQuery } from "$lib/server/websearch/generateQuery";
|
5 |
import { parseWeb } from "$lib/server/websearch/parseWeb";
|
6 |
import { chunk } from "$lib/utils/chunk";
|
7 |
-
import {
|
8 |
-
MAX_SEQ_LEN as CHUNK_CAR_LEN,
|
9 |
-
findSimilarSentences,
|
10 |
-
} from "$lib/server/websearch/sentenceSimilarity";
|
11 |
import type { Conversation } from "$lib/types/Conversation";
|
12 |
import type { MessageUpdate } from "$lib/types/MessageUpdate";
|
13 |
import { getWebSearchProvider } from "./searchWeb";
|
|
|
14 |
|
15 |
const MAX_N_PAGES_SCRAPE = 10 as const;
|
16 |
const MAX_N_PAGES_EMBED = 5 as const;
|
@@ -63,6 +61,14 @@ export async function runWebSearch(
|
|
63 |
.filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links
|
64 |
.slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
let paragraphChunks: { source: WebSearchSource; text: string }[] = [];
|
67 |
if (webSearch.results.length > 0) {
|
68 |
appendUpdate("Browsing results");
|
@@ -78,7 +84,7 @@ export async function runWebSearch(
|
|
78 |
}
|
79 |
}
|
80 |
const MAX_N_CHUNKS = 100;
|
81 |
-
const texts = chunk(text,
|
82 |
return texts.map((t) => ({ source: result, text: t }));
|
83 |
});
|
84 |
const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED);
|
@@ -93,7 +99,7 @@ export async function runWebSearch(
|
|
93 |
appendUpdate("Extracting relevant information");
|
94 |
const topKClosestParagraphs = 8;
|
95 |
const texts = paragraphChunks.map(({ text }) => text);
|
96 |
-
const indices = await findSimilarSentences(prompt, texts, {
|
97 |
topK: topKClosestParagraphs,
|
98 |
});
|
99 |
webSearch.context = indices.map((idx) => texts[idx]).join("");
|
|
|
4 |
import { generateQuery } from "$lib/server/websearch/generateQuery";
|
5 |
import { parseWeb } from "$lib/server/websearch/parseWeb";
|
6 |
import { chunk } from "$lib/utils/chunk";
|
7 |
+
import { findSimilarSentences } from "$lib/server/sentenceSimilarity";
|
|
|
|
|
|
|
8 |
import type { Conversation } from "$lib/types/Conversation";
|
9 |
import type { MessageUpdate } from "$lib/types/MessageUpdate";
|
10 |
import { getWebSearchProvider } from "./searchWeb";
|
11 |
+
import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels";
|
12 |
|
13 |
const MAX_N_PAGES_SCRAPE = 10 as const;
|
14 |
const MAX_N_PAGES_EMBED = 5 as const;
|
|
|
61 |
.filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links
|
62 |
.slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only
|
63 |
|
64 |
+
// fetch the model
|
65 |
+
const embeddingModel =
|
66 |
+
embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel;
|
67 |
+
|
68 |
+
if (!embeddingModel) {
|
69 |
+
throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`);
|
70 |
+
}
|
71 |
+
|
72 |
let paragraphChunks: { source: WebSearchSource; text: string }[] = [];
|
73 |
if (webSearch.results.length > 0) {
|
74 |
appendUpdate("Browsing results");
|
|
|
84 |
}
|
85 |
}
|
86 |
const MAX_N_CHUNKS = 100;
|
87 |
+
const texts = chunk(text, embeddingModel.chunkCharLength).slice(0, MAX_N_CHUNKS);
|
88 |
return texts.map((t) => ({ source: result, text: t }));
|
89 |
});
|
90 |
const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED);
|
|
|
99 |
appendUpdate("Extracting relevant information");
|
100 |
const topKClosestParagraphs = 8;
|
101 |
const texts = paragraphChunks.map(({ text }) => text);
|
102 |
+
const indices = await findSimilarSentences(embeddingModel, prompt, texts, {
|
103 |
topK: topKClosestParagraphs,
|
104 |
});
|
105 |
webSearch.context = indices.map((idx) => texts[idx]).join("");
|
@@ -10,6 +10,7 @@ export interface Conversation extends Timestamps {
|
|
10 |
userId?: User["_id"];
|
11 |
|
12 |
model: string;
|
|
|
13 |
|
14 |
title: string;
|
15 |
messages: Message[];
|
|
|
10 |
userId?: User["_id"];
|
11 |
|
12 |
model: string;
|
13 |
+
embeddingModel: string;
|
14 |
|
15 |
title: string;
|
16 |
messages: Message[];
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { z } from "zod";
|
2 |
+
import {
|
3 |
+
embeddingEndpointTei,
|
4 |
+
embeddingEndpointTeiParametersSchema,
|
5 |
+
} from "$lib/server/embeddingEndpoints/tei/embeddingEndpoints";
|
6 |
+
import {
|
7 |
+
embeddingEndpointTransformersJS,
|
8 |
+
embeddingEndpointTransformersJSParametersSchema,
|
9 |
+
} from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
|
10 |
+
|
11 |
+
// parameters passed when generating text
|
12 |
+
interface EmbeddingEndpointParameters {
|
13 |
+
inputs: string[];
|
14 |
+
}
|
15 |
+
|
16 |
+
export type Embedding = number[];
|
17 |
+
|
18 |
+
// type signature for the endpoint
|
19 |
+
export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise<Embedding[]>;
|
20 |
+
|
21 |
+
export const embeddingEndpointSchema = z.discriminatedUnion("type", [
|
22 |
+
embeddingEndpointTeiParametersSchema,
|
23 |
+
embeddingEndpointTransformersJSParametersSchema,
|
24 |
+
]);
|
25 |
+
|
26 |
+
type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
|
27 |
+
|
28 |
+
// generator function that takes in type discrimantor value for defining the endpoint and return the endpoint
|
29 |
+
export type EmbeddingEndpointGenerator<T extends EmbeddingEndpointTypeOptions> = (
|
30 |
+
inputs: Extract<z.infer<typeof embeddingEndpointSchema>, { type: T }>
|
31 |
+
) => EmbeddingEndpoint | Promise<EmbeddingEndpoint>;
|
32 |
+
|
33 |
+
// list of all endpoint generators
|
34 |
+
export const embeddingEndpoints: {
|
35 |
+
[Key in EmbeddingEndpointTypeOptions]: EmbeddingEndpointGenerator<Key>;
|
36 |
+
} = {
|
37 |
+
tei: embeddingEndpointTei,
|
38 |
+
transformersjs: embeddingEndpointTransformersJS,
|
39 |
+
};
|
40 |
+
|
41 |
+
export default embeddingEndpoints;
|
@@ -7,6 +7,8 @@ export interface SharedConversation extends Timestamps {
|
|
7 |
hash: string;
|
8 |
|
9 |
model: string;
|
|
|
|
|
10 |
title: string;
|
11 |
messages: Message[];
|
12 |
preprompt?: string;
|
|
|
7 |
hash: string;
|
8 |
|
9 |
model: string;
|
10 |
+
embeddingModel: string;
|
11 |
+
|
12 |
title: string;
|
13 |
messages: Message[];
|
14 |
preprompt?: string;
|
@@ -6,6 +6,7 @@ import { base } from "$app/paths";
|
|
6 |
import { z } from "zod";
|
7 |
import type { Message } from "$lib/types/Message";
|
8 |
import { models, validateModel } from "$lib/server/models";
|
|
|
9 |
|
10 |
export const POST: RequestHandler = async ({ locals, request }) => {
|
11 |
const body = await request.text();
|
@@ -22,6 +23,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
|
|
22 |
.parse(JSON.parse(body));
|
23 |
|
24 |
let preprompt = values.preprompt;
|
|
|
25 |
|
26 |
if (values.fromShare) {
|
27 |
const conversation = await collections.sharedConversations.findOne({
|
@@ -35,6 +37,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
|
|
35 |
title = conversation.title;
|
36 |
messages = conversation.messages;
|
37 |
values.model = conversation.model;
|
|
|
38 |
preprompt = conversation.preprompt;
|
39 |
}
|
40 |
|
@@ -44,6 +47,8 @@ export const POST: RequestHandler = async ({ locals, request }) => {
|
|
44 |
throw error(400, "Invalid model");
|
45 |
}
|
46 |
|
|
|
|
|
47 |
if (model.unlisted) {
|
48 |
throw error(400, "Can't start a conversation with an unlisted model");
|
49 |
}
|
@@ -59,6 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
|
|
59 |
preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt,
|
60 |
createdAt: new Date(),
|
61 |
updatedAt: new Date(),
|
|
|
62 |
...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }),
|
63 |
...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}),
|
64 |
});
|
|
|
6 |
import { z } from "zod";
|
7 |
import type { Message } from "$lib/types/Message";
|
8 |
import { models, validateModel } from "$lib/server/models";
|
9 |
+
import { defaultEmbeddingModel } from "$lib/server/embeddingModels";
|
10 |
|
11 |
export const POST: RequestHandler = async ({ locals, request }) => {
|
12 |
const body = await request.text();
|
|
|
23 |
.parse(JSON.parse(body));
|
24 |
|
25 |
let preprompt = values.preprompt;
|
26 |
+
let embeddingModel: string;
|
27 |
|
28 |
if (values.fromShare) {
|
29 |
const conversation = await collections.sharedConversations.findOne({
|
|
|
37 |
title = conversation.title;
|
38 |
messages = conversation.messages;
|
39 |
values.model = conversation.model;
|
40 |
+
embeddingModel = conversation.embeddingModel;
|
41 |
preprompt = conversation.preprompt;
|
42 |
}
|
43 |
|
|
|
47 |
throw error(400, "Invalid model");
|
48 |
}
|
49 |
|
50 |
+
embeddingModel ??= model.embeddingModel ?? defaultEmbeddingModel.name;
|
51 |
+
|
52 |
if (model.unlisted) {
|
53 |
throw error(400, "Can't start a conversation with an unlisted model");
|
54 |
}
|
|
|
64 |
preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt,
|
65 |
createdAt: new Date(),
|
66 |
updatedAt: new Date(),
|
67 |
+
embeddingModel: embeddingModel,
|
68 |
...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }),
|
69 |
...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}),
|
70 |
});
|
@@ -173,6 +173,7 @@
|
|
173 |
inputs.forEach(async (el: string) => {
|
174 |
try {
|
175 |
const update = JSON.parse(el) as MessageUpdate;
|
|
|
176 |
if (update.type === "finalAnswer") {
|
177 |
finalAnswer = update.text;
|
178 |
reader.cancel();
|
@@ -225,7 +226,7 @@
|
|
225 |
});
|
226 |
}
|
227 |
|
228 |
-
// reset the
|
229 |
webSearchMessages = [];
|
230 |
|
231 |
await invalidate(UrlDependency.ConversationList);
|
|
|
173 |
inputs.forEach(async (el: string) => {
|
174 |
try {
|
175 |
const update = JSON.parse(el) as MessageUpdate;
|
176 |
+
|
177 |
if (update.type === "finalAnswer") {
|
178 |
finalAnswer = update.text;
|
179 |
reader.cancel();
|
|
|
226 |
});
|
227 |
}
|
228 |
|
229 |
+
// reset the websearchMessages
|
230 |
webSearchMessages = [];
|
231 |
|
232 |
await invalidate(UrlDependency.ConversationList);
|
@@ -38,6 +38,7 @@ export async function POST({ params, url, locals }) {
|
|
38 |
updatedAt: new Date(),
|
39 |
title: conversation.title,
|
40 |
model: conversation.model,
|
|
|
41 |
preprompt: conversation.preprompt,
|
42 |
};
|
43 |
|
|
|
38 |
updatedAt: new Date(),
|
39 |
title: conversation.title,
|
40 |
model: conversation.model,
|
41 |
+
embeddingModel: conversation.embeddingModel,
|
42 |
preprompt: conversation.preprompt,
|
43 |
};
|
44 |
|
@@ -6,6 +6,7 @@ import { ObjectId } from "mongodb";
|
|
6 |
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
|
7 |
import { defaultModel } from "$lib/server/models";
|
8 |
import { findUser } from "$lib/server/auth";
|
|
|
9 |
|
10 |
const userData = {
|
11 |
preferred_username: "new-username",
|
@@ -46,6 +47,7 @@ const insertRandomConversations = async (count: number) => {
|
|
46 |
title: "random title",
|
47 |
messages: [],
|
48 |
model: defaultModel.id,
|
|
|
49 |
createdAt: new Date(),
|
50 |
updatedAt: new Date(),
|
51 |
sessionId: locals.sessionId,
|
|
|
6 |
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
|
7 |
import { defaultModel } from "$lib/server/models";
|
8 |
import { findUser } from "$lib/server/auth";
|
9 |
+
import { defaultEmbeddingModel } from "$lib/server/embeddingModels";
|
10 |
|
11 |
const userData = {
|
12 |
preferred_username: "new-username",
|
|
|
47 |
title: "random title",
|
48 |
messages: [],
|
49 |
model: defaultModel.id,
|
50 |
+
embeddingModel: defaultEmbeddingModel.id,
|
51 |
createdAt: new Date(),
|
52 |
updatedAt: new Date(),
|
53 |
sessionId: locals.sessionId,
|