Welcome to the final part of our computer vision series! We've progressed from classifying entire images to detecting and localizing objects. Now we'll tackle semantic segmentation - the most detailed level of image understanding, where we classify every single pixel in an image.

Semantic segmentation is crucial for applications requiring precise spatial understanding: autonomous driving (road, sidewalk, vehicles), medical imaging (organ boundaries, tumor detection), satellite imagery (land use classification), and augmented reality (scene understanding for object placement).

Understanding Pixel-Level Prediction

Unlike classification (image-level) and detection (object-level), segmentation operates at the pixel level:

  • Classification: "This image contains a cat"
  • Detection: "There's a cat at coordinates (x, y, w, h)"
  • Segmentation: "These specific pixels belong to the cat class"

Types of Segmentation

  1. Semantic Segmentation: Classify every pixel (cars are all "vehicle" class)
  2. Instance Segmentation: Separate individual objects (car #1, car #2, etc.)
  3. Panoptic Segmentation: Combines both semantic and instance segmentation
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

def visualize_segmentation(image, mask, classes, alpha=0.7):
    """
    Visualize segmentation mask overlaid on original image
    """
    # Create color map for classes
    colors = plt.cm.Set3(np.linspace(0, 1, len(classes)))

    # Convert mask to RGB
    h, w = mask.shape
    mask_rgb = np.zeros((h, w, 3))

    for class_id, color in enumerate(colors):
        mask_rgb[mask == class_id] = color[:3]

    # Overlay on original image
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).numpy()

    # Normalize image if needed
    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)

    # Blend images
    blended = cv2.addWeighted(
        image.astype(np.uint8), 
        1 - alpha, 
        (mask_rgb * 255).astype(np.uint8), 
        alpha, 
        0
    )

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(mask_rgb)
    plt.title('Segmentation Mask')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(blended)
    plt.title('Overlay')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# Example usage
# visualize_segmentation(image, mask, ['background', 'person', 'car', 'road'])

Dataset Preparation

We'll work with the Pascal VOC and Cityscapes dataset formats, which are standard for semantic segmentation.

Python
from torch.utils.data import Dataset, DataLoader
import os
import json

class CityscapesDataset(Dataset):
    """
    Custom dataset class for Cityscapes-style segmentation data
    """
    def __init__(self, root_dir, split='train', transforms=None):
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms

        # Define class mapping
        self.classes = [
            'background', 'road', 'sidewalk', 'building', 'wall', 'fence',
            'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain',
            'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
            'motorcycle', 'bicycle'
        ]

        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.num_classes = len(self.classes)

        # Load image and mask file paths
        self.images_dir = os.path.join(root_dir, 'images', split)
        self.masks_dir = os.path.join(root_dir, 'masks', split)

        self.image_files = sorted([f for f in os.listdir(self.images_dir) 
                                  if f.endswith(('.jpg', '.png'))])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        # Load mask
        mask_name = img_name.replace('.jpg', '.png')  # Assuming masks are PNG
        mask_path = os.path.join(self.masks_dir, mask_name)
        mask = Image.open(mask_path)
        mask = np.array(mask, dtype=np.int64)

        if self.transforms:
            image, mask = self.transforms(image, mask)

        return image, mask

class SegmentationTransforms:
    """
    Custom transforms for segmentation that apply same transforms to both image and mask
    """
    def __init__(self, image_size=(256, 256), train=True):
        self.image_size = image_size
        self.train = train

    def __call__(self, image, mask):
        # Resize
        image = image.resize(self.image_size, Image.BILINEAR)
        mask = Image.fromarray(mask).resize(self.image_size, Image.NEAREST)
        mask = np.array(mask, dtype=np.int64)

        if self.train:
            # Random horizontal flip
            if np.random.rand() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                mask = np.fliplr(mask)

            # Random crop (simplified)
            # Add more augmentations as needed

        # Convert to tensor
        image = transforms.ToTensor()(image)
        image = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )(image)

        mask = torch.from_numpy(mask).long()

        return image, mask

# Create datasets
# train_dataset = CityscapesDataset(
#     root_dir='path/to/cityscapes',
#     split='train',
#     transforms=SegmentationTransforms(train=True)
# )

# val_dataset = CityscapesDataset(
#     root_dir='path/to/cityscapes',
#     split='val',
#     transforms=SegmentationTransforms(train=False)
# )

# Data loaders
# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

