import gradio as gr
from threading import Thread
import transformers
import spaces
import torch
import unicodedata
import regex as re
# Model
model_name = "OpenLLM-France/Claire-7B-0.1"
# Title and description
title = "Conversation avec Claire"
description = """\
Simulation de conversation en Français avec [OpenLLM-France/Claire-7B](https://huggingface.co/OpenLLM-France/Claire-7B-0.1).
Claire n'est pas un assistant personnel, elle a tendance à comprendre et répondre un langage parlé, \
peut faire preuve d'humour, et ne vous dira pas (forcément) des vérités.
"""
# Default variables
default_max_new_tokens = 200
default_temperature = 1.0
default_repetition_penalty = 1.5
default_top_k = 10
default_top_p = 0.99
default_parameters = [
default_max_new_tokens,
default_temperature,
default_repetition_penalty,
default_top_k,
default_top_p,
]
# Examples
examples = [
[
"Bonjour Claire. Quel est votre sport préféré?", # user_message
False,
"", # bot_message_start
# "", # First name
*default_parameters,
],
[
"Bonjour. Je vous propose de faire un tour de table.", # user_message
True, # more than one turn
"", # bot_message_start
# "", # First name
*default_parameters,
],
[
"Que vas-tu nous cuisiner aujourd'hui?", # user_message
False,
"Alors, nous allons voir la recette", # bot_message_start
# "", # First name
*default_parameters,
],
]
# Override default gradio buttons
gradio_buttons = dict(
submit_btn=gr.Button("Envoyer"), # Sumbit
retry_btn=gr.Button("🔄 Générer une autre réponse"), # "🔄 Retry"
undo_btn=gr.Button("↩️ Annuler"), # "↩️ Undo"
clear_btn=gr.Button("🗑️ Effacer la conversation"), # "🗑️ Clear"
# stop_btn= None,
stop_btn=gr.Button("Arrêter"), # Stop
)
additional_inputs_name="Paramètres" # "Additional inputs"
textbox=gr.Textbox(
container=False,
show_label=False,
label="Message",
placeholder="Votre message (laissez vide pour que le Bot continue seul)...",
scale=7,
lines=2,
autofocus=False,
)
chatbot_label="Conversation" # Chatbot
additional_inputs = [
gr.Checkbox(
False,
label="Plus qu'un tour de parole",
info="Générer plusieurs tours de parole (et donc comment vous pourriez continuer la conversation)",
),
gr.Textbox(
"",
label="Début de réponse",
info="Vous pouvez taper ici ce que commence à vous répondre le Bot (pensez à actualiser entre chaque génération)",
type="text",
),
# gr.Textbox(
# "",
# label="Votre prénom",
# info="Prénom de vous en tant qu'interlocuteur (si vous vous nommez, le bot s'appellera Claire)",
# ),
gr.Slider(
label="Longueur max",
info="Longueur maximale du texte généré (en nombre de 'tokens' ~ mots et ponctuations)",
value=default_max_new_tokens,
minimum=25,
maximum=1000,
step=25,
interactive=True,
),
gr.Slider(
label="Température",
info="Une valeur élevée augmente la diversité du texte généré, mais peut aussi produire des résultats incohérents",
value=default_temperature,
minimum=0.1,
maximum=1.9,
step=0.1,
interactive=True,
),
gr.Slider(
label="Pénalité de répétition",
info="Pénalisation des répétitions",
value=default_repetition_penalty,
minimum=1.0,
maximum=1.95,
step=0.05,
interactive=True,
),
gr.Slider(
label="Top-k",
info="Une valeur élevée permet d'explorer plus d'alternatives",
value=default_top_k,
minimum=1,
maximum=50,
step=1,
interactive=True,
),
gr.Slider(
label="Top-p",
info="Une valeur élevée permet d'explorer plus d'alternatives",
value=default_top_p,
minimum=0.9,
maximum=1.0,
step=0.01,
interactive=True,
),
]
STREAMING = True
print("Loading model...")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_4bit=True,
)
print("Optimizing model...")
import optimum
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
print("Setup chat...")
eos_token_id = tokenizer.eos_token_id
newspk_token_id = tokenizer.encode("[")
assert len(newspk_token_id) == 1
newspk_token_id = newspk_token_id[0]
tokenizer.add_special_tokens({"eos_token": "["})
user_internal_tag = "[Intervenant 1:]"
bot_internal_tag = "[Intervenant 2:]"
device = "cuda" if torch.cuda.is_available() else "cpu"
@spaces.GPU
def generate(
user_message,
conversation_history=[],
generate_several_turns=False,
bot_message_start="",
# user_surname="",
max_new_tokens=default_max_new_tokens,
temperature=default_temperature,
repetition_penalty=default_repetition_penalty,
top_k=default_top_k,
top_p=default_top_p,
user_surname="", # Experimental (TODO)
remove_unfinished_sentence=True,
):
user_message = claire_text_preproc_message(user_message)
bot_message_start = claire_text_preproc_message(bot_message_start)
if user_surname:
user_surname = capitalize(collapse_whitespaces(re.sub(r"[^\p{L}\-\.']", " ", user_surname))).strip()
if user_surname:
user_tag = f"[{user_surname}:]"
bot_tag = f"[Claire:]"
else:
user_tag = user_internal_tag
bot_tag = bot_internal_tag
if conversation_history:
conversation_history = "\n".join(
[
f"{user_tag} {claire_text_preproc_message(user)}\n{bot_tag} {claire_text_preproc_message(bot) if bot else ''}"
for user, bot in conversation_history
]
)
conversation_history = from_display_to_internal(conversation_history)
conversation_history = conversation_history.rstrip()
if conversation_history:
conversation_history += "\n"
else:
conversation_history = ""
if not bot_message_start:
bot_message_start = ""
# Combine the user and bot messages into a conversation
conversation = f"{conversation_history}{user_tag} {user_message}\n{bot_tag} {bot_message_start}".strip()
conversation = remove_empty_turns(conversation)
# Encode the conversation using the tokenizer
input_ids = tokenizer.encode(
conversation, return_tensors="pt", add_special_tokens=True
)
input_ids = input_ids.to(device)
skip_special_tokens = not generate_several_turns
if STREAMING:
streamer = transformers.TextIteratorStreamer(
tokenizer,
timeout=10.0,
skip_prompt=True,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=False,
)
else:
streamer = None
# Generation parameters
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
eos_token_id=eos_token_id if generate_several_turns else newspk_token_id,
pad_token_id=eos_token_id,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
num_beams=1,
# use_cache=False,
# early_stopping=False,
)
if STREAMING:
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
if bot_message_start.strip():
yield bot_message_start
for token in streamer:
# Ignore line breaks
if not generate_several_turns and re.match(r"\s*\n$", token):
continue
outputs.append(token)
text = bot_message_start + from_internal_to_display("".join(outputs))
yield text
else:
output_ids = model.generate(**generate_kwargs)
output_ids = output_ids[0][len(input_ids[0]) :]
text = tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
if bot_message_start.strip():
bot_message_start = bot_message_start.strip() + " "
text = bot_message_start + from_internal_to_display(text.rstrip("[").strip())
yield text
if generate_several_turns:
if remove_unfinished_sentence:
yield remove_last_unfinished_sentence(text)
else:
yield remove_last_unfinished_turn(text)[0]
def claire_text_preproc_message(text):
text = format_punctuations_for_french(text)
text = format_special_characters(text)
text = collapse_whitespaces(text)
text = replace_brackets(text)
return text
def collapse_whitespaces(text):
text = re.sub(r"\s+", " ", text)
text = re.sub(r" ([\.,])", r"\1", text)
return text.lstrip().rstrip(" ")
def replace_brackets(text):
text = re.sub(r"[\[\{]", "(", text)
text = re.sub(r"[\]\}]", ")", text)
return text
def format_punctuations_for_french(text):
for before, after in french_punctuation_rules:
text = re.sub(before, after, text)
return text
french_punctuation_rules = {
# Add a space before double punctuation marks
(r"([" + re.escape('?!:;') + r"])", r" \1"),
# Remove space before simple punctuation marks
(r"\s+([" + re.escape(',.') + r"])", r"\1"),
# Add space after punctuation marks
(r"([" + re.escape('?!:;,') + r"]+)([^ " + re.escape('?!:;,') + r"\d])", r"\1 \2"),
(r"([" + re.escape('.') + r"]+)([A-Z])", r"\1 \2"),
}
def format_special_characters(text):
text = unicodedata.normalize("NFC", text)
for before, after in [
("…", "..."),
(r"[«“][^\S\r\n]*", '"'),
(r"[^\S\r\n]*[»”″„]", '"'),
(r"(``|'')", '"'),
(r"[’‘‛ʿ]", "'"),
("‚", ","),
(r"–", "-"),
("[ ]", " "), # unbreakable spaces
(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", ""), # non-printable characters
# ("·", "."),
(r"ᵉʳ", "er"),
(r"ᵉ", "e"),
]:
text = re.sub(before, after, text)
return text
user_name = "[Vous:]"
bot_name = "[Bot:]"
def from_internal_to_display(text):
for before, after in [
(user_internal_tag, user_name),
(bot_internal_tag, bot_name),
]:
text = text.replace(before, after)
return text
def from_display_to_internal(text):
for before, after in [
(user_name, user_internal_tag),
(bot_name, bot_internal_tag),
]:
text = text.replace(before, after)
return text
def remove_last_unfinished_sentence(text):
text, removed_turn = remove_last_unfinished_turn(text)
if removed_turn:
return text
line_breaks = [u.span(0)[0] for u in re.finditer("\n", text)]
remove_last_sentence = True
if len(line_breaks) >= 1 and len(text[line_breaks[-1]:].split("]")[-1]) < 15:
text = text[: line_breaks[-1]]
line_breaks.pop(-1)
remove_last_sentence = False
if remove_last_sentence and len(line_breaks) >= 1:
sentence_ends = [u.span(0)[0] for u in re.finditer(r"[\.!?]", text)]
sentence_ends = [p for p in sentence_ends if p > line_breaks[-1]]
if sentence_ends:
text = text[: sentence_ends[-1] + 1]
else:
phrase_ends = [u.span(0)[0] for u in re.finditer(r"[,;]", text)]
phrase_ends = [p for p in phrase_ends if p > line_breaks[-1]]
if phrase_ends:
text = text[: phrase_ends[-1] + 1]
return text
def remove_last_unfinished_turn(text):
starts = [u.span(0)[0] for u in re.finditer(r"\[", text)]
did_it = False
if starts and "]" not in text[starts[-1] :]:
text = text[: starts[-1]]
did_it = True
return text.rstrip(), did_it
def remove_empty_turns(text):
while re.search(_empty_turn, text):
# Remove empty turns
text = re.sub(_empty_turn, r"\1", text)
# Remove same speaker speaking twice
text = re.sub(_repeated_turn, r"\1 \2", text)
return text
_speaker_regex = r"\[[^\]]+:\]"
_empty_turn = re.compile(_speaker_regex + r"[^\p{L}]*" + "(" + _speaker_regex + ")")
_repeated_turn = re.compile(r"(" + _speaker_regex + r") ([^\[]*)\s\1")
def capitalize(text):
# michel JR claude-marie -> Michel JR Claude-Marie
words = text.split(" ")
words = [w.capitalize() if (not w.isupper() or len(w)>2) else w for w in words]
for i, w in enumerate(words):
for sep in "-", "'":
if sep in w:
words[i] = sep.join([x.capitalize() if not x.isupper() else x for x in w.split(sep)])
return " ".join(words)
# # Test
# list(generate(*(examples[0][:1] + [[]] + examples[0][1:])))
chat_interface = gr.ChatInterface(
fn=generate,
title=title,
description=description,
chatbot=gr.Chatbot(label=chatbot_label),
textbox=textbox,
examples=examples,
additional_inputs=additional_inputs,
additional_inputs_accordion=gr.Accordion(
label="Paramètres",
open=True,
),
autofocus=False,
**gradio_buttons,
)
if __name__ == "__main__":
print("Launching chat...")
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
demo.queue(max_size=20).launch()