import logging import gradio as gr from queue import Queue import time from prometheus_client import start_http_server, Counter, Histogram, Gauge import threading import psutil import random from transformers import pipeline from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score import requests from datasets import load_dataset import os from logging import FileHandler from typing import Iterable # Ensure the log files exist log_file_path = 'chat_log.log' debug_log_file_path = 'debug.log' if not os.path.exists(log_file_path): with open(log_file_path, 'w') as f: f.write(" ") if not os.path.exists(debug_log_file_path): with open(debug_log_file_path, 'w') as f: f.write(" ") # Create logger instance logger = logging.getLogger() logger.setLevel(logging.DEBUG) # Set logger level to the lowest level needed #Create formatter formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S') # Create handlers info_handler = FileHandler( filename=log_file_path, mode='w+') info_handler.setLevel(logging.INFO) info_handler.setFormatter(formatter) debug_handler = FileHandler(filename=debug_log_file_path, mode='w+') debug_handler.setLevel(logging.DEBUG) debug_handler.setFormatter(formatter) # Function to capture logs for Gradio display class GradioHandler(logging.Handler): def __init__(self, logs_queue): super().__init__() self.logs_queue = logs_queue def emit(self, record): log_entry = self.format(record) self.logs_queue.put(log_entry) # Create a logs queue logs_queue = Queue() # Create and configure Gradio handler gradio_handler = GradioHandler(logs_queue) gradio_handler.setLevel(logging.INFO) gradio_handler.setFormatter(formatter) # Add handlers to the logger logger.addHandler(info_handler) logger.addHandler(debug_handler) logger.addHandler(gradio_handler) # Load the model try: ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner") logger.debug("NER pipeline loaded.") except Exception as e: logger.debug(f"Error loading NER pipeline: {e}") # Load the dataset try: dataset = load_dataset("surrey-nlp/PLOD-filtered") logger.debug("Dataset loaded.") except Exception as e: logger.debug(f"Error loading dataset: {e}") # --- Prometheus Metrics Setup --- try: REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests') REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds') ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors') RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes') CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent') MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent') QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue') logger.debug("Prometheus metrics setup complete.") except Exception as e: logger.debug(f"Error setting up Prometheus metrics: {e}") # --- Queue and Metrics --- chat_queue = Queue() # Define chat_queue globally label_mapping = { 0: 'B-O', 1: 'B-AC', 3: 'B-LF', 4: 'I-LF' } def classification(message): # Predict using the model ner_results = ner_pipeline(" ".join(message)) detailed_response = [] model_predicted_labels = [] for result in ner_results: token = result['word'] score = result['score'] entity = result['entity'] label_id = int(entity.split('_')[-1]) # Extract numeric label from entity model_predicted_labels.append(label_mapping[label_id]) detailed_response.append(f"Token: {token}, Entity: {label_mapping[label_id]}, Score: {score:.4f}") response = "\n".join(detailed_response) response_size = len(response.encode('utf-8')) RESPONSE_SIZE.observe(response_size) time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time return response, model_predicted_labels # --- Chat Function with Monitoring --- def chat_function(input, datasets): logger.debug("Starting chat_function") with REQUEST_LATENCY.time(): REQUEST_COUNT.inc() try: if input.isnumeric(): chat_queue.put(input) # Get the example from the dataset if datasets: example = datasets[int(input)] else: example = dataset['train'][int(input)] tokens = example['tokens'] ground_truth_labels = [label_mapping[label] for label in example['ner_tags']] # Call the classification function response, model_predicted_labels = classification(tokens) # Ensure the model and ground truth labels are the same length for comparison model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)] precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) accuracy = accuracy_score(ground_truth_labels, model_predicted_labels) metrics_response = (f"Precision: {precision:.4f}\n" f"Recall: {recall:.4f}\n" f"F1 Score: {f1:.4f}\n" f"Accuracy: {accuracy:.4f}") full_response = f"**Record**:\nTokens: {tokens}\nGround Truth Labels: {ground_truth_labels}\n\n**Predictions**:\n{response}\n\n**Metrics**:\n{metrics_response}" logger.info(f"\nInput details: \n Received index from user: {input} Sending response to user: {full_response}") else: chat_queue.put(input) response, predicted_labels = classification([input]) full_response = f"Input details: \n**Input Sentence:** {input}\n\n**Predictions**:\n{response}\n\n" logger.info(full_response) chat_queue.get() return full_response except Exception as e: ERROR_COUNT.inc() logger.error(f"Error in chat processing: {e}", exc_info=True) return f"An error occurred. Please try again. Error: {e}" # Function to simulate stress test def stress_test(num_requests, message, delay): def send_chat_message(): try: response = requests.post("http://127.0.0.1:7860/api/predict/", json={ "data": [message], "fn_index": 0 # This might need to be updated based on your Gradio app's function index }) logger.debug(f"Request payload: {message}",exc_info=True) logger.debug(f"Response: {response.json()}",exc_info=True) except Exception as e: logger.debug(f"Error during stress test request: {e}", exc_info=True) threads = [] for _ in range(num_requests): t = threading.Thread(target=send_chat_message) t.start() threads.append(t) time.sleep(delay) # Delay between requests for t in threads: t.join() # --- Gradio Interface with Background Image and Three Windows --- with gr.Blocks(title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image with gr.Tab("Sentence input"): gr.Markdown("## Chat with the Bot") index_input = gr.Textbox(label="Enter A sentence:", lines=1) output = gr.Markdown(label="Response") chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output) with gr.Tab("Dataset and Index Input"): gr.Markdown("## Chat with the Bot") interface = gr.Interface(fn = chat_function, inputs=[gr.Textbox(label="Enter dataset index:", lines=1), gr.UploadButton(label ="Upload Dataset", file_types=[".csv", ".tsv"])], outputs = gr.Markdown(label="Response")) with gr.Tab("Model Parameters"): model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters with gr.Tab("Performance Metrics"): request_count_display = gr.Number(label="Request Count", value=0) avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0) with gr.Tab("Infrastructure"): cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0) mem_usage_display = gr.Number(label="Memory Usage (%)", value=0) with gr.Tab("Logs"): logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility with gr.Tab("Stress Testing"): num_requests_input = gr.Number(label="Number of Requests", value=10) index_input_stress = gr.Textbox(label="Dataset Index", value="2") delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1) stress_test_button = gr.Button("Start Stress Test") stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False) def run_stress_test(num_requests, index, delay): stress_test_status.value = "Stress test started..." try: stress_test(num_requests, index, delay) stress_test_status.value = "Stress test completed." except Exception as e: stress_test_status.value = f"Stress test failed: {e}" stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status) img = gr.Image( "stag.jpeg", label="Image" ) # --- Update Functions --- def update_metrics(request_count_display, avg_latency_display): while True: request_count = REQUEST_COUNT._value.get() latency_samples = REQUEST_LATENCY.collect()[0].samples avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero request_count_display.value = request_count avg_latency_display.value = round(avg_latency, 2) time.sleep(5) # Update every 5 seconds def update_usage(cpu_usage_display, mem_usage_display): while True: cpu_usage_display.value = psutil.cpu_percent() mem_usage_display.value = psutil.virtual_memory().percent CPU_USAGE.set(psutil.cpu_percent()) MEM_USAGE.set(psutil.virtual_memory().percent) time.sleep(5) def update_logs(logs_display): while True: info_log_vector = [] logs = [] while not logs_queue.empty(): logs.append(logs_queue.get()) logs_display.value = "\n".join(logs[-10:]) time.sleep(1) # Update every 1 second def display_model_params(model_params_display): while True: model_params = ner_pipeline.model.config.to_dict() model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items()) model_params_display.value = model_params_str time.sleep(10) # Update every 10 seconds def update_queue_length(): while True: QUEUE_LENGTH.set(chat_queue.qsize()) time.sleep(1) # Update every second # --- Start Threads --- threading.Thread(target=start_http_server, args=(8000,), daemon=True).start() threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start() threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start() threading.Thread(target=update_logs, args=(logs_display), daemon=True).start() threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start() threading.Thread(target=update_queue_length, daemon=True).start() # Launch the app demo.launch(share=True)