Segmentation Architectures

Fully Convolutional Network (FCN)

FCN replaces fully connected layers with convolutional layers, allowing input of any size.

Python
class FCN(nn.Module):
    """
    Fully Convolutional Network for semantic segmentation
    """
    def __init__(self, num_classes=21, pretrained=True):
        super(FCN, self).__init__()

        # Use pre-trained VGG16 as backbone
        vgg = torchvision.models.vgg16(pretrained=pretrained)

        # Extract features (all layers except classifier)
        self.features = vgg.features

        # Convert classifier to fully convolutional
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 4096, kernel_size=7),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5),
            nn.Conv2d(4096, 4096, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5),
            nn.Conv2d(4096, num_classes, kernel_size=1)
        )

        # Transposed convolution for upsampling
        self.upsample = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16
        )

    def forward(self, x):
        input_size = x.size()[2:]

        # Feature extraction
        x = self.features(x)

        # Classification
        x = self.classifier(x)

        # Upsample to original size
        x = self.upsample(x)

        # Crop to exact input size
        x = x[:, :, :input_size[0], :input_size[1]]

        return x

fcn_model = FCN(num_classes=20)
print(f"FCN parameters: {sum(p.numel() for p in fcn_model.parameters()):,}")

U-Net Architecture

U-Net uses skip connections to combine low-level and high-level features.

Python
class DoubleConv(nn.Module):
    """Double convolution block used in U-Net"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Pad x1 to match x2 size if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        # Concatenate along channel dimension
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    """
    U-Net architecture for semantic segmentation
    """
    def __init__(self, n_channels=3, n_classes=21, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)

        return logits

unet_model = UNet(n_channels=3, n_classes=20, bilinear=False)
print(f"U-Net parameters: {sum(p.numel() for p in unet_model.parameters()):,}")

DeepLab with Atrous Convolution

DeepLab uses atrous (dilated) convolutions to capture multi-scale context.

Python
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)

class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)

class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling module
    """
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()))

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

class DeepLabV3(nn.Module):
    """
    DeepLabV3 with ResNet backbone
    """
    def __init__(self, num_classes=21, output_stride=16):
        super(DeepLabV3, self).__init__()

        # Use ResNet50 as backbone
        resnet = torchvision.models.resnet50(pretrained=True)

        # Modify ResNet for segmentation
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

        # ASPP module
        if output_stride == 16:
            atrous_rates = [6, 12, 18]
        else:  # output_stride == 8
            atrous_rates = [12, 24, 36]

        self.aspp = ASPP(2048, atrous_rates)

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(256, num_classes, 1)
        )

    def forward(self, x):
        input_size = x.shape[-2:]

        # Feature extraction
        features = self.backbone(x)

        # ASPP
        x = self.aspp(features)

        # Classification
        x = self.classifier(x)

        # Upsample to input size
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)

        return x

deeplab_model = DeepLabV3(num_classes=20)
print(f"DeepLab parameters: {sum(p.numel() for p in deeplab_model.parameters()):,}")

Training Semantic Segmentation Models

Loss Functions for Segmentation

Python
class SegmentationLoss(nn.Module):
    """
    Combined loss function for semantic segmentation
    """
    def __init__(self, num_classes, ignore_index=255, weight=None):
        super(SegmentationLoss, self).__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index

        # Cross entropy loss
        self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)

        # Dice loss for better boundary handling
        self.dice_loss = DiceLoss(num_classes, ignore_index)

    def forward(self, predictions, targets):
        ce_loss = self.ce_loss(predictions, targets)
        dice_loss = self.dice_loss(predictions, targets)

        # Combine losses
        total_loss = ce_loss + dice_loss

        return total_loss, ce_loss, dice_loss

class DiceLoss(nn.Module):
    """
    Dice loss for segmentation
    """
    def __init__(self, num_classes, ignore_index=255, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.smooth = smooth

    def forward(self, predictions, targets):
        # Convert predictions to probabilities
        probs = F.softmax(predictions, dim=1)

        # Create one-hot encoding for targets
        targets_one_hot = F.one_hot(targets, self.num_classes).permute(0, 3, 1, 2).float()

        # Mask out ignore_index
        if self.ignore_index is not None:
            mask = (targets != self.ignore_index).unsqueeze(1).float()
            probs = probs * mask
            targets_one_hot = targets_one_hot * mask

        # Calculate Dice coefficient for each class
        dice_scores = []
        for class_idx in range(self.num_classes):
            pred_class = probs[:, class_idx]
            target_class = targets_one_hot[:, class_idx]

            intersection = torch.sum(pred_class * target_class)
            union = torch.sum(pred_class) + torch.sum(target_class)

            dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)

        # Average Dice loss
        dice_loss = 1.0 - torch.mean(torch.stack(dice_scores))

        return dice_loss

