Jeronymous commited on
Commit
0dcddb0
1 Parent(s): 30af681

initial commit

Browse files
Files changed (3) hide show
  1. README.md +4 -4
  2. app.py +418 -0
  3. requirements.txt +11 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Claire Chat
3
- emoji: 📉
4
  colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.4.0
8
  app_file: app.py
9
- pinned: false
10
  license: cc-by-nc-sa-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Claire Chat
3
+ emoji: 🎙💬
4
  colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.4.0
8
  app_file: app.py
9
+ pinned: true
10
  license: cc-by-nc-sa-4.0
11
  ---
12
 
13
+ Démo de conversations en Français générées par [Claire](https://huggingface.co/OpenLLM-France/Claire-7B-0.1)
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
+ import torch
5
+ import unicodedata
6
+ import re
7
+
8
+ # Default variables
9
+ default_max_new_tokens = 100
10
+ default_temperature = 1.0
11
+ default_top_k = 10
12
+ default_top_p = 0.99
13
+ default_repetition_penalty = 1.0
14
+
15
+ model_name = "OpenLLM-France/Claire-7B-0.1"
16
+
17
+ print("Loading model...")
18
+
19
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
20
+ model = transformers.AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ device_map="auto",
23
+ torch_dtype=torch.bfloat16,
24
+ load_in_4bit=True,
25
+ )
26
+
27
+ print("Optimizing model...")
28
+
29
+ import optimum
30
+ from optimum.bettertransformer import BetterTransformer
31
+
32
+ model = BetterTransformer.transform(model)
33
+
34
+ print("Setup chat...")
35
+
36
+ eos_token_id = tokenizer.eos_token_id
37
+ newspk_token_id = tokenizer.encode("[")
38
+ assert len(newspk_token_id) == 1
39
+ newspk_token_id = newspk_token_id[0]
40
+
41
+
42
+ # Class to encapsulate the Claire chatbot
43
+ class ClaireChatBot:
44
+ def __init__(
45
+ self,
46
+ # Chat will display...
47
+ user_name="VOUS:",
48
+ bot_name="CHATBOT:",
49
+ other_name_regex_in=r"AUTRE (\d+):",
50
+ other_name_regex_out=r"AUTRE \1:",
51
+ # but Claire was trained on...
52
+ user_internal_tag="[Intervenant 1:]",
53
+ bot_internal_tag="[Intervenant 2:]",
54
+ other_internal_tag_regex_in=r"\[Intervenant (\d+):\]",
55
+ other_internal_tag_regex_out=r"\[Intervenant \1:\]",
56
+ ):
57
+ self.user_name = user_name
58
+ self.bot_name = bot_name
59
+ self.other_name_regex_in = other_name_regex_in
60
+ self.other_name_regex_out = other_name_regex_out
61
+
62
+ self.user_internal_tag = user_internal_tag
63
+ self.bot_internal_tag = bot_internal_tag
64
+ self.other_internal_tag_regex_in = other_internal_tag_regex_in
65
+ self.other_internal_tag_regex_out = other_internal_tag_regex_out
66
+
67
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
68
+
69
+ self.has_started_bracket = False
70
+ self.history = ""
71
+ self.history_raw = ""
72
+ self.reinject_history = False
73
+ self.reshow_history = False
74
+
75
+ def predict(
76
+ self,
77
+ user_message,
78
+ bot_message_start="",
79
+ conversation_history="",
80
+ generate_several_turns=False,
81
+ max_new_tokens=default_max_new_tokens,
82
+ temperature=default_temperature,
83
+ top_k=default_top_k,
84
+ top_p=default_top_p,
85
+ repetition_penalty=default_repetition_penalty,
86
+ ):
87
+ user_message = claire_text_preproc_message(user_message)
88
+ bot_message_start = claire_text_preproc_message(bot_message_start)
89
+
90
+ if conversation_history:
91
+ # Format conversation history
92
+ for spk_in, spk_out in [
93
+ (self.user_name, self.user_internal_tag),
94
+ (self.bot_name, self.bot_internal_tag),
95
+ ]:
96
+ conversation_history = conversation_history.replace(spk_in, spk_out)
97
+ conversation_history = re.sub(self.other_name_regex_in, self.other_internal_tag_regex_out, conversation_history)
98
+ conversation_history = claire_text_preproc_conversation(conversation_history)
99
+ conversation_history = conversation_history.rstrip() + "\n"
100
+ else:
101
+ conversation_history = self.history_raw
102
+
103
+ # (Only relevant if self.reinject_history is True)
104
+ user_internal_tag = self.user_internal_tag
105
+ if self.has_started_bracket:
106
+ user_internal_tag = user_internal_tag[1:]
107
+
108
+ # Combine the user and bot messages into a conversation
109
+ conversation = f"{conversation_history}{user_internal_tag} {user_message}\n{self.bot_internal_tag} {bot_message_start if bot_message_start else ''}".strip()
110
+
111
+ # Encode the conversation using the tokenizer
112
+ input_ids = tokenizer.encode(
113
+ conversation, return_tensors="pt", add_special_tokens=False
114
+ )
115
+ input_ids = input_ids.to(self.device)
116
+
117
+ # Generate a response using Claire
118
+ response = model.generate(
119
+ input_ids=input_ids,
120
+ use_cache=False,
121
+ early_stopping=False,
122
+ temperature=temperature,
123
+ do_sample=True,
124
+ max_new_tokens=max_new_tokens,
125
+ top_k=top_k,
126
+ top_p=top_p,
127
+ repetition_penalty=repetition_penalty,
128
+ pad_token_id=eos_token_id,
129
+ eos_token_id=eos_token_id if generate_several_turns else newspk_token_id,
130
+ )
131
+
132
+ # Decode the generated response to text
133
+ response_text = tokenizer.decode(response[0], skip_special_tokens=True)
134
+
135
+ # Remove last unfinished speech turn/sentence/phrase
136
+ line_breaks = [u.span(0)[0] for u in re.finditer("\n", response_text)]
137
+ remove_last_sentence = True
138
+ if generate_several_turns:
139
+ if len(line_breaks) >= 2:
140
+ response_text = response_text[: line_breaks[-1]]
141
+ line_breaks.pop(-1)
142
+ remove_last_sentence = False
143
+ if remove_last_sentence and len(line_breaks) == 1:
144
+ sentence_ends = [
145
+ u.span(0)[0] for u in re.finditer(r"[\.!?]", response_text)
146
+ ]
147
+ sentence_ends = [p for p in sentence_ends if p > line_breaks[-1]]
148
+ if sentence_ends:
149
+ response_text = response_text[: sentence_ends[-1] + 1]
150
+ else:
151
+ phrase_ends = [
152
+ u.span(0)[0] for u in re.finditer(r"[,;]", response_text)
153
+ ]
154
+ phrase_ends = [p for p in phrase_ends if p > line_breaks[-1]]
155
+ if phrase_ends:
156
+ response_text = response_text[: phrase_ends[-1] + 1]
157
+
158
+ ended_with_bracket = response_text.endswith("[")
159
+
160
+ if self.reinject_history:
161
+ self.history_raw = response_text
162
+ self.has_started_bracket = ended_with_bracket
163
+
164
+ if ended_with_bracket:
165
+ response_text = response_text[:-1]
166
+
167
+ for spk_in, spk_out in [
168
+ (self.user_internal_tag, self.user_name),
169
+ (self.user_internal_tag[1:], self.user_name), # Starting bracket may be missing
170
+ (self.bot_internal_tag, self.bot_name),
171
+ ]:
172
+ response_text = response_text.replace(spk_in, spk_out)
173
+ response_text = re.sub(self.other_internal_tag_regex_in, self.other_name_regex_out, response_text)
174
+
175
+ if self.reshow_history:
176
+ previous_history = self.history
177
+ self.history = previous_history + response_text + "\n"
178
+ else:
179
+ previous_history = ""
180
+
181
+ return previous_history + response_text
182
+
183
+
184
+ def claire_text_preproc_conversation(text):
185
+ text = format_special_characters(text)
186
+ text = collapse_whitespaces_conversations(text)
187
+ return text
188
+
189
+
190
+ def claire_text_preproc_message(text):
191
+ text = format_special_characters(text)
192
+ text = collapse_whitespaces_message(text)
193
+ text = replace_brackets(text)
194
+ return text
195
+
196
+
197
+ def collapse_whitespaces_conversations(text):
198
+ text = re.sub(r"\n+", "\n", text)
199
+ text = re.sub(r"[ \t]+", " ", text)
200
+ text = re.sub(r"\n ", "\n", text)
201
+ text = re.sub(r" ([\.,])", r"\1", text)
202
+ return text.lstrip().rstrip(" ")
203
+
204
+
205
+ def collapse_whitespaces_message(text):
206
+ text = re.sub(r"\s+", " ", text)
207
+ text = re.sub(r" ([\.,])", r"\1", text)
208
+ return text.lstrip().rstrip(" ")
209
+
210
+
211
+ def replace_brackets(text):
212
+ text = re.sub(r"[\[\{]", "(", text)
213
+ text = re.sub(r"[\]\}]", ")", text)
214
+ return text
215
+
216
+
217
+ def format_special_characters(text):
218
+ text = unicodedata.normalize("NFC", text)
219
+ for before, after in [
220
+ ("…", "..."),
221
+ (r"[«“][^\S\r\n]*", '"'),
222
+ (r"[^\S\r\n]*[»”″„]", '"'),
223
+ (r"(``|'')", '"'),
224
+ (r"[’‘‛ʿ]", "'"),
225
+ ("‚", ","),
226
+ (r"–", "-"),
227
+ ("[  ]", " "), # unbreakable spaces
228
+ (r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", ""), # non-printable characters
229
+ # ("·", "."),
230
+ (r"ᵉʳ", "er"),
231
+ (r"ᵉ", "e"),
232
+ ]:
233
+ text = re.sub(before, after, text)
234
+
235
+ return text
236
+
237
+
238
+ # Create the Claire chatbot instance
239
+ chatbot = ClaireChatBot()
240
+
241
+ # Define the Gradio interface
242
+ title = "Démo de conversation avec Claire"
243
+ description = "Simulation de conversations en Français avec [Claire](https://huggingface.co/OpenLLM-France/Claire-7B-0.1), sans recherche de vérité, et avec potentiellement un peu d'humour."
244
+
245
+ default_parameters = [
246
+ default_temperature,
247
+ default_top_k,
248
+ default_top_p,
249
+ default_repetition_penalty,
250
+ ]
251
+
252
+ examples = [
253
+ [
254
+ "Nous allons commencer cette interview avec une question un peu classique. Quel est votre sport préféré?", # user_message
255
+ "", # bot_message_start
256
+ "", # conversation_history
257
+ True, # generate_several_turns
258
+ 200, # max_new_tokens
259
+ *default_parameters,
260
+ ],
261
+ [
262
+ "Que vas-tu nous cuisiner aujourd'hui?", # user_message
263
+ "Alors, nous allons voir la recette de", # bot_message_start
264
+ "VOUS: Bonjour Claire.\nCHATBOT: Bonjour Dominique.", # conversation_history
265
+ False, # generate_several_turns
266
+ default_max_new_tokens, # max_new_tokens
267
+ *default_parameters,
268
+ ],
269
+ ]
270
+
271
+ # # Test
272
+ # chatbot.predict(*examples[0])
273
+
274
+ inputs = [
275
+ gr.Textbox(
276
+ "",
277
+ label="Prompt",
278
+ info="Tapez ce que vous voulez dire au ChatBot",
279
+ type="text",
280
+ lines=2,
281
+ ),
282
+ gr.Textbox(
283
+ "",
284
+ label="Début de réponse",
285
+ info="Vous pouvez taper ici ce que commence à vous répondre le ChatBot",
286
+ type="text",
287
+ ),
288
+ gr.Textbox(
289
+ "",
290
+ label="Historique",
291
+ info="Vous pouvez copier-coller (et modifier?) ici votre historique de conversation, pour continuer cette conversation",
292
+ type="text",
293
+ lines=3,
294
+ ),
295
+ gr.Checkbox(
296
+ False,
297
+ label="Plus qu'un tour de parole",
298
+ info="Générer aussi comment pourrait continuer la conversation (plusieurs tours de parole incluant le vôtre)",
299
+ ),
300
+ gr.Slider(
301
+ label="Longueur max",
302
+ info="Longueur maximale du texte généré (en nombre de 'tokens' ~ mots et ponctuations)",
303
+ value=default_max_new_tokens,
304
+ minimum=25,
305
+ maximum=1000,
306
+ step=25,
307
+ interactive=True,
308
+ ),
309
+ gr.Slider(
310
+ label="Température",
311
+ info="Une valeur élevée augmente la diversité du texte généré, mais peut aussi produire des résultats incohérents",
312
+ value=default_temperature,
313
+ minimum=0.1,
314
+ maximum=1.9,
315
+ step=0.1,
316
+ interactive=True,
317
+ ),
318
+ gr.Slider(
319
+ label="Top-k",
320
+ info="Une valeur élevée permet d'explorer plus d'alternatives, mais augmente les temps de calcul",
321
+ value=default_top_k,
322
+ minimum=1,
323
+ maximum=50,
324
+ step=1,
325
+ interactive=True,
326
+ ),
327
+ gr.Slider(
328
+ label="Top-p",
329
+ info="Une valeur élevée permet d'explorer des alternatives moins probables",
330
+ value=default_top_p,
331
+ minimum=0.9,
332
+ maximum=1.0,
333
+ step=0.01,
334
+ interactive=True,
335
+ ),
336
+ gr.Slider(
337
+ label="Pénalité de répétition",
338
+ info="Pénalisation des répétitions",
339
+ value=default_repetition_penalty,
340
+ minimum=1.0,
341
+ maximum=2.0,
342
+ step=0.05,
343
+ interactive=True,
344
+ ),
345
+ ]
346
+
347
+ theme = gr.themes.Monochrome(
348
+ secondary_hue="emerald",
349
+ neutral_hue="teal",
350
+ ).set(
351
+ body_background_fill="*primary_950",
352
+ body_background_fill_dark="*secondary_950",
353
+ body_text_color="*primary_50",
354
+ body_text_color_dark="*secondary_100",
355
+ body_text_color_subdued="*primary_300",
356
+ body_text_color_subdued_dark="*primary_300",
357
+ background_fill_primary="*primary_600",
358
+ background_fill_primary_dark="*primary_400",
359
+ background_fill_secondary="*primary_950",
360
+ background_fill_secondary_dark="*primary_950",
361
+ border_color_accent="*secondary_600",
362
+ border_color_primary="*secondary_50",
363
+ border_color_primary_dark="*secondary_50",
364
+ color_accent="*secondary_50",
365
+ color_accent_soft="*primary_500",
366
+ color_accent_soft_dark="*primary_500",
367
+ link_text_color="*secondary_950",
368
+ link_text_color_dark="*primary_50",
369
+ link_text_color_active="*primary_50",
370
+ link_text_color_active_dark="*primary_50",
371
+ link_text_color_hover="*primary_50",
372
+ link_text_color_hover_dark="*primary_50",
373
+ link_text_color_visited="*primary_50",
374
+ block_background_fill="*primary_950",
375
+ block_background_fill_dark="*primary_950",
376
+ block_border_color="*secondary_500",
377
+ block_border_color_dark="*secondary_500",
378
+ block_info_text_color="*primary_50",
379
+ block_info_text_color_dark="*primary_50",
380
+ block_label_background_fill="*primary_950",
381
+ block_label_background_fill_dark="*secondary_950",
382
+ block_label_border_color="*secondary_500",
383
+ block_label_border_color_dark="*secondary_500",
384
+ block_label_text_color="*secondary_500",
385
+ block_label_text_color_dark="*secondary_500",
386
+ block_title_background_fill="*primary_950",
387
+ panel_background_fill="*primary_950",
388
+ panel_border_color="*primary_950",
389
+ checkbox_background_color="*primary_950",
390
+ checkbox_background_color_dark="*primary_950",
391
+ checkbox_background_color_focus="*primary_950",
392
+ checkbox_border_color="*secondary_500",
393
+ input_background_fill="*primary_800",
394
+ input_background_fill_focus="*primary_950",
395
+ input_background_fill_hover="*secondary_950",
396
+ input_placeholder_color="*secondary_950",
397
+ slider_color="*primary_950",
398
+ slider_color_dark="*primary_950",
399
+ table_even_background_fill="*primary_800",
400
+ table_odd_background_fill="*primary_600",
401
+ button_primary_background_fill="*primary_800",
402
+ button_primary_background_fill_dark="*primary_800",
403
+ )
404
+
405
+ iface = gr.Interface(
406
+ fn=chatbot.predict,
407
+ title=title,
408
+ description=description,
409
+ examples=examples,
410
+ inputs=inputs,
411
+ outputs="text",
412
+ theme=theme,
413
+ )
414
+
415
+ print("Launching chat...")
416
+
417
+ # Launch the Gradio interface for the model
418
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ optimum
2
+ accelerate==0.24.1
3
+ bitsandbytes==0.41.1
4
+ gradio==4.1.1
5
+ protobuf==3.20.3
6
+ scipy==1.11.2
7
+ sentencepiece==0.1.99
8
+ spaces==0.18.0
9
+ torch==2.0.0
10
+ transformers==4.35.0
11
+ transformers_stream_generator==0.0.4