final submission readyu...moved notebookks
This commit is contained in:
File diff suppressed because one or more lines are too long
511
notebooks/Final-Submission.ipynb
Normal file
511
notebooks/Final-Submission.ipynb
Normal file
@@ -0,0 +1,511 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "21b10b99",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 1: Load Dataset\n",
|
||||
"Load images from disk and count per class to verify dataset integrity"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d318d1f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"data_dir = '../data/raw/vehicle_classification'\n",
|
||||
"\n",
|
||||
"total_count = 0\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
" if os.path.isdir(class_path):\n",
|
||||
" count = len(os.listdir(class_path))\n",
|
||||
" total_count += count\n",
|
||||
" print(f\"{class_name}: {count} images\")\n",
|
||||
"\n",
|
||||
"print(f\"Total Count: {total_count} images\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "64122ad4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Check out sample image from dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5604ace3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# first image in first folder\n",
|
||||
"first_class = os.listdir(data_dir)[0]\n",
|
||||
"first_image_path = os.path.join(data_dir, first_class, os.listdir(os.path.join(data_dir, first_class))[0])\n",
|
||||
"\n",
|
||||
"img = Image.open(first_image_path)\n",
|
||||
"print(f\"Size: {img.size}\")\n",
|
||||
"print(f\"Mode: {img.mode}\")\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.title(first_class)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c19ec00a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ensure that all images are RGB, all of same resolution"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3cedd586",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sizes = set()\n",
|
||||
"modes = set()\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
" if not os.path.isdir(class_path):\n",
|
||||
" continue\n",
|
||||
" for img_name in os.listdir(class_path):\n",
|
||||
" img = Image.open(os.path.join(class_path, img_name))\n",
|
||||
" sizes.add(img.size)\n",
|
||||
" modes.add(img.mode)\n",
|
||||
"\n",
|
||||
"print(f\"Unique sizes: {sizes}\")\n",
|
||||
"print(f\"Unique modes: {modes}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "88ac961b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Accelrate torch with GPU or MPS if available (credit: Claude)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8f556b22",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" DEVICE = torch.device('cuda')\n",
|
||||
" print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
|
||||
"elif torch.backends.mps.is_available():\n",
|
||||
" DEVICE = torch.device('mps')\n",
|
||||
" print('Apple Silicon (MPS)')\n",
|
||||
"else:\n",
|
||||
" DEVICE = torch.device('cpu')\n",
|
||||
" print('CPU')\n",
|
||||
"\n",
|
||||
"print(f'Running on: {DEVICE}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3ad97919",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 2: Split Dataset 80:20 (Train / Test)\n",
|
||||
"\n",
|
||||
"Augmentation applied to training set only — test set kept clean for fair evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f68c1a25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import math \n",
|
||||
"from torchvision import datasets, transforms\n",
|
||||
"from torch.utils.data import random_split, DataLoader\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"train_transform = transforms.Compose([\n",
|
||||
" transforms.Resize((64, 64)), # Resize to 64x64 (even though all images are)\n",
|
||||
" transforms.RandomHorizontalFlip(), # randomly mirror image\n",
|
||||
" transforms.RandomRotation(20), # rotate up to 20 degrees\n",
|
||||
" transforms.ColorJitter(brightness=0.3, contrast=0.3), # vary lighting\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize tensors as in Pytorch tutorial\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"test_transform = transforms.Compose([\n",
|
||||
" transforms.Resize((64,64)), # Resize to 64x64 (even though all images are)\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize tensors as in Pytorch tutorial\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"train_full = datasets.ImageFolder(root=data_dir, transform=train_transform) #Load full dataset with train transform\n",
|
||||
"test_full = datasets.ImageFolder(root=data_dir, transform=test_transform) # Load full dataset with test transform\n",
|
||||
"\n",
|
||||
"train_size = math.floor(len(train_full) * 0.8) #80% split for training\n",
|
||||
"test_size = len(train_full) - train_size # Remaining 20% used for testing\n",
|
||||
"\n",
|
||||
"torch.manual_seed(42) # Fixes the RNG to the same starting point...42 is convention according to GeeksForGeeks\n",
|
||||
"indices = torch.randperm(len(train_full)).tolist() #randomly shuffle the indices\n",
|
||||
"\n",
|
||||
"train_indices = indices[:train_size] # First 80% of indices\n",
|
||||
"test_indices = indices[train_size:] # remaining 20% of indices\n",
|
||||
"\n",
|
||||
"train_dataset = torch.utils.data.Subset(train_full, train_indices) #Create final datasets\n",
|
||||
"test_dataset = torch.utils.data.Subset(test_full, test_indices)\n",
|
||||
"\n",
|
||||
"print(f\"Train: {len(train_dataset)}, Test: {len(test_dataset)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2eede814",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Credit: Claude: load dataset into batches (64 is standard), and dedicate n threads to the process (min 1, preferrably 4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1539eaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"NUM_WORKERS = min(4, os.cpu_count() or 1)\n",
|
||||
"PIN_MEMORY = (DEVICE.type == 'cuda') # Pin memory if GPU available for CUDA\n",
|
||||
"\n",
|
||||
"print(NUM_WORKERS)\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,\n",
|
||||
" num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)\n",
|
||||
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False,\n",
|
||||
" num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)\n",
|
||||
"classes = train_full.classes \n",
|
||||
"print(classes)\n",
|
||||
"print(len(classes))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e7255041",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 3: CNN Architecture\n",
|
||||
"Model takes a batch of (3, 64, 64) images and outputs 8 class scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d1b7d9ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
"class Net(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Net, self).__init__()\n",
|
||||
"\n",
|
||||
" self.features = nn.Sequential(\n",
|
||||
" # go from 64x64 to 32x32\n",
|
||||
"\n",
|
||||
" # kernel size = 3x3 filter patch\n",
|
||||
" #padding = 1, so 64x64 stays 64x64 after conv\n",
|
||||
" nn.Conv2d(3, 32, kernel_size=3, padding=1), # 3 channel RRGB, 32 filters (recommended), and adding 1 pixel of zeros so keep output at same size\n",
|
||||
"\n",
|
||||
" nn.BatchNorm2d(32), # mirrors conv2d output\n",
|
||||
" nn.ReLU(), #Activation fn\n",
|
||||
" nn.MaxPool2d(2,2), # 2x2 window, stride = 2: so halved --> 32x32\n",
|
||||
"\n",
|
||||
" #Go from 32x32 --> 16x16 in the same manner\n",
|
||||
"\n",
|
||||
" nn.Conv2d(32, 64, kernel_size=3, padding=1), # Double filters --> more complex features detected\n",
|
||||
" nn.BatchNorm2d(64),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.MaxPool2d(2,2), \n",
|
||||
" \n",
|
||||
" # Go from 16x 16 0 --> 8x8 in the same manner\n",
|
||||
"\n",
|
||||
" nn.Conv2d(64, 128, kernel_size=3, padding=1), # Double filters again --> even more complex rfeatures detected\n",
|
||||
" nn.BatchNorm2d(128),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.MaxPool2d(2,2), \n",
|
||||
"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.classifier = nn.Sequential(\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear (128 * 8 *8, 512), # flattened size of 8x8 * 128, 512 is arbitrary number of hidden neurons (recommended by GeeksForGeeks)....tunned to learn details without overfitting\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.5),#Randomly zero 50% of neurons --> prevent memorization and overfitting\n",
|
||||
" nn.Linear(512, len(classes)) # one score per vehicle class\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x): \n",
|
||||
" x = self.features(x) # extract spatial features via conv blocks\n",
|
||||
" x = self.classifier(x) # flatten to 8 vehicle clases\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"model = Net().to(DEVICE)\n",
|
||||
"device = DEVICE\n",
|
||||
"\n",
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22e71032",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Loss fn and optimizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "54d11a04",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.optim as optim\n",
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss() # Applied softmax to convert scores --> probabilities --> penalizes model \n",
|
||||
"\n",
|
||||
"# Changed to adam optimizer (internal momentumn calculation)\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "572d80e3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 4: Train Model\n",
|
||||
"\n",
|
||||
"Track loss and accuracy per epoch — stored in lists for plotting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "374d0590",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_losses, train_accs = [], [] #To store accs for visualization\n",
|
||||
"\n",
|
||||
"for epoch in range(30): # 30 epochs\n",
|
||||
" running_loss = 0.0 # keep track of running loss\n",
|
||||
" correct = 0\n",
|
||||
" total = 0\n",
|
||||
"\n",
|
||||
" for i, data in enumerate(train_loader, 0):\n",
|
||||
" inputs, labels = data\n",
|
||||
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
||||
" optimizer.zero_grad() # clear prior gradients\n",
|
||||
" outputs = model(inputs) # forward pass \n",
|
||||
" loss = criterion(outputs, labels) # compare to GT\n",
|
||||
" loss.backward() # Backprop...compute gradient of loss\n",
|
||||
" optimizer.step() #Use adam optimizer to update weights using gradients\n",
|
||||
"\n",
|
||||
" running_loss += loss.item() #Extract scalar loss value \n",
|
||||
" _, predicted = torch.max(outputs, 1) # Take index of highest score as predicted class\n",
|
||||
" total += labels.size(0)\n",
|
||||
" correct += (predicted == labels).sum().item() #update correct tally\n",
|
||||
"\n",
|
||||
" epoch_loss = running_loss / len(train_loader) #Compute avg loss\n",
|
||||
" epoch_acc = 100 * correct / total #Avg acc accross epoch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" #Adding epochs to list\n",
|
||||
" train_losses.append(epoch_loss)\n",
|
||||
" train_accs.append(epoch_acc) \n",
|
||||
"\n",
|
||||
" print(f'Epoch {epoch+1}: Loss={epoch_loss:.3f}, Accuracy={epoch_acc:.2f}%')\n",
|
||||
"\n",
|
||||
"print('Finished Training')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26ab705a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 6 (Bonus): Plot Loss & Accuracy Curves\n",
|
||||
"\n",
|
||||
"Visualises how loss decreased and accuracy improved across 30 epochs [credit: Claude]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c71ee0ff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Plot 1: Training loss vs epoch \n",
|
||||
"ax1.plot(train_losses, color='steelblue', linewidth=2)\n",
|
||||
"ax1.set_title('Training Loss')\n",
|
||||
"ax1.set_xlabel('Epoch')\n",
|
||||
"ax1.set_ylabel('Loss')\n",
|
||||
"ax1.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"#Plot 2: training acc. vs epoch\n",
|
||||
"ax2.plot(train_accs, color='darkorange', linewidth=2)\n",
|
||||
"ax2.set_title('Training Accuracy')\n",
|
||||
"ax2.set_xlabel('Epoch')\n",
|
||||
"ax2.set_ylabel('Accuracy (%)')\n",
|
||||
"ax2.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"#Concat two plots, save, and show\n",
|
||||
"plt.suptitle('Training Curves', fontsize=14, fontweight='bold')\n",
|
||||
"plt.tight_layout()\n",
|
||||
"os.makedirs('../results', exist_ok=True)\n",
|
||||
"plt.savefig('../results/training_curves.png', dpi=150, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b3bfda75",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Save Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2bf2b9a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.makedirs('../models', exist_ok=True)\n",
|
||||
"PATH = '../models/final-classifier.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "057d5d72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 5: Final Accuracy\n",
|
||||
"Evaluate on both train and test sets with Dropout disabled (model.eval())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e54f566",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.eval() # Switch to eval mode (disabled droupout --> higher acc) (credit: Claude)\n",
|
||||
"train_correct, train_total = 0, 0\n",
|
||||
"with torch.no_grad(): #Gradient computation not needed for inference\n",
|
||||
" for images, labels in train_loader:\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" outputs = model(images) #Fwd pass only\n",
|
||||
" _, predicted = torch.max(outputs, 1) #highest score = predicted class\n",
|
||||
" train_total += labels.size(0) #Count total \n",
|
||||
" train_correct += (predicted == labels).sum().item() # Count correct\n",
|
||||
"\n",
|
||||
"# Test accuracy - repeat with test set\n",
|
||||
"test_correct, test_total = 0, 0\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for images, labels in test_loader:\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" outputs = model(images)\n",
|
||||
" _, predicted = torch.max(outputs, 1)\n",
|
||||
" test_total += labels.size(0)\n",
|
||||
" test_correct += (predicted == labels).sum().item()\n",
|
||||
"\n",
|
||||
"print(f'Final Train Accuracy : {100 * train_correct / train_total:.2f}%')\n",
|
||||
"print(f'Final Test Accuracy : {100 * test_correct / test_total:.2f}%')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "60666242",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Credit Claude: Testing accuracy per class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8cc7ed40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"correct_pred = {classname: 0 for classname in classes} # Correct predicitions per class\n",
|
||||
"total_pred = {classname: 0 for classname in classes} # total images seen per class\n",
|
||||
"\n",
|
||||
"model.eval()\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in test_loader:\n",
|
||||
" images, labels = data\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" outputs = model(images)\n",
|
||||
" _, predictions = torch.max(outputs, 1) #predicted class index per class\n",
|
||||
" for label, prediction in zip(labels, predictions):\n",
|
||||
" if label == prediction:\n",
|
||||
" correct_pred[classes[label]] += 1\n",
|
||||
" total_pred[classes[label]] += 1\n",
|
||||
"\n",
|
||||
"for classname, correct_count in correct_pred.items():\n",
|
||||
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
|
||||
" print(f'Accuracy for class: {classname:10s} is {accuracy:.1f}%')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -42,7 +42,7 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"data_dir = '../data/raw/vehicle_classification'\n",
|
||||
"data_dir = '../../data/raw/vehicle_classification'\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
@@ -378,16 +378,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: Loss=1.480, Accuracy=50.95%\n",
|
||||
"Epoch 2: Loss=1.073, Accuracy=62.46%\n",
|
||||
"Epoch 3: Loss=0.974, Accuracy=66.04%\n",
|
||||
"Epoch 4: Loss=0.899, Accuracy=68.51%\n",
|
||||
"Epoch 5: Loss=0.834, Accuracy=70.87%\n",
|
||||
"Epoch 6: Loss=0.779, Accuracy=72.60%\n",
|
||||
"Epoch 7: Loss=0.730, Accuracy=74.29%\n",
|
||||
"Epoch 8: Loss=0.689, Accuracy=75.64%\n",
|
||||
"Epoch 9: Loss=0.649, Accuracy=77.14%\n",
|
||||
"Epoch 10: Loss=0.620, Accuracy=78.21%\n",
|
||||
"Epoch 1: Loss=1.518, Accuracy=51.03%\n",
|
||||
"Epoch 2: Loss=1.110, Accuracy=62.02%\n",
|
||||
"Epoch 3: Loss=0.976, Accuracy=66.24%\n",
|
||||
"Epoch 4: Loss=0.906, Accuracy=68.52%\n",
|
||||
"Epoch 5: Loss=0.838, Accuracy=70.77%\n",
|
||||
"Epoch 6: Loss=0.782, Accuracy=72.33%\n",
|
||||
"Epoch 7: Loss=0.735, Accuracy=74.25%\n",
|
||||
"Epoch 8: Loss=0.696, Accuracy=75.42%\n",
|
||||
"Epoch 9: Loss=0.665, Accuracy=76.53%\n",
|
||||
"Epoch 10: Loss=0.634, Accuracy=77.40%\n",
|
||||
"Finished Training\n"
|
||||
]
|
||||
}
|
||||
@@ -430,7 +430,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PATH = '../models/tutorial-cnn.pth'\n",
|
||||
"PATH = '../../models/tutorial-cnn.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
@@ -452,7 +452,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 76.35%\n"
|
||||
"Test Accuracy: 75.66%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -482,14 +482,14 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy for class: Bicycle is 45.2%\n",
|
||||
"Accuracy for class: Bus is 71.3%\n",
|
||||
"Accuracy for class: Car is 77.3%\n",
|
||||
"Accuracy for class: Motorcycle is 81.2%\n",
|
||||
"Accuracy for class: NonVehicles is 98.2%\n",
|
||||
"Accuracy for class: Taxi is 37.0%\n",
|
||||
"Accuracy for class: Truck is 35.6%\n",
|
||||
"Accuracy for class: Van is 35.5%\n"
|
||||
"Accuracy for class: Bicycle is 73.6%\n",
|
||||
"Accuracy for class: Bus is 65.1%\n",
|
||||
"Accuracy for class: Car is 86.2%\n",
|
||||
"Accuracy for class: Motorcycle is 55.7%\n",
|
||||
"Accuracy for class: NonVehicles is 99.3%\n",
|
||||
"Accuracy for class: Taxi is 32.9%\n",
|
||||
"Accuracy for class: Truck is 24.8%\n",
|
||||
"Accuracy for class: Van is 22.9%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -42,7 +42,7 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"data_dir = '../data/raw/vehicle_classification'\n",
|
||||
"data_dir = '../../data/raw/vehicle_classification'\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
@@ -378,36 +378,36 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: Loss=1.732, Accuracy=43.80%\n",
|
||||
"Epoch 2: Loss=1.403, Accuracy=54.06%\n",
|
||||
"Epoch 3: Loss=1.137, Accuracy=60.66%\n",
|
||||
"Epoch 4: Loss=1.003, Accuracy=64.95%\n",
|
||||
"Epoch 5: Loss=0.917, Accuracy=67.92%\n",
|
||||
"Epoch 6: Loss=0.845, Accuracy=70.22%\n",
|
||||
"Epoch 7: Loss=0.796, Accuracy=72.06%\n",
|
||||
"Epoch 8: Loss=0.741, Accuracy=73.88%\n",
|
||||
"Epoch 9: Loss=0.711, Accuracy=75.05%\n",
|
||||
"Epoch 10: Loss=0.670, Accuracy=76.52%\n",
|
||||
"Epoch 11: Loss=0.636, Accuracy=77.59%\n",
|
||||
"Epoch 12: Loss=0.611, Accuracy=78.49%\n",
|
||||
"Epoch 13: Loss=0.580, Accuracy=79.42%\n",
|
||||
"Epoch 14: Loss=0.546, Accuracy=80.87%\n",
|
||||
"Epoch 15: Loss=0.519, Accuracy=81.77%\n",
|
||||
"Epoch 16: Loss=0.502, Accuracy=82.36%\n",
|
||||
"Epoch 17: Loss=0.474, Accuracy=83.28%\n",
|
||||
"Epoch 18: Loss=0.445, Accuracy=84.34%\n",
|
||||
"Epoch 19: Loss=0.424, Accuracy=85.12%\n",
|
||||
"Epoch 20: Loss=0.392, Accuracy=86.07%\n",
|
||||
"Epoch 21: Loss=0.366, Accuracy=86.99%\n",
|
||||
"Epoch 22: Loss=0.330, Accuracy=88.30%\n",
|
||||
"Epoch 23: Loss=0.304, Accuracy=89.51%\n",
|
||||
"Epoch 24: Loss=0.267, Accuracy=90.67%\n",
|
||||
"Epoch 25: Loss=0.227, Accuracy=92.35%\n",
|
||||
"Epoch 26: Loss=0.222, Accuracy=92.38%\n",
|
||||
"Epoch 27: Loss=0.176, Accuracy=94.24%\n",
|
||||
"Epoch 28: Loss=0.156, Accuracy=94.71%\n",
|
||||
"Epoch 29: Loss=0.143, Accuracy=94.92%\n",
|
||||
"Epoch 30: Loss=0.124, Accuracy=95.87%\n",
|
||||
"Epoch 1: Loss=1.693, Accuracy=43.96%\n",
|
||||
"Epoch 2: Loss=1.285, Accuracy=56.70%\n",
|
||||
"Epoch 3: Loss=1.070, Accuracy=62.78%\n",
|
||||
"Epoch 4: Loss=0.964, Accuracy=66.34%\n",
|
||||
"Epoch 5: Loss=0.891, Accuracy=68.85%\n",
|
||||
"Epoch 6: Loss=0.826, Accuracy=70.92%\n",
|
||||
"Epoch 7: Loss=0.764, Accuracy=72.87%\n",
|
||||
"Epoch 8: Loss=0.734, Accuracy=74.25%\n",
|
||||
"Epoch 9: Loss=0.690, Accuracy=75.77%\n",
|
||||
"Epoch 10: Loss=0.636, Accuracy=77.49%\n",
|
||||
"Epoch 11: Loss=0.616, Accuracy=78.49%\n",
|
||||
"Epoch 12: Loss=0.581, Accuracy=79.67%\n",
|
||||
"Epoch 13: Loss=0.558, Accuracy=80.53%\n",
|
||||
"Epoch 14: Loss=0.526, Accuracy=81.44%\n",
|
||||
"Epoch 15: Loss=0.498, Accuracy=82.36%\n",
|
||||
"Epoch 16: Loss=0.458, Accuracy=83.76%\n",
|
||||
"Epoch 17: Loss=0.436, Accuracy=84.62%\n",
|
||||
"Epoch 18: Loss=0.398, Accuracy=86.02%\n",
|
||||
"Epoch 19: Loss=0.378, Accuracy=86.44%\n",
|
||||
"Epoch 20: Loss=0.327, Accuracy=88.69%\n",
|
||||
"Epoch 21: Loss=0.310, Accuracy=89.20%\n",
|
||||
"Epoch 22: Loss=0.265, Accuracy=90.82%\n",
|
||||
"Epoch 23: Loss=0.237, Accuracy=91.60%\n",
|
||||
"Epoch 24: Loss=0.212, Accuracy=92.75%\n",
|
||||
"Epoch 25: Loss=0.193, Accuracy=93.34%\n",
|
||||
"Epoch 26: Loss=0.161, Accuracy=94.54%\n",
|
||||
"Epoch 27: Loss=0.119, Accuracy=96.25%\n",
|
||||
"Epoch 28: Loss=0.095, Accuracy=96.94%\n",
|
||||
"Epoch 29: Loss=0.099, Accuracy=96.86%\n",
|
||||
"Epoch 30: Loss=0.078, Accuracy=97.66%\n",
|
||||
"Finished Training\n"
|
||||
]
|
||||
}
|
||||
@@ -450,7 +450,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PATH = '../models/tutorial-cnn-modified.pth'\n",
|
||||
"PATH = '../../models/tutorial-cnn-modified.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
@@ -472,7 +472,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 78.54%\n"
|
||||
"Test Accuracy: 76.90%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -502,14 +502,14 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy for class: Bicycle is 52.4%\n",
|
||||
"Accuracy for class: Bus is 79.0%\n",
|
||||
"Accuracy for class: Car is 74.9%\n",
|
||||
"Accuracy for class: Motorcycle is 85.9%\n",
|
||||
"Accuracy for class: NonVehicles is 98.7%\n",
|
||||
"Accuracy for class: Taxi is 58.2%\n",
|
||||
"Accuracy for class: Truck is 44.9%\n",
|
||||
"Accuracy for class: Van is 29.8%\n"
|
||||
"Accuracy for class: Bicycle is 61.4%\n",
|
||||
"Accuracy for class: Bus is 75.6%\n",
|
||||
"Accuracy for class: Car is 76.0%\n",
|
||||
"Accuracy for class: Motorcycle is 76.0%\n",
|
||||
"Accuracy for class: NonVehicles is 98.5%\n",
|
||||
"Accuracy for class: Taxi is 45.6%\n",
|
||||
"Accuracy for class: Truck is 38.0%\n",
|
||||
"Accuracy for class: Van is 31.6%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -42,7 +42,7 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"data_dir = '../data/raw/vehicle_classification'\n",
|
||||
"data_dir = '../../data/raw/vehicle_classification'\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
@@ -378,36 +378,36 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: Loss=1.337, Accuracy=56.65%\n",
|
||||
"Epoch 2: Loss=0.965, Accuracy=66.88%\n",
|
||||
"Epoch 3: Loss=0.870, Accuracy=69.80%\n",
|
||||
"Epoch 4: Loss=0.789, Accuracy=72.31%\n",
|
||||
"Epoch 5: Loss=0.746, Accuracy=73.78%\n",
|
||||
"Epoch 6: Loss=0.692, Accuracy=75.63%\n",
|
||||
"Epoch 7: Loss=0.659, Accuracy=76.80%\n",
|
||||
"Epoch 8: Loss=0.624, Accuracy=77.99%\n",
|
||||
"Epoch 9: Loss=0.608, Accuracy=78.58%\n",
|
||||
"Epoch 10: Loss=0.566, Accuracy=80.22%\n",
|
||||
"Epoch 11: Loss=0.540, Accuracy=80.91%\n",
|
||||
"Epoch 12: Loss=0.517, Accuracy=81.71%\n",
|
||||
"Epoch 13: Loss=0.495, Accuracy=82.49%\n",
|
||||
"Epoch 14: Loss=0.460, Accuracy=84.00%\n",
|
||||
"Epoch 15: Loss=0.436, Accuracy=84.83%\n",
|
||||
"Epoch 16: Loss=0.405, Accuracy=85.72%\n",
|
||||
"Epoch 17: Loss=0.380, Accuracy=86.55%\n",
|
||||
"Epoch 18: Loss=0.348, Accuracy=87.83%\n",
|
||||
"Epoch 19: Loss=0.322, Accuracy=88.65%\n",
|
||||
"Epoch 20: Loss=0.294, Accuracy=89.97%\n",
|
||||
"Epoch 21: Loss=0.265, Accuracy=90.71%\n",
|
||||
"Epoch 22: Loss=0.233, Accuracy=91.95%\n",
|
||||
"Epoch 23: Loss=0.210, Accuracy=92.85%\n",
|
||||
"Epoch 24: Loss=0.184, Accuracy=93.87%\n",
|
||||
"Epoch 25: Loss=0.163, Accuracy=94.46%\n",
|
||||
"Epoch 26: Loss=0.137, Accuracy=95.61%\n",
|
||||
"Epoch 27: Loss=0.118, Accuracy=96.28%\n",
|
||||
"Epoch 28: Loss=0.096, Accuracy=97.20%\n",
|
||||
"Epoch 29: Loss=0.080, Accuracy=97.70%\n",
|
||||
"Epoch 30: Loss=0.066, Accuracy=98.04%\n",
|
||||
"Epoch 1: Loss=1.301, Accuracy=56.17%\n",
|
||||
"Epoch 2: Loss=1.014, Accuracy=64.94%\n",
|
||||
"Epoch 3: Loss=0.933, Accuracy=67.15%\n",
|
||||
"Epoch 4: Loss=0.858, Accuracy=69.63%\n",
|
||||
"Epoch 5: Loss=0.783, Accuracy=72.12%\n",
|
||||
"Epoch 6: Loss=0.749, Accuracy=73.16%\n",
|
||||
"Epoch 7: Loss=0.695, Accuracy=75.33%\n",
|
||||
"Epoch 8: Loss=0.662, Accuracy=76.60%\n",
|
||||
"Epoch 9: Loss=0.634, Accuracy=77.59%\n",
|
||||
"Epoch 10: Loss=0.618, Accuracy=78.29%\n",
|
||||
"Epoch 11: Loss=0.586, Accuracy=79.37%\n",
|
||||
"Epoch 12: Loss=0.554, Accuracy=80.26%\n",
|
||||
"Epoch 13: Loss=0.538, Accuracy=81.11%\n",
|
||||
"Epoch 14: Loss=0.507, Accuracy=82.21%\n",
|
||||
"Epoch 15: Loss=0.489, Accuracy=82.95%\n",
|
||||
"Epoch 16: Loss=0.462, Accuracy=83.55%\n",
|
||||
"Epoch 17: Loss=0.435, Accuracy=84.84%\n",
|
||||
"Epoch 18: Loss=0.412, Accuracy=85.40%\n",
|
||||
"Epoch 19: Loss=0.390, Accuracy=86.27%\n",
|
||||
"Epoch 20: Loss=0.364, Accuracy=87.19%\n",
|
||||
"Epoch 21: Loss=0.348, Accuracy=87.81%\n",
|
||||
"Epoch 22: Loss=0.305, Accuracy=89.40%\n",
|
||||
"Epoch 23: Loss=0.279, Accuracy=90.41%\n",
|
||||
"Epoch 24: Loss=0.268, Accuracy=90.56%\n",
|
||||
"Epoch 25: Loss=0.227, Accuracy=92.43%\n",
|
||||
"Epoch 26: Loss=0.209, Accuracy=92.78%\n",
|
||||
"Epoch 27: Loss=0.181, Accuracy=94.17%\n",
|
||||
"Epoch 28: Loss=0.173, Accuracy=94.14%\n",
|
||||
"Epoch 29: Loss=0.140, Accuracy=95.50%\n",
|
||||
"Epoch 30: Loss=0.126, Accuracy=95.96%\n",
|
||||
"Finished Training\n"
|
||||
]
|
||||
}
|
||||
@@ -450,7 +450,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PATH = '../models/adam-optimized.pth'\n",
|
||||
"PATH = '../../models/adam-optimized.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
@@ -472,7 +472,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 75.68%\n"
|
||||
"Test Accuracy: 76.50%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -502,14 +502,14 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy for class: Bicycle is 68.8%\n",
|
||||
"Accuracy for class: Bus is 76.1%\n",
|
||||
"Accuracy for class: Car is 66.5%\n",
|
||||
"Accuracy for class: Motorcycle is 78.4%\n",
|
||||
"Accuracy for class: NonVehicles is 98.5%\n",
|
||||
"Accuracy for class: Taxi is 42.6%\n",
|
||||
"Accuracy for class: Truck is 41.2%\n",
|
||||
"Accuracy for class: Van is 39.9%\n"
|
||||
"Accuracy for class: Bicycle is 66.5%\n",
|
||||
"Accuracy for class: Bus is 68.0%\n",
|
||||
"Accuracy for class: Car is 73.8%\n",
|
||||
"Accuracy for class: Motorcycle is 77.1%\n",
|
||||
"Accuracy for class: NonVehicles is 99.0%\n",
|
||||
"Accuracy for class: Taxi is 41.3%\n",
|
||||
"Accuracy for class: Truck is 43.8%\n",
|
||||
"Accuracy for class: Van is 23.6%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -42,7 +42,7 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"data_dir = '../data/raw/vehicle_classification'\n",
|
||||
"data_dir = '../../data/raw/vehicle_classification'\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
@@ -418,36 +418,36 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: Loss=1.455, Accuracy=59.60%\n",
|
||||
"Epoch 2: Loss=0.872, Accuracy=69.47%\n",
|
||||
"Epoch 3: Loss=0.748, Accuracy=73.57%\n",
|
||||
"Epoch 4: Loss=0.669, Accuracy=76.44%\n",
|
||||
"Epoch 5: Loss=0.613, Accuracy=78.60%\n",
|
||||
"Epoch 6: Loss=0.575, Accuracy=79.73%\n",
|
||||
"Epoch 7: Loss=0.539, Accuracy=80.91%\n",
|
||||
"Epoch 8: Loss=0.507, Accuracy=82.19%\n",
|
||||
"Epoch 9: Loss=0.483, Accuracy=82.79%\n",
|
||||
"Epoch 10: Loss=0.461, Accuracy=83.69%\n",
|
||||
"Epoch 11: Loss=0.439, Accuracy=84.35%\n",
|
||||
"Epoch 12: Loss=0.409, Accuracy=85.46%\n",
|
||||
"Epoch 13: Loss=0.396, Accuracy=85.98%\n",
|
||||
"Epoch 14: Loss=0.369, Accuracy=86.39%\n",
|
||||
"Epoch 15: Loss=0.356, Accuracy=87.00%\n",
|
||||
"Epoch 16: Loss=0.351, Accuracy=87.15%\n",
|
||||
"Epoch 17: Loss=0.330, Accuracy=88.06%\n",
|
||||
"Epoch 18: Loss=0.307, Accuracy=88.95%\n",
|
||||
"Epoch 19: Loss=0.286, Accuracy=89.57%\n",
|
||||
"Epoch 20: Loss=0.253, Accuracy=90.44%\n",
|
||||
"Epoch 21: Loss=0.244, Accuracy=91.00%\n",
|
||||
"Epoch 22: Loss=0.233, Accuracy=91.22%\n",
|
||||
"Epoch 23: Loss=0.225, Accuracy=91.56%\n",
|
||||
"Epoch 24: Loss=0.201, Accuracy=92.63%\n",
|
||||
"Epoch 25: Loss=0.198, Accuracy=92.55%\n",
|
||||
"Epoch 26: Loss=0.178, Accuracy=93.46%\n",
|
||||
"Epoch 27: Loss=0.169, Accuracy=93.60%\n",
|
||||
"Epoch 28: Loss=0.149, Accuracy=94.44%\n",
|
||||
"Epoch 29: Loss=0.153, Accuracy=94.34%\n",
|
||||
"Epoch 30: Loss=0.135, Accuracy=95.05%\n",
|
||||
"Epoch 1: Loss=1.427, Accuracy=61.06%\n",
|
||||
"Epoch 2: Loss=0.814, Accuracy=71.39%\n",
|
||||
"Epoch 3: Loss=0.709, Accuracy=74.82%\n",
|
||||
"Epoch 4: Loss=0.645, Accuracy=77.18%\n",
|
||||
"Epoch 5: Loss=0.597, Accuracy=79.08%\n",
|
||||
"Epoch 6: Loss=0.564, Accuracy=80.10%\n",
|
||||
"Epoch 7: Loss=0.529, Accuracy=81.31%\n",
|
||||
"Epoch 8: Loss=0.500, Accuracy=82.12%\n",
|
||||
"Epoch 9: Loss=0.468, Accuracy=83.23%\n",
|
||||
"Epoch 10: Loss=0.451, Accuracy=84.01%\n",
|
||||
"Epoch 11: Loss=0.430, Accuracy=84.79%\n",
|
||||
"Epoch 12: Loss=0.423, Accuracy=84.93%\n",
|
||||
"Epoch 13: Loss=0.385, Accuracy=86.47%\n",
|
||||
"Epoch 14: Loss=0.368, Accuracy=86.48%\n",
|
||||
"Epoch 15: Loss=0.346, Accuracy=87.31%\n",
|
||||
"Epoch 16: Loss=0.341, Accuracy=87.57%\n",
|
||||
"Epoch 17: Loss=0.310, Accuracy=89.02%\n",
|
||||
"Epoch 18: Loss=0.303, Accuracy=89.11%\n",
|
||||
"Epoch 19: Loss=0.274, Accuracy=89.85%\n",
|
||||
"Epoch 20: Loss=0.258, Accuracy=90.29%\n",
|
||||
"Epoch 21: Loss=0.246, Accuracy=90.69%\n",
|
||||
"Epoch 22: Loss=0.239, Accuracy=91.06%\n",
|
||||
"Epoch 23: Loss=0.227, Accuracy=91.57%\n",
|
||||
"Epoch 24: Loss=0.207, Accuracy=92.31%\n",
|
||||
"Epoch 25: Loss=0.191, Accuracy=92.94%\n",
|
||||
"Epoch 26: Loss=0.178, Accuracy=93.24%\n",
|
||||
"Epoch 27: Loss=0.163, Accuracy=93.96%\n",
|
||||
"Epoch 28: Loss=0.147, Accuracy=94.67%\n",
|
||||
"Epoch 29: Loss=0.150, Accuracy=94.36%\n",
|
||||
"Epoch 30: Loss=0.130, Accuracy=95.14%\n",
|
||||
"Finished Training\n"
|
||||
]
|
||||
}
|
||||
@@ -490,7 +490,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PATH = '../models/batch-norm-dropout.pth'\n",
|
||||
"PATH = '../../models/batch-norm-dropout.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
@@ -512,7 +512,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 83.30%\n"
|
||||
"Test Accuracy: 82.26%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -542,14 +542,14 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy for class: Bicycle is 72.2%\n",
|
||||
"Accuracy for class: Bus is 83.6%\n",
|
||||
"Accuracy for class: Car is 82.9%\n",
|
||||
"Accuracy for class: Motorcycle is 84.2%\n",
|
||||
"Accuracy for class: NonVehicles is 99.6%\n",
|
||||
"Accuracy for class: Taxi is 52.5%\n",
|
||||
"Accuracy for class: Truck is 48.5%\n",
|
||||
"Accuracy for class: Van is 45.0%\n"
|
||||
"Accuracy for class: Bicycle is 72.7%\n",
|
||||
"Accuracy for class: Bus is 81.5%\n",
|
||||
"Accuracy for class: Car is 78.2%\n",
|
||||
"Accuracy for class: Motorcycle is 82.6%\n",
|
||||
"Accuracy for class: NonVehicles is 100.0%\n",
|
||||
"Accuracy for class: Taxi is 53.2%\n",
|
||||
"Accuracy for class: Truck is 52.9%\n",
|
||||
"Accuracy for class: Van is 41.8%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -569,7 +569,7 @@
|
||||
" total_pred[dataset.classes[label]] += 1\n",
|
||||
"\n",
|
||||
"for classname, correct_count in correct_pred.items():\n",
|
||||
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
|
||||
" accuracy = 100 * float(correct_count) / total_pred[classname] \n",
|
||||
" print(f'Accuracy for class: {classname:10s} is {accuracy:.1f}%')"
|
||||
]
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 12,
|
||||
"id": "7a37220a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 13,
|
||||
"id": "d318d1f0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -42,7 +42,7 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"data_dir = '../data/raw/vehicle_classification'\n",
|
||||
"data_dir = '../../data/raw/vehicle_classification'\n",
|
||||
"\n",
|
||||
"for class_name in os.listdir(data_dir):\n",
|
||||
" class_path = os.path.join(data_dir, class_name)\n",
|
||||
@@ -61,7 +61,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 14,
|
||||
"id": "5604ace3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -110,7 +110,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 15,
|
||||
"id": "3cedd586",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -150,7 +150,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 16,
|
||||
"id": "8f556b22",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -184,7 +184,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 17,
|
||||
"id": "37793c77",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -222,7 +222,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 18,
|
||||
"id": "f68c1a25",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -279,7 +279,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 19,
|
||||
"id": "e1539eaa",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -311,7 +311,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 20,
|
||||
"id": "d1b7d9ca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -409,7 +409,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 21,
|
||||
"id": "54d11a04",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -432,7 +432,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 22,
|
||||
"id": "374d0590",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -440,36 +440,36 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1: Loss=1.405, Accuracy=59.91%\n",
|
||||
"Epoch 2: Loss=0.880, Accuracy=69.16%\n",
|
||||
"Epoch 3: Loss=0.770, Accuracy=72.99%\n",
|
||||
"Epoch 4: Loss=0.714, Accuracy=75.05%\n",
|
||||
"Epoch 5: Loss=0.660, Accuracy=76.92%\n",
|
||||
"Epoch 6: Loss=0.629, Accuracy=78.41%\n",
|
||||
"Epoch 7: Loss=0.596, Accuracy=79.03%\n",
|
||||
"Epoch 8: Loss=0.588, Accuracy=79.40%\n",
|
||||
"Epoch 9: Loss=0.556, Accuracy=80.52%\n",
|
||||
"Epoch 10: Loss=0.543, Accuracy=81.06%\n",
|
||||
"Epoch 11: Loss=0.520, Accuracy=81.85%\n",
|
||||
"Epoch 12: Loss=0.513, Accuracy=81.94%\n",
|
||||
"Epoch 13: Loss=0.501, Accuracy=82.56%\n",
|
||||
"Epoch 14: Loss=0.485, Accuracy=82.90%\n",
|
||||
"Epoch 15: Loss=0.481, Accuracy=82.92%\n",
|
||||
"Epoch 16: Loss=0.469, Accuracy=83.50%\n",
|
||||
"Epoch 17: Loss=0.459, Accuracy=84.02%\n",
|
||||
"Epoch 18: Loss=0.453, Accuracy=83.94%\n",
|
||||
"Epoch 19: Loss=0.445, Accuracy=84.24%\n",
|
||||
"Epoch 20: Loss=0.434, Accuracy=84.77%\n",
|
||||
"Epoch 21: Loss=0.427, Accuracy=84.97%\n",
|
||||
"Epoch 22: Loss=0.412, Accuracy=85.20%\n",
|
||||
"Epoch 23: Loss=0.411, Accuracy=85.15%\n",
|
||||
"Epoch 24: Loss=0.402, Accuracy=85.52%\n",
|
||||
"Epoch 25: Loss=0.403, Accuracy=85.32%\n",
|
||||
"Epoch 26: Loss=0.393, Accuracy=85.93%\n",
|
||||
"Epoch 27: Loss=0.386, Accuracy=86.07%\n",
|
||||
"Epoch 28: Loss=0.380, Accuracy=86.28%\n",
|
||||
"Epoch 29: Loss=0.372, Accuracy=86.65%\n",
|
||||
"Epoch 30: Loss=0.362, Accuracy=87.02%\n",
|
||||
"Epoch 1: Loss=1.515, Accuracy=58.09%\n",
|
||||
"Epoch 2: Loss=0.901, Accuracy=68.38%\n",
|
||||
"Epoch 3: Loss=0.786, Accuracy=72.40%\n",
|
||||
"Epoch 4: Loss=0.723, Accuracy=74.59%\n",
|
||||
"Epoch 5: Loss=0.665, Accuracy=76.57%\n",
|
||||
"Epoch 6: Loss=0.634, Accuracy=77.73%\n",
|
||||
"Epoch 7: Loss=0.602, Accuracy=78.87%\n",
|
||||
"Epoch 8: Loss=0.593, Accuracy=79.30%\n",
|
||||
"Epoch 9: Loss=0.580, Accuracy=79.91%\n",
|
||||
"Epoch 10: Loss=0.548, Accuracy=80.70%\n",
|
||||
"Epoch 11: Loss=0.535, Accuracy=80.95%\n",
|
||||
"Epoch 12: Loss=0.517, Accuracy=81.76%\n",
|
||||
"Epoch 13: Loss=0.521, Accuracy=81.70%\n",
|
||||
"Epoch 14: Loss=0.499, Accuracy=82.46%\n",
|
||||
"Epoch 15: Loss=0.488, Accuracy=82.87%\n",
|
||||
"Epoch 16: Loss=0.483, Accuracy=83.06%\n",
|
||||
"Epoch 17: Loss=0.480, Accuracy=82.94%\n",
|
||||
"Epoch 18: Loss=0.465, Accuracy=83.34%\n",
|
||||
"Epoch 19: Loss=0.464, Accuracy=83.39%\n",
|
||||
"Epoch 20: Loss=0.443, Accuracy=84.22%\n",
|
||||
"Epoch 21: Loss=0.445, Accuracy=84.40%\n",
|
||||
"Epoch 22: Loss=0.436, Accuracy=84.60%\n",
|
||||
"Epoch 23: Loss=0.417, Accuracy=85.12%\n",
|
||||
"Epoch 24: Loss=0.412, Accuracy=85.36%\n",
|
||||
"Epoch 25: Loss=0.410, Accuracy=85.11%\n",
|
||||
"Epoch 26: Loss=0.405, Accuracy=85.42%\n",
|
||||
"Epoch 27: Loss=0.402, Accuracy=85.44%\n",
|
||||
"Epoch 28: Loss=0.388, Accuracy=85.96%\n",
|
||||
"Epoch 29: Loss=0.383, Accuracy=86.22%\n",
|
||||
"Epoch 30: Loss=0.377, Accuracy=86.23%\n",
|
||||
"Finished Training\n"
|
||||
]
|
||||
}
|
||||
@@ -507,12 +507,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 23,
|
||||
"id": "2bf2b9a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PATH = '../models/data-augmented.pth'\n",
|
||||
"PATH = '../../models/data-augmented.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
@@ -526,7 +526,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 24,
|
||||
"id": "bc158602",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -534,7 +534,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 82.71%\n"
|
||||
"Test Accuracy: 84.17%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -556,7 +556,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 25,
|
||||
"id": "8cc7ed40",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -564,14 +564,14 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy for class: Bicycle is 68.8%\n",
|
||||
"Accuracy for class: Bus is 83.2%\n",
|
||||
"Accuracy for class: Car is 78.8%\n",
|
||||
"Accuracy for class: Motorcycle is 90.5%\n",
|
||||
"Accuracy for class: NonVehicles is 99.7%\n",
|
||||
"Accuracy for class: Taxi is 49.4%\n",
|
||||
"Accuracy for class: Truck is 59.2%\n",
|
||||
"Accuracy for class: Van is 40.1%\n"
|
||||
"Accuracy for class: Bicycle is 77.4%\n",
|
||||
"Accuracy for class: Bus is 85.3%\n",
|
||||
"Accuracy for class: Car is 86.3%\n",
|
||||
"Accuracy for class: Motorcycle is 81.8%\n",
|
||||
"Accuracy for class: NonVehicles is 99.8%\n",
|
||||
"Accuracy for class: Taxi is 60.8%\n",
|
||||
"Accuracy for class: Truck is 52.8%\n",
|
||||
"Accuracy for class: Van is 41.3%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
BIN
results/training_curves.png
Normal file
BIN
results/training_curves.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 60 KiB |
Reference in New Issue
Block a user