File size: 2,045 Bytes
e5ffa90 98b648b e5ffa90 e438cdf e5ffa90 5f726f0 98b648b aff6f4b 8cfc5e5 e5ffa90 98b648b e5ffa90 aff6f4b 98b648b aff6f4b e5ffa90 8cfc5e5 e5ffa90 adc3723 e5ffa90 98b648b 8cfc5e5 98b648b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
from tokenizer import tokenizer
import config
import gradio as gr
DEVICE = config.device
# MODEL = BERTBaseUncased()
# MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(DEVICE)))
# MODEL.eval()
T = tokenizer.TweetTokenizer(
preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
def preprocess(text):
tokens = T.tokenize(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def sentence_prediction(sentence):
sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=3
)
# device = config.device
model = BERTBaseUncased()
# model.load_state_dict(torch.load(
# model_path, map_location=torch.device(device)))
model.to(device)
outputs, [] = engine.predict_fn(test_data_loader, MODEL, device)
print(outputs)
return {"label":outputs[0]}
if __name__ == "__main__":
demo = gr.Interface(
fn=sentence_prediction,
inputs=gr.Textbox(placeholder="Enter a sentence here..."),
outputs="label",
# interpretation="default",
examples=[["!"]])
demo.launch(debug = True,
enable_queue=True,
show_error = True)
|