Update code
Browse files- modeling_minicpm.py +32 -5
modeling_minicpm.py
CHANGED
@@ -20,7 +20,7 @@
|
|
20 |
""" PyTorch MiniCPM model."""
|
21 |
import math
|
22 |
import warnings
|
23 |
-
from typing import List, Optional, Tuple, Union
|
24 |
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
@@ -49,11 +49,13 @@ from transformers.utils import (
|
|
49 |
)
|
50 |
from transformers.utils.import_utils import is_torch_fx_available
|
51 |
from .configuration_minicpm import MiniCPMConfig
|
|
|
52 |
|
53 |
-
|
54 |
-
if is_flash_attn_2_available():
|
55 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
56 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
57 |
|
58 |
|
59 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
@@ -124,7 +126,7 @@ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
|
|
124 |
|
125 |
|
126 |
class MiniCPMRotaryEmbedding(nn.Module):
|
127 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=
|
128 |
super().__init__()
|
129 |
|
130 |
self.dim = dim
|
@@ -762,7 +764,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
762 |
def __init__(self, config: MiniCPMConfig, layer_idx: int):
|
763 |
super().__init__()
|
764 |
self.hidden_size = config.hidden_size
|
765 |
-
|
766 |
self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
767 |
|
768 |
self.mlp = MiniCPMMLP(config)
|
@@ -1302,6 +1303,32 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
1302 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1303 |
)
|
1304 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1305 |
|
1306 |
|
1307 |
@add_start_docstrings(
|
|
|
20 |
""" PyTorch MiniCPM model."""
|
21 |
import math
|
22 |
import warnings
|
23 |
+
from typing import List, Optional, Tuple, Union, Dict
|
24 |
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
|
|
49 |
)
|
50 |
from transformers.utils.import_utils import is_torch_fx_available
|
51 |
from .configuration_minicpm import MiniCPMConfig
|
52 |
+
import re
|
53 |
|
54 |
+
try:
|
|
|
55 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
56 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
57 |
+
except:
|
58 |
+
pass
|
59 |
|
60 |
|
61 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
|
|
126 |
|
127 |
|
128 |
class MiniCPMRotaryEmbedding(nn.Module):
|
129 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
130 |
super().__init__()
|
131 |
|
132 |
self.dim = dim
|
|
|
764 |
def __init__(self, config: MiniCPMConfig, layer_idx: int):
|
765 |
super().__init__()
|
766 |
self.hidden_size = config.hidden_size
|
|
|
767 |
self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
768 |
|
769 |
self.mlp = MiniCPMMLP(config)
|
|
|
1303 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1304 |
)
|
1305 |
return reordered_past
|
1306 |
+
|
1307 |
+
@torch.inference_mode()
|
1308 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1309 |
+
max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
|
1310 |
+
**kwargs):
|
1311 |
+
if history is None:
|
1312 |
+
history = []
|
1313 |
+
if logits_processor:
|
1314 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1315 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1316 |
+
else:
|
1317 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1318 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1319 |
+
|
1320 |
+
history.append({"role": role, "content": query})
|
1321 |
+
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
|
1322 |
+
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
1323 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
1324 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1325 |
+
response = tokenizer.decode(outputs)
|
1326 |
+
pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
|
1327 |
+
matches = pattern.findall(response)
|
1328 |
+
if len(matches) > 0:
|
1329 |
+
response = matches[0]
|
1330 |
+
history.append({"role": "assistant", "content": response})
|
1331 |
+
return response, history
|
1332 |
|
1333 |
|
1334 |
@add_start_docstrings(
|