asadAbdullah commited on
Commit
6feb2e4
1 Parent(s): 7e5c1c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -140
app.py CHANGED
@@ -2,24 +2,26 @@
2
  import os
3
  import pandas as pd
4
  import streamlit as st
 
5
  from transformers import pipeline
6
  from sentence_transformers import SentenceTransformer, util
7
  import requests
8
  import json
9
 
10
- # Configure Hugging Face API token securely (ensure it's set in environment variables)
11
  api_key = os.getenv("HF_API_KEY")
12
 
13
- # Load the CSV dataset (place the CSV in the same directory as app.py in Hugging Face Spaces)
14
  try:
15
  data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
16
  except FileNotFoundError:
17
  st.error("Dataset file not found. Please upload it to this directory.")
18
 
19
- # Initialize Sentence Transformer model for RAG-based retrieval
20
- retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
21
 
22
- # Preprocess the dataset by creating a combined description column
23
  if 'combined_description' not in data.columns:
24
  data['combined_description'] = (
25
  data['Symptoms'].fillna('') + " " +
@@ -31,151 +33,42 @@ if 'combined_description' not in data.columns:
31
  data['Emergency Treatment'].fillna('')
32
  )
33
 
34
- # Define weights for each column based on importance
35
- column_weights = {
36
- 'Symptoms': 0.4,
37
- 'Severity Level': 0.2,
38
- 'Risk Assessment': 0.1,
39
- 'Treatment Options': 0.15,
40
- 'Suggested Medical Tests': 0.05,
41
- 'Minimum Values for Medical Tests': 0.05,
42
- 'Emergency Treatment': 0.05
43
- }
44
-
45
- # Precompute embeddings for each weighted column
46
- for col in column_weights.keys():
47
- if f"{col}_embeddings" not in data.columns:
48
- data[f"{col}_embeddings"] = data[col].fillna("").apply(lambda x: retriever_model.encode(x).tolist())
49
-
50
- # Function to retrieve relevant information with weighted scoring
51
- def get_weighted_relevant_info(query, top_k=3):
52
- query_embedding = retriever_model.encode(query)
53
- weighted_similarities = []
54
- for idx, row in data.iterrows():
55
- weighted_score = 0
56
- for col, weight in column_weights.items():
57
- if row[f"{col}_embeddings"]:
58
- col_similarity = util.cos_sim(query_embedding, row[f"{col}_embeddings"])[0][0].item()
59
- weighted_score += col_similarity * weight
60
- weighted_similarities.append(weighted_score)
61
-
62
- top_indices = sorted(range(len(weighted_similarities)), key=lambda i: weighted_similarities[i], reverse=True)[:top_k]
63
- return data.iloc[top_indices]
64
-
65
- # Generate embeddings for the combined description if not already done
66
- if 'embeddings' not in data.columns:
67
- data['embeddings'] = data['combined_description'].apply(lambda x: retriever_model.encode(x).tolist() if x else [])
68
-
69
- # Function to retrieve relevant information based on user query (non-weighted)
70
- def get_relevant_info(query, top_k=3):
71
- query_embedding = retriever_model.encode(query)
72
- similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
73
- top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
74
- return data.iloc[top_indices]
75
-
76
- # Enhanced response generation function with debugging
77
- # Import required libraries
78
- import os
79
- import pandas as pd
80
- import streamlit as st
81
- from transformers import pipeline
82
- from sentence_transformers import SentenceTransformer, util
83
- import requests
84
- import json
85
-
86
- # Configure Hugging Face API token securely (ensure it's set in environment variables)
87
- api_key = os.getenv("HF_API_KEY")
88
-
89
- # Load the CSV dataset (place the CSV in the same directory as app.py in Hugging Face Spaces)
90
- try:
91
- data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
92
- except FileNotFoundError:
93
- st.error("Dataset file not found. Please upload it to this directory.")
94
-
95
  # Initialize Sentence Transformer model for RAG-based retrieval
96
  retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
97
 
