Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import datasets, transforms, models | |
import os | |
from PIL import Image | |
# Ustawienia parametrów modelu | |
img_width, img_height = 224, 224 # Wymiary obrazu wymagane przez model ResNet | |
model_path = 'animal_classifier_resnet.pth' # Ścieżka do wytrenowanego modelu | |
# Sprawdzenie, czy jest dostępny GPU i przypisanie urządzenia do zmiennej `device` | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Wyłączenie CuDNN, chyba że twoja karta wspiera bibliotekę CuDNN (NVIDIA CUDA Deep Neural Network library) | |
# torch.backends.cudnn.enabled = False | |
# Transformacje danych wejściowych (zmiana rozmiaru, normalizacja) | |
transform = transforms.Compose([ | |
transforms.Resize((img_width, img_height)), # Zmiana rozmiaru obrazu | |
transforms.ToTensor(), # Konwersja obrazu do tensoru | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizacja obrazu | |
]) | |
# Inicjalizacja modelu ResNet18 bez wstępnych wag | |
model = models.resnet18(weights=None) | |
num_ftrs = model.fc.in_features # Liczba wejściowych cech ostatniej warstwy | |
model.fc = nn.Linear(num_ftrs, len(datasets.ImageFolder('raw-img', transform=transform).classes)) # Zastąpienie ostatniej warstwy dopasowanej do liczby klas w danych | |
# Przeniesienie modelu na GPU, jeśli jest dostępny | |
model = model.to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) # Wczytanie wytrenowanych wag modelu | |
model.eval() # Ustawienie modelu w tryb ewaluacyjny | |
# Funkcja do przewidywania klasy obrazu | |
def predict(image_path): | |
image = Image.open(image_path).convert('RGB') # Otworzenie obrazu i konwersja do RGB | |
image = transform(image).unsqueeze(0) # Zastosowanie transformacji i dodanie wymiaru batch | |
image = image.to(device) # Przeniesienie obrazu na GPU, jeśli jest dostępny | |
with torch.no_grad(): # Wyłączenie gradientów dla przewidywania | |
output = model(image) # Przekazanie obrazu przez model | |
_, predicted = torch.max(output, 1) # Wybranie klasy z najwyższym prawdopodobieństwem | |
return datasets.ImageFolder('raw-img', transform=transform).classes[predicted.item()] # Zwrócenie nazwy klasy | |
# Przykład użycia funkcji predict do przewidywania klasy obrazów w katalogu `recognize` | |
test_image_dir = 'recognize' # Ścieżka do katalogu z obrazami do przewidywania | |
for filename in os.listdir(test_image_dir): # Pętla przez pliki w katalogu | |
if filename.lower().endswith(('.jpg', '.jpeg', '.png')): # Sprawdzenie rozszerzenia pliku | |
image_path = os.path.join(test_image_dir, filename) # Pełna ścieżka do pliku | |
print(f"Prediction for {filename}: {predict(image_path)}") # Wyświetlenie przewidywanej klasy dla obrazu | |
def predict_all(): | |
results = [] | |
test_image_dir = 'recognize' | |
for filename in os.listdir(test_image_dir): | |
if filename.lower().endswith(('.jpg', '.jpeg', '.png')): | |
image_path = os.path.join(test_image_dir, filename) | |
prediction = predict(image_path) | |
results.append(f"Prediction for {filename}: {prediction}") | |
return results | |
if __name__ == "__main__": | |
print("\n".join(predict_all())) | |