Spaces:
Sleeping
Sleeping
import gradio as gr | |
import fitz # Import pymupdf | |
from PyPDF2 import PdfReader # Import PyPDF2 | |
import requests # For handling URLs | |
from bs4 import BeautifulSoup # For extracting text from HTML | |
import torch | |
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification | |
# Load the saved XLM-RoBERTa model | |
model_save_path = "xlm_roberta_multilabel_model.pth" | |
# Load the tokenizer | |
tokenizer = XLMRobertaTokenizer.from_pretrained("tokenizer") | |
# Load model and move it to the device (GPU/CPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = XLMRobertaForSequenceClassification.from_pretrained("xlm-roberta-base", num_labels=8) | |
model.load_state_dict(torch.load(model_save_path, map_location=device)) | |
model.to(device) | |
model.eval() # Set model to evaluation mode | |
def analyze_document(input_type, input_text, input_file, input_url): | |
if input_type == "Text": | |
document = input_text | |
elif input_type == "URL": | |
try: | |
response = requests.get(input_url) | |
soup = BeautifulSoup(response.content, "html.parser") | |
document = soup.get_text() # Extract all text from the webpage | |
except Exception as e: | |
return f"Error fetching data from URL: {str(e)}", 0, "" | |
else: | |
try: | |
# Try using pymupdf (fitz) first | |
with fitz.open(stream=input_file.read(), filetype="pdf") as pdf: | |
document = "" | |
for page in pdf: | |
document += page.get_text() | |
except Exception as e: | |
print(f"Error occurred while processing PDF with pymupdf: {str(e)}") | |
try: | |
# If pymupdf fails, try using PyPDF2 | |
reader = PdfReader(input_file) | |
document = "" | |
for page in reader.pages: | |
document += page.extract_text() | |
except Exception as e: | |
print(f"Error occurred while processing PDF with PyPDF2: {str(e)}") | |
return f"Error: {str(e)}", 0, "" | |
# Tokenize the document and pass it through the XLM-RoBERTa model | |
try: | |
inputs = tokenizer(document, return_tensors="pt", max_length=512, padding="max_length", truncation=True) | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.sigmoid(logits) | |
result = (probabilities > 0.5).int().cpu().numpy()[0] # Get binary predictions | |
except Exception as e: | |
return f"Error: {str(e)}", 0, "" | |
# Define label names and their details including sub-label counts and percentages | |
label_details = { | |
1: {"name": "User Consent", "percentage": 11.11}, | |
2: {"name": "Data Collection and Processing", "percentage": 16.67}, | |
3: {"name": "Data Retention", "percentage": 11.11}, | |
4: {"name": "Data Protection and Sharing", "percentage": 16.67}, | |
6: {"name": "User Rights", "percentage": 16.67}, | |
7: {"name": "Advertisements", "percentage": 11.11}, | |
8: {"name": "Breach Notification", "percentage": 11.11}, | |
9: {"name": "Responsibility", "percentage": 5.56}, | |
} | |
# Adjusted labels considering missing 5 | |
adjusted_labels = [1, 2, 3, 4, 6, 7, 8, 9] | |
# Initialize the total compliance percentage | |
total_compliance_percentage = 0 | |
# Initialize an empty string to store the formatted output | |
formatted_output = "" | |
# Initialize a count for the number of categories adhered to | |
categories_adhered = 0 | |
# Loop through each label and its index | |
for i, label in enumerate(adjusted_labels): | |
# Check if the label is present | |
label_presence = result[i] | |
# Retrieve label details | |
label_name = label_details[label]["name"] | |
label_percentage = label_details[label]["percentage"] | |
# Display label presence/absence with tick/wrong mark | |
if label_presence == 1: | |
formatted_output += f"β’ {label_name}: β\n" | |
categories_adhered += 1 | |
else: | |
formatted_output += f"β’ {label_name}: β\n" | |
# Add to the total compliance percentage if the label is present | |
if label_presence == 1: | |
total_compliance_percentage += label_percentage | |
# Determine the compliance rank based on the number of categories adhered to | |
if categories_adhered <= 3: | |
compliance_rank = "Weakly Compliant" | |
elif categories_adhered <= 6: | |
compliance_rank = "Moderately Compliant" | |
else: | |
compliance_rank = "Mostly Compliant" | |
# Return the formatted result, total compliance percentage, and compliance rank | |
return formatted_output, str(round(total_compliance_percentage, 2)) + "%", compliance_rank | |
def clear_input(): | |
# Function to clear the input; returns empty string for input_component and None for input_file | |
return "Text", "", None, "", 0, "" | |
with gr.Blocks(css=""" | |
@media screen and (max-width: 768px) { | |
.input-column { | |
order: 1; | |
} | |
.output-column { | |
order: 2; | |
} | |
.gr-row { | |
flex-direction: column-reverse; | |
} | |
} | |
.gr-row { | |
display: flex; | |
justify-content: center; | |
} | |
.input-column, .output-column { | |
width: 100%; | |
} | |
#input-textbox { | |
direction: ltr; | |
text-align: left !important; | |
height: 360px; | |
overflow-y: scroll !important; | |
scrollbar-width: thin !important; | |
scrollbar-color: #888 #ffffff !important; } | |
#output-textbox { | |
direction: ltr; | |
text-align: left !important; | |
height: 270px; | |
overflow-y: scroll !important; | |
scrollbar-width: thin !important; | |
scrollbar-color: #888 #ffffff !important; } | |
textarea::placeholder { | |
text-align: left; | |
direction: ltr; | |
} | |
input::placeholder { | |
text-align: left; | |
direction: ltr; | |
} | |
#logo img { | |
max-height: 100px; | |
width: 500; | |
} | |
.no-border { | |
border: none !important; | |
box-shadow: none !important; | |
background: none !important; | |
padding: 0 !important; | |
} | |
#centered-markdown { | |
text-align: center; | |
} | |
.percentage-output { | |
direction: ltr; | |
text-align: left; | |
display: flex; | |
justify-content: space-between; | |
width: 100%; | |
margin-right: "10px" | |
} | |
.compliance-rank-output { | |
direction: ltr; | |
text-align: left; | |
} | |
.PDF_file { | |
text-align: left; | |
direction: ltr; | |
justify-content: space-between; | |
} | |
""") as demo: | |
gr.Markdown("### [Arabic](https://huggingface.co/spaces/Malak-Omar/Arabic_Mumtathil)") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
gr.Markdown("", elem_id="centered-markdown") | |
with gr.Column(scale=0): | |
logo = gr.Image(value="logo.png", elem_id="logo", elem_classes="no-border") | |
with gr.Row(): | |
# Input section | |
with gr.Column(scale=4, elem_classes=['input-column']): | |
input_type = gr.Radio(["Text", "PDF File", "URL"], label="Input Type", value="Text", elem_classes="PDF_file") | |
input_text = gr.Textbox( | |
lines=14, label="Privacy Policy", | |
placeholder="Enter text here", | |
elem_id="input-textbox", | |
rtl=False, | |
visible=True | |
) | |
input_file = gr.File(label="Upload PDF", file_types=[".pdf"], visible=False) | |
input_url = gr.Textbox(label="Privacy Policy URL", placeholder="Enter URL here", elem_id="input-textbox", visible=False) | |
submit_button = gr.Button("Submit") | |
# Output section | |
with gr.Column(scale=4, elem_classes=['output-column']): | |
output_component = gr.Textbox(lines=10, label="Privacy Policy Compliance", elem_id="output-textbox", rtl=False) | |
compliance_rank_output = gr.Textbox( | |
label="Compliance Level", | |
elem_classes="compliance-rank-output", | |
rtl=False, | |
info="Compliance level calculated based on the number of categories aligned with the Saudi Personal Data Protection Law") | |
percentage_output = gr.Textbox( | |
label="Compliance Percentage", | |
elem_classes="percentage-output", | |
rtl=False, | |
info="Compliance percentage calculated based on the importance of each category as defined by the Saudi Personal Data Protection Law") | |
def update_input_visibility(input_type): | |
clear_input() # Call the clear_input function to reset input and output | |
if input_type == "Text": | |
return gr.update(visible=True, value=""), gr.update(visible=False, value=None), gr.update(visible=False, value="") | |
elif input_type == "URL": | |
return gr.update(visible=False, value=""), gr.update(visible=False, value=None), gr.update(visible=True, value="") | |
else: | |
return gr.update(visible=False, value=""), gr.update(visible=True, value=None), gr.update(visible=False, value="") | |
input_type.change(fn=update_input_visibility, inputs=input_type, outputs=[input_text, input_file, input_url]) | |
submit_button.click(fn=analyze_document, inputs=[input_type, input_text, input_file, input_url], outputs=[output_component, percentage_output, compliance_rank_output]) | |
with gr.Row(): | |
gr.Markdown("Compliance assessment for privacy policies with the Saudi Personal Data Protection Law", elem_id="centered-markdown") | |
demo.launch() | |