joint-demo / app.py
Shaltiel's picture
Reversed arrow direction
2cd4845
from io import BytesIO
import streamlit as st
import base64
from transformers import AutoModel, AutoTokenizer
from graphviz import Digraph
import json
def display_tree(output):
size = str(int(len(output))) + ',5'
dpi = '300'
format = 'svg'
print(size, dpi)
# Initialize Digraph object
dot = Digraph(engine='dot', format=format)
dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi)
# Add nodes and edges
for i,word_info in enumerate(output):
word = word_info['word'] # Prepare word for RTL display
head_idx = word_info['dep_head_idx']
dep_func = word_info['dep_func']
dot.node(str(i), word)
# Create an invisible edge from the previous word to this one to enforce order
if i > 0:
dot.edge(str(i), str(i - 1), style='invis')
if head_idx != -1:
dot.edge(str(head_idx), str(i), label=dep_func, constraint='False')
# Render the Digraph object
dot.render('syntax_tree', format=format, cleanup=True)
# Display the image in a scrollable container
st.markdown(
f"""
<div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw">
<img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}"
style="display: block; margin: auto; max-height: 240px;">
</div>
""", unsafe_allow_html=True)
#st.image('syntax_tree.' + format, use_column_width=True)
def display_download(disp_string):
to_download = BytesIO(disp_string.encode())
st.download_button(label="⬇️ Download text file",
data=to_download,
file_name="parsed_output.txt",
mime="text/plain")
# Streamlit app title
st.title('DictaBERT-Joint Visualizer')
# Load Hugging Face token
hf_token = st.secrets["HF_TOKEN"] # Assuming you've set up the token in Streamlit secrets
# Authenticate and load model
tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token)
model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True)
model.eval()
# Checkbox for the compute_mst parameter
compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True)
output_style = st.selectbox(
'Output Style: ',
('JSON', 'UD', 'IAHLT_UD'), index=1).lower()
# User input
sentence = st.text_input('Enter a sentence to analyze:')
if sentence:
# Display the input sentence
st.text(sentence)
# Model prediction
output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0]
if output_style == 'ud' or output_style == 'iahlt_ud':
ud_output = output
# convert to tree format of [dict(word, dep_head_idx, dep_func)]
tree = []
for l in ud_output[2:]:
parts = l.split('\t')
if '-' in parts[0]: continue
tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7]))
display_tree(tree)
display_download('\n'.join(ud_output))
# Construct the table as a Markdown string
table_md = "<div dir='rtl' style='text-align: right;'>\n\n" # Start with RTL div
# Add the UD header lines
table_md += "##" + ud_output[0] + "\n"
table_md += "##" + ud_output[1] + "\n"
# Table header
table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n"
# Table alignment
table_md += "| " + " | ".join(["---"]*10) + " |\n"
for line in ud_output[2:]:
# Each UD line as a table row
cells = line.replace('_', '\\_').replace('|', '&#124;').replace(':', '&colon;').split('\t')
table_md += "| " + " | ".join(cells) + " |\n"
table_md += "</div>" # Close the RTL div
print(table_md)
# Display the table using a single markdown call
st.markdown(table_md, unsafe_allow_html=True)
else:
# display the tree
tree = [w['syntax'] for w in output['tokens']]
display_tree(tree)
json_output = json.dumps(output, ensure_ascii=False, indent=2)
display_download(json_output)
# and the full json
st.markdown("```json\n" + json_output + "\n```")