File size: 2,045 Bytes
e5ffa90
98b648b
e5ffa90
 
 
 
 
e438cdf
e5ffa90
5f726f0
 
98b648b
 
aff6f4b
 
 
8cfc5e5
 
e5ffa90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98b648b
e5ffa90
aff6f4b
98b648b
 
aff6f4b
e5ffa90
8cfc5e5
e5ffa90
adc3723
e5ffa90
98b648b
8cfc5e5
98b648b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
from tokenizer import tokenizer
import config

import gradio as gr

DEVICE = config.device

# MODEL = BERTBaseUncased()
# MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(DEVICE)))
# MODEL.eval()



T = tokenizer.TweetTokenizer(
    preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)

def preprocess(text):
    tokens = T.tokenize(text)
    print(tokens, file=sys.stderr)
    ptokens = []
    for index, token in enumerate(tokens):
        if "@" in token:
            if index > 0:
                # check if previous token was mention
                if "@" in tokens[index-1]:
                    pass
                else:
                    ptokens.append("mention_0")
            else:
                ptokens.append("mention_0")
        else:
            ptokens.append(token)

    print(ptokens, file=sys.stderr)
    return " ".join(ptokens)


def sentence_prediction(sentence):
    sentence = preprocess(sentence)
    model_path = config.MODEL_PATH

    test_dataset = dataset.BERTDataset(
        review=[sentence],
        target=[0]
    )

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=3
    )

    # device = config.device

    model = BERTBaseUncased()
    # model.load_state_dict(torch.load(
    #     model_path, map_location=torch.device(device)))
    model.to(device)

    outputs, [] = engine.predict_fn(test_data_loader, MODEL, device)
    print(outputs)
    return {"label":outputs[0]}

if __name__ == "__main__":
   
    
    demo = gr.Interface(
        fn=sentence_prediction, 
        inputs=gr.Textbox(placeholder="Enter a  sentence here..."), 
        outputs="label", 
        # interpretation="default",
        examples=[["!"]])
    
    demo.launch(debug = True,
                enable_queue=True,
                show_error = True)