azhongai666666's picture
Update app.py
ce4e7b6 verified
raw
history blame
1.76 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.init import _calculate_fan_in_and_fan_out
from timm.models.layers import to_2tuple, trunc_normal_
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from model import dehazeformer_t
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
t_model_load = dehazeformer_t().to(device)
t_model_load
best_model_weights = torch.load('best_t_model_weights.pth')
t_model_load.load_state_dict(best_model_weights)
def pred_one_image(inp):
one_image = np.array(inp.resize((256, 256)).convert("RGB"))/255
# convert to other format HWC -> CHW
one_image = np.moveaxis(one_image, -1, 0)
# mask = np.expand_dims(mask, 0)
one_image = torch.tensor(one_image).float()
one_image = one_image.unsqueeze(0)
one_image = one_image.to(device)
with torch.no_grad():
t_model_load.eval()
output = t_model_load(one_image)
print(output.shape)
output = output[0].cpu().permute((1, 2, 0))
plt.figure(figsize=(10, 10))
plt.imshow(output.numpy()) # convert CHW -> HWC
plt.axis("off")
# 保存图像,可以指定文件名和格式,例如 'image.png'
plt.savefig('image.png', format='png', dpi=300) # dpi是图像的分辨率
out_img = Image.open('image.png')
return out_img
demo = gr.Interface(fn=pred_one_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
examples=[image_path],
)
demo.launch(debug=True)
# demo.launch()