Sevixdd's picture
Update app.py
295eea2 verified
import logging
import gradio as gr
from queue import Queue
import time
from prometheus_client import start_http_server, Counter, Histogram, Gauge
import threading
import psutil
import random
from transformers import pipeline
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import requests
from datasets import load_dataset
import os
from logging import FileHandler
from typing import Iterable
# Ensure the log files exist
log_file_path = 'chat_log.log'
debug_log_file_path = 'debug.log'
if not os.path.exists(log_file_path):
with open(log_file_path, 'w') as f:
f.write(" ")
if not os.path.exists(debug_log_file_path):
with open(debug_log_file_path, 'w') as f:
f.write(" ")
# Create logger instance
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) # Set logger level to the lowest level needed
#Create formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
# Create handlers
info_handler = FileHandler( filename=log_file_path, mode='w+')
info_handler.setLevel(logging.INFO)
info_handler.setFormatter(formatter)
debug_handler = FileHandler(filename=debug_log_file_path, mode='w+')
debug_handler.setLevel(logging.DEBUG)
debug_handler.setFormatter(formatter)
# Function to capture logs for Gradio display
class GradioHandler(logging.Handler):
def __init__(self, logs_queue):
super().__init__()
self.logs_queue = logs_queue
def emit(self, record):
log_entry = self.format(record)
self.logs_queue.put(log_entry)
# Create a logs queue
logs_queue = Queue()
# Create and configure Gradio handler
gradio_handler = GradioHandler(logs_queue)
gradio_handler.setLevel(logging.INFO)
gradio_handler.setFormatter(formatter)
# Add handlers to the logger
logger.addHandler(info_handler)
logger.addHandler(debug_handler)
logger.addHandler(gradio_handler)
# Load the model
try:
ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
logger.debug("NER pipeline loaded.")
except Exception as e:
logger.debug(f"Error loading NER pipeline: {e}")
# Load the dataset
try:
dataset = load_dataset("surrey-nlp/PLOD-filtered")
logger.debug("Dataset loaded.")
except Exception as e:
logger.debug(f"Error loading dataset: {e}")
# --- Prometheus Metrics Setup ---
try:
REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors')
RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes')
CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent')
MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent')
QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue')
logger.debug("Prometheus metrics setup complete.")
except Exception as e:
logger.debug(f"Error setting up Prometheus metrics: {e}")
# --- Queue and Metrics ---
chat_queue = Queue() # Define chat_queue globally
label_mapping = {
0: 'B-O',
1: 'B-AC',
3: 'B-LF',
4: 'I-LF'
}
def classification(message):
# Predict using the model
ner_results = ner_pipeline(" ".join(message))
detailed_response = []
model_predicted_labels = []
for result in ner_results:
token = result['word']
score = result['score']
entity = result['entity']
label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
model_predicted_labels.append(label_mapping[label_id])
detailed_response.append(f"Token: {token}, Entity: {label_mapping[label_id]}, Score: {score:.4f}")
response = "\n".join(detailed_response)
response_size = len(response.encode('utf-8'))
RESPONSE_SIZE.observe(response_size)
time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
return response, model_predicted_labels
# --- Chat Function with Monitoring ---
def chat_function(input, datasets):
logger.debug("Starting chat_function")
with REQUEST_LATENCY.time():
REQUEST_COUNT.inc()
try:
if input.isnumeric():
chat_queue.put(input)
# Get the example from the dataset
if datasets:
example = datasets[int(input)]
else:
example = dataset['train'][int(input)]
tokens = example['tokens']
ground_truth_labels = [label_mapping[label] for label in example['ner_tags']]
# Call the classification function
response, model_predicted_labels = classification(tokens)
# Ensure the model and ground truth labels are the same length for comparison
model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)]
precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
accuracy = accuracy_score(ground_truth_labels, model_predicted_labels)
metrics_response = (f"Precision: {precision:.4f}\n"
f"Recall: {recall:.4f}\n"
f"F1 Score: {f1:.4f}\n"
f"Accuracy: {accuracy:.4f}")
full_response = f"**Record**:\nTokens: {tokens}\nGround Truth Labels: {ground_truth_labels}\n\n**Predictions**:\n{response}\n\n**Metrics**:\n{metrics_response}"
logger.info(f"\nInput details: \n Received index from user: {input} Sending response to user: {full_response}")
else:
chat_queue.put(input)
response, predicted_labels = classification([input])
full_response = f"Input details: \n**Input Sentence:** {input}\n\n**Predictions**:\n{response}\n\n"
logger.info(full_response)
chat_queue.get()
return full_response
except Exception as e:
ERROR_COUNT.inc()
logger.error(f"Error in chat processing: {e}", exc_info=True)
return f"An error occurred. Please try again. Error: {e}"
# Function to simulate stress test
def stress_test(num_requests, message, delay):
def send_chat_message():
try:
response = requests.post("http://127.0.0.1:7860/api/predict/", json={
"data": [message],
"fn_index": 0 # This might need to be updated based on your Gradio app's function index
})
logger.debug(f"Request payload: {message}",exc_info=True)
logger.debug(f"Response: {response.json()}",exc_info=True)
except Exception as e:
logger.debug(f"Error during stress test request: {e}", exc_info=True)
threads = []
for _ in range(num_requests):
t = threading.Thread(target=send_chat_message)
t.start()
threads.append(t)
time.sleep(delay) # Delay between requests
for t in threads:
t.join()
# --- Gradio Interface with Background Image and Three Windows ---
with gr.Blocks(title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
with gr.Tab("Sentence input"):
gr.Markdown("## Chat with the Bot")
index_input = gr.Textbox(label="Enter A sentence:", lines=1)
output = gr.Markdown(label="Response")
chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output)
with gr.Tab("Dataset and Index Input"):
gr.Markdown("## Chat with the Bot")
interface = gr.Interface(fn = chat_function,
inputs=[gr.Textbox(label="Enter dataset index:", lines=1), gr.UploadButton(label ="Upload Dataset", file_types=[".csv", ".tsv"])],
outputs = gr.Markdown(label="Response"))
with gr.Tab("Model Parameters"):
model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
with gr.Tab("Performance Metrics"):
request_count_display = gr.Number(label="Request Count", value=0)
avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0)
with gr.Tab("Infrastructure"):
cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0)
mem_usage_display = gr.Number(label="Memory Usage (%)", value=0)
with gr.Tab("Logs"):
logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility
with gr.Tab("Stress Testing"):
num_requests_input = gr.Number(label="Number of Requests", value=10)
index_input_stress = gr.Textbox(label="Dataset Index", value="2")
delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
stress_test_button = gr.Button("Start Stress Test")
stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
def run_stress_test(num_requests, index, delay):
stress_test_status.value = "Stress test started..."
try:
stress_test(num_requests, index, delay)
stress_test_status.value = "Stress test completed."
except Exception as e:
stress_test_status.value = f"Stress test failed: {e}"
stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status)
img = gr.Image(
"stag.jpeg", label="Image"
)
# --- Update Functions ---
def update_metrics(request_count_display, avg_latency_display):
while True:
request_count = REQUEST_COUNT._value.get()
latency_samples = REQUEST_LATENCY.collect()[0].samples
avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero
request_count_display.value = request_count
avg_latency_display.value = round(avg_latency, 2)
time.sleep(5) # Update every 5 seconds
def update_usage(cpu_usage_display, mem_usage_display):
while True:
cpu_usage_display.value = psutil.cpu_percent()
mem_usage_display.value = psutil.virtual_memory().percent
CPU_USAGE.set(psutil.cpu_percent())
MEM_USAGE.set(psutil.virtual_memory().percent)
time.sleep(5)
def update_logs(logs_display):
while True:
info_log_vector = []
logs = []
while not logs_queue.empty():
logs.append(logs_queue.get())
logs_display.value = "\n".join(logs[-10:])
time.sleep(1) # Update every 1 second
def display_model_params(model_params_display):
while True:
model_params = ner_pipeline.model.config.to_dict()
model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items())
model_params_display.value = model_params_str
time.sleep(10) # Update every 10 seconds
def update_queue_length():
while True:
QUEUE_LENGTH.set(chat_queue.qsize())
time.sleep(1) # Update every second
# --- Start Threads ---
threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
threading.Thread(target=update_logs, args=(logs_display), daemon=True).start()
threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
threading.Thread(target=update_queue_length, daemon=True).start()
# Launch the app
demo.launch(share=True)