Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
from torchvision import transforms | |
import random as rand | |
def plot_incorrect_preds(incorrect, classes, num_imgs): | |
# num_imgs is a multiple of 5 | |
assert num_imgs % 5 == 0 | |
assert len(incorrect) >= num_imgs | |
incorrect_inds = rand.sample(range(len(incorrect)), num_imgs) | |
# incorrect (data, target, pred, output) | |
fig = plt.figure(figsize=(10, num_imgs // 2)) | |
plt.suptitle("Target | Predicted Label") | |
for i in range(num_imgs): | |
cur_incorrect = incorrect[incorrect_inds[i]] | |
plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto") | |
# unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) | |
unnormalized = transforms.Normalize( | |
(-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762) | |
)(cur_incorrect[0]) | |
plt.imshow(transforms.ToPILImage()(unnormalized)) | |
plt.title( | |
f"{classes[cur_incorrect[1].item()]}|{classes[cur_incorrect[2].item()]}", | |
# fontsize=8, | |
) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.tight_layout() | |
return fig |