Malak-Omar's picture
Update app.py
6c286b4 verified
import gradio as gr
import fitz # Import pymupdf
from PyPDF2 import PdfReader # Import PyPDF2
import requests # For handling URLs
from bs4 import BeautifulSoup # For extracting text from HTML
import torch
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification
# Load the saved XLM-RoBERTa model
model_save_path = "xlm_roberta_multilabel_model.pth"
# Load the tokenizer
tokenizer = XLMRobertaTokenizer.from_pretrained("tokenizer")
# Load model and move it to the device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XLMRobertaForSequenceClassification.from_pretrained("xlm-roberta-base", num_labels=8)
model.load_state_dict(torch.load(model_save_path, map_location=device))
model.to(device)
model.eval() # Set model to evaluation mode
def analyze_document(input_type, input_text, input_file, input_url):
if input_type == "Text":
document = input_text
elif input_type == "URL":
try:
response = requests.get(input_url)
soup = BeautifulSoup(response.content, "html.parser")
document = soup.get_text() # Extract all text from the webpage
except Exception as e:
return f"Error fetching data from URL: {str(e)}", 0, ""
else:
try:
# Try using pymupdf (fitz) first
with fitz.open(stream=input_file.read(), filetype="pdf") as pdf:
document = ""
for page in pdf:
document += page.get_text()
except Exception as e:
print(f"Error occurred while processing PDF with pymupdf: {str(e)}")
try:
# If pymupdf fails, try using PyPDF2
reader = PdfReader(input_file)
document = ""
for page in reader.pages:
document += page.extract_text()
except Exception as e:
print(f"Error occurred while processing PDF with PyPDF2: {str(e)}")
return f"Error: {str(e)}", 0, ""
# Tokenize the document and pass it through the XLM-RoBERTa model
try:
inputs = tokenizer(document, return_tensors="pt", max_length=512, padding="max_length", truncation=True)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
result = (probabilities > 0.5).int().cpu().numpy()[0] # Get binary predictions
except Exception as e:
return f"Error: {str(e)}", 0, ""
# Define label names and their details including sub-label counts and percentages
label_details = {
1: {"name": "User Consent", "percentage": 11.11},
2: {"name": "Data Collection and Processing", "percentage": 16.67},
3: {"name": "Data Retention", "percentage": 11.11},
4: {"name": "Data Protection and Sharing", "percentage": 16.67},
6: {"name": "User Rights", "percentage": 16.67},
7: {"name": "Advertisements", "percentage": 11.11},
8: {"name": "Breach Notification", "percentage": 11.11},
9: {"name": "Responsibility", "percentage": 5.56},
}
# Adjusted labels considering missing 5
adjusted_labels = [1, 2, 3, 4, 6, 7, 8, 9]
# Initialize the total compliance percentage
total_compliance_percentage = 0
# Initialize an empty string to store the formatted output
formatted_output = ""
# Initialize a count for the number of categories adhered to
categories_adhered = 0
# Loop through each label and its index
for i, label in enumerate(adjusted_labels):
# Check if the label is present
label_presence = result[i]
# Retrieve label details
label_name = label_details[label]["name"]
label_percentage = label_details[label]["percentage"]
# Display label presence/absence with tick/wrong mark
if label_presence == 1:
formatted_output += f"β€’ {label_name}: βœ“\n"
categories_adhered += 1
else:
formatted_output += f"β€’ {label_name}: βœ—\n"
# Add to the total compliance percentage if the label is present
if label_presence == 1:
total_compliance_percentage += label_percentage
# Determine the compliance rank based on the number of categories adhered to
if categories_adhered <= 3:
compliance_rank = "Weakly Compliant"
elif categories_adhered <= 6:
compliance_rank = "Moderately Compliant"
else:
compliance_rank = "Mostly Compliant"
# Return the formatted result, total compliance percentage, and compliance rank
return formatted_output, str(round(total_compliance_percentage, 2)) + "%", compliance_rank
def clear_input():
# Function to clear the input; returns empty string for input_component and None for input_file
return "Text", "", None, "", 0, ""
with gr.Blocks(css="""
@media screen and (max-width: 768px) {
.input-column {
order: 1;
}
.output-column {
order: 2;
}
.gr-row {
flex-direction: column-reverse;
}
}
.gr-row {
display: flex;
justify-content: center;
}
.input-column, .output-column {
width: 100%;
}
#input-textbox {
direction: ltr;
text-align: left !important;
height: 360px;
overflow-y: scroll !important;
scrollbar-width: thin !important;
scrollbar-color: #888 #ffffff !important; }
#output-textbox {
direction: ltr;
text-align: left !important;
height: 270px;
overflow-y: scroll !important;
scrollbar-width: thin !important;
scrollbar-color: #888 #ffffff !important; }
textarea::placeholder {
text-align: left;
direction: ltr;
}
input::placeholder {
text-align: left;
direction: ltr;
}
#logo img {
max-height: 100px;
width: 500;
}
.no-border {
border: none !important;
box-shadow: none !important;
background: none !important;
padding: 0 !important;
}
#centered-markdown {
text-align: center;
}
.percentage-output {
direction: ltr;
text-align: left;
display: flex;
justify-content: space-between;
width: 100%;
margin-right: "10px"
}
.compliance-rank-output {
direction: ltr;
text-align: left;
}
.PDF_file {
text-align: left;
direction: ltr;
justify-content: space-between;
}
""") as demo:
gr.Markdown("### [Arabic](https://huggingface.co/spaces/Malak-Omar/Arabic_Mumtathil)")
with gr.Row():
with gr.Column(scale=4):
gr.Markdown("", elem_id="centered-markdown")
with gr.Column(scale=0):
logo = gr.Image(value="logo.png", elem_id="logo", elem_classes="no-border")
with gr.Row():
# Input section
with gr.Column(scale=4, elem_classes=['input-column']):
input_type = gr.Radio(["Text", "PDF File", "URL"], label="Input Type", value="Text", elem_classes="PDF_file")
input_text = gr.Textbox(
lines=14, label="Privacy Policy",
placeholder="Enter text here",
elem_id="input-textbox",
rtl=False,
visible=True
)
input_file = gr.File(label="Upload PDF", file_types=[".pdf"], visible=False)
input_url = gr.Textbox(label="Privacy Policy URL", placeholder="Enter URL here", elem_id="input-textbox", visible=False)
submit_button = gr.Button("Submit")
# Output section
with gr.Column(scale=4, elem_classes=['output-column']):
output_component = gr.Textbox(lines=10, label="Privacy Policy Compliance", elem_id="output-textbox", rtl=False)
compliance_rank_output = gr.Textbox(
label="Compliance Level",
elem_classes="compliance-rank-output",
rtl=False,
info="Compliance level calculated based on the number of categories aligned with the Saudi Personal Data Protection Law")
percentage_output = gr.Textbox(
label="Compliance Percentage",
elem_classes="percentage-output",
rtl=False,
info="Compliance percentage calculated based on the importance of each category as defined by the Saudi Personal Data Protection Law")
def update_input_visibility(input_type):
clear_input() # Call the clear_input function to reset input and output
if input_type == "Text":
return gr.update(visible=True, value=""), gr.update(visible=False, value=None), gr.update(visible=False, value="")
elif input_type == "URL":
return gr.update(visible=False, value=""), gr.update(visible=False, value=None), gr.update(visible=True, value="")
else:
return gr.update(visible=False, value=""), gr.update(visible=True, value=None), gr.update(visible=False, value="")
input_type.change(fn=update_input_visibility, inputs=input_type, outputs=[input_text, input_file, input_url])
submit_button.click(fn=analyze_document, inputs=[input_type, input_text, input_file, input_url], outputs=[output_component, percentage_output, compliance_rank_output])
with gr.Row():
gr.Markdown("Compliance assessment for privacy policies with the Saudi Personal Data Protection Law", elem_id="centered-markdown")
demo.launch()