add chat function
Browse files- modeling_orion.py +21 -0
modeling_orion.py
CHANGED
@@ -30,6 +30,10 @@ from transformers.utils import (
|
|
30 |
replace_return_docstrings,
|
31 |
)
|
32 |
|
|
|
|
|
|
|
|
|
33 |
if is_flash_attn_2_available():
|
34 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
35 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
@@ -951,6 +955,23 @@ class OrionForCausalLM(OrionPreTrainedModel):
|
|
951 |
attentions=outputs.attentions,
|
952 |
)
|
953 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
954 |
def prepare_inputs_for_generation(
|
955 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
956 |
):
|
|
|
30 |
replace_return_docstrings,
|
31 |
)
|
32 |
|
33 |
+
from .generation_utils import build_chat_input, TextIterStreamer
|
34 |
+
from transformers.generation.utils import GenerationConfig
|
35 |
+
from threading import Thread
|
36 |
+
|
37 |
if is_flash_attn_2_available():
|
38 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
39 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
955 |
attentions=outputs.attentions,
|
956 |
)
|
957 |
|
958 |
+
def chat(self, tokenizer, messages: List[dict], streaming=False,generation_config: Optional[GenerationConfig]=None):
|
959 |
+
generation_config = generation_config or self.generation_config
|
960 |
+
input_tokens = build_chat_input(tokenizer,messages)
|
961 |
+
input_ids = torch.LongTensor([input_tokens]).to(self.device)
|
962 |
+
|
963 |
+
if streaming:
|
964 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
965 |
+
Thread(target=self.generate, kwargs=dict(
|
966 |
+
inputs=input_ids, streamer=streamer,
|
967 |
+
generation_config=generation_config,
|
968 |
+
)).start()
|
969 |
+
return streamer
|
970 |
+
else:
|
971 |
+
outputs = self.generate(input_ids, generation_config=generation_config)
|
972 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
973 |
+
return response
|
974 |
+
|
975 |
def prepare_inputs_for_generation(
|
976 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
977 |
):
|