princepride commited on
Commit
c4604a1
1 Parent(s): 7f13f1f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +25 -2
model.py CHANGED
@@ -69,7 +69,7 @@ class SperSignFilter(Filter):
69
  for i, input_str in enumerate(inputs):
70
  if '%s' in input_str:
71
  encoded_str = input_str.replace('%s', '*')
72
- self.code.append(i) # 将包含 's%' 的字符串的索引存储到 self.code 中
73
  else:
74
  encoded_str = input_str
75
  encoded_inputs.append(encoded_str)
@@ -80,6 +80,29 @@ class SperSignFilter(Filter):
80
  for i in self.code:
81
  decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') # 使用 self.code 中的索引还原原始字符串
82
  return decoded_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  class ChevronsFilter(Filter):
85
  def __init__(self):
@@ -238,7 +261,7 @@ class Model():
238
  # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
239
 
240
  def generate(self, inputs, original_language, target_languages, max_batch_size):
241
- filter_list = [SpecialTokenFilter(), SperSignFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
242
  filter_pipeline = FilterPipeline(filter_list)
243
  def language_mapping(original_language):
244
  d = {
 
69
  for i, input_str in enumerate(inputs):
70
  if '%s' in input_str:
71
  encoded_str = input_str.replace('%s', '*')
72
+ self.code.append(i) # 将包含 '%s' 的字符串的索引存储到 self.code 中
73
  else:
74
  encoded_str = input_str
75
  encoded_inputs.append(encoded_str)
 
80
  for i in self.code:
81
  decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') # 使用 self.code 中的索引还原原始字符串
82
  return decoded_inputs
83
+
84
+ class ParenSParenFilter(Filter):
85
+ def __init__(self):
86
+ self.name = 'Paren s paren filter'
87
+ self.code = []
88
+
89
+ def encoder(self, inputs):
90
+ encoded_inputs = []
91
+ self.code = [] # 清空 self.code
92
+ for i, input_str in enumerate(inputs):
93
+ if '(s)' in input_str:
94
+ encoded_str = input_str.replace('(s)', '$')
95
+ self.code.append(i) # 将包含 '(s)' 的字符串的索引存储到 self.code 中
96
+ else:
97
+ encoded_str = input_str
98
+ encoded_inputs.append(encoded_str)
99
+ return encoded_inputs
100
+
101
+ def decoder(self, inputs):
102
+ decoded_inputs = inputs.copy()
103
+ for i in self.code:
104
+ decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)') # 使用 self.code 中的索引还原原始字符串
105
+ return decoded_inputs
106
 
107
  class ChevronsFilter(Filter):
108
  def __init__(self):
 
261
  # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
262
 
263
  def generate(self, inputs, original_language, target_languages, max_batch_size):
264
+ filter_list = [SpecialTokenFilter(), SperSignFilter(), ParenSParenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
265
  filter_pipeline = FilterPipeline(filter_list)
266
  def language_mapping(original_language):
267
  d = {