animals_detection / predict_model.py
wiklif's picture
second commit
daa4333
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
import os
from PIL import Image
# Ustawienia parametrów modelu
img_width, img_height = 224, 224 # Wymiary obrazu wymagane przez model ResNet
model_path = 'animal_classifier_resnet.pth' # Ścieżka do wytrenowanego modelu
# Sprawdzenie, czy jest dostępny GPU i przypisanie urządzenia do zmiennej `device`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Wyłączenie CuDNN, chyba że twoja karta wspiera bibliotekę CuDNN (NVIDIA CUDA Deep Neural Network library)
# torch.backends.cudnn.enabled = False
# Transformacje danych wejściowych (zmiana rozmiaru, normalizacja)
transform = transforms.Compose([
transforms.Resize((img_width, img_height)), # Zmiana rozmiaru obrazu
transforms.ToTensor(), # Konwersja obrazu do tensoru
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizacja obrazu
])
# Inicjalizacja modelu ResNet18 bez wstępnych wag
model = models.resnet18(weights=None)
num_ftrs = model.fc.in_features # Liczba wejściowych cech ostatniej warstwy
model.fc = nn.Linear(num_ftrs, len(datasets.ImageFolder('raw-img', transform=transform).classes)) # Zastąpienie ostatniej warstwy dopasowanej do liczby klas w danych
# Przeniesienie modelu na GPU, jeśli jest dostępny
model = model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device)) # Wczytanie wytrenowanych wag modelu
model.eval() # Ustawienie modelu w tryb ewaluacyjny
# Funkcja do przewidywania klasy obrazu
def predict(image_path):
image = Image.open(image_path).convert('RGB') # Otworzenie obrazu i konwersja do RGB
image = transform(image).unsqueeze(0) # Zastosowanie transformacji i dodanie wymiaru batch
image = image.to(device) # Przeniesienie obrazu na GPU, jeśli jest dostępny
with torch.no_grad(): # Wyłączenie gradientów dla przewidywania
output = model(image) # Przekazanie obrazu przez model
_, predicted = torch.max(output, 1) # Wybranie klasy z najwyższym prawdopodobieństwem
return datasets.ImageFolder('raw-img', transform=transform).classes[predicted.item()] # Zwrócenie nazwy klasy
# Przykład użycia funkcji predict do przewidywania klasy obrazów w katalogu `recognize`
test_image_dir = 'recognize' # Ścieżka do katalogu z obrazami do przewidywania
for filename in os.listdir(test_image_dir): # Pętla przez pliki w katalogu
if filename.lower().endswith(('.jpg', '.jpeg', '.png')): # Sprawdzenie rozszerzenia pliku
image_path = os.path.join(test_image_dir, filename) # Pełna ścieżka do pliku
print(f"Prediction for {filename}: {predict(image_path)}") # Wyświetlenie przewidywanej klasy dla obrazu
def predict_all():
results = []
test_image_dir = 'recognize'
for filename in os.listdir(test_image_dir):
if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
image_path = os.path.join(test_image_dir, filename)
prediction = predict(image_path)
results.append(f"Prediction for {filename}: {prediction}")
return results
if __name__ == "__main__":
print("\n".join(predict_all()))