princepride commited on
Commit
7f13f1f
1 Parent(s): 1916f39

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +33 -22
model.py CHANGED
@@ -148,13 +148,13 @@ class SimilarFilter(Filter):
148
  return decoded_inputs
149
 
150
  class ChineseFilter:
151
- def __init__(self, pinyin_lib_file='./pinyin.txt'):
152
  self.name = 'chinese filter'
153
  self.code = []
154
  self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file)
155
 
156
  def load_pinyin_lib(self, file_path):
157
- with open(file_path, 'r', encoding='utf-8') as f:
158
  return set(line.strip().lower() for line in f)
159
 
160
  def is_valid_chinese(self, word):
@@ -407,27 +407,38 @@ class Model():
407
  print(batch)
408
  batch = filter_pipeline.batch_encoder(batch)
409
  print(batch)
410
- input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
411
  temp = []
412
- for target_language in target_languages:
413
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
414
- generated_tokens = self.model.generate(
415
- **input_ids,
416
- forced_bos_token_id=target_lang_code,
417
- )
418
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
419
-
420
- print(generated_translation)
421
- generated_translation = filter_pipeline.batch_decoder(generated_translation)
422
- print(generated_translation)
423
- print(len(generated_translation))
424
- # Append result to output
425
- temp.append({
426
- "target_language": target_language,
427
- "generated_translation": generated_translation,
428
- })
429
- input_ids.to('cpu')
430
- del input_ids
 
 
 
 
 
 
 
 
 
 
 
 
431
  temp_outputs.append(temp)
432
  processed_num += len(batch)
433
  if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
 
148
  return decoded_inputs
149
 
150
  class ChineseFilter:
151
+ def __init__(self, pinyin_lib_file='pinyin.txt'):
152
  self.name = 'chinese filter'
153
  self.code = []
154
  self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file)
155
 
156
  def load_pinyin_lib(self, file_path):
157
+ with open(os.path.join(script_dir,file_path), 'r', encoding='utf-8') as f:
158
  return set(line.strip().lower() for line in f)
159
 
160
  def is_valid_chinese(self, word):
 
407
  print(batch)
408
  batch = filter_pipeline.batch_encoder(batch)
409
  print(batch)
 
410
  temp = []
411
+ if len(batch) > 0:
412
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
413
+ for target_language in target_languages:
414
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
415
+ generated_tokens = self.model.generate(
416
+ **input_ids,
417
+ forced_bos_token_id=target_lang_code,
418
+ )
419
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
420
+
421
+ print(generated_translation)
422
+ generated_translation = filter_pipeline.batch_decoder(generated_translation)
423
+ print(generated_translation)
424
+ print(len(generated_translation))
425
+ # Append result to output
426
+ temp.append({
427
+ "target_language": target_language,
428
+ "generated_translation": generated_translation,
429
+ })
430
+ input_ids.to('cpu')
431
+ del input_ids
432
+ else:
433
+ for target_language in target_languages:
434
+ generated_translation = filter_pipeline.batch_decoder(batch)
435
+ print(generated_translation)
436
+ print(len(generated_translation))
437
+ # Append result to output
438
+ temp.append({
439
+ "target_language": target_language,
440
+ "generated_translation": generated_translation,
441
+ })
442
  temp_outputs.append(temp)
443
  processed_num += len(batch)
444
  if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1: