esun-choi's picture
Update ner.py
90881fe
raw
history blame contribute delete
No virus
3.43 kB
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
from collections import defaultdict
import torch
# device = torch.device("cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
model.to(device)
# ์†ก์žฅ์ด๋ผ ์ถ”์ •๋˜๋Š”๋ถ€๋ถ„์„ craft์— ํ†ต๊ณผ์‹œํ‚ค๊ณ  text ๊ฐ€ ์žˆ๋Š”๋ถ€๋ถ„์„ ํฌ๋กญํ•ด์„œ trocr๋กœ text๋ฅผ ๊ทธ ์˜์—ญ์— ๋ฝ‘์•„๋‚ธ์ดํ›„ ํ”„๋กœ์„ธ์Šค์ž…๋‹ˆ๋‹ค.
# ๋ฝ‘ํžŒ text์— ๋Œ€ํ•œ class๋ฅผ ํŒ๋ณ„ํ•ฉ๋‹ˆ๋‹ค.
# text์— ๋Œ€ํ•œ class๊ฐ€ "์‚ฌ๋žŒ์ด๋ฆ„ PS", "๋„๋กœ/๊ฑด๋ฌผ ์ด๋ฆ„ AF", "์ฃผ์†Œ LC" ์— ์†ํ•˜๋ฉด 1์„ ๋ฐ˜ํ™˜ํ•˜์—ฌ ์ดํ›„ ๋ชจ์ž์ดํฌ ํ•˜๋„๋กํ•ฉ๋‹ˆ๋‹ค.
# ner ๋ชจ๋ธ์€ text๋ฅผ ์–ด์ ˆ ๋งˆ๋‹ค ์ชผ๊ฐœ์„œ ๊ฐ ๋‹จ์–ด์— ๋Œ€ํ•œ class๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
# ์ด ๋•Œ ๋ชจ๋“  ๋‹จ์–ด์— ๋Œ€ํ•œ class๋ฅผ ๊ณ ๋ คํ•˜๋‹ค๋ณด๋ฉด infer speed ๊ฐ€ ๋งค์šฐ๋Š๋ ค์„œ ์ตœ์†Œํ•œ ํ•˜๋‚˜๋ผ๋„ ps,af,lc ํด๋ž˜์Šค ํ•ด๋‹น ๋‹จ์–ด๊ฐ€ ์žˆ์œผ๋ฉด 1 ๋ฐ˜ํ™˜ํ•˜๋„๋กํ•ฉ๋‹ˆ๋‹ค.
def check_entity(entities):
for entity_info in entities:
entity_value = entity_info.get('entity', '').upper()
if 'LC' in entity_value or 'PS' in entity_value or 'AF' in entity_value:
return 1
return 0
def ner(example):
ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device)
ner_results = ner(example)
ner_results=check_entity(ner_results)
return ner_results
# ํ•˜๋‚˜
# def find_longest_value_key(input_dict):
# max_length = 0
# max_length_keys = []
# for key, value in input_dict.items():
# current_length = len(value)
# if current_length > max_length:
# max_length = current_length
# max_length_keys = [key]
# elif current_length == max_length:
# max_length_keys.append(key)
# if len(max_length_keys) == 1:
# return 0
# else:
# return 1
# def find_longest_value_key2(input_dict):
# if not input_dict:
# return None
# max_key = max(input_dict, key=lambda k: len(input_dict[k]))
# return max_key
# def find_most_frequent_entity(entities):
# entity_counts = defaultdict(list)
# for item in entities:
# split_entity = item['entity'].split('-')
# entity_type = split_entity[1]
# entity_counts[entity_type].append(item['score'])
# number=find_longest_value_key(entity_counts)
# if number==1:
# max_entities = []
# max_score_average = -1
# for entity, scores in entity_counts.items():
# score_average = sum(scores) / len(scores)
# if score_average > max_score_average:
# max_entities = [entity]
# max_score_average = score_average
# elif score_average == max_score_average:
# max_entities.append(entity)
# if len(max_entities)>0:
# return max_entities if len(max_entities) > 1 else max_entities[0]
# else:
# return "Do not mosaik"
# else:
# A=find_longest_value_key2(entity_counts)
# return A
# ํ•˜๋‚˜๋ผ๋„ ps ๋‚˜ lc ๊ฐ€ ์žˆ์œผ๋ฉด ๋ฐ”๋กœ ps , lc ๊บผ๋‚ด๊ธฐ
# label=filtering(ner_results)
# if label.find("PS")>-1 or label.find("LC")>-1:
# return 1
# else:
# return 0
#print(ner("ํ™๊ธธ๋™"))
#label=check_label(example)