animals_detection / train_model.py
wiklif's picture
second commit
daa4333
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
def train():
# Ustawienia parametrów treningu
img_width, img_height = 224, 224 # Wymiary obrazu wymagane przez model ResNet
batch_size = 32 # Liczba obrazów przetwarzanych na raz podczas treningu
epochs = 10 # Liczba epok treningu
learning_rate = 0.001 # Wskaźnik uczenia się dla optymalizatora
model_path = 'animal_classifier_resnet.pth' # Ścieżka do zapisu wytrenowanego modelu
# Sprawdzenie, czy jest dostępny GPU i przypisanie urządzenia do zmiennej `device`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
print(f"CUDA Version: {torch.version.cuda}")
# Transformacje danych wejściowych (zmiana rozmiaru, normalizacja)
transform = transforms.Compose([
transforms.Resize((img_width, img_height)), # Zmiana rozmiaru obrazu
transforms.ToTensor(), # Konwersja obrazu do tensora
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizacja obrazu
])
# Przygotowanie danych treningowych z katalogu `raw-img`
data_dir = 'raw-img' # Ścieżka do katalogu z obrazami treningowymi
train_dataset = datasets.ImageFolder(data_dir, transform=transform) # Wczytanie obrazów i zastosowanie transformacji
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Tworzenie loadera danych
# Użycie pretrenowanego modelu ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features # Liczba wejściowych cech ostatniej warstwy
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes)) # Zastąpienie ostatniej warstwy dopasowanej do liczby klas w danych
# Przeniesienie modelu na GPU, jeśli jest dostępny
model = model.to(device)
# Definicja funkcji kosztu (CrossEntropyLoss) i optymalizatora (Adam)
criterion = nn.CrossEntropyLoss() # Funkcja kosztu
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Optymalizator
# Trening modelu
for epoch in range(epochs): # Pętla przez epoki
model.train() # Ustawienie modelu w tryb treningowy
running_loss = 0.0 # Zmienna do śledzenia straty
for inputs, labels in train_loader: # Pętla przez batch'e danych
# Przeniesienie danych na GPU, jeśli jest dostępny
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # Zerowanie gradientów
outputs = model(inputs) # Przekazanie danych przez model
loss = criterion(outputs, labels) # Obliczenie straty
loss.backward() # Propagacja wsteczna
optimizer.step() # Aktualizacja wag modelu
running_loss += loss.item() # Akumulacja straty
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}") # Wyświetlenie średniej straty na epokę
# Zapisywanie wytrenowanego modelu do pliku
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")
if __name__ == "__main__":
train()