Switch chat model back to mistral, use zephyr for small tasks (#515)
Browse files* Switch chat model back to mistral, use zephyr for small tasks
* typo
* fix tests
- .env.template +42 -10
- src/lib/server/models.ts +69 -67
.env.template
CHANGED
@@ -94,24 +94,24 @@ MODELS=`[
|
|
94 |
]
|
95 |
},
|
96 |
{
|
97 |
-
"name": "
|
98 |
-
"displayName": "
|
99 |
-
"description": "
|
100 |
-
"websiteUrl": "https://
|
101 |
"preprompt": "",
|
102 |
-
"chatPromptTemplate" : "
|
103 |
"parameters": {
|
104 |
-
"temperature": 0.
|
105 |
"top_p": 0.95,
|
106 |
"repetition_penalty": 1.2,
|
107 |
"top_k": 50,
|
108 |
"truncate": 1000,
|
109 |
"max_new_tokens": 2048,
|
110 |
-
"stop": ["</s>"
|
111 |
},
|
112 |
"promptExamples": [
|
113 |
{
|
114 |
-
|
115 |
"prompt": "As a restaurant owner, write a professional email to the supplier to get these products every week: \n\n- Wine (x10)\n- Eggs (x24)\n- Bread (x12)"
|
116 |
}, {
|
117 |
"title": "Code a snake game",
|
@@ -124,8 +124,40 @@ MODELS=`[
|
|
124 |
}
|
125 |
]`
|
126 |
|
127 |
-
OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-llama-30b-xor"}, {"name":"
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
APP_BASE="/chat"
|
131 |
PUBLIC_ORIGIN=https://huggingface.co
|
|
|
94 |
]
|
95 |
},
|
96 |
{
|
97 |
+
"name": "mistralai/Mistral-7B-Instruct-v0.1",
|
98 |
+
"displayName": "mistralai/Mistral-7B-Instruct-v0.1",
|
99 |
+
"description": "Mistral 7B is a new Apache 2.0 model, released by Mistral AI that outperforms Llama2 13B in benchmarks.",
|
100 |
+
"websiteUrl": "https://mistral.ai/news/announcing-mistral-7b/",
|
101 |
"preprompt": "",
|
102 |
+
"chatPromptTemplate" : "<s>{{#each messages}}{{#ifUser}}[INST] {{#if @first}}{{#if @root.preprompt}}{{@root.preprompt}}\n{{/if}}{{/if}}{{content}} [/INST]{{/ifUser}}{{#ifAssistant}}{{content}}</s>{{/ifAssistant}}{{/each}}",
|
103 |
"parameters": {
|
104 |
+
"temperature": 0.1,
|
105 |
"top_p": 0.95,
|
106 |
"repetition_penalty": 1.2,
|
107 |
"top_k": 50,
|
108 |
"truncate": 1000,
|
109 |
"max_new_tokens": 2048,
|
110 |
+
"stop": ["</s>"]
|
111 |
},
|
112 |
"promptExamples": [
|
113 |
{
|
114 |
+
"title": "Write an email from bullet list",
|
115 |
"prompt": "As a restaurant owner, write a professional email to the supplier to get these products every week: \n\n- Wine (x10)\n- Eggs (x24)\n- Bread (x12)"
|
116 |
}, {
|
117 |
"title": "Code a snake game",
|
|
|
124 |
}
|
125 |
]`
|
126 |
|
127 |
+
OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-llama-30b-xor"}, {"name":"HuggingFaceH4/zephyr-7b-alpha"}]`
|
128 |
+
|
129 |
+
TASK_MODEL='
|
130 |
+
{
|
131 |
+
"name": "HuggingFaceH4/zephyr-7b-alpha",
|
132 |
+
"displayName": "HuggingFaceH4/zephyr-7b-alpha",
|
133 |
+
"description": "Zephyr 7B α is a fine-tune of Mistral 7B, released by the Hugging Face H4 RLHF team.",
|
134 |
+
"websiteUrl": "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha/",
|
135 |
+
"preprompt": "",
|
136 |
+
"chatPromptTemplate" : "<|system|>\n{{preprompt}}</s>\n{{#each messages}}{{#ifUser}}<|user|>\n{{content}}</s>\n<|assistant|>\n{{/ifUser}}{{#ifAssistant}}{{content}}</s>\n{{/ifAssistant}}{{/each}}",
|
137 |
+
"parameters": {
|
138 |
+
"temperature": 0.7,
|
139 |
+
"top_p": 0.95,
|
140 |
+
"repetition_penalty": 1.2,
|
141 |
+
"top_k": 50,
|
142 |
+
"truncate": 1000,
|
143 |
+
"max_new_tokens": 2048,
|
144 |
+
"stop": ["</s>", "<|>"]
|
145 |
+
},
|
146 |
+
"promptExamples": [
|
147 |
+
{
|
148 |
+
"title": "Write an email from bullet list",
|
149 |
+
"prompt": "As a restaurant owner, write a professional email to the supplier to get these products every week: \n\n- Wine (x10)\n- Eggs (x24)\n- Bread (x12)"
|
150 |
+
}, {
|
151 |
+
"title": "Code a snake game",
|
152 |
+
"prompt": "Code a basic snake game in python, give explanations for each step."
|
153 |
+
}, {
|
154 |
+
"title": "Assist in a task",
|
155 |
+
"prompt": "How do I make a delicious lemon cheesecake?"
|
156 |
+
}
|
157 |
+
]
|
158 |
+
}
|
159 |
+
'
|
160 |
+
|
161 |
|
162 |
APP_BASE="/chat"
|
163 |
PUBLIC_ORIGIN=https://huggingface.co
|
src/lib/server/models.ts
CHANGED
@@ -37,70 +37,68 @@ const combinedEndpoint = endpoint.transform((data) => {
|
|
37 |
}
|
38 |
});
|
39 |
|
40 |
-
const
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
.
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
max_new_tokens: z.number().int().positive(),
|
84 |
-
stop: z.array(z.string()).optional(),
|
85 |
-
})
|
86 |
-
.passthrough()
|
87 |
-
.optional(),
|
88 |
})
|
89 |
-
|
90 |
-
|
|
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
}
|
103 |
-
);
|
|
|
|
|
104 |
|
105 |
// Models that have been deprecated
|
106 |
export const oldModels = OLD_MODELS
|
@@ -116,14 +114,18 @@ export const oldModels = OLD_MODELS
|
|
116 |
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
|
117 |
: [];
|
118 |
|
119 |
-
export type BackendModel = Optional<(typeof models)[0], "preprompt" | "parameters">;
|
120 |
-
export type Endpoint = z.infer<typeof endpoint>;
|
121 |
-
|
122 |
export const defaultModel = models[0];
|
123 |
|
124 |
-
export const smallModel = models.find((m) => m.name === TASK_MODEL) || defaultModel;
|
125 |
-
|
126 |
export const validateModel = (_models: BackendModel[]) => {
|
127 |
// Zod enum function requires 2 parameters
|
128 |
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
|
129 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
}
|
38 |
});
|
39 |
|
40 |
+
const modelConfig = z.object({
|
41 |
+
/** Used as an identifier in DB */
|
42 |
+
id: z.string().optional(),
|
43 |
+
/** Used to link to the model page, and for inference */
|
44 |
+
name: z.string().min(1),
|
45 |
+
displayName: z.string().min(1).optional(),
|
46 |
+
description: z.string().min(1).optional(),
|
47 |
+
websiteUrl: z.string().url().optional(),
|
48 |
+
modelUrl: z.string().url().optional(),
|
49 |
+
datasetName: z.string().min(1).optional(),
|
50 |
+
datasetUrl: z.string().url().optional(),
|
51 |
+
userMessageToken: z.string().default(""),
|
52 |
+
userMessageEndToken: z.string().default(""),
|
53 |
+
assistantMessageToken: z.string().default(""),
|
54 |
+
assistantMessageEndToken: z.string().default(""),
|
55 |
+
messageEndToken: z.string().default(""),
|
56 |
+
preprompt: z.string().default(""),
|
57 |
+
prepromptUrl: z.string().url().optional(),
|
58 |
+
chatPromptTemplate: z
|
59 |
+
.string()
|
60 |
+
.default(
|
61 |
+
"{{preprompt}}" +
|
62 |
+
"{{#each messages}}" +
|
63 |
+
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
|
64 |
+
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
|
65 |
+
"{{/each}}" +
|
66 |
+
"{{assistantMessageToken}}"
|
67 |
+
),
|
68 |
+
promptExamples: z
|
69 |
+
.array(
|
70 |
+
z.object({
|
71 |
+
title: z.string().min(1),
|
72 |
+
prompt: z.string().min(1),
|
73 |
+
})
|
74 |
+
)
|
75 |
+
.optional(),
|
76 |
+
endpoints: z.array(combinedEndpoint).optional(),
|
77 |
+
parameters: z
|
78 |
+
.object({
|
79 |
+
temperature: z.number().min(0).max(1),
|
80 |
+
truncate: z.number().int().positive(),
|
81 |
+
max_new_tokens: z.number().int().positive(),
|
82 |
+
stop: z.array(z.string()).optional(),
|
|
|
|
|
|
|
|
|
|
|
83 |
})
|
84 |
+
.passthrough()
|
85 |
+
.optional(),
|
86 |
+
});
|
87 |
|
88 |
+
const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
|
89 |
+
|
90 |
+
const processModel = async (m: z.infer<typeof modelConfig>) => ({
|
91 |
+
...m,
|
92 |
+
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
|
93 |
+
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
|
94 |
+
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
|
95 |
+
id: m.id || m.name,
|
96 |
+
displayName: m.displayName || m.name,
|
97 |
+
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
|
98 |
+
parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
|
99 |
+
});
|
100 |
+
|
101 |
+
export const models = await Promise.all(modelsRaw.map(processModel));
|
102 |
|
103 |
// Models that have been deprecated
|
104 |
export const oldModels = OLD_MODELS
|
|
|
114 |
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
|
115 |
: [];
|
116 |
|
|
|
|
|
|
|
117 |
export const defaultModel = models[0];
|
118 |
|
|
|
|
|
119 |
export const validateModel = (_models: BackendModel[]) => {
|
120 |
// Zod enum function requires 2 parameters
|
121 |
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
|
122 |
};
|
123 |
+
|
124 |
+
// if `TASK_MODEL` is the name of a model we use it, else we try to parse `TASK_MODEL` as a model config itself
|
125 |
+
export const smallModel = TASK_MODEL
|
126 |
+
? models.find((m) => m.name === TASK_MODEL) ||
|
127 |
+
(await processModel(modelConfig.parse(JSON.parse(TASK_MODEL))))
|
128 |
+
: defaultModel;
|
129 |
+
|
130 |
+
export type BackendModel = Optional<(typeof models)[0], "preprompt" | "parameters">;
|
131 |
+
export type Endpoint = z.infer<typeof endpoint>;
|