princepride
commited on
Commit
•
7f13f1f
1
Parent(s):
1916f39
Update model.py
Browse files
model.py
CHANGED
@@ -148,13 +148,13 @@ class SimilarFilter(Filter):
|
|
148 |
return decoded_inputs
|
149 |
|
150 |
class ChineseFilter:
|
151 |
-
def __init__(self, pinyin_lib_file='
|
152 |
self.name = 'chinese filter'
|
153 |
self.code = []
|
154 |
self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file)
|
155 |
|
156 |
def load_pinyin_lib(self, file_path):
|
157 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
158 |
return set(line.strip().lower() for line in f)
|
159 |
|
160 |
def is_valid_chinese(self, word):
|
@@ -407,27 +407,38 @@ class Model():
|
|
407 |
print(batch)
|
408 |
batch = filter_pipeline.batch_encoder(batch)
|
409 |
print(batch)
|
410 |
-
input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
|
411 |
temp = []
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
temp_outputs.append(temp)
|
432 |
processed_num += len(batch)
|
433 |
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
|
|
|
148 |
return decoded_inputs
|
149 |
|
150 |
class ChineseFilter:
|
151 |
+
def __init__(self, pinyin_lib_file='pinyin.txt'):
|
152 |
self.name = 'chinese filter'
|
153 |
self.code = []
|
154 |
self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file)
|
155 |
|
156 |
def load_pinyin_lib(self, file_path):
|
157 |
+
with open(os.path.join(script_dir,file_path), 'r', encoding='utf-8') as f:
|
158 |
return set(line.strip().lower() for line in f)
|
159 |
|
160 |
def is_valid_chinese(self, word):
|
|
|
407 |
print(batch)
|
408 |
batch = filter_pipeline.batch_encoder(batch)
|
409 |
print(batch)
|
|
|
410 |
temp = []
|
411 |
+
if len(batch) > 0:
|
412 |
+
input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
|
413 |
+
for target_language in target_languages:
|
414 |
+
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
|
415 |
+
generated_tokens = self.model.generate(
|
416 |
+
**input_ids,
|
417 |
+
forced_bos_token_id=target_lang_code,
|
418 |
+
)
|
419 |
+
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
420 |
+
|
421 |
+
print(generated_translation)
|
422 |
+
generated_translation = filter_pipeline.batch_decoder(generated_translation)
|
423 |
+
print(generated_translation)
|
424 |
+
print(len(generated_translation))
|
425 |
+
# Append result to output
|
426 |
+
temp.append({
|
427 |
+
"target_language": target_language,
|
428 |
+
"generated_translation": generated_translation,
|
429 |
+
})
|
430 |
+
input_ids.to('cpu')
|
431 |
+
del input_ids
|
432 |
+
else:
|
433 |
+
for target_language in target_languages:
|
434 |
+
generated_translation = filter_pipeline.batch_decoder(batch)
|
435 |
+
print(generated_translation)
|
436 |
+
print(len(generated_translation))
|
437 |
+
# Append result to output
|
438 |
+
temp.append({
|
439 |
+
"target_language": target_language,
|
440 |
+
"generated_translation": generated_translation,
|
441 |
+
})
|
442 |
temp_outputs.append(temp)
|
443 |
processed_num += len(batch)
|
444 |
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
|