Spaces:
Sleeping
Sleeping
asadAbdullah
commited on
Commit
•
6feb2e4
1
Parent(s):
7e5c1c8
Update app.py
Browse files
app.py
CHANGED
@@ -2,24 +2,26 @@
|
|
2 |
import os
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
|
|
5 |
from transformers import pipeline
|
6 |
from sentence_transformers import SentenceTransformer, util
|
7 |
import requests
|
8 |
import json
|
9 |
|
10 |
-
# Configure Hugging Face API token securely
|
11 |
api_key = os.getenv("HF_API_KEY")
|
12 |
|
13 |
-
# Load the CSV dataset
|
14 |
try:
|
15 |
data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
|
16 |
except FileNotFoundError:
|
17 |
st.error("Dataset file not found. Please upload it to this directory.")
|
18 |
|
19 |
-
#
|
20 |
-
|
|
|
21 |
|
22 |
-
#
|
23 |
if 'combined_description' not in data.columns:
|
24 |
data['combined_description'] = (
|
25 |
data['Symptoms'].fillna('') + " " +
|
@@ -31,151 +33,42 @@ if 'combined_description' not in data.columns:
|
|
31 |
data['Emergency Treatment'].fillna('')
|
32 |
)
|
33 |
|
34 |
-
# Define weights for each column based on importance
|
35 |
-
column_weights = {
|
36 |
-
'Symptoms': 0.4,
|
37 |
-
'Severity Level': 0.2,
|
38 |
-
'Risk Assessment': 0.1,
|
39 |
-
'Treatment Options': 0.15,
|
40 |
-
'Suggested Medical Tests': 0.05,
|
41 |
-
'Minimum Values for Medical Tests': 0.05,
|
42 |
-
'Emergency Treatment': 0.05
|
43 |
-
}
|
44 |
-
|
45 |
-
# Precompute embeddings for each weighted column
|
46 |
-
for col in column_weights.keys():
|
47 |
-
if f"{col}_embeddings" not in data.columns:
|
48 |
-
data[f"{col}_embeddings"] = data[col].fillna("").apply(lambda x: retriever_model.encode(x).tolist())
|
49 |
-
|
50 |
-
# Function to retrieve relevant information with weighted scoring
|
51 |
-
def get_weighted_relevant_info(query, top_k=3):
|
52 |
-
query_embedding = retriever_model.encode(query)
|
53 |
-
weighted_similarities = []
|
54 |
-
for idx, row in data.iterrows():
|
55 |
-
weighted_score = 0
|
56 |
-
for col, weight in column_weights.items():
|
57 |
-
if row[f"{col}_embeddings"]:
|
58 |
-
col_similarity = util.cos_sim(query_embedding, row[f"{col}_embeddings"])[0][0].item()
|
59 |
-
weighted_score += col_similarity * weight
|
60 |
-
weighted_similarities.append(weighted_score)
|
61 |
-
|
62 |
-
top_indices = sorted(range(len(weighted_similarities)), key=lambda i: weighted_similarities[i], reverse=True)[:top_k]
|
63 |
-
return data.iloc[top_indices]
|
64 |
-
|
65 |
-
# Generate embeddings for the combined description if not already done
|
66 |
-
if 'embeddings' not in data.columns:
|
67 |
-
data['embeddings'] = data['combined_description'].apply(lambda x: retriever_model.encode(x).tolist() if x else [])
|
68 |
-
|
69 |
-
# Function to retrieve relevant information based on user query (non-weighted)
|
70 |
-
def get_relevant_info(query, top_k=3):
|
71 |
-
query_embedding = retriever_model.encode(query)
|
72 |
-
similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
|
73 |
-
top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
|
74 |
-
return data.iloc[top_indices]
|
75 |
-
|
76 |
-
# Enhanced response generation function with debugging
|
77 |
-
# Import required libraries
|
78 |
-
import os
|
79 |
-
import pandas as pd
|
80 |
-
import streamlit as st
|
81 |
-
from transformers import pipeline
|
82 |
-
from sentence_transformers import SentenceTransformer, util
|
83 |
-
import requests
|
84 |
-
import json
|
85 |
-
|
86 |
-
# Configure Hugging Face API token securely (ensure it's set in environment variables)
|
87 |
-
api_key = os.getenv("HF_API_KEY")
|
88 |
-
|
89 |
-
# Load the CSV dataset (place the CSV in the same directory as app.py in Hugging Face Spaces)
|
90 |
-
try:
|
91 |
-
data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
|
92 |
-
except FileNotFoundError:
|
93 |
-
st.error("Dataset file not found. Please upload it to this directory.")
|
94 |
-
|
95 |
# Initialize Sentence Transformer model for RAG-based retrieval
|
96 |
retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
97 |
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
data['Minimum Values for Medical Tests'].fillna('') + " " +
|
107 |
-
data['Emergency Treatment'].fillna('')
|
108 |
-
)
|
109 |
-
|
110 |
-
# Define weights for each column based on importance
|
111 |
-
column_weights = {
|
112 |
-
'Symptoms': 0.4,
|
113 |
-
'Severity Level': 0.2,
|
114 |
-
'Risk Assessment': 0.1,
|
115 |
-
'Treatment Options': 0.15,
|
116 |
-
'Suggested Medical Tests': 0.05,
|
117 |
-
'Minimum Values for Medical Tests': 0.05,
|
118 |
-
'Emergency Treatment': 0.05
|
119 |
-
}
|
120 |
-
|
121 |
-
# Precompute embeddings for each weighted column
|
122 |
-
for col in column_weights.keys():
|
123 |
-
if f"{col}_embeddings" not in data.columns:
|
124 |
-
data[f"{col}_embeddings"] = data[col].fillna("").apply(lambda x: retriever_model.encode(x).tolist())
|
125 |
-
|
126 |
-
# Function to retrieve relevant information with weighted scoring
|
127 |
-
def get_weighted_relevant_info(query, top_k=3):
|
128 |
-
query_embedding = retriever_model.encode(query)
|
129 |
-
weighted_similarities = []
|
130 |
-
for idx, row in data.iterrows():
|
131 |
-
weighted_score = 0
|
132 |
-
for col, weight in column_weights.items():
|
133 |
-
if row[f"{col}_embeddings"]:
|
134 |
-
col_similarity = util.cos_sim(query_embedding, row[f"{col}_embeddings"])[0][0].item()
|
135 |
-
weighted_score += col_similarity * weight
|
136 |
-
weighted_similarities.append(weighted_score)
|
137 |
-
|
138 |
-
top_indices = sorted(range(len(weighted_similarities)), key=lambda i: weighted_similarities[i], reverse=True)[:top_k]
|
139 |
-
return data.iloc[top_indices]
|
140 |
|
141 |
-
# Generate embeddings for the combined description
|
142 |
if 'embeddings' not in data.columns:
|
143 |
-
data['embeddings'] = data['combined_description'].apply(
|
144 |
|
145 |
-
# Function to retrieve relevant information based on user query
|
146 |
def get_relevant_info(query, top_k=3):
|
147 |
query_embedding = retriever_model.encode(query)
|
148 |
similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
|
149 |
top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
|
150 |
return data.iloc[top_indices]
|
151 |
|
152 |
-
#
|
153 |
def generate_response(input_text, relevant_info):
|
|
|
154 |
context = "\n".join(relevant_info['combined_description'].tolist())
|
155 |
input_with_context = f"Context: {context}\n\nUser Query: {input_text}"
|
156 |
-
|
157 |
-
api_url = "https://api-inference.huggingface.co/models/m42-health/Llama3-Med42-8B"
|
158 |
-
headers = {"Authorization": f"Bearer {api_key}"}
|
159 |
-
payload = {"inputs": input_with_context}
|
160 |
-
|
161 |
-
try:
|
162 |
-
response = requests.post(api_url, headers=headers, json=payload)
|
163 |
-
st.write("API Raw Response:", response.text) # Display raw response for debugging
|
164 |
-
|
165 |
-
# Check response status
|
166 |
-
if response.status_code != 200:
|
167 |
-
return f"Error: API responded with status code {response.status_code}. Full response: {response.json()}"
|
168 |
-
|
169 |
-
# Parse and validate response
|
170 |
-
response_data = response.json()
|
171 |
-
if isinstance(response_data, list) and "generated_text" in response_data[0]:
|
172 |
-
return response_data[0]["generated_text"]
|
173 |
-
else:
|
174 |
-
return f"Unexpected response format from API. Full response: {response_data}"
|
175 |
-
except Exception as e:
|
176 |
-
return f"Error during API request: {e}"
|
177 |
-
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
# Streamlit UI for the Chatbot
|
181 |
def main():
|
@@ -190,22 +83,24 @@ def main():
|
|
190 |
|
191 |
# Process the query if provided
|
192 |
if user_query:
|
193 |
-
st.write("###
|
194 |
|
195 |
-
# Retrieve relevant information from
|
196 |
-
relevant_info =
|
|
|
197 |
for i, row in relevant_info.iterrows():
|
198 |
-
st.write(f"- {row['combined_description']}")
|
199 |
|
200 |
-
# Generate a response from
|
201 |
response = generate_response(user_query, relevant_info)
|
202 |
st.write("#### Model's Response:")
|
203 |
st.write(response)
|
204 |
|
205 |
# Process the uploaded file (if any)
|
206 |
if uploaded_file:
|
|
|
207 |
st.write("### Uploaded Report Analysis:")
|
208 |
-
report_text = "Extracted report content here" # Placeholder for file processing
|
209 |
st.write(report_text)
|
210 |
|
211 |
if __name__ == "__main__":
|
|
|
2 |
import os
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
5 |
+
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
6 |
from transformers import pipeline
|
7 |
from sentence_transformers import SentenceTransformer, util
|
8 |
import requests
|
9 |
import json
|
10 |
|
11 |
+
# Configure Hugging Face API token securely
|
12 |
api_key = os.getenv("HF_API_KEY")
|
13 |
|
14 |
+
# Load the CSV dataset
|
15 |
try:
|
16 |
data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
|
17 |
except FileNotFoundError:
|
18 |
st.error("Dataset file not found. Please upload it to this directory.")
|
19 |
|
20 |
+
# Load DistilBERT Tokenizer and Model
|
21 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
22 |
+
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
|
23 |
|
24 |
+
# Preprocessing the dataset (if needed)
|
25 |
if 'combined_description' not in data.columns:
|
26 |
data['combined_description'] = (
|
27 |
data['Symptoms'].fillna('') + " " +
|
|
|
33 |
data['Emergency Treatment'].fillna('')
|
34 |
)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# Initialize Sentence Transformer model for RAG-based retrieval
|
37 |
retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
38 |
|
39 |
+
# Define a function to get embeddings using DistilBERT
|
40 |
+
def generate_embedding(description):
|
41 |
+
if description:
|
42 |
+
inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
43 |
+
outputs = model(**inputs)
|
44 |
+
return outputs.logits.detach().numpy().flatten()
|
45 |
+
else:
|
46 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# Generate embeddings for the combined description
|
49 |
if 'embeddings' not in data.columns:
|
50 |
+
data['embeddings'] = data['combined_description'].apply(generate_embedding)
|
51 |
|
52 |
+
# Function to retrieve relevant information based on user query
|
53 |
def get_relevant_info(query, top_k=3):
|
54 |
query_embedding = retriever_model.encode(query)
|
55 |
similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
|
56 |
top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
|
57 |
return data.iloc[top_indices]
|
58 |
|
59 |
+
# Function to generate response using DistilBERT (integrating with the model)
|
60 |
def generate_response(input_text, relevant_info):
|
61 |
+
# Concatenate the relevant information as context for the model
|
62 |
context = "\n".join(relevant_info['combined_description'].tolist())
|
63 |
input_with_context = f"Context: {context}\n\nUser Query: {input_text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# Simple logic for generating a response using DistilBERT-based model
|
66 |
+
inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
67 |
+
outputs = model(**inputs)
|
68 |
+
logits = outputs.logits.detach().numpy().flatten()
|
69 |
+
response = tokenizer.decode(logits.argmax(), skip_special_tokens=True)
|
70 |
+
|
71 |
+
return response
|
72 |
|
73 |
# Streamlit UI for the Chatbot
|
74 |
def main():
|
|
|
83 |
|
84 |
# Process the query if provided
|
85 |
if user_query:
|
86 |
+
st.write("### Query Response:")
|
87 |
|
88 |
+
# Retrieve relevant information from dataset
|
89 |
+
relevant_info = get_relevant_info(user_query)
|
90 |
+
st.write("#### Relevant Medical Information:")
|
91 |
for i, row in relevant_info.iterrows():
|
92 |
+
st.write(f"- {row['combined_description']}") # Adjust to show meaningful info
|
93 |
|
94 |
+
# Generate a response from DistilBERT model
|
95 |
response = generate_response(user_query, relevant_info)
|
96 |
st.write("#### Model's Response:")
|
97 |
st.write(response)
|
98 |
|
99 |
# Process the uploaded file (if any)
|
100 |
if uploaded_file:
|
101 |
+
# Display analysis of the uploaded report file (process based on file type)
|
102 |
st.write("### Uploaded Report Analysis:")
|
103 |
+
report_text = "Extracted report content here" # Placeholder for file processing logic
|
104 |
st.write(report_text)
|
105 |
|
106 |
if __name__ == "__main__":
|