import streamlit as st import numpy as np import torch import torch.nn as nn from PIL import Image from model import model, image_transforms_gs, image_transforms_rgb st.title("UNET Image Colorizer") upload_file = st.file_uploader("Upload Image") if upload_file: image = upload_file image_gs = upload_file image = Image.open(image) if len(np.array(image).shape) < 3: image = image_transforms_gs(image) else: image = image_transforms_rgb(image) image_color = model(image.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy() col1, col2 = st.columns(2) col1.image(image_gs) col2.image(image_color, clamp=True, channels='RGB')