epochs-demos's picture
Update pages/Chest.py
40b70f5
import streamlit as st
from PIL import Image
import torch.nn as nn
import timm
import torch
import time
import torchmetrics
from torchmetrics import F1Score,Recall,Accuracy
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.models as models
import lightning.pytorch as pl
import torchvision
from lightning.pytorch.loggers import WandbLogger
import captum
import matplotlib.pyplot as plt
import json
from transformers import pipeline, set_seed
from transformers import BioGptTokenizer, BioGptForCausalLM
text_model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
labels_path = 'labels.json'
import os
import base64
with open(labels_path) as json_data:
idx_to_labels = json.load(json_data)
class FineTuneModel(pl.LightningModule):
def __init__(self, model_name, num_classes, learning_rate, dropout_rate,beta1,beta2,eps):
super().__init__()
self.model_name = model_name
self.num_classes = num_classes
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.dropout_rate = dropout_rate
self.model = timm.create_model(self.model_name, pretrained=True,num_classes=self.num_classes)
self.loss_fn = nn.CrossEntropyLoss()
self.f1 = F1Score(task='multiclass', num_classes=self.num_classes)
self.recall = Recall(task='multiclass', num_classes=self.num_classes)
self.accuracy = Accuracy(task='multiclass', num_classes=self.num_classes)
#for param in self.model.parameters():
#param.requires_grad = True
#self.model.classifier= nn.Sequential(nn.Dropout(p=self.dropout_rate),nn.Linear(self.model.classifier.in_features, self.num_classes))
#self.model.classifier.requires_grad = True
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self.loss_fn(y_hat, y)
acc = self.accuracy(y_hat.argmax(dim=1),y)
f1 = self.f1(y_hat.argmax(dim=1),y)
recall = self.recall(y_hat.argmax(dim=1),y)
self.log('train_loss', loss,on_step=False,on_epoch=True)
self.log('train_acc', acc,on_step=False,on_epoch = True)
self.log('train_f1',f1,on_step=False,on_epoch=True)
self.log('train_recall',recall,on_step=False,on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self.loss_fn(y_hat, y)
acc = self.accuracy(y_hat.argmax(dim=1),y)
f1 = self.f1(y_hat.argmax(dim=1),y)
recall = self.recall(y_hat.argmax(dim=1),y)
self.log('val_loss', loss,on_step=False,on_epoch=True)
self.log('val_acc', acc,on_step=False,on_epoch=True)
self.log('val_f1',f1,on_step=False,on_epoch=True)
self.log('val_recall',recall,on_step=False,on_epoch=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate,betas=(self.beta1,self.beta2),eps=self.eps)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
#load model
# Get the current working directory
current_dir = os.getcwd()
# Construct the absolute path to the logo.png file
logo_path = os.path.join(current_dir, "logo.png")
with open(logo_path, "rb") as f:
image_data = f.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
# Add custom CSS for the header
header_css = """
<style>
.header {
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
font-size: 18px;
}
.header img {
margin-right: 10px;
width: 80px;
height: 50px;
}
.header p {
font-size: 14px;
}
</style>
"""
# Render the custom CSS
st.markdown(header_css, unsafe_allow_html=True)
# Render the header
header_html = f"""
<div class="header">
<img src='data:image/jpeg;base64,{image_base64}' alt="Logo"/>
<p>Disclaimer: This web app is for demonstration purposes only and not intended for commercial use. Contact: contact@1001epochs.co.uk for full solution.</p>
</div>
"""
st.markdown(header_html, unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center; '>Chest Xray Diagnosis</h1>",unsafe_allow_html=True)
# Display a file uploader widget for the user to upload an image
uploaded_file = st.file_uploader("Choose an Chest XRay Image file", type=["jpg", "jpeg", "png"])
# Load the uploaded image, or display emojis if no file was uploaded
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Diagnosis',width=224, use_column_width=True)
model = timm.create_model(model_name='efficientnet_b2', pretrained=True,num_classes=4)
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)
model_transforms = torchvision.transforms.Compose([transform])
transformed_image = model_transforms(image)
xray_model = torch.load('models/timm_xray_model.pth')
xray_model.eval()
with torch.inference_mode():
with st.progress(100):
prediction = torch.nn.functional.softmax(xray_model(transformed_image.unsqueeze(dim=0))[0], dim=0)
prediction_score, pred_label_idx = torch.topk(prediction, 1)
pred_label_idx.squeeze_()
predicted_label = idx_to_labels[str(pred_label_idx.item())]
st.write( f'Predicted Label: {predicted_label}')
if st.button('Know More'):
generator = pipeline("text-generation",model=text_model,tokenizer=tokenizer)
input_text = f"Patient has {predicted_label} and is advised to take the following medicines:"
with st.spinner('Generating Text'):
generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)
st.markdown(generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)[0]['generated_text'])
else:
st.success("Please upload an image file ⚕️")