Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
from transformers import pipeline | |
from collections import defaultdict | |
import torch | |
# device = torch.device("cuda") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner") | |
model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner") | |
model.to(device) | |
# ์ก์ฅ์ด๋ผ ์ถ์ ๋๋๋ถ๋ถ์ craft์ ํต๊ณผ์ํค๊ณ text ๊ฐ ์๋๋ถ๋ถ์ ํฌ๋กญํด์ trocr๋ก text๋ฅผ ๊ทธ ์์ญ์ ๋ฝ์๋ธ์ดํ ํ๋ก์ธ์ค์ ๋๋ค. | |
# ๋ฝํ text์ ๋ํ class๋ฅผ ํ๋ณํฉ๋๋ค. | |
# text์ ๋ํ class๊ฐ "์ฌ๋์ด๋ฆ PS", "๋๋ก/๊ฑด๋ฌผ ์ด๋ฆ AF", "์ฃผ์ LC" ์ ์ํ๋ฉด 1์ ๋ฐํํ์ฌ ์ดํ ๋ชจ์์ดํฌ ํ๋๋กํฉ๋๋ค. | |
# ner ๋ชจ๋ธ์ text๋ฅผ ์ด์ ๋ง๋ค ์ชผ๊ฐ์ ๊ฐ ๋จ์ด์ ๋ํ class๋ฅผ ๋ฐํํฉ๋๋ค. | |
# ์ด ๋ ๋ชจ๋ ๋จ์ด์ ๋ํ class๋ฅผ ๊ณ ๋ คํ๋ค๋ณด๋ฉด infer speed ๊ฐ ๋งค์ฐ๋๋ ค์ ์ต์ํ ํ๋๋ผ๋ ps,af,lc ํด๋์ค ํด๋น ๋จ์ด๊ฐ ์์ผ๋ฉด 1 ๋ฐํํ๋๋กํฉ๋๋ค. | |
def check_entity(entities): | |
for entity_info in entities: | |
entity_value = entity_info.get('entity', '').upper() | |
if 'LC' in entity_value or 'PS' in entity_value or 'AF' in entity_value: | |
return 1 | |
return 0 | |
def ner(example): | |
ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device) | |
ner_results = ner(example) | |
ner_results=check_entity(ner_results) | |
return ner_results | |
# ํ๋ | |
# def find_longest_value_key(input_dict): | |
# max_length = 0 | |
# max_length_keys = [] | |
# for key, value in input_dict.items(): | |
# current_length = len(value) | |
# if current_length > max_length: | |
# max_length = current_length | |
# max_length_keys = [key] | |
# elif current_length == max_length: | |
# max_length_keys.append(key) | |
# if len(max_length_keys) == 1: | |
# return 0 | |
# else: | |
# return 1 | |
# def find_longest_value_key2(input_dict): | |
# if not input_dict: | |
# return None | |
# max_key = max(input_dict, key=lambda k: len(input_dict[k])) | |
# return max_key | |
# def find_most_frequent_entity(entities): | |
# entity_counts = defaultdict(list) | |
# for item in entities: | |
# split_entity = item['entity'].split('-') | |
# entity_type = split_entity[1] | |
# entity_counts[entity_type].append(item['score']) | |
# number=find_longest_value_key(entity_counts) | |
# if number==1: | |
# max_entities = [] | |
# max_score_average = -1 | |
# for entity, scores in entity_counts.items(): | |
# score_average = sum(scores) / len(scores) | |
# if score_average > max_score_average: | |
# max_entities = [entity] | |
# max_score_average = score_average | |
# elif score_average == max_score_average: | |
# max_entities.append(entity) | |
# if len(max_entities)>0: | |
# return max_entities if len(max_entities) > 1 else max_entities[0] | |
# else: | |
# return "Do not mosaik" | |
# else: | |
# A=find_longest_value_key2(entity_counts) | |
# return A | |
# ํ๋๋ผ๋ ps ๋ lc ๊ฐ ์์ผ๋ฉด ๋ฐ๋ก ps , lc ๊บผ๋ด๊ธฐ | |
# label=filtering(ner_results) | |
# if label.find("PS")>-1 or label.find("LC")>-1: | |
# return 1 | |
# else: | |
# return 0 | |
#print(ner("ํ๊ธธ๋")) | |
#label=check_label(example) | |