98
- # Preprocess the dataset by creating a combined description column
99
- if 'combined_description' not in data.columns:
100
- data['combined_description'] = (
101
- data['Symptoms'].fillna('') + " " +
102
- data['Severity Level'].fillna('') + " " +
103
- data['Risk Assessment'].fillna('') + " " +
104
- data['Treatment Options'].fillna('') + " " +
105
- data['Suggested Medical Tests'].fillna('') + " " +
106
- data['Minimum Values for Medical Tests'].fillna('') + " " +
107
- data['Emergency Treatment'].fillna('')
108
- )
109
-
110
- # Define weights for each column based on importance
111
- column_weights = {
112
- 'Symptoms': 0.4,
113
- 'Severity Level': 0.2,
114
- 'Risk Assessment': 0.1,
115
- 'Treatment Options': 0.15,
116
- 'Suggested Medical Tests': 0.05,
117
- 'Minimum Values for Medical Tests': 0.05,
118
- 'Emergency Treatment': 0.05
119
- }
120
-
121
- # Precompute embeddings for each weighted column
122
- for col in column_weights.keys():
123
- if f"{col}_embeddings" not in data.columns:
124
- data[f"{col}_embeddings"] = data[col].fillna("").apply(lambda x: retriever_model.encode(x).tolist())
125
-
126
- # Function to retrieve relevant information with weighted scoring
127
- def get_weighted_relevant_info(query, top_k=3):
128
- query_embedding = retriever_model.encode(query)
129
- weighted_similarities = []
130
- for idx, row in data.iterrows():
131
- weighted_score = 0
132
- for col, weight in column_weights.items():
133
- if row[f"{col}_embeddings"]:
134
- col_similarity = util.cos_sim(query_embedding, row[f"{col}_embeddings"])[0][0].item()
135
- weighted_score += col_similarity * weight
136
- weighted_similarities.append(weighted_score)
137
-
138
- top_indices = sorted(range(len(weighted_similarities)), key=lambda i: weighted_similarities[i], reverse=True)[:top_k]
139
- return data.iloc[top_indices]
140
 
141
- # Generate embeddings for the combined description if not already done
142
  if 'embeddings' not in data.columns:
143
- data['embeddings'] = data['combined_description'].apply(lambda x: retriever_model.encode(x).tolist() if x else [])
144
 
145
- # Function to retrieve relevant information based on user query (non-weighted)
146
  def get_relevant_info(query, top_k=3):
147
  query_embedding = retriever_model.encode(query)
148
  similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
149
  top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
150
  return data.iloc[top_indices]
151
 
152
- # Enhanced response generation function with debugging
153
  def generate_response(input_text, relevant_info):
 
154
  context = "\n".join(relevant_info['combined_description'].tolist())
155
  input_with_context = f"Context: {context}\n\nUser Query: {input_text}"
156
-
157
- api_url = "https://api-inference.huggingface.co/models/m42-health/Llama3-Med42-8B"
158
- headers = {"Authorization": f"Bearer {api_key}"}
159
- payload = {"inputs": input_with_context}
160
-
161
- try:
162
- response = requests.post(api_url, headers=headers, json=payload)
163
- st.write("API Raw Response:", response.text) # Display raw response for debugging
164
-
165
- # Check response status
166
- if response.status_code != 200:
167
- return f"Error: API responded with status code {response.status_code}. Full response: {response.json()}"
168
-
169
- # Parse and validate response
170
- response_data = response.json()
171
- if isinstance(response_data, list) and "generated_text" in response_data[0]:
172
- return response_data[0]["generated_text"]
173
- else:
174
- return f"Unexpected response format from API. Full response: {response_data}"
175
- except Exception as e:
176
- return f"Error during API request: {e}"
177
-
178
 
 
 
 
 
 
 
 
179
 
180
  # Streamlit UI for the Chatbot
181
  def main():
@@ -190,22 +83,24 @@ def main():
190
 
191
  # Process the query if provided
192
  if user_query:
