pipeline_tag: text-generation
tags:
- PyTorch
- Transformers
- gpt2
license: unlicense
language: ru
widget:
- text: >-
- У Джульетты было 7 пончиков, а потом она 3 съела. Сколько у нее осталось
пончиков? -
- text: >-
- Поглажено 4 манула. Осталось погладить 6. Сколько всего манулов надо
погладить? -
- text: '- Для начала скажи, чему равно пятью девять? -'
- text: '- ты чё такой борзый? -'
- text: '- Привет! Как ваше ничего? -'
Russian Chit-chat, Deductive and Common Sense reasoning model
Модель является ядром прототипа диалоговой системы с двумя основными функциями.
Первая функция - генерация реплик чит-чата. В качестве затравки подается история диалога (предшествующие несколько реплик, от 1 до 10).
- Привет, как дела?
- Привет, так себе.
- <<< эту реплику ожидаем от модели >>>
Вторая функция модели - вывод ответа на заданный вопрос, опираясь на дополнительные факты или на "здравый смысл". Предполагается, что релевантные факты извлекаются из стороннего хранилища (базы знаний) с помощью другой модели, например sbert_pq. Используя указанный факт(ы) и текст вопроса, модель построит грамматичный и максимально краткий ответ, как это сделал бы человек в подобной коммуникативной ситуации. Релевантные факты следует указывать перед текстом заданного вопроса так, будто сам собеседник сказал их:
- Сегодня 15 сентября. Какой сейчас у нас месяц?
- Сентябрь
Модель не ожидает, что все найденные и добавленные в контекст диалога факты действительно имеют отношение к заданному вопросу. Поэтому модель, извлекающая из базы знаний информацию, может жертвовать точностью в пользу полноте и добавлять что-то лишнее. Модель читчата в этом случае сама выберет среди добавленных в контекст фактов необходимую фактуру и проигнорирует лишнее. Текущая версия модели допускает до 5 фактов перед вопросом. Например:
- Стасу 16 лет. Стас живет в Подольске. У Стаса нет своей машины. Где живет Стас?
- в Подольске
В некоторых случаях модель может выполнять силлогический вывод ответа, опираясь на 2 предпосылки, связанные друг с другом. Выводимое из двух предпосылок следствие не фигурирует явно, а как бы используется для вывода ответа:
- Смертен ли Аристофан, если он был греческим философом, а все философы смертны?
- Да
Как можно видеть из приведенных примеров, формат подаваемой на вход модели фактической информации для выполнения вывода предельно естественный и свободный.
Кроме логического вывода, модель также умеет решать простые арифметические задачи в рамках 1-2 классов начальной школы, с двумя числовыми аргументами:
- Чему равно 2+8?
- 10
Варианты модели и метрики
Выложенная на данный момент модель имеет 760 млн. параметров, т.е. уровня sberbank-ai/rugpt3large_based_on_gpt2. Далее приводится результат замера точности решения арифметических задач на отложенном тестовом наборе сэмплов:
base model | arith. accuracy |
---|---|
sberbank-ai/rugpt3large_based_on_gpt2 | 0.91 |
sberbank-ai/rugpt3medium_based_on_gpt2 | 0.70 |
sberbank-ai/rugpt3small_based_on_gpt2 | 0.58 |
tinkoff-ai/ruDialoGPT-small | 0.44 |
tinkoff-ai/ruDialoGPT-medium | 0.69 |
Цифра 0.91 в столбце "arith. accuracy" означает, что 91% тестовых задач решено полностью верно. Любое отклонение сгенерированного ответа от эталонного рассматривается как ошибка. Например, выдача ответа "120" вместо "119" тоже фиксируется как ошибка.
Пример использования
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "inkoziev/rugpt_chitchat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
input_text = """<s>- Привет! Что делаешь?
- Привет :) В такси еду
-"""
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
output_sequences = model.generate(input_ids=encoded_prompt, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
text = text[: text.find('</s>')]
print(text)
Контакты
Если у Вас есть какие-то вопросы по использованию этой модели, или предложения по ее улучшению - пишите мне mentalcomputing@gmail.com
Citation:
@MISC{rugpt_chitchat,
author = {Ilya Koziev},
title = {Russian Chit-chat with Common sence Reasoning},
url = {https://huggingface.co/inkoziev/rugpt_chitchat},
year = 2022
}