SEED-LLaMA / gradio_demo /conversation.py
sjzhao's picture
update demo
bd63939
raw
history blame
6.53 kB
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import io
import base64
import os
from PIL import Image
import copy
IMG_FLAG = '<image>'
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
def decode_image(encoded_image: str) -> Image:
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
buffer = io.BytesIO(decoded_bytes)
image = Image.open(buffer)
return image
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
with io.BytesIO() as buffer:
image.save(buffer, format=format)
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
return encoded_image
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = copy.deepcopy(self.messages)
if self.sep_style == SeparatorStyle.SINGLE:
if self.system is None or self.system == '':
text = ''
else:
text = self.system + self.sep
images = []
for message in messages:
text += message['role'] + ": " + message['message']['text'] + self.sep
for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
if image_ids is not None:
images.append(image_ids)
else:
image = Image.open(image_path).resize((256, 256))
image_base64 = encode_image(image)
images.append(image_base64)
text += self.roles[1] + ":"
elif self.sep_style == SeparatorStyle.LLAMA_2:
b_token = "[INST] "
e_token = " [/INST]"
if self.system is None or self.system == '':
text = ''
else:
text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
images = []
for idx, message in enumerate(messages):
# text += message['role'] + ": " + message['message']['text'] + self.sep
if idx % 2 == 0:
text += b_token + message['message']['text'] + e_token + self.sep
else:
text += message['message']['text'] + self.sep
for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
if image_ids is not None:
images.append(image_ids)
else:
image = Image.open(image_path).resize((256, 256))
image_base64 = encode_image(image)
images.append(image_base64)
else:
raise NotImplementedError
return {'text': text, 'images': images}
def update_image_ids(self, images_ids):
image_count = 0
for message in self.messages:
for idx in range(len(message['message']['images_ids'])):
if message['message']["images_ids"][idx] is None:
message['message']["images_ids"][idx] = images_ids[image_count]
image_count += 1
assert len(images_ids) == image_count, print(len(images_ids), image_count)
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
dialog = []
for i, single_turn in enumerate(self.messages[self.offset:]):
single_turn = single_turn['message']
text_list = single_turn['text'].split(IMG_FLAG)
assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
message = ''
for image_idx in range(len(single_turn['images'])):
# image = single_turn['images'][image_idx]
# image_base64 = encode_image(image)
# image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
image_path = single_turn['images'][image_idx]
if image_path == '':
message += text_list[image_idx] + '<corrupt_image>'
else:
message += text_list[image_idx] + f'![](file={image_path})'
message += text_list[-1]
if i % 2 == 0:
dialog.append([message, None])
else:
dialog[-1][-1] = message
return dialog
def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=copy.deepcopy(self.messages),
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
messages = copy.deepcopy(self.messages)
for message in messages:
if 'images_ids' in message:
message.pop('images_ids')
for i in range(len(message['message']['images'])):
message['message']['images'][i] = os.path.basename(message['message']['images'][i])
return {
"system": self.system,
"roles": self.roles,
"messages": messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_seed_vicuna = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep='\n',
)
conv_seed_vicuna_system = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. ",
roles=("USER", "ASSISTANT"),
version="v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep='\n',
)
conv_seed_llama2 = Conversation(
system="",
roles=("[INST]", "[/INST]"),
version="v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep='\n',
)