File size: 4,478 Bytes
15d2ecf
6614d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd4845
6614d86
 
 
 
 
 
 
 
 
 
 
 
 
 
15d2ecf
 
 
 
 
 
 
 
6614d86
8739835
6614d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15d2ecf
 
6614d86
 
 
 
 
 
 
 
 
 
 
 
1ae6640
6614d86
 
02a3276
6614d86
 
 
 
 
 
 
 
 
15d2ecf
 
 
6614d86
15d2ecf
6614d86
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from io import BytesIO
import streamlit as st
import base64
from transformers import AutoModel, AutoTokenizer
from graphviz import Digraph
import json

def display_tree(output):
    size = str(int(len(output))) + ',5'
    dpi = '300'
    format = 'svg'
    print(size, dpi)
    
    # Initialize Digraph object
    dot = Digraph(engine='dot', format=format)
    dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi)
    
    # Add nodes and edges
    for i,word_info in enumerate(output):
        word = word_info['word']  # Prepare word for RTL display
        head_idx = word_info['dep_head_idx']
        dep_func = word_info['dep_func']
        
        dot.node(str(i), word)
        # Create an invisible edge from the previous word to this one to enforce order
        if i > 0:
            dot.edge(str(i), str(i - 1), style='invis')
        if head_idx != -1:
            dot.edge(str(head_idx), str(i), label=dep_func, constraint='False')


    # Render the Digraph object
    dot.render('syntax_tree', format=format, cleanup=True)
    # Display the image in a scrollable container
    st.markdown(
        f"""
            <div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw">
                <img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}" 
                    style="display: block; margin: auto; max-height: 240px;">
            </div>
        """, unsafe_allow_html=True)
    
    #st.image('syntax_tree.' + format, use_column_width=True)

def display_download(disp_string):
    to_download = BytesIO(disp_string.encode())
    st.download_button(label="⬇️ Download text file", 
                    data=to_download, 
                    file_name="parsed_output.txt", 
                    mime="text/plain")

# Streamlit app title
st.title('DictaBERT-Joint Visualizer')

# Load Hugging Face token
hf_token = st.secrets["HF_TOKEN"]  # Assuming you've set up the token in Streamlit secrets

# Authenticate and load model
tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token)
model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True)

model.eval()

# Checkbox for the compute_mst parameter
compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True)

output_style = st.selectbox(
    'Output Style: ',
    ('JSON', 'UD', 'IAHLT_UD'), index=1).lower()

# User input
sentence = st.text_input('Enter a sentence to analyze:')

if sentence:
    # Display the input sentence
    st.text(sentence)

    # Model prediction
    output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0]
    
    if output_style == 'ud' or output_style == 'iahlt_ud':
        ud_output = output
        # convert to tree format of [dict(word, dep_head_idx, dep_func)]
        tree = []
        for l in ud_output[2:]:
            parts = l.split('\t')
            if '-' in parts[0]: continue
            tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7]))
        display_tree(tree)

        display_download('\n'.join(ud_output))

        # Construct the table as a Markdown string
        table_md = "<div dir='rtl' style='text-align: right;'>\n\n"  # Start with RTL div
        
        # Add the UD header lines
        table_md += "##" + ud_output[0] + "\n"
        table_md += "##" + ud_output[1] + "\n"
        # Table header
        table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n"
        # Table alignment
        table_md += "| " + " | ".join(["---"]*10) + " |\n"
        for line in ud_output[2:]:
            # Each UD line as a table row
            cells = line.replace('_', '\\_').replace('|', '&#124;').replace(':', '&colon;').split('\t')
            table_md += "| " + " | ".join(cells) + " |\n"
        table_md += "</div>"  # Close the RTL div
        print(table_md)
        
        # Display the table using a single markdown call
        st.markdown(table_md, unsafe_allow_html=True)

    else:
        # display the tree
        tree = [w['syntax'] for w in output['tokens']]
        display_tree(tree)
        
        json_output = json.dumps(output, ensure_ascii=False, indent=2)
        display_download(json_output)

        # and the full json
        st.markdown("```json\n" + json_output + "\n```")