abhirajeshbhai's picture
loaded new colorizer weights
03856d4
raw
history blame contribute delete
674 Bytes
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from model import model, image_transforms_gs, image_transforms_rgb
st.title("UNET Image Colorizer")
upload_file = st.file_uploader("Upload Image")
if upload_file:
image = upload_file
image_gs = upload_file
image = Image.open(image)
if len(np.array(image).shape) < 3:
image = image_transforms_gs(image)
else:
image = image_transforms_rgb(image)
image_color = model(image.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy()
col1, col2 = st.columns(2)
col1.image(image_gs)
col2.image(image_color, clamp=True, channels='RGB')