Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import datasets, transforms, models | |
from torch.utils.data import DataLoader | |
def train(): | |
# Ustawienia parametrów treningu | |
img_width, img_height = 224, 224 # Wymiary obrazu wymagane przez model ResNet | |
batch_size = 32 # Liczba obrazów przetwarzanych na raz podczas treningu | |
epochs = 10 # Liczba epok treningu | |
learning_rate = 0.001 # Wskaźnik uczenia się dla optymalizatora | |
model_path = 'animal_classifier_resnet.pth' # Ścieżka do zapisu wytrenowanego modelu | |
# Sprawdzenie, czy jest dostępny GPU i przypisanie urządzenia do zmiennej `device` | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
if device.type == 'cuda': | |
print(f"CUDA Device: {torch.cuda.get_device_name(0)}") | |
print(f"CUDA Version: {torch.version.cuda}") | |
# Transformacje danych wejściowych (zmiana rozmiaru, normalizacja) | |
transform = transforms.Compose([ | |
transforms.Resize((img_width, img_height)), # Zmiana rozmiaru obrazu | |
transforms.ToTensor(), # Konwersja obrazu do tensora | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizacja obrazu | |
]) | |
# Przygotowanie danych treningowych z katalogu `raw-img` | |
data_dir = 'raw-img' # Ścieżka do katalogu z obrazami treningowymi | |
train_dataset = datasets.ImageFolder(data_dir, transform=transform) # Wczytanie obrazów i zastosowanie transformacji | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Tworzenie loadera danych | |
# Użycie pretrenowanego modelu ResNet18 | |
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
num_ftrs = model.fc.in_features # Liczba wejściowych cech ostatniej warstwy | |
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes)) # Zastąpienie ostatniej warstwy dopasowanej do liczby klas w danych | |
# Przeniesienie modelu na GPU, jeśli jest dostępny | |
model = model.to(device) | |
# Definicja funkcji kosztu (CrossEntropyLoss) i optymalizatora (Adam) | |
criterion = nn.CrossEntropyLoss() # Funkcja kosztu | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Optymalizator | |
# Trening modelu | |
for epoch in range(epochs): # Pętla przez epoki | |
model.train() # Ustawienie modelu w tryb treningowy | |
running_loss = 0.0 # Zmienna do śledzenia straty | |
for inputs, labels in train_loader: # Pętla przez batch'e danych | |
# Przeniesienie danych na GPU, jeśli jest dostępny | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() # Zerowanie gradientów | |
outputs = model(inputs) # Przekazanie danych przez model | |
loss = criterion(outputs, labels) # Obliczenie straty | |
loss.backward() # Propagacja wsteczna | |
optimizer.step() # Aktualizacja wag modelu | |
running_loss += loss.item() # Akumulacja straty | |
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}") # Wyświetlenie średniej straty na epokę | |
# Zapisywanie wytrenowanego modelu do pliku | |
torch.save(model.state_dict(), model_path) | |
print(f"Model saved to {model_path}") | |
if __name__ == "__main__": | |
train() |