|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
import sentencepiece |
|
import torch |
|
import plotly.graph_objects as go |
|
import streamlit as st |
|
|
|
text_1 = """Bilim insanları Botsvana’da Covid-19’un şu ana kadar en çok mutasyona uğramış varyantını tespit etti. \ |
|
Resmi olarak B.1.1.529 koduyla bilinen bu varyantı ise “Nu varyantı” adı verildi. Uzmanlar bu varyant içerisinde \ |
|
tam 32 farklı mutasyon tespit edildiğini açıklarken, bu virüsün corona virüsü aşılarına karşı daha dirençli olabileceğini duyurdu.""" |
|
|
|
text_2 = """Şampiyonlar Ligi’nde 5. hafta oynanan karşılaşmaların ardından sona erdi. Real Madrid, Inter ve Sporting \ |
|
oynadıkları mücadeleler sonrasında Son 16 turuna yükselmeyi başardı. Gecenin dev mücadelesinde ise Manchester City, \ |
|
PSG’yi yenerek liderliği garantiledi.""" |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def list2text(label_list): |
|
labels = "" |
|
for label in label_list: |
|
labels = labels + label + "," |
|
labels = labels[:-1] |
|
return labels |
|
|
|
label_list_1 = ["dünya", "ekonomi", "kültür", "sağlık", "siyaset", "spor", "teknoloji"] |
|
label_list_2 = ["positive", "negative", "neutral"] |
|
|
|
st.title("Turkish Zero-Shot Text Classification \ |
|
with Multilingual XLM-RoBERTa and mDeBERTa Models") |
|
|
|
model_list = ['vicgalle/xlm-roberta-large-xnli-anli', |
|
'joeddav/xlm-roberta-large-xnli', |
|
'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7'] |
|
|
|
st.sidebar.header("Select Model") |
|
model_checkpoint = st.sidebar.radio("", model_list) |
|
|
|
st.sidebar.write("For details of models:") |
|
st.sidebar.write("https://huggingface.co/vicgalle") |
|
st.sidebar.write("https://huggingface.co/joeddav") |
|
st.sidebar.write("https://huggingface.co/MoritzLaurer") |
|
|
|
st.sidebar.write("For XNLI Dataset:") |
|
st.sidebar.write("https://huggingface.co/datasets/xnli") |
|
|
|
st.subheader("Select Text and Label List") |
|
st.text_area("Text #1", text_1, height=128) |
|
st.text_area("Text #2", text_2, height=128) |
|
st.write(f"Label List #1: {list2text(label_list_1)}") |
|
st.write(f"Label List #2: {list2text(label_list_2)}") |
|
|
|
text = st.radio("Select Text", ("Text #1", "Text #2", "New Text")) |
|
labels = st.radio("Select Label List", ("Label List #1", "Label List #2", "New Label List")) |
|
|
|
if text == "Text #1": selected_text = text_1 |
|
elif text == "Text #2": selected_text = text_2 |
|
elif text == "New Text": |
|
selected_text = st.text_area("New Text", value="", height=128) |
|
|
|
if labels == "Label List #1": selected_labels = label_list_1 |
|
elif labels == "Label List #2": selected_labels = label_list_2 |
|
elif labels == "New Label List": |
|
selected_labels = st.text_area("New Label List (Pls Input as comma-separated)", value="", height=16).split(",") |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def setModel(model_checkpoint): |
|
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer) |
|
|
|
Run_Button = st.button("Run", key=None) |
|
if Run_Button == True: |
|
|
|
with st.spinner('Model is running...'): |
|
zstc_pipeline = setModel(model_checkpoint) |
|
output = zstc_pipeline(sequences=selected_text, candidate_labels=selected_labels) |
|
output_labels = output["labels"] |
|
output_scores = output["scores"] |
|
|
|
st.header("Result") |
|
fig = go.Figure([go.Bar(x=output_labels, y=output_scores)]) |
|
st.plotly_chart(fig, use_container_width=False, sharing="streamlit") |
|
st.success('Done!') |
|
|