Computer vision is one of the most exciting fields in artificial intelligence, enabling machines to understand and interpret visual information. In this three-part series, we'll explore the core techniques of computer vision using Python and PyTorch. We'll start with image classification, then move to object detection, and finally tackle semantic segmentation.

Introduction to Computer Vision & PyTorch

PyTorch has become the go-to framework for computer vision research and production due to its flexibility and intuitive design. Let's start by setting up our environment and understanding the fundamentals.

Setting Up the Environment

First, let's install the necessary packages:

Bash
pip install torch torchvision torchaudio
pip install matplotlib numpy pillow

Understanding Tensors and Image Data

In PyTorch, images are represented as tensors with shape (C, H, W) where C is channels, H is height, and W is width.

Python
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# Load and display an image
def load_image(image_path):
    image = Image.open(image_path)
    return image

# Convert PIL image to tensor
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# Example: Load and transform an image
# image = load_image('path/to/your/image.jpg')
# tensor_image = transform(image)
# print(f"Image shape: {tensor_image.shape}")  # Should be [3, 224, 224]

Dataset Preparation

For this tutorial, we'll use the CIFAR-10 dataset, which contains 60,000 32x32 color images in 10 classes.

Python
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# Define transforms for training and testing
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True, 
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root='./data', 
    train=False, 
    download=True, 
    transform=test_transform
)

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# CIFAR-10 class names
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Visualizing the Dataset

Python
def show_sample_images(dataset, num_samples=8):
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i in range(num_samples):
        image, label = dataset[i]
        # Denormalize for visualization
        image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        image = image + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        image = torch.clamp(image, 0, 1)

        ax = axes[i // 4, i % 4]
        ax.imshow(image.permute(1, 2, 0))
        ax.set_title(f'{classes[label]}')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Display sample images
show_sample_images(train_dataset)

Building Classification Models

Simple CNN from Scratch

Let's build a simple Convolutional Neural Network:

Python
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        # Calculate the size after convolutions and pooling
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # First conv block
        x = F.relu(self.conv1(x))
        x = self.pool(x)

        # Second conv block
        x = F.relu(self.conv2(x))
        x = self.pool(x)

        # Third conv block
        x = F.relu(self.conv3(x))
        x = self.pool(x)

        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

# Initialize the model
model = SimpleCNN(num_classes=10)
print(model)

Using Pre-trained Models

For better performance, let's use a pre-trained ResNet model:

Python
import torchvision.models as models

# Load pre-trained ResNet18
resnet_model = models.resnet18(pretrained=True)

# Modify the final layer for CIFAR-10 (10 classes)
resnet_model.fc = nn.Linear(resnet_model.fc.in_features, 10)

print(f"Model parameters: {sum(p.numel() for p in resnet_model.parameters()):,}")

Training and Evaluation

Training Loop

Python
import torch.optim as optim
from tqdm import tqdm

def train_model(model, train_loader, test_loader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    train_losses = []
    train_accuracies = []
    test_accuracies = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        # Calculate training metrics
        epoch_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct_train / total_train

        # Validation phase
        test_accuracy = evaluate_model(model, test_loader, device)

        train_losses.append(epoch_loss)
        train_accuracies.append(train_accuracy)
        test_accuracies.append(test_accuracy)

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Loss: {epoch_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Test Acc: {test_accuracy:.2f}%')

        scheduler.step()

    return train_losses, train_accuracies, test_accuracies

def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# Train the model
train_losses, train_accs, test_accs = train_model(resnet_model, train_loader, test_loader, num_epochs=10)

Visualization and Results

Python
def plot_training_history(train_losses, train_accs, test_accs):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot loss
    ax1.plot(train_losses)
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.grid(True)

    # Plot accuracy
    ax2.plot(train_accs, label='Train Accuracy')
    ax2.plot(test_accs, label='Test Accuracy')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

# Plot training history
plot_training_history(train_losses, train_accs, test_accs)

Model Interpretation

Python
def predict_single_image(model, image_path, transform, classes, device):
    model.eval()
    image = Image.open(image_path)
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image_tensor)
        probabilities = F.softmax(output, dim=1)
        _, predicted = torch.max(output, 1)

    # Get top 3 predictions
    top_probs, top_indices = torch.topk(probabilities[0], 3)

    plt.figure(figsize=(10, 5))

    # Display image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f'Predicted: {classes[predicted.item()]}')
    plt.axis('off')

    # Display top predictions
    plt.subplot(1, 2, 2)
    plt.barh(range(3), top_probs.cpu().numpy())
    plt.yticks(range(3), [classes[i] for i in top_indices])
    plt.xlabel('Probability')
    plt.title('Top 3 Predictions')
    plt.gca().invert_yaxis()

    plt.tight_layout()
    plt.show()

# Example usage:
# predict_single_image(resnet_model, 'path/to/test/image.jpg', test_transform, classes, device)

Confusion Matrix and Classification Report

Python
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

def generate_classification_report(model, test_loader, classes, device):
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Classification report
    report = classification_report(all_labels, all_predictions, target_names=classes)
    print("Classification Report:")
    print(report)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    return all_predictions, all_labels

# Generate classification report
predictions, labels = generate_classification_report(resnet_model, test_loader, classes, device)

Key Takeaways

In this first part of our computer vision series, we covered:

  1. Environment Setup: Installing PyTorch and essential libraries
  2. Data Handling: Loading, preprocessing, and augmenting image data
  3. Model Architecture: Building CNNs from scratch and using pre-trained models
  4. Training Process: Implementing training loops with proper validation
  5. Evaluation: Analyzing model performance with various metrics

Best Practices for Image Classification:

  • Data Augmentation: Use random transformations to improve generalization
  • Transfer Learning: Leverage pre-trained models for better performance
  • Proper Validation: Always evaluate on unseen data
  • Learning Rate Scheduling: Adjust learning rate during training
  • Regularization: Use dropout and weight decay to prevent overfitting

In the next part, we'll dive into object detection, where we'll learn to not only classify objects but also locate them within images. This introduces new challenges like handling variable numbers of objects and learning to predict bounding boxes.

Next Steps

  • Experiment with different CNN architectures (VGG, DenseNet, EfficientNet)
  • Try advanced data augmentation techniques (Mixup, CutMix)
  • Implement attention mechanisms to understand what the model focuses on
  • Practice with different datasets (ImageNet, custom datasets)

Stay tuned for Part 2, where we'll explore the fascinating world of object detection!