|
""" |
|
Conversation prompt templates. |
|
|
|
We kindly request that you import fastchat instead of copying this file if you wish to use it. |
|
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. |
|
""" |
|
|
|
import dataclasses |
|
from enum import IntEnum, auto |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
|
|
class SeparatorStyle(IntEnum): |
|
"""Separator styles.""" |
|
|
|
ADD_COLON_SINGLE = auto() |
|
ADD_COLON_TWO = auto() |
|
ADD_COLON_SPACE_SINGLE = auto() |
|
NO_COLON_SINGLE = auto() |
|
NO_COLON_TWO = auto() |
|
ADD_NEW_LINE_SINGLE = auto() |
|
LLAMA2 = auto() |
|
CHATGLM = auto() |
|
CHATML = auto() |
|
CHATINTERN = auto() |
|
DOLLY = auto() |
|
RWKV = auto() |
|
PHOENIX = auto() |
|
ROBIN = auto() |
|
FALCON_CHAT = auto() |
|
CHATGLM3 = auto() |
|
INTERNVL_ZH = auto() |
|
MPT = auto() |
|
|
|
|
|
@dataclasses.dataclass |
|
class Conversation: |
|
"""A class that manages prompt templates and keeps all conversation history.""" |
|
|
|
|
|
name: str |
|
|
|
system_template: str = '{system_message}' |
|
|
|
system_message: str = '' |
|
|
|
roles: Tuple[str] = ('USER', 'ASSISTANT') |
|
|
|
messages: List[List[str]] = () |
|
|
|
offset: int = 0 |
|
|
|
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE |
|
sep: str = '\n' |
|
sep2: str = None |
|
|
|
stop_str: Union[str, List[str]] = None |
|
|
|
stop_token_ids: List[int] = None |
|
|
|
def get_prompt(self) -> str: |
|
"""Get the prompt for generation.""" |
|
system_prompt = self.system_template.format(system_message=self.system_message) |
|
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: |
|
ret = system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ': ' + message + self.sep |
|
else: |
|
ret += role + ':' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: |
|
seps = [self.sep, self.sep2] |
|
ret = system_prompt + seps[0] |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += role + ': ' + message + seps[i % 2] |
|
else: |
|
ret += role + ':' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: |
|
ret = system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ': ' + message + self.sep |
|
else: |
|
ret += role + ': ' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: |
|
ret = '' if system_prompt == '' else system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + '\n' + message + self.sep |
|
else: |
|
ret += role + '\n' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: |
|
ret = system_prompt |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + message + self.sep |
|
else: |
|
ret += role |
|
return ret |
|
elif self.sep_style == SeparatorStyle.NO_COLON_TWO: |
|
seps = [self.sep, self.sep2] |
|
ret = system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += role + message + seps[i % 2] |
|
else: |
|
ret += role |
|
return ret |
|
elif self.sep_style == SeparatorStyle.RWKV: |
|
ret = system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += ( |
|
role |
|
+ ': ' |
|
+ message.replace('\r\n', '\n').replace('\n\n', '\n') |
|
) |
|
ret += '\n\n' |
|
else: |
|
ret += role + ':' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.LLAMA2: |
|
seps = [self.sep, self.sep2] |
|
if self.system_message: |
|
ret = system_prompt |
|
else: |
|
ret = '[INST] ' |
|
for i, (role, message) in enumerate(self.messages): |
|
tag = self.roles[i % 2] |
|
if message: |
|
if i == 0: |
|
ret += message + ' ' |
|
else: |
|
ret += tag + ' ' + message + seps[i % 2] |
|
else: |
|
ret += tag |
|
return ret |
|
elif self.sep_style == SeparatorStyle.CHATGLM: |
|
|
|
|
|
round_add_n = 1 if self.name == 'chatglm2' else 0 |
|
if system_prompt: |
|
ret = system_prompt + self.sep |
|
else: |
|
ret = '' |
|
|
|
for i, (role, message) in enumerate(self.messages): |
|
if i % 2 == 0: |
|
ret += f'[Round {i//2 + round_add_n}]{self.sep}' |
|
|
|
if message: |
|
ret += f'{role}:{message}{self.sep}' |
|
else: |
|
ret += f'{role}:' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.CHATML: |
|
ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + '\n' + message + self.sep + '\n' |
|
else: |
|
ret += role + '\n' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.CHATGLM3: |
|
ret = '' |
|
if self.system_message: |
|
ret += system_prompt |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + '\n' + ' ' + message |
|
else: |
|
ret += role |
|
return ret |
|
elif self.sep_style == SeparatorStyle.CHATINTERN: |
|
|
|
seps = [self.sep, self.sep2] |
|
ret = system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
|
|
|
|
if message: |
|
ret += role + ':' + message + seps[i % 2] + '\n' |
|
else: |
|
ret += role + ':' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.DOLLY: |
|
seps = [self.sep, self.sep2] |
|
ret = system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += role + ':\n' + message + seps[i % 2] |
|
if i % 2 == 1: |
|
ret += '\n\n' |
|
else: |
|
ret += role + ':\n' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.PHOENIX: |
|
ret = system_prompt |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ': ' + '<s>' + message + '</s>' |
|
else: |
|
ret += role + ': ' + '<s>' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.ROBIN: |
|
ret = system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ':\n' + message + self.sep |
|
else: |
|
ret += role + ':\n' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.FALCON_CHAT: |
|
ret = '' |
|
if self.system_message: |
|
ret += system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
ret += role + ': ' + message + self.sep |
|
else: |
|
ret += role + ':' |
|
|
|
return ret |
|
elif self.sep_style == SeparatorStyle.INTERNVL_ZH: |
|
seps = [self.sep, self.sep2] |
|
ret = self.system_message + seps[0] |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
ret += role + ': ' + message + seps[i % 2] |
|
else: |
|
ret += role + ':' |
|
return ret |
|
elif self.sep_style == SeparatorStyle.MPT: |
|
ret = system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
if type(message) is tuple: |
|
message, _, _ = message |
|
ret += role + message + self.sep |
|
else: |
|
ret += role |
|
return ret |
|
else: |
|
raise ValueError(f'Invalid style: {self.sep_style}') |
|
|
|
def set_system_message(self, system_message: str): |
|
"""Set the system message.""" |
|
self.system_message = system_message |
|
|
|
def append_message(self, role: str, message: str): |
|
"""Append a new message.""" |
|
self.messages.append([role, message]) |
|
|
|
def update_last_message(self, message: str): |
|
"""Update the last output. |
|
|
|
The last message is typically set to be None when constructing the prompt, |
|
so we need to update it in-place after getting the response from a model. |
|
""" |
|
self.messages[-1][1] = message |
|
|
|
def to_gradio_chatbot(self): |
|
"""Convert the conversation to gradio chatbot format.""" |
|
ret = [] |
|
for i, (role, msg) in enumerate(self.messages[self.offset :]): |
|
if i % 2 == 0: |
|
ret.append([msg, None]) |
|
else: |
|
ret[-1][-1] = msg |
|
return ret |
|
|
|
def to_openai_api_messages(self): |
|
"""Convert the conversation to OpenAI chat completion format.""" |
|
ret = [{'role': 'system', 'content': self.system_message}] |
|
|
|
for i, (_, msg) in enumerate(self.messages[self.offset :]): |
|
if i % 2 == 0: |
|
ret.append({'role': 'user', 'content': msg}) |
|
else: |
|
if msg is not None: |
|
ret.append({'role': 'assistant', 'content': msg}) |
|
return ret |
|
|
|
def copy(self): |
|
return Conversation( |
|
name=self.name, |
|
system_template=self.system_template, |
|
system_message=self.system_message, |
|
roles=self.roles, |
|
messages=[[x, y] for x, y in self.messages], |
|
offset=self.offset, |
|
sep_style=self.sep_style, |
|
sep=self.sep, |
|
sep2=self.sep2, |
|
stop_str=self.stop_str, |
|
stop_token_ids=self.stop_token_ids, |
|
) |
|
|
|
def dict(self): |
|
return { |
|
'template_name': self.name, |
|
'system_message': self.system_message, |
|
'roles': self.roles, |
|
'messages': self.messages, |
|
'offset': self.offset, |
|
} |
|
|
|
|
|
|
|
conv_templates: Dict[str, Conversation] = {} |
|
|
|
|
|
def register_conv_template(template: Conversation, override: bool = False): |
|
"""Register a new conversation template.""" |
|
if not override: |
|
assert ( |
|
template.name not in conv_templates |
|
), f'{template.name} has been registered.' |
|
|
|
conv_templates[template.name] = template |
|
|
|
|
|
def get_conv_template(name: str) -> Conversation: |
|
"""Get a conversation template.""" |
|
return conv_templates[name].copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
register_conv_template( |
|
Conversation( |
|
name='Hermes-2', |
|
system_template='<|im_start|>system\n{system_message}', |
|
|
|
|
|
|
|
system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.', |
|
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), |
|
sep_style=SeparatorStyle.MPT, |
|
sep='<|im_end|>', |
|
stop_token_ids=[ |
|
2, |
|
6, |
|
7, |
|
8, |
|
], |
|
stop_str='<|endoftext|>', |
|
) |
|
) |
|
|
|
|
|
register_conv_template( |
|
Conversation( |
|
name='internlm2-chat', |
|
system_template='<|im_start|>system\n{system_message}', |
|
|
|
|
|
|
|
system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.', |
|
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), |
|
sep_style=SeparatorStyle.MPT, |
|
sep='<|im_end|>', |
|
stop_token_ids=[ |
|
2, |
|
92543, |
|
92542 |
|
] |
|
) |
|
) |
|
|
|
|
|
register_conv_template( |
|
Conversation( |
|
name='phi3-chat', |
|
system_template='<|system|>\n{system_message}', |
|
|
|
|
|
|
|
system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.', |
|
roles=('<|user|>\n', '<|assistant|>\n'), |
|
sep_style=SeparatorStyle.MPT, |
|
sep='<|end|>', |
|
stop_token_ids=[ |
|
2, |
|
32000, |
|
32007 |
|
] |
|
) |
|
) |
|
|