Welcome to the final part of our computer vision series! We've progressed from classifying entire images to detecting and localizing objects. Now we'll tackle semantic segmentation - the most detailed level of image understanding, where we classify every single pixel in an image.
Semantic segmentation is crucial for applications requiring precise spatial understanding: autonomous driving (road, sidewalk, vehicles), medical imaging (organ boundaries, tumor detection), satellite imagery (land use classification), and augmented reality (scene understanding for object placement).
Understanding Pixel-Level Prediction
Unlike classification (image-level) and detection (object-level), segmentation operates at the pixel level:
- Classification: "This image contains a cat"
- Detection: "There's a cat at coordinates (x, y, w, h)"
- Segmentation: "These specific pixels belong to the cat class"
Types of Segmentation
- Semantic Segmentation: Classify every pixel (cars are all "vehicle" class)
- Instance Segmentation: Separate individual objects (car #1, car #2, etc.)
- Panoptic Segmentation: Combines both semantic and instance segmentation
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2
def visualize_segmentation(image, mask, classes, alpha=0.7):
"""
Visualize segmentation mask overlaid on original image
"""
# Create color map for classes
colors = plt.cm.Set3(np.linspace(0, 1, len(classes)))
# Convert mask to RGB
h, w = mask.shape
mask_rgb = np.zeros((h, w, 3))
for class_id, color in enumerate(colors):
mask_rgb[mask == class_id] = color[:3]
# Overlay on original image
if isinstance(image, torch.Tensor):
image = image.permute(1, 2, 0).numpy()
# Normalize image if needed
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
# Blend images
blended = cv2.addWeighted(
image.astype(np.uint8),
1 - alpha,
(mask_rgb * 255).astype(np.uint8),
alpha,
0
)
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(mask_rgb)
plt.title('Segmentation Mask')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(blended)
plt.title('Overlay')
plt.axis('off')
plt.tight_layout()
plt.show()
# Example usage
# visualize_segmentation(image, mask, ['background', 'person', 'car', 'road'])
Dataset Preparation
We'll work with the Pascal VOC and Cityscapes dataset formats, which are standard for semantic segmentation.
from torch.utils.data import Dataset, DataLoader
import os
import json
class CityscapesDataset(Dataset):
"""
Custom dataset class for Cityscapes-style segmentation data
"""
def __init__(self, root_dir, split='train', transforms=None):
self.root_dir = root_dir
self.split = split
self.transforms = transforms
# Define class mapping
self.classes = [
'background', 'road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle'
]
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.num_classes = len(self.classes)
# Load image and mask file paths
self.images_dir = os.path.join(root_dir, 'images', split)
self.masks_dir = os.path.join(root_dir, 'masks', split)
self.image_files = sorted([f for f in os.listdir(self.images_dir)
if f.endswith(('.jpg', '.png'))])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# Load image
img_name = self.image_files[idx]
img_path = os.path.join(self.images_dir, img_name)
image = Image.open(img_path).convert('RGB')
# Load mask
mask_name = img_name.replace('.jpg', '.png') # Assuming masks are PNG
mask_path = os.path.join(self.masks_dir, mask_name)
mask = Image.open(mask_path)
mask = np.array(mask, dtype=np.int64)
if self.transforms:
image, mask = self.transforms(image, mask)
return image, mask
class SegmentationTransforms:
"""
Custom transforms for segmentation that apply same transforms to both image and mask
"""
def __init__(self, image_size=(256, 256), train=True):
self.image_size = image_size
self.train = train
def __call__(self, image, mask):
# Resize
image = image.resize(self.image_size, Image.BILINEAR)
mask = Image.fromarray(mask).resize(self.image_size, Image.NEAREST)
mask = np.array(mask, dtype=np.int64)
if self.train:
# Random horizontal flip
if np.random.rand() > 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
mask = np.fliplr(mask)
# Random crop (simplified)
# Add more augmentations as needed
# Convert to tensor
image = transforms.ToTensor()(image)
image = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)(image)
mask = torch.from_numpy(mask).long()
return image, mask
# Create datasets
# train_dataset = CityscapesDataset(
# root_dir='path/to/cityscapes',
# split='train',
# transforms=SegmentationTransforms(train=True)
# )
# val_dataset = CityscapesDataset(
# root_dir='path/to/cityscapes',
# split='val',
# transforms=SegmentationTransforms(train=False)
# )
# Data loaders
# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
Segmentation Architectures
Fully Convolutional Network (FCN)
FCN replaces fully connected layers with convolutional layers, allowing input of any size.
class FCN(nn.Module):
"""
Fully Convolutional Network for semantic segmentation
"""
def __init__(self, num_classes=21, pretrained=True):
super(FCN, self).__init__()
# Use pre-trained VGG16 as backbone
vgg = torchvision.models.vgg16(pretrained=pretrained)
# Extract features (all layers except classifier)
self.features = vgg.features
# Convert classifier to fully convolutional
self.classifier = nn.Sequential(
nn.Conv2d(512, 4096, kernel_size=7),
nn.ReLU(inplace=True),
nn.Dropout2d(0.5),
nn.Conv2d(4096, 4096, kernel_size=1),
nn.ReLU(inplace=True),
nn.Dropout2d(0.5),
nn.Conv2d(4096, num_classes, kernel_size=1)
)
# Transposed convolution for upsampling
self.upsample = nn.ConvTranspose2d(
num_classes, num_classes, kernel_size=64, stride=32, padding=16
)
def forward(self, x):
input_size = x.size()[2:]
# Feature extraction
x = self.features(x)
# Classification
x = self.classifier(x)
# Upsample to original size
x = self.upsample(x)
# Crop to exact input size
x = x[:, :, :input_size[0], :input_size[1]]
return x
fcn_model = FCN(num_classes=20)
print(f"FCN parameters: {sum(p.numel() for p in fcn_model.parameters()):,}")
U-Net Architecture
U-Net uses skip connections to combine low-level and high-level features.
class DoubleConv(nn.Module):
"""Double convolution block used in U-Net"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# Pad x1 to match x2 size if needed
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# Concatenate along channel dimension
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""
U-Net architecture for semantic segmentation
"""
def __init__(self, n_channels=3, n_classes=21, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
unet_model = UNet(n_channels=3, n_classes=20, bilinear=False)
print(f"U-Net parameters: {sum(p.numel() for p in unet_model.parameters()):,}")
DeepLab with Atrous Convolution
DeepLab uses atrous (dilated) convolutions to capture multi-scale context.
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
"""
Atrous Spatial Pyramid Pooling module
"""
def __init__(self, in_channels, atrous_rates, out_channels=256):
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
class DeepLabV3(nn.Module):
"""
DeepLabV3 with ResNet backbone
"""
def __init__(self, num_classes=21, output_stride=16):
super(DeepLabV3, self).__init__()
# Use ResNet50 as backbone
resnet = torchvision.models.resnet50(pretrained=True)
# Modify ResNet for segmentation
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
# ASPP module
if output_stride == 16:
atrous_rates = [6, 12, 18]
else: # output_stride == 8
atrous_rates = [12, 24, 36]
self.aspp = ASPP(2048, atrous_rates)
# Final classifier
self.classifier = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x):
input_size = x.shape[-2:]
# Feature extraction
features = self.backbone(x)
# ASPP
x = self.aspp(features)
# Classification
x = self.classifier(x)
# Upsample to input size
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
return x
deeplab_model = DeepLabV3(num_classes=20)
print(f"DeepLab parameters: {sum(p.numel() for p in deeplab_model.parameters()):,}")
Training Semantic Segmentation Models
Loss Functions for Segmentation
class SegmentationLoss(nn.Module):
"""
Combined loss function for semantic segmentation
"""
def __init__(self, num_classes, ignore_index=255, weight=None):
super(SegmentationLoss, self).__init__()
self.num_classes = num_classes
self.ignore_index = ignore_index
# Cross entropy loss
self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
# Dice loss for better boundary handling
self.dice_loss = DiceLoss(num_classes, ignore_index)
def forward(self, predictions, targets):
ce_loss = self.ce_loss(predictions, targets)
dice_loss = self.dice_loss(predictions, targets)
# Combine losses
total_loss = ce_loss + dice_loss
return total_loss, ce_loss, dice_loss
class DiceLoss(nn.Module):
"""
Dice loss for segmentation
"""
def __init__(self, num_classes, ignore_index=255, smooth=1e-5):
super(DiceLoss, self).__init__()
self.num_classes = num_classes
self.ignore_index = ignore_index
self.smooth = smooth
def forward(self, predictions, targets):
# Convert predictions to probabilities
probs = F.softmax(predictions, dim=1)
# Create one-hot encoding for targets
targets_one_hot = F.one_hot(targets, self.num_classes).permute(0, 3, 1, 2).float()
# Mask out ignore_index
if self.ignore_index is not None:
mask = (targets != self.ignore_index).unsqueeze(1).float()
probs = probs * mask
targets_one_hot = targets_one_hot * mask
# Calculate Dice coefficient for each class
dice_scores = []
for class_idx in range(self.num_classes):
pred_class = probs[:, class_idx]
target_class = targets_one_hot[:, class_idx]
intersection = torch.sum(pred_class * target_class)
union = torch.sum(pred_class) + torch.sum(target_class)
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
dice_scores.append(dice)
# Average Dice loss
dice_loss = 1.0 - torch.mean(torch.stack(dice_scores))
return dice_loss
class FocalLoss(nn.Module):
"""
Focal loss for handling class imbalance
"""
def __init__(self, alpha=1, gamma=2, reduction='mean', ignore_index=255):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
Training Loop
def train_segmentation_model(model, train_loader, val_loader, num_epochs=50):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Loss function and optimizer
criterion = SegmentationLoss(num_classes=20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 20)
# Training phase
model.train()
running_loss = 0.0
running_ce_loss = 0.0
running_dice_loss = 0.0
for i, (images, masks) in enumerate(train_loader):
images = images.to(device)
masks = masks.to(device)
optimizer.zero_grad()
# Forward pass
outputs = model(images)
total_loss, ce_loss, dice_loss = criterion(outputs, masks)
# Backward pass
total_loss.backward()
optimizer.step()
running_loss += total_loss.item()
running_ce_loss += ce_loss.item()
running_dice_loss += dice_loss.item()
if (i + 1) % 10 == 0:
print(f'Batch [{i+1}/{len(train_loader)}], '
f'Loss: {total_loss.item():.4f}')
# Calculate epoch losses
epoch_loss = running_loss / len(train_loader)
epoch_ce_loss = running_ce_loss / len(train_loader)
epoch_dice_loss = running_dice_loss / len(train_loader)
train_losses.append(epoch_loss)
# Validation phase
val_loss = validate_segmentation_model(model, val_loader, criterion, device)
val_losses.append(val_loss)
print(f'Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dice: {epoch_dice_loss:.4f})')
print(f'Val Loss: {val_loss:.4f}')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_segmentation_model.pth')
print('Saved best model!')
scheduler.step(val_loss)
print()
return train_losses, val_losses
def validate_segmentation_model(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
with torch.no_grad():
for images, masks in val_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
total_loss, _, _ = criterion(outputs, masks)
running_loss += total_loss.item()
return running_loss / len(val_loader)
# Train the model
# train_losses, val_losses = train_segmentation_model(unet_model, train_loader, val_loader)
Evaluation Metrics
IoU and mIoU Calculation
def calculate_iou(pred, target, num_classes, ignore_index=255):
"""
Calculate Intersection over Union (IoU) for each class
"""
ious = []
pred = pred.flatten()
target = target.flatten()
# Remove ignored pixels
if ignore_index is not None:
mask = target != ignore_index
pred = pred[mask]
target = target[mask]
for class_id in range(num_classes):
pred_inds = pred == class_id
target_inds = target == class_id
if target_inds.sum() == 0: # No ground truth for this class
ious.append(float('nan'))
continue
intersection = (pred_inds & target_inds).sum().float()
union = (pred_inds | target_inds).sum().float()
if union == 0:
ious.append(float('nan'))
else:
ious.append((intersection / union).item())
return ious
def evaluate_segmentation(model, test_loader, num_classes, device, class_names=None):
"""
Comprehensive evaluation of segmentation model
"""
model.eval()
all_ious = []
pixel_acc_total = 0
pixel_count_total = 0
with torch.no_grad():
for images, masks in test_loader:
images = images.to(device)
masks = masks.to(device)
# Forward pass
outputs = model(images)
predictions = torch.argmax(outputs, dim=1)
# Calculate metrics for each image in batch
for pred, target in zip(predictions.cpu(), masks.cpu()):
# IoU calculation
ious = calculate_iou(pred, target, num_classes)
all_ious.append(ious)
# Pixel accuracy
correct_pixels = (pred == target).sum().float()
total_pixels = target.numel()
pixel_acc_total += correct_pixels
pixel_count_total += total_pixels
# Calculate mean IoU
all_ious = np.array(all_ious)
mean_ious = np.nanmean(all_ious, axis=0)
overall_miou = np.nanmean(mean_ious)
# Calculate pixel accuracy
pixel_accuracy = (pixel_acc_total / pixel_count_total).item()
# Print results
print(f'Overall mIoU: {overall_miou:.4f}')
print(f'Pixel Accuracy: {pixel_accuracy:.4f}')
print('\nPer-class IoU:')
for i, iou in enumerate(mean_ious):
class_name = class_names[i] if class_names else f'Class {i}'
print(f'{class_name}: {iou:.4f}')
return overall_miou, pixel_accuracy, mean_ious
# Example evaluation
# miou, pixel_acc, class_ious = evaluate_segmentation(
# unet_model, test_loader, num_classes=20, device=device, class_names=dataset.classes
# )
Confusion Matrix for Segmentation
import seaborn as sns
from sklearn.metrics import confusion_matrix
def plot_segmentation_confusion_matrix(model, test_loader, class_names, device, normalize=True):
"""
Plot confusion matrix for segmentation results
"""
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for images, masks in test_loader:
images = images.to(device)
masks = masks.cpu().numpy()
outputs = model(images)
predictions = torch.argmax(outputs, dim=1).cpu().numpy()
# Flatten predictions and targets
for pred, target in zip(predictions, masks):
# Remove ignore pixels
mask = target != 255
all_preds.extend(pred[mask])
all_targets.extend(target[mask])
# Calculate confusion matrix
cm = confusion_matrix(all_targets, all_preds, labels=range(len(class_names)))
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# Plot
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
xticklabels=class_names, yticklabels=class_names,
cmap='Blues')
plt.title('Segmentation Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()
return cm
# Example usage
# cm = plot_segmentation_confusion_matrix(unet_model, test_loader, dataset.classes, device)
Advanced Techniques
Multi-Scale Training and Testing
class MultiScaleTraining:
"""
Multi-scale training for better segmentation performance
"""
def __init__(self, base_size=512, scales=[0.5, 0.75, 1.0, 1.25, 1.5]):
self.base_size = base_size
self.scales = scales
def __call__(self, image, mask):
# Randomly select a scale
scale = np.random.choice(self.scales)
new_size = int(self.base_size * scale)
# Resize image and mask
image = image.resize((new_size, new_size), Image.BILINEAR)
mask = Image.fromarray(mask).resize((new_size, new_size), Image.NEAREST)
mask = np.array(mask)
# Random crop to base_size if larger
if new_size > self.base_size:
h, w = mask.shape
start_h = np.random.randint(0, h - self.base_size + 1)
start_w = np.random.randint(0, w - self.base_size + 1)
image = image.crop((start_w, start_h,
start_w + self.base_size,
start_h + self.base_size))
mask = mask[start_h:start_h + self.base_size,
start_w:start_w + self.base_size]
# Pad if smaller
elif new_size < self.base_size:
pad = self.base_size - new_size
pad_top = pad // 2
pad_bottom = pad - pad_top
image = np.array(image)
image = np.pad(image, ((pad_top, pad_bottom), (pad_top, pad_bottom), (0, 0)),
mode='constant')
image = Image.fromarray(image)
mask = np.pad(mask, ((pad_top, pad_bottom), (pad_top, pad_bottom)),
mode='constant', constant_values=255)
return image, mask
def test_time_augmentation(model, image, device, scales=[0.75, 1.0, 1.25]):
"""
Test-time augmentation for better inference results
"""
model.eval()
predictions = []
with torch.no_grad():
for scale in scales:
# Scale image
h, w = image.shape[-2:]
new_h, new_w = int(h * scale), int(w * scale)
scaled_image = F.interpolate(
image.unsqueeze(0),
size=(new_h, new_w),
mode='bilinear',
align_corners=False
)
# Forward pass
output = model(scaled_image.to(device))
# Scale back to original size
output = F.interpolate(
output,
size=(h, w),
mode='bilinear',
align_corners=False
)
predictions.append(output)
# Also test horizontal flip
flipped_image = torch.flip(scaled_image, dims=[3])
flipped_output = model(flipped_image.to(device))
flipped_output = torch.flip(flipped_output, dims=[3])
flipped_output = F.interpolate(
flipped_output,
size=(h, w),
mode='bilinear',
align_corners=False
)
predictions.append(flipped_output)
# Average all predictions
final_prediction = torch.mean(torch.stack(predictions), dim=0)
return final_prediction
# Example usage
# multiscale_transform = MultiScaleTraining(base_size=256, scales=[0.5, 0.75, 1.0, 1.25])
# tta_prediction = test_time_augmentation(model, test_image, device)
Real-World Applications
Medical Image Segmentation
class MedicalSegmentationModel(nn.Module):
"""
Specialized U-Net for medical image segmentation
"""
def __init__(self, in_channels=1, out_channels=2): # Grayscale input, binary output
super(MedicalSegmentationModel, self).__init__()
self.unet = UNet(n_channels=in_channels, n_classes=out_channels)
def forward(self, x):
return self.unet(x)
def preprocess_medical_image(image_path, target_size=(256, 256)):
"""
Preprocess medical images (DICOM, etc.)
"""
# For DICOM files, you would use pydicom
# import pydicom
# dicom = pydicom.dcmread(image_path)
# image = dicom.pixel_array
# For now, assume regular image files
image = Image.open(image_path).convert('L') # Grayscale
image = image.resize(target_size)
# Normalize to [0, 1]
image = np.array(image, dtype=np.float32) / 255.0
# Apply medical-specific normalization
# (e.g., HU windowing for CT scans)
return transforms.ToTensor()(image)
def calculate_dice_coefficient(pred, target):
"""
Calculate Dice coefficient for medical segmentation
"""
smooth = 1e-5
pred_flat = pred.flatten()
target_flat = target.flatten()
intersection = (pred_flat * target_flat).sum()
union = pred_flat.sum() + target_flat.sum()
dice = (2.0 * intersection + smooth) / (union + smooth)
return dice
Autonomous Driving Segmentation
class AutonomousDrivingSegmentation:
"""
Real-time segmentation for autonomous driving
"""
def __init__(self, model_path, device):
self.device = device
self.model = self.load_model(model_path)
# Cityscapes classes relevant for driving
self.driving_classes = {
0: 'road',
1: 'sidewalk',
2: 'building',
3: 'wall',
4: 'fence',
5: 'pole',
6: 'traffic_light',
7: 'traffic_sign',
8: 'vegetation',
9: 'terrain',
10: 'sky',
11: 'person',
12: 'rider',
13: 'car',
14: 'truck',
15: 'bus',
16: 'train',
17: 'motorcycle',
18: 'bicycle'
}
# Define colors for visualization
self.colors = np.random.randint(0, 255, (len(self.driving_classes), 3))
def load_model(self, model_path):
model = DeepLabV3(num_classes=19)
model.load_state_dict(torch.load(model_path, map_location=self.device))
model.to(self.device)
model.eval()
return model
def preprocess_frame(self, frame):
"""Preprocess video frame for segmentation"""
# Resize and normalize
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
transform = transforms.Compose([
transforms.Resize((256, 512)), # Typical driving camera aspect ratio
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(frame_pil).unsqueeze(0)
def segment_frame(self, frame):
"""Segment a single frame"""
with torch.no_grad():
input_tensor = self.preprocess_frame(frame).to(self.device)
output = self.model(input_tensor)
prediction = torch.argmax(output, dim=1).squeeze().cpu().numpy()
return prediction
def visualize_segmentation(self, frame, segmentation):
"""Overlay segmentation on original frame"""
# Resize segmentation to match frame size
h, w = frame.shape[:2]
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (w, h),
interpolation=cv2.INTER_NEAREST)
# Create colored segmentation
colored_seg = np.zeros((h, w, 3), dtype=np.uint8)
for class_id in range(len(self.driving_classes)):
mask = segmentation_resized == class_id
colored_seg[mask] = self.colors[class_id]
# Blend with original frame
alpha = 0.6
blended = cv2.addWeighted(frame, 1-alpha, colored_seg, alpha, 0)
return blended
def process_video(self, video_path, output_path):
"""Process entire video for segmentation"""
cap = cv2.VideoCapture(video_path)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Setup video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
# Segment frame
segmentation = self.segment_frame(frame)
# Visualize
result_frame = self.visualize_segmentation(frame, segmentation)
# Write frame
out.write(result_frame)
cap.release()
out.release()
# Example usage
# driving_segmenter = AutonomousDrivingSegmentation('cityscapes_model.pth', device)
# driving_segmenter.process_video('driving_video.mp4', 'segmented_output.mp4')
Deployment and Optimization
Model Optimization for Production
def optimize_model_for_deployment(model, example_input):
"""
Optimize segmentation model for deployment
"""
model.eval()
# 1. TorchScript compilation
traced_model = torch.jit.trace(model, example_input)
# 2. Quantization (for CPU deployment)
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Conv2d, nn.Linear}, dtype=torch.qint8
)
# 3. ONNX export (for cross-platform deployment)
torch.onnx.export(
model,
example_input,
"segmentation_model.onnx",
export_params=True,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
return traced_model, quantized_model
# Benchmark different optimizations
def benchmark_models(original_model, traced_model, quantized_model, test_input, device):
"""
Compare inference speed of different model optimizations
"""
import time
models = {
'Original': original_model,
'TorchScript': traced_model,
'Quantized': quantized_model
}
results = {}
for name, model in models.items():
model.eval()
# Warmup
for _ in range(10):
with torch.no_grad():
_ = model(test_input)
# Benchmark
times = []
for _ in range(100):
start_time = time.time()
with torch.no_grad():
_ = model(test_input)
end_time = time.time()
times.append(end_time - start_time)
avg_time = np.mean(times) * 1000 # Convert to milliseconds
results[name] = avg_time
print(f'{name}: {avg_time:.2f} ms per inference')
return results
# Example usage
# example_input = torch.randn(1, 3, 256, 256).to(device)
# traced_model, quantized_model = optimize_model_for_deployment(unet_model, example_input)
# benchmark_results = benchmark_models(unet_model, traced_model, quantized_model, example_input, device)
Key Takeaways
In this final part of our computer vision series, we explored:
- Semantic Segmentation Fundamentals: Pixel-level classification and its applications
- Advanced Architectures: FCN, U-Net, and DeepLab with their unique approaches
- Training Strategies: Custom loss functions, multi-scale training, and evaluation metrics
- Real-World Applications: Medical imaging and autonomous driving use cases
- Deployment Optimization: Model optimization techniques for production
Best Practices for Semantic Segmentation:
- Data Augmentation: Use segmentation-aware transformations
- Loss Functions: Combine multiple losses (CE + Dice + Focal)
- Multi-scale Training: Train and test at different resolutions
- Skip Connections: Use encoder-decoder architectures for better feature fusion
- Post-processing: Apply CRF or similar techniques for boundary refinement
Architecture Selection Guidelines:
- U-Net: Best for medical imaging and small datasets
- DeepLab: Excellent for natural images and large datasets
- FCN: Good baseline, simpler architecture
- Custom architectures: Design based on specific domain requirements
Deployment Considerations:
- Real-time Requirements: Use lightweight architectures (MobileNet backbones)
- Memory Constraints: Apply quantization and pruning
- Accuracy vs Speed: Balance model complexity with performance requirements
- Hardware Optimization: Leverage GPU, TPU, or specialized inference engines
Series Conclusion
Throughout this three-part series, we've journeyed through the core areas of computer vision:
- Part 1 - Classification: Understanding what's in an image
- Part 2 - Detection: Finding where objects are located
- Part 3 - Segmentation: Precisely delineating every pixel
Each level builds upon the previous, increasing in complexity and detail. Modern computer vision applications often combine these approaches - autonomous vehicles use all three for comprehensive scene understanding, medical systems integrate classification and segmentation for diagnosis and treatment planning.
The Path Forward
Computer vision continues to evolve rapidly:
- Transformer Architectures: Vision Transformers (ViTs) are revolutionizing the field
- Self-Supervised Learning: Learning from unlabeled data
- Neural Architecture Search: Automatically designing optimal architectures
- Edge AI: Deploying powerful models on mobile and embedded devices
- Multimodal AI: Combining vision with language and other modalities
Next Steps for Your Journey
- Practice with Real Data: Work on projects that interest you
- Stay Updated: Follow recent papers and implementations
- Join the Community: Participate in computer vision challenges and forums
- Specialize: Choose a domain (medical, automotive, robotics) for deeper expertise
- Build End-to-End Systems: Integrate models into complete applications
The field of computer vision offers endless possibilities for innovation and impact. Whether you're building the next breakthrough in medical imaging, contributing to autonomous systems, or creating new forms of creative expression, the foundations we've covered provide a solid starting point for your computer vision journey.
Remember: the best way to learn computer vision is by doing. Start building, experimenting, and pushing the boundaries of what's possible with visual AI!