class FocalLoss(nn.Module):
    """
    Focal loss for handling class imbalance
    """
    def __init__(self, alpha=1, gamma=2, reduction='mean', ignore_index=255):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

Training Loop

Python
def train_segmentation_model(model, train_loader, val_loader, num_epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Loss function and optimizer
    criterion = SegmentationLoss(num_classes=20)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 20)

        # Training phase
        model.train()
        running_loss = 0.0
        running_ce_loss = 0.0
        running_dice_loss = 0.0

        for i, (images, masks) in enumerate(train_loader):
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            total_loss, ce_loss, dice_loss = criterion(outputs, masks)

            # Backward pass
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()
            running_ce_loss += ce_loss.item()
            running_dice_loss += dice_loss.item()

            if (i + 1) % 10 == 0:
                print(f'Batch [{i+1}/{len(train_loader)}], '
                      f'Loss: {total_loss.item():.4f}')

        # Calculate epoch losses
        epoch_loss = running_loss / len(train_loader)
        epoch_ce_loss = running_ce_loss / len(train_loader)
        epoch_dice_loss = running_dice_loss / len(train_loader)

        train_losses.append(epoch_loss)

        # Validation phase
        val_loss = validate_segmentation_model(model, val_loader, criterion, device)
        val_losses.append(val_loss)

        print(f'Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dice: {epoch_dice_loss:.4f})')
        print(f'Val Loss: {val_loss:.4f}')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_segmentation_model.pth')
            print('Saved best model!')

        scheduler.step(val_loss)
        print()

    return train_losses, val_losses

def validate_segmentation_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            total_loss, _, _ = criterion(outputs, masks)

            running_loss += total_loss.item()

    return running_loss / len(val_loader)

# Train the model
# train_losses, val_losses = train_segmentation_model(unet_model, train_loader, val_loader)

Evaluation Metrics

IoU and mIoU Calculation

Python
def calculate_iou(pred, target, num_classes, ignore_index=255):
    """
    Calculate Intersection over Union (IoU) for each class
    """
    ious = []
    pred = pred.flatten()
    target = target.flatten()

    # Remove ignored pixels
    if ignore_index is not None:
        mask = target != ignore_index
        pred = pred[mask]
        target = target[mask]

    for class_id in range(num_classes):
        pred_inds = pred == class_id
        target_inds = target == class_id

        if target_inds.sum() == 0:  # No ground truth for this class
            ious.append(float('nan'))
            continue

        intersection = (pred_inds & target_inds).sum().float()
        union = (pred_inds | target_inds).sum().float()

        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append((intersection / union).item())

    return ious

def evaluate_segmentation(model, test_loader, num_classes, device, class_names=None):
    """
    Comprehensive evaluation of segmentation model
    """
    model.eval()
    all_ious = []
    pixel_acc_total = 0
    pixel_count_total = 0

    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)

            # Calculate metrics for each image in batch
            for pred, target in zip(predictions.cpu(), masks.cpu()):
                # IoU calculation
                ious = calculate_iou(pred, target, num_classes)
                all_ious.append(ious)

                # Pixel accuracy
                correct_pixels = (pred == target).sum().float()
                total_pixels = target.numel()

                pixel_acc_total += correct_pixels
                pixel_count_total += total_pixels

    # Calculate mean IoU
    all_ious = np.array(all_ious)
    mean_ious = np.nanmean(all_ious, axis=0)
    overall_miou = np.nanmean(mean_ious)

    # Calculate pixel accuracy
    pixel_accuracy = (pixel_acc_total / pixel_count_total).item()

    # Print results
    print(f'Overall mIoU: {overall_miou:.4f}')
    print(f'Pixel Accuracy: {pixel_accuracy:.4f}')
    print('\nPer-class IoU:')

    for i, iou in enumerate(mean_ious):
        class_name = class_names[i] if class_names else f'Class {i}'
        print(f'{class_name}: {iou:.4f}')

    return overall_miou, pixel_accuracy, mean_ious