193
- st.write("### FAQ and Responses:")
194
 
195
- # Retrieve relevant information from the dataset
196
- relevant_info = get_weighted_relevant_info(user_query)
 
197
  for i, row in relevant_info.iterrows():
198
- st.write(f"- {row['combined_description']}")
199
 
200
- # Generate a response from the model
201
  response = generate_response(user_query, relevant_info)
202
  st.write("#### Model's Response:")
203
  st.write(response)
204
 
205
  # Process the uploaded file (if any)
206
  if uploaded_file:
 
207
  st.write("### Uploaded Report Analysis:")
208
- report_text = "Extracted report content here" # Placeholder for file processing
209
  st.write(report_text)
210
 
211
  if __name__ == "__main__":
 
2
  import os
3
  import pandas as pd
4
  import streamlit as st
5
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
6
  from transformers import pipeline
7
  from sentence_transformers import SentenceTransformer, util
8
  import requests
9
  import json
10
 
11
+ # Configure Hugging Face API token securely
12
  api_key = os.getenv("HF_API_KEY")
13
 
14
+ # Load the CSV dataset
15
  try:
16
  data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
17
  except FileNotFoundError:
18
  st.error("Dataset file not found. Please upload it to this directory.")
19
 
20
+ # Load DistilBERT Tokenizer and Model
21
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
22
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
23
 
24
+ # Preprocessing the dataset (if needed)
25
  if 'combined_description' not in data.columns:
26
  data['combined_description'] = (
27
  data['Symptoms'].fillna('') + " " +
 
33
  data['Emergency Treatment'].fillna('')
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Initialize Sentence Transformer model for RAG-based retrieval
37
  retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
38
 
39
+ # Define a function to get embeddings using DistilBERT
40
+ def generate_embedding(description):
41
+ if description:
42
+ inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512)
43
+ outputs = model(**inputs)
44
+ return outputs.logits.detach().numpy().flatten()
45
+ else:
46
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Generate embeddings for the combined description
49
  if 'embeddings' not in data.columns:
50
+ data['embeddings'] = data['combined_description'].apply(generate_embedding)
51
 
52
+ # Function to retrieve relevant information based on user query
53
  def get_relevant_info(query, top_k=3):
54
  query_embedding = retriever_model.encode(query)
55
  similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
56
  top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
57
  return data.iloc[top_indices]
58
 
59
+ # Function to generate response using DistilBERT (integrating with the model)
60
  def generate_response(input_text, relevant_info):
61
+ # Concatenate the relevant information as context for the model
62
  context = "\n".join(relevant_info['combined_description'].tolist())
63
  input_with_context = f"Context: {context}\n\nUser Query: {input_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Simple logic for generating a response using DistilBERT-based model
66
+ inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512)
67
+ outputs = model(**inputs)
68
+ logits = outputs.logits.detach().numpy().flatten()
69
+ response = tokenizer.decode(logits.argmax(), skip_special_tokens=True)
70
+
71
+ return response
72
 
73
  # Streamlit UI for the Chatbot
74
  def main():
 
83
 
84
  # Process the query if provided
85
  if user_query:
86
+ st.write("### Query Response:")
87
 
88
+ # Retrieve relevant information from dataset
89
+ relevant_info = get_relevant_info(user_query)
90
+ st.write("#### Relevant Medical Information:")
91
  for i, row in relevant_info.iterrows():
92
+ st.write(f"- {row['combined_description']}") # Adjust to show meaningful info
93
 
94
+ # Generate a response from DistilBERT model
95
  response = generate_response(user_query, relevant_info)
96
  st.write("#### Model's Response:")
97
  st.write(response)
98
 
99
  # Process the uploaded file (if any)
100
  if uploaded_file:
101
+ # Display analysis of the uploaded report file (process based on file type)
102
  st.write("### Uploaded Report Analysis:")
103
+ report_text = "Extracted report content here" # Placeholder for file processing logic
104
  st.write(report_text)
105
 
106
  if __name__ == "__main__":