import pandas as pd import tempfile from PIL import Image from pathlib import Path from shiny import App, Inputs, Outputs, Session, reactive, render, ui from shiny.types import FileInfo import json import torch import numpy as np import os from transformers import SamModel import torchvision.transforms as transforms import matplotlib.pyplot as plt image_resize_transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor() ]) app_ui = ui.page_fluid( ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False), ui.output_image("original_image"), ui.output_image("image_display") ) def server(input: Inputs, output: Outputs, session: Session): @reactive.calc def loaded_image(): file: list[FileInfo] | None = input.file2() if file is None: return None device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model2 = SamModel.from_pretrained("facebook/sam-vit-base") model2.load_state_dict(torch.load('model.pth', map_location=device)) model2.eval() model2.to(device) image = Image.open(file[0]["datapath"]).convert('RGB') transform = image_resize_transform image_tensor = transform(image).to(device) with torch.no_grad(): outputs = model2(pixel_values=image_tensor.unsqueeze(0),multimask_output=False) predicted_masks = outputs.pred_masks.squeeze(1) predicted_masks = predicted_masks[:, 0, :, :] mask_tensor = predicted_masks.cpu().detach().squeeze() mask_array = mask_tensor.numpy() mask_array = (mask_array * 255).astype(np.uint8) mask = Image.fromarray(mask_array) mask = mask.resize((1024, 1024), Image.LANCZOS) mask = mask.convert('RGBA') alpha = Image.new('L', mask.size, 128) mask.putalpha(alpha) image = Image.open(file[0]["datapath"]).convert('RGB') image = image.resize((1024, 1024), Image.LANCZOS) image = image.convert('RGBA') combined = Image.alpha_composite(image, mask) combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') image.save(original_file.name, "PNG", quality=100) mask.save(combined_file.name, "PNG", quality=100) return original_file.name, combined_file.name @render.image def original_image(): result = loaded_image() if result is None: return None img_path, _ = result return {"src": img_path, "width": "300px"} @render.image def image_display(): result = loaded_image() if result is None: return None _, img_path = result return {"src": img_path, "width": "300px"} app = App(app_ui, server)