# Example evaluation
# miou, pixel_acc, class_ious = evaluate_segmentation(
#     unet_model, test_loader, num_classes=20, device=device, class_names=dataset.classes
# )

Confusion Matrix for Segmentation

Python
import seaborn as sns
from sklearn.metrics import confusion_matrix

def plot_segmentation_confusion_matrix(model, test_loader, class_names, device, normalize=True):
    """
    Plot confusion matrix for segmentation results
    """
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.cpu().numpy()

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()

            # Flatten predictions and targets
            for pred, target in zip(predictions, masks):
                # Remove ignore pixels
                mask = target != 255
                all_preds.extend(pred[mask])
                all_targets.extend(target[mask])

    # Calculate confusion matrix
    cm = confusion_matrix(all_targets, all_preds, labels=range(len(class_names)))

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # Plot
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
                xticklabels=class_names, yticklabels=class_names,
                cmap='Blues')
    plt.title('Segmentation Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

    return cm

# Example usage
# cm = plot_segmentation_confusion_matrix(unet_model, test_loader, dataset.classes, device)

Advanced Techniques

Multi-Scale Training and Testing

Python
class MultiScaleTraining:
    """
    Multi-scale training for better segmentation performance
    """
    def __init__(self, base_size=512, scales=[0.5, 0.75, 1.0, 1.25, 1.5]):
        self.base_size = base_size
        self.scales = scales

    def __call__(self, image, mask):
        # Randomly select a scale
        scale = np.random.choice(self.scales)
        new_size = int(self.base_size * scale)

        # Resize image and mask
        image = image.resize((new_size, new_size), Image.BILINEAR)
        mask = Image.fromarray(mask).resize((new_size, new_size), Image.NEAREST)
        mask = np.array(mask)

        # Random crop to base_size if larger
        if new_size > self.base_size:
            h, w = mask.shape
            start_h = np.random.randint(0, h - self.base_size + 1)
            start_w = np.random.randint(0, w - self.base_size + 1)

            image = image.crop((start_w, start_h, 
                              start_w + self.base_size, 
                              start_h + self.base_size))
            mask = mask[start_h:start_h + self.base_size,
                       start_w:start_w + self.base_size]

        # Pad if smaller
        elif new_size < self.base_size:
            pad = self.base_size - new_size
            pad_top = pad // 2
            pad_bottom = pad - pad_top

            image = np.array(image)
            image = np.pad(image, ((pad_top, pad_bottom), (pad_top, pad_bottom), (0, 0)), 
                          mode='constant')
            image = Image.fromarray(image)

            mask = np.pad(mask, ((pad_top, pad_bottom), (pad_top, pad_bottom)), 
                         mode='constant', constant_values=255)

        return image, mask

def test_time_augmentation(model, image, device, scales=[0.75, 1.0, 1.25]):
    """
    Test-time augmentation for better inference results
    """
    model.eval()
    predictions = []

    with torch.no_grad():
        for scale in scales:
            # Scale image
            h, w = image.shape[-2:]
            new_h, new_w = int(h * scale), int(w * scale)

            scaled_image = F.interpolate(
                image.unsqueeze(0), 
                size=(new_h, new_w), 
                mode='bilinear', 
                align_corners=False
            )

            # Forward pass
            output = model(scaled_image.to(device))

            # Scale back to original size
            output = F.interpolate(
                output, 
                size=(h, w), 
                mode='bilinear', 
                align_corners=False
            )

            predictions.append(output)

            # Also test horizontal flip
            flipped_image = torch.flip(scaled_image, dims=[3])
            flipped_output = model(flipped_image.to(device))
            flipped_output = torch.flip(flipped_output, dims=[3])
            flipped_output = F.interpolate(
                flipped_output, 
                size=(h, w), 
                mode='bilinear', 
                align_corners=False
            )
            predictions.append(flipped_output)

    # Average all predictions
    final_prediction = torch.mean(torch.stack(predictions), dim=0)
    return final_prediction

# Example usage
# multiscale_transform = MultiScaleTraining(base_size=256, scales=[0.5, 0.75, 1.0, 1.25])
# tta_prediction = test_time_augmentation(model, test_image, device)

Real-World Applications

Medical Image Segmentation

Python
class MedicalSegmentationModel(nn.Module):
    """
    Specialized U-Net for medical image segmentation
    """
    def __init__(self, in_channels=1, out_channels=2):  # Grayscale input, binary output
        super(MedicalSegmentationModel, self).__init__()
        self.unet = UNet(n_channels=in_channels, n_classes=out_channels)

    def forward(self, x):
        return self.unet(x)

def preprocess_medical_image(image_path, target_size=(256, 256)):
    """
    Preprocess medical images (DICOM, etc.)
    """
    # For DICOM files, you would use pydicom
    # import pydicom
    # dicom = pydicom.dcmread(image_path)
    # image = dicom.pixel_array

    # For now, assume regular image files
    image = Image.open(image_path).convert('L')  # Grayscale
    image = image.resize(target_size)

    # Normalize to [0, 1]
    image = np.array(image, dtype=np.float32) / 255.0

    # Apply medical-specific normalization
    # (e.g., HU windowing for CT scans)

    return transforms.ToTensor()(image)

def calculate_dice_coefficient(pred, target):
    """
    Calculate Dice coefficient for medical segmentation
    """
    smooth = 1e-5
    pred_flat = pred.flatten()
    target_flat = target.flatten()

    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum()

    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice

Autonomous Driving Segmentation

Python
class AutonomousDrivingSegmentation:
    """
    Real-time segmentation for autonomous driving
    """
    def __init__(self, model_path, device):
        self.device = device
        self.model = self.load_model(model_path)

        # Cityscapes classes relevant for driving
        self.driving_classes = {
            0: 'road',
            1: 'sidewalk', 
            2: 'building',
            3: 'wall',
            4: 'fence',
            5: 'pole',
            6: 'traffic_light',
            7: 'traffic_sign',
            8: 'vegetation',
            9: 'terrain',
            10: 'sky',
            11: 'person',
            12: 'rider',
            13: 'car',
            14: 'truck',
            15: 'bus',
            16: 'train',
            17: 'motorcycle',
            18: 'bicycle'
        }

        # Define colors for visualization
        self.colors = np.random.randint(0, 255, (len(self.driving_classes), 3))

    def load_model(self, model_path):
        model = DeepLabV3(num_classes=19)
        model.load_state_dict(torch.load(model_path, map_location=self.device))
        model.to(self.device)
        model.eval()
        return model

    def preprocess_frame(self, frame):
        """Preprocess video frame for segmentation"""
        # Resize and normalize
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_pil = Image.fromarray(frame_rgb)

        transform = transforms.Compose([
            transforms.Resize((256, 512)),  # Typical driving camera aspect ratio
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])

        return transform(frame_pil).unsqueeze(0)

    def segment_frame(self, frame):
        """Segment a single frame"""
        with torch.no_grad():
            input_tensor = self.preprocess_frame(frame).to(self.device)
            output = self.model(input_tensor)
            prediction = torch.argmax(output, dim=1).squeeze().cpu().numpy()

        return prediction

    def visualize_segmentation(self, frame, segmentation):
        """Overlay segmentation on original frame"""
        # Resize segmentation to match frame size
        h, w = frame.shape[:2]
        segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (w, h), 
                                        interpolation=cv2.INTER_NEAREST)

        # Create colored segmentation
        colored_seg = np.zeros((h, w, 3), dtype=np.uint8)
        for class_id in range(len(self.driving_classes)):
            mask = segmentation_resized == class_id
            colored_seg[mask] = self.colors[class_id]

        # Blend with original frame
        alpha = 0.6
        blended = cv2.addWeighted(frame, 1-alpha, colored_seg, alpha, 0)

        return blended

    def process_video(self, video_path, output_path):
        """Process entire video for segmentation"""
        cap = cv2.VideoCapture(video_path)

        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # Setup video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Segment frame
            segmentation = self.segment_frame(frame)

            # Visualize
            result_frame = self.visualize_segmentation(frame, segmentation)

            # Write frame
            out.write(result_frame)

        cap.release()
        out.release()

