import torch import torch.nn as nn from torchvision import datasets, transforms, models import os from PIL import Image # Ustawienia parametrów modelu img_width, img_height = 224, 224 # Wymiary obrazu wymagane przez model ResNet model_path = 'animal_classifier_resnet.pth' # Ścieżka do wytrenowanego modelu # Sprawdzenie, czy jest dostępny GPU i przypisanie urządzenia do zmiennej `device` device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Wyłączenie CuDNN, chyba że twoja karta wspiera bibliotekę CuDNN (NVIDIA CUDA Deep Neural Network library) # torch.backends.cudnn.enabled = False # Transformacje danych wejściowych (zmiana rozmiaru, normalizacja) transform = transforms.Compose([ transforms.Resize((img_width, img_height)), # Zmiana rozmiaru obrazu transforms.ToTensor(), # Konwersja obrazu do tensoru transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizacja obrazu ]) # Inicjalizacja modelu ResNet18 bez wstępnych wag model = models.resnet18(weights=None) num_ftrs = model.fc.in_features # Liczba wejściowych cech ostatniej warstwy model.fc = nn.Linear(num_ftrs, len(datasets.ImageFolder('raw-img', transform=transform).classes)) # Zastąpienie ostatniej warstwy dopasowanej do liczby klas w danych # Przeniesienie modelu na GPU, jeśli jest dostępny model = model.to(device) model.load_state_dict(torch.load(model_path, map_location=device)) # Wczytanie wytrenowanych wag modelu model.eval() # Ustawienie modelu w tryb ewaluacyjny # Funkcja do przewidywania klasy obrazu def predict(image_path): image = Image.open(image_path).convert('RGB') # Otworzenie obrazu i konwersja do RGB image = transform(image).unsqueeze(0) # Zastosowanie transformacji i dodanie wymiaru batch image = image.to(device) # Przeniesienie obrazu na GPU, jeśli jest dostępny with torch.no_grad(): # Wyłączenie gradientów dla przewidywania output = model(image) # Przekazanie obrazu przez model _, predicted = torch.max(output, 1) # Wybranie klasy z najwyższym prawdopodobieństwem return datasets.ImageFolder('raw-img', transform=transform).classes[predicted.item()] # Zwrócenie nazwy klasy # Przykład użycia funkcji predict do przewidywania klasy obrazów w katalogu `recognize` test_image_dir = 'recognize' # Ścieżka do katalogu z obrazami do przewidywania for filename in os.listdir(test_image_dir): # Pętla przez pliki w katalogu if filename.lower().endswith(('.jpg', '.jpeg', '.png')): # Sprawdzenie rozszerzenia pliku image_path = os.path.join(test_image_dir, filename) # Pełna ścieżka do pliku print(f"Prediction for {filename}: {predict(image_path)}") # Wyświetlenie przewidywanej klasy dla obrazu def predict_all(): results = [] test_image_dir = 'recognize' for filename in os.listdir(test_image_dir): if filename.lower().endswith(('.jpg', '.jpeg', '.png')): image_path = os.path.join(test_image_dir, filename) prediction = predict(image_path) results.append(f"Prediction for {filename}: {prediction}") return results if __name__ == "__main__": print("\n".join(predict_all()))