File size: 3,082 Bytes
c5bd7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import data_setup, model_builder, engine, utils, plotting

from torchvision import transforms
import argparse

# To Avoid Glitches related to CUDA
def set_memory_limit():
    if torch.cuda.is_available():
        try:
            torch.tensor([1], device='cuda') # Adjust memory fraction as needed
            print(f"Device is GPU/CUDA.")
            device = 'cuda'
            return device
        except:
            print("Device is CPU.")
            device = 'cpu'
            return device

# Define argument parsing directly
parser = argparse.ArgumentParser(description="Train a model for Classification of types of Trash.")
parser.add_argument("--train_dir", type=str, default="data/train", help="Directory containing training images")
parser.add_argument("--test_dir", type=str, default="data/test", help="Directory containing testing images")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for training")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
parser.add_argument("--num_epochs", type=int, default=20, help="Number of epochs to train for")
args = parser.parse_args()

# Set Args up in Variables
train_dir = args.train_dir
test_dir = args.test_dir
LEARNING_RATE = args.learning_rate
BATCH_SIZE = args.batch_size
NUM_EPOCHS = args.num_epochs
HIDDEN_UNITS = 15

# Data transformation
data_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor()
])

# Create DataLoaders
train_dataloader, test_dataloader, class_names = data_setup.train_test_dataloader(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=data_transform,
    batch_size=BATCH_SIZE
)

# Model creation
device = set_memory_limit()
model = model_builder.TrashClassificationCNNModel(input_shape=3,
                                                  hidden_units=HIDDEN_UNITS,
                                                  output_shape=len(class_names)
                                                  ).to(device)

# Loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=LEARNING_RATE)

# Start training
metrics = engine.train(model=model,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=NUM_EPOCHS,
                       device=device)

# Save the model
utils.save_model(model=model,
                 target_dir="models",
                 model_name="Trash_Classification_Model_COLOURED.pth")

# Clear CUDA cache
torch.cuda.empty_cache()

# Plot the Confusion Matrix
plotting.plot_confusion_Matrix(model_path="models\Trash_Classification_Model_COLOURED.pth",
                               dataloader=test_dataloader,
                               class_names=class_names,
                               device=device)

# Plot the Loss and Accuracy
plotting.plot_metrics(metrics)