# Example usage
# driving_segmenter = AutonomousDrivingSegmentation('cityscapes_model.pth', device)
# driving_segmenter.process_video('driving_video.mp4', 'segmented_output.mp4')

Deployment and Optimization

Model Optimization for Production

Python
def optimize_model_for_deployment(model, example_input):
    """
    Optimize segmentation model for deployment
    """
    model.eval()

    # 1. TorchScript compilation
    traced_model = torch.jit.trace(model, example_input)

    # 2. Quantization (for CPU deployment)
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Conv2d, nn.Linear}, dtype=torch.qint8
    )

    # 3. ONNX export (for cross-platform deployment)
    torch.onnx.export(
        model,
        example_input,
        "segmentation_model.onnx",
        export_params=True,
        opset_version=11,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )

    return traced_model, quantized_model

# Benchmark different optimizations
def benchmark_models(original_model, traced_model, quantized_model, test_input, device):
    """
    Compare inference speed of different model optimizations
    """
    import time

    models = {
        'Original': original_model,
        'TorchScript': traced_model,
        'Quantized': quantized_model
    }

    results = {}

    for name, model in models.items():
        model.eval()

        # Warmup
        for _ in range(10):
            with torch.no_grad():
                _ = model(test_input)

        # Benchmark
        times = []
        for _ in range(100):
            start_time = time.time()
            with torch.no_grad():
                _ = model(test_input)
            end_time = time.time()
            times.append(end_time - start_time)

        avg_time = np.mean(times) * 1000  # Convert to milliseconds
        results[name] = avg_time

        print(f'{name}: {avg_time:.2f} ms per inference')

    return results

