Fake-Detect / detect.py
ubuntu
Initial Commit
3fdc2a4
raw
history blame contribute delete
No virus
2.52 kB
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from io import BytesIO
from scipy.ndimage import gaussian_filter
from model import CLIPViTL14Model
import seaborn as sns
import matplotlib.pyplot as plt
MEAN = {
"imagenet":[0.485, 0.456, 0.406],
"clip":[0.48145466, 0.4578275, 0.40821073]
}
STD = {
"imagenet":[0.229, 0.224, 0.225],
"clip":[0.26862954, 0.26130258, 0.27577711]
}
def png2jpg(img, quality):
out = BytesIO()
img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default
img = Image.open(out)
# load from memory before ByteIO closes
img = np.array(img)
out.close()
return Image.fromarray(img)
def gaussian_blur(img, sigma):
img = np.array(img)
gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)
return Image.fromarray(img)
def plot_pie_chart(false_prob, save_path):
labels = ['Real', 'Fake']
probabilities = [1-false_prob, false_prob]
colors = ['#ADD8E6', '#FFC0CB'] # 浅蓝色和浅红色
explode = (0.1, 0) # 设置偏移量
plt.figure(figsize=(6, 6))
plt.pie(probabilities, labels=labels, colors=colors, explode=explode, autopct='%1.1f%%', startangle=90)
plt.axis('equal')
plt.savefig(save_path)
def detect(
img_path: str,
save_path: str,
pretrained_path: str=None,
stat_from: str="clip",
gaussian_sigma: int=None,
jpeg_quality: int=None,
device: str="cpu"
):
img = Image.open(img_path).convert("RGB")
if gaussian_sigma is not None:
img = gaussian_blur(img, gaussian_sigma)
if jpeg_quality is not None:
img = png2jpg(img, jpeg_quality)
# transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ),
])
img = transform(img)
img: torch.Tensor
if img.ndim == 3:
img = img.unsqueeze(dim=0)
img = img.to(device=device)
model = CLIPViTL14Model()
if pretrained_path:
state_dict = torch.load(pretrained_path, map_location=device)
model.fc.load_state_dict(state_dict)
model.eval()
model.to(device=device)
probs = model(img).sigmoid().flatten().tolist()[0]
plot_pie_chart(probs, save_path)