File size: 3,021 Bytes
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10eaeda
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from metrics import calc_metrics
import gradio as gr
from openai import OpenAI
import os

from transformers import pipeline
# from dotenv import load_dotenv, find_dotenv
import huggingface_hub
import json
# from simcse import SimCSE # use for gpt
from evaluate_data import store_sample_data, get_metrics_trf

store_sample_data()



with open('./data/sample_data.json', 'r') as f:
    # sample_data = [
    #     {'id': "", 'text': "", 'orgs': ["", ""]}
    # ]
    sample_data = json.load(f)
    
# _ = load_dotenv(find_dotenv()) # read local .env file
hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)

pipe = pipeline("token-classification", model="elshehawy/finer-ord-transformers", aggregation_strategy="first")


llm_model = 'gpt-3.5-turbo-0125'
# openai.api_key = os.environ['OPENAI_API_KEY']

client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
)


def get_completion(prompt, model=llm_model):
    messages = [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        messages=messages,
        model=model,
        temperature=0,
    )
    return response.choices[0].message.content



def find_orgs_gpt(sentence):
    prompt = f"""
    In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
    
    text:
    ```
    {sentence}
    ```
    You should output only a list of organizations and follow this output format exactly: ["org_1", "org_2", "org_3"]
    """
    
    sent_orgs_str = get_completion(prompt)
    sent_orgs = json.loads(sent_orgs_str)
    
    return sent_orgs


    
# def find_orgs_trf(sentence):
#     org_list = []
#     for ent in pipe(sentence):
#         if ent['entity_group'] == 'ORG':
#             # message += f'\n- {ent["word"]} \t- score: {ent["score"]}'
#             # message += f'\n- {ent["word"]}'# \t- score: {ent["score"]}'
#             org_list.append(ent['word'])
#     return list(set(org_list))


true_orgs = [sent['orgs'] for sent in sample_data]

predicted_orgs_gpt = [find_orgs_gpt(sent['text']) for sent in sample_data]
# predicted_orgs_trf = [find_orgs_trf(sent['text']) for sent in sample_data]

all_metrics = {}

# sim_model = SimCSE('sentence-transformers/all-MiniLM-L6-v2')
# all_metrics['gpt'] = calc_metrics(true_orgs, predicted_orgs_gpt, sim_model)
print('Finiding all metrics trf')
all_metrics['trf'] = get_metrics_trf()



# example = """
# My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.

# """
def find_orgs(sentence, choice):
    return all_metrics
radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)

iface = gr.Interface(fn=find_orgs, inputs=[textbox, radio_btn], outputs="text",  examples=[[example]])
iface.launch(share=True)