# Example usage
# example_input = torch.randn(1, 3, 256, 256).to(device)
# traced_model, quantized_model = optimize_model_for_deployment(unet_model, example_input)
# benchmark_results = benchmark_models(unet_model, traced_model, quantized_model, example_input, device)

Key Takeaways

In this final part of our computer vision series, we explored:

  1. Semantic Segmentation Fundamentals: Pixel-level classification and its applications
  2. Advanced Architectures: FCN, U-Net, and DeepLab with their unique approaches
  3. Training Strategies: Custom loss functions, multi-scale training, and evaluation metrics
  4. Real-World Applications: Medical imaging and autonomous driving use cases
  5. Deployment Optimization: Model optimization techniques for production

Best Practices for Semantic Segmentation:

  • Data Augmentation: Use segmentation-aware transformations
  • Loss Functions: Combine multiple losses (CE + Dice + Focal)
  • Multi-scale Training: Train and test at different resolutions
  • Skip Connections: Use encoder-decoder architectures for better feature fusion
  • Post-processing: Apply CRF or similar techniques for boundary refinement

Architecture Selection Guidelines:

  • U-Net: Best for medical imaging and small datasets
  • DeepLab: Excellent for natural images and large datasets
  • FCN: Good baseline, simpler architecture
  • Custom architectures: Design based on specific domain requirements

Deployment Considerations:

  • Real-time Requirements: Use lightweight architectures (MobileNet backbones)
  • Memory Constraints: Apply quantization and pruning
  • Accuracy vs Speed: Balance model complexity with performance requirements
  • Hardware Optimization: Leverage GPU, TPU, or specialized inference engines

Series Conclusion

Throughout this three-part series, we've journeyed through the core areas of computer vision:

  1. Part 1 - Classification: Understanding what's in an image
  2. Part 2 - Detection: Finding where objects are located
  3. Part 3 - Segmentation: Precisely delineating every pixel

Each level builds upon the previous, increasing in complexity and detail. Modern computer vision applications often combine these approaches - autonomous vehicles use all three for comprehensive scene understanding, medical systems integrate classification and segmentation for diagnosis and treatment planning.

The Path Forward

Computer vision continues to evolve rapidly:

  • Transformer Architectures: Vision Transformers (ViTs) are revolutionizing the field
  • Self-Supervised Learning: Learning from unlabeled data
  • Neural Architecture Search: Automatically designing optimal architectures
  • Edge AI: Deploying powerful models on mobile and embedded devices
  • Multimodal AI: Combining vision with language and other modalities

Next Steps for Your Journey

  1. Practice with Real Data: Work on projects that interest you
  2. Stay Updated: Follow recent papers and implementations
  3. Join the Community: Participate in computer vision challenges and forums
  4. Specialize: Choose a domain (medical, automotive, robotics) for deeper expertise
  5. Build End-to-End Systems: Integrate models into complete applications

The field of computer vision offers endless possibilities for innovation and impact. Whether you're building the next breakthrough in medical imaging, contributing to autonomous systems, or creating new forms of creative expression, the foundations we've covered provide a solid starting point for your computer vision journey.

Remember: the best way to learn computer vision is by doing. Start building, experimenting, and pushing the boundaries of what's possible with visual AI!