|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
|
|
@st.cache_data |
|
def get_results(text): |
|
pipe = pipeline( |
|
"text-classification", |
|
model="harshildarji/privacy-policy-relation-extraction", |
|
return_all_scores=True, |
|
framework="pt", |
|
) |
|
return pipe(text) |
|
|
|
|
|
st.title("Privacy Policy Relation Extraction") |
|
|
|
example = st.text_area( |
|
"Enter text:", |
|
value="We store your basic account information, including your name, username, and email address until you ask us to delete them.", |
|
height=150, |
|
) |
|
|
|
if st.button("Analyze"): |
|
with st.spinner("Processing..."): |
|
results = get_results(example) |
|
st.session_state.results = results |
|
|
|
if "results" in st.session_state: |
|
threshold = st.slider( |
|
"Confidence threshold:", min_value=0.0, max_value=1.0, value=0.5 |
|
) |
|
filtered_results = [ |
|
result for result in st.session_state.results[0] if result["score"] >= threshold |
|
] |
|
|
|
sorted_results = sorted(filtered_results, key=lambda x: x["score"], reverse=True) |
|
|
|
if sorted_results: |
|
for result in sorted_results: |
|
cols = st.columns([3, 5, 0.5]) |
|
with cols[0]: |
|
st.write(f"**{result['label']}**") |
|
with cols[1]: |
|
st.progress(int(result["score"] * 100)) |
|
with cols[2]: |
|
st.write(f"**{result['score']:.2f}**") |
|
else: |
|
st.warning("No relations found with the specified threshold.") |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.reportview-container { |
|
background: #f0f2f6; |
|
padding: 20px; |
|
} |
|
.stButton button { |
|
border: 0.5; |
|
transition: background-color 0.3s, transform 0.2s; |
|
border-radius: 10px; |
|
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08); |
|
} |
|
.stButton button:hover { |
|
transform: translateY(-1px); |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|