import os import requests import logging import gradio as gr from dotenv import load_dotenv from pydub import AudioSegment from io import BytesIO import time import sqlite3 import re # Configure logging logging.basicConfig(level=logging.DEBUG) # Load environment variables load_dotenv() # Configure Hugging Face API URL and headers for Meta-Llama-3-70B-Instruct api_url = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct" huggingface_api_key = os.getenv("HF_API_TOKEN") headers = {"Authorization": f"Bearer {huggingface_api_key}"} # Function to query the Hugging Face model def query_huggingface(payload): logging.debug(f"Querying model with payload: {payload}") response = requests.post(api_url, headers=headers, json=payload) logging.debug(f"Received response: {response.status_code} {response.text}") return response.json() # Function to query the Whisper model for audio transcription def query_whisper(audio_path): API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2" headers = {"Authorization": f"Bearer {huggingface_api_key}"} MAX_RETRIES = 5 RETRY_DELAY = 1 # seconds for attempt in range(MAX_RETRIES): try: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file does not exist: {audio_path}") with open(audio_path, "rb") as f: data = f.read() response = requests.post(API_URL_WHISPER, headers=headers, data=data) response.raise_for_status() return response.json() except Exception as e: if attempt < MAX_RETRIES - 1: time.sleep(RETRY_DELAY) else: return {"error": str(e)} # Function to generate speech from text using Nithu TTS def generate_speech_nithu(answer): API_URL_TTS_Nithu = "https://api-inference.huggingface.co/models/Nithu/text-to-speech" headers = {"Authorization": f"Bearer {huggingface_api_key}"} payload = {"inputs": answer} MAX_RETRIES = 5 RETRY_DELAY = 1 # seconds for attempt in range(MAX_RETRIES): try: response = requests.post(API_URL_TTS_Nithu, headers=headers, json=payload) response.raise_for_status() audio_segment = AudioSegment.from_file(BytesIO(response.content), format="flac") audio_file_path = "/tmp/answer_nithu.wav" audio_segment.export(audio_file_path, format="wav") return audio_file_path except Exception as e: if attempt < MAX_RETRIES - 1: time.sleep(RETRY_DELAY) else: return {"error": str(e)} # Function to generate speech from text using Ryan TTS def generate_speech_ryan(answer): API_URL_TTS_Ryan = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_fastspeech2" headers = {"Authorization": f"Bearer {huggingface_api_key}"} payload = {"inputs": answer} MAX_RETRIES = 5 RETRY_DELAY = 1 # seconds for attempt in range(MAX_RETRIES): try: response = requests.post(API_URL_TTS_Ryan, headers=headers, json=payload) response.raise_for_status() response_json = response.json() audio = response_json.get("audio", None) sampling_rate = response_json.get("sampling_rate", None) if audio and sampling_rate: audio_segment = AudioSegment.from_file(BytesIO(audio), format="wav") audio_file_path = "/tmp/answer_ryan.wav" audio_segment.export(audio_file_path, format="wav") return audio_file_path else: raise ValueError("Invalid response format from Ryan TTS API") except Exception as e: if attempt < MAX_RETRIES - 1: time.sleep(RETRY_DELAY) else: return {"error": str(e)} # Function to fetch patient data from both databases def fetch_patient_data(cataract_db_path, glaucoma_db_path): patient_data = {} # Fetch data from cataract_results table try: conn = sqlite3.connect(cataract_db_path) cursor = conn.cursor() cursor.execute("SELECT * FROM cataract_results") cataract_data = cursor.fetchall() conn.close() patient_data['cataract_results'] = cataract_data except Exception as e: patient_data['cataract_results'] = f"Error fetching cataract results: {str(e)}" # Fetch data from results table (glaucoma) try: conn = sqlite3.connect(glaucoma_db_path) cursor = conn.cursor() cursor.execute("SELECT * FROM results") glaucoma_data = cursor.fetchall() conn.close() patient_data['results'] = glaucoma_data except Exception as e: patient_data['results'] = f"Error fetching glaucoma results: {str(e)}" return patient_data # Function to transform fetched data into a readable format def transform_patient_data(patient_data): readable_data = "Readable Patient Data:\n\n" if 'cataract_results' in patient_data: if isinstance(patient_data['cataract_results'], str): readable_data += patient_data['cataract_results'] + "\n" else: readable_data += "Cataract Results:\n" for row in patient_data['cataract_results']: if len(row) >= 6: readable_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n" else: readable_data += "Error: Incomplete data row in cataract results\n" readable_data += "\n" if 'results' in patient_data: if isinstance(patient_data['results'], str): readable_data += patient_data['results'] + "\n" else: readable_data += "Glaucoma Results:\n" for row in patient_data['results']: if len(row) >= 7: readable_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n" else: readable_data += "Error: Incomplete data row in glaucoma results\n" readable_data += "\n" return readable_data # Paths to your databases cataract_db_path = 'cataract_results.db' glaucoma_db_path = 'glaucoma_results.db' # Fetch and transform patient data patient_data = fetch_patient_data(cataract_db_path, glaucoma_db_path) readable_patient_data = transform_patient_data(patient_data) # Function to extract details from the input prompt def extract_details_from_prompt(prompt): pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE) matches = pattern.findall(prompt) return [(match[0].capitalize(), int(match[1])) for match in matches] # Function to fetch specific patient data based on the condition and ID def get_specific_patient_data(patient_data, condition, patient_id): specific_data = "" if condition == "Cataract": specific_data = "Cataract Results:\n" for row in patient_data.get('cataract_results', []): if isinstance(row, tuple) and row[0] == patient_id: specific_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n" break elif condition == "Glaucoma": specific_data = "Glaucoma Results:\n" for row in patient_data.get('results', []): if isinstance(row, tuple) and row[0] == patient_id: specific_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n" break return specific_data # Function to aggregate patient history for all mentioned IDs in the question def get_aggregated_patient_history(patient_data, details): history = "" for condition, patient_id in details: history += get_specific_patient_data(patient_data, condition, patient_id) + "\n" return history.strip() # Toggle visibility of input elements based on input type def toggle_visibility(input_type): if input_type == "Voice": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) def cleanup_response(response): # Extract only the part after "Answer:" and remove any trailing spaces answer_start = response.find("Answer:") if answer_start != -1: response = response[answer_start + len("Answer:"):].strip() return response # Gradio interface for the chatbot def chatbot(audio, input_type, text): if input_type == "Voice": transcription = query_whisper(audio.name) if "error" in transcription: return "Error transcribing audio: " + transcription["error"], None query = transcription['text'] else: query = text details = extract_details_from_prompt(query) patient_history = get_aggregated_patient_history(patient_data, details) payload = { "inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}" } logging.debug(f"Raw input to the LLM: {payload['inputs']}") response = query_huggingface(payload) if isinstance(response, list): raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.") else: raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.") logging.debug(f"Raw output from the LLM: {raw_response}") clean_response = cleanup_response(raw_response) return clean_response, None # Gradio interface for generating voice response def generate_voice_response(tts_model, text_response): if tts_model == "Nithu (Custom)": audio_file_path = generate_speech_nithu(text_response) return audio_file_path, None elif tts_model == "Ryan (ESPnet)": audio_file_path = generate_speech_ryan(text_response) return audio_file_path, None else: return None, None # Function to update patient history in the interface def update_patient_history(): return readable_patient_data