import streamlit as st from PIL import Image import torch.nn as nn import timm import torch import time import torchmetrics from torchmetrics import F1Score,Recall,Accuracy import torch.optim.lr_scheduler as lr_scheduler import torchvision.models as models import lightning.pytorch as pl import torchvision from lightning.pytorch.loggers import WandbLogger import captum import matplotlib.pyplot as plt import json from transformers import pipeline, set_seed from transformers import BioGptTokenizer, BioGptForCausalLM text_model = BioGptForCausalLM.from_pretrained("microsoft/biogpt") tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt") labels_path = 'labels.json' import os import base64 with open(labels_path) as json_data: idx_to_labels = json.load(json_data) class FineTuneModel(pl.LightningModule): def __init__(self, model_name, num_classes, learning_rate, dropout_rate,beta1,beta2,eps): super().__init__() self.model_name = model_name self.num_classes = num_classes self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.dropout_rate = dropout_rate self.model = timm.create_model(self.model_name, pretrained=True,num_classes=self.num_classes) self.loss_fn = nn.CrossEntropyLoss() self.f1 = F1Score(task='multiclass', num_classes=self.num_classes) self.recall = Recall(task='multiclass', num_classes=self.num_classes) self.accuracy = Accuracy(task='multiclass', num_classes=self.num_classes) #for param in self.model.parameters(): #param.requires_grad = True #self.model.classifier= nn.Sequential(nn.Dropout(p=self.dropout_rate),nn.Linear(self.model.classifier.in_features, self.num_classes)) #self.model.classifier.requires_grad = True def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = self.loss_fn(y_hat, y) acc = self.accuracy(y_hat.argmax(dim=1),y) f1 = self.f1(y_hat.argmax(dim=1),y) recall = self.recall(y_hat.argmax(dim=1),y) self.log('train_loss', loss,on_step=False,on_epoch=True) self.log('train_acc', acc,on_step=False,on_epoch = True) self.log('train_f1',f1,on_step=False,on_epoch=True) self.log('train_recall',recall,on_step=False,on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = self.loss_fn(y_hat, y) acc = self.accuracy(y_hat.argmax(dim=1),y) f1 = self.f1(y_hat.argmax(dim=1),y) recall = self.recall(y_hat.argmax(dim=1),y) self.log('val_loss', loss,on_step=False,on_epoch=True) self.log('val_acc', acc,on_step=False,on_epoch=True) self.log('val_f1',f1,on_step=False,on_epoch=True) self.log('val_recall',recall,on_step=False,on_epoch=True) def configure_optimizers(self): optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate,betas=(self.beta1,self.beta2),eps=self.eps) scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) return {'optimizer': optimizer, 'lr_scheduler': scheduler} #load model # Get the current working directory current_dir = os.getcwd() # Construct the absolute path to the logo.png file logo_path = os.path.join(current_dir, "logo.png") with open(logo_path, "rb") as f: image_data = f.read() image_base64 = base64.b64encode(image_data).decode("utf-8") # Add custom CSS for the header header_css = """ """ # Render the custom CSS st.markdown(header_css, unsafe_allow_html=True) # Render the header header_html = f"""
Logo

Disclaimer: This web app is for demonstration purposes only and not intended for commercial use. Contact: contact@1001epochs.co.uk for full solution.

""" st.markdown(header_html, unsafe_allow_html=True) st.markdown("

Chest Xray Diagnosis

",unsafe_allow_html=True) # Display a file uploader widget for the user to upload an image uploaded_file = st.file_uploader("Choose an Chest XRay Image file", type=["jpg", "jpeg", "png"]) # Load the uploaded image, or display emojis if no file was uploaded if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Diagnosis',width=224, use_column_width=True) model = timm.create_model(model_name='efficientnet_b2', pretrained=True,num_classes=4) data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) transform = timm.data.create_transform(**data_cfg) model_transforms = torchvision.transforms.Compose([transform]) transformed_image = model_transforms(image) xray_model = torch.load('models/timm_xray_model.pth') xray_model.eval() with torch.inference_mode(): with st.progress(100): prediction = torch.nn.functional.softmax(xray_model(transformed_image.unsqueeze(dim=0))[0], dim=0) prediction_score, pred_label_idx = torch.topk(prediction, 1) pred_label_idx.squeeze_() predicted_label = idx_to_labels[str(pred_label_idx.item())] st.write( f'Predicted Label: {predicted_label}') if st.button('Know More'): generator = pipeline("text-generation",model=text_model,tokenizer=tokenizer) input_text = f"Patient has {predicted_label} and is advised to take the following medicines:" with st.spinner('Generating Text'): generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1) st.markdown(generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)[0]['generated_text']) else: st.success("Please upload an image file ⚕️")