TogetherIntermediate

Teaching AI to See - Computer Vision

Discover how computers can understand images! Learn about convolutional neural networks and build your own image classifier.

Computer VisionCNNImage RecognitionProject

Teaching AI to See 👁️

When you look at a photo of a cat, you instantly recognize it. But how does a computer "see"? Welcome to the fascinating world of computer vision — where we teach machines to understand images!

How Humans vs Computers See

Your brain processes images in a hierarchy:

  1. Eyes detect light and colors
  2. Early brain finds edges and simple shapes
  3. Later brain recognizes complex patterns
  4. Final stage identifies objects

Computers can do this too using Convolutional Neural Networks (CNNs)!

💡 Fun Fact: CNNs were inspired by how cat brains process visual information — scientists studied cat visual cortexes to design the first neural networks for vision!

What is a Convolution?

A convolution is like using a magnifying glass to look for specific patterns. We slide a small "filter" across the image to detect features.

The Math

For each position, we multiply overlapping pixels and sum the results:

(IK)(i,j)=mnI(i+m,j+n)K(m,n)(I * K)(i, j) = \sum_m \sum_n I(i+m, j+n) \cdot K(m, n)

Where:

  • II is the image
  • KK is the kernel/filter
  • The result is a "feature map" showing where that pattern appears

Example: Edge Detection

Here's a simple vertical edge detector kernel:

K=[101202101]K = \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix}

When this slides over an image, it highlights vertical edges!

🎯 Activity: Look at objects around you. Can you trace their edges with your finger? That's what an edge detection filter does!

Building a CNN

A CNN has three main types of layers:

1. Convolutional Layers 🔍

Find patterns like edges, textures, and shapes.

import torch.nn as nn

# A convolutional layer
conv_layer = nn.Conv2d(
    in_channels=3,      # RGB image (3 color channels)
    out_channels=16,    # Number of filters to learn
    kernel_size=3,      # 3x3 filter size
    padding=1           # Keep image size the same
)

# When we apply it to an image
# Input:  1 image, 3 channels, 224x224 pixels
# Output: 1 image, 16 channels, 224x224 pixels

2. Pooling Layers 📉

Reduce image size while keeping important information.

# Max pooling - keeps the most important feature in each region
pool = nn.MaxPool2d(kernel_size=2, stride=2)

# This halves the width and height, reducing computation!

3. Fully Connected Layers 🎯

Make the final classification decision.

Complete CNN Example

Let's build a CNN to classify cats vs dogs:

import torch
import torch.nn as nn
import torch.nn.functional as F

class CatDogCNN(nn.Module):
    def __init__(self):
        super(CatDogCNN, self).__init__()
        
        # First convolutional block
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        
        # Second convolutional block
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        # Third convolutional block
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Calculate size after convolutions and pooling
        # 224 -> 112 -> 56 -> 28 (3 poolings)
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 2)  # Cat or Dog
    
    def forward(self, x):
        # Block 1
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        # Block 2
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        # Block 3
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Create model
model = CatDogCNN()
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

What Each Layer Learns

| Layer | What It Detects | Visual | |-------|----------------|--------| | 1st Conv | Edges, corners, simple textures | ▓▓░░ | | 2nd Conv | Shapes, circles, corners | ◯ △ □ | | 3rd Conv | Textures, patterns | Fur, scales | | 4th Conv | Parts of objects | Eyes, ears, noses | | Fully Connected | Complete objects | Cat, Dog, Bird |

🖼️ Visualize It: As you go deeper into the network, the "vision" becomes more abstract. Early layers see pixels, late layers see concepts!

Data Augmentation: More Data for Free

One challenge in computer vision is needing lots of images. Data augmentation creates variations of your images:

from torchvision import transforms

# Transform pipeline
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),      # Random crop
    transforms.RandomHorizontalFlip(),       # Flip left-right
    transforms.RandomRotation(15),           # Rotate slightly
    transforms.ColorJitter(brightness=0.2),  # Change brightness
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

It's like showing the network the same cat from different angles!

Your First Project: Handwritten Digit Classifier

Let's build something you can run right now:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Load the MNIST dataset (handwritten digits)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64)

# Simple CNN for digits
class DigitCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Training (simplified)
model = DigitCNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Training started...")
for epoch in range(2):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

print("Training complete! 🎉")

⚠️ Note for Parents: This code downloads ~10MB of data on first run. The MNIST dataset is a classic benchmark containing 70,000 handwritten digits (0-9).

Real-World Applications

CNNs power amazing technologies:

  • Medical Imaging - Detecting diseases in X-rays and MRIs
  • Self-Driving Cars - Recognizing pedestrians, signs, and obstacles
  • Face Recognition - Unlocking phones and tagging photos
  • Satellite Imagery - Monitoring deforestation and urban growth
  • Art Creation - Style transfer and deep dreams

Testing Your Understanding

Think About It

Why do you think CNNs work better than regular neural networks for images?

Hint: Consider what would happen if you flattened a 1000x1000 image into a vector for a regular neural network!

Answer: A 1000x1000 image has 1,000,000 pixels. A regular neural network would need millions of connections just for the first layer! CNNs use the fact that nearby pixels are related, dramatically reducing parameters.

What's Next?

In this lesson, you learned:

  • How convolutions detect image features
  • The architecture of CNNs
  • How to build your own image classifier

In the next lesson, we'll explore transfer learning — using pre-trained models to solve new problems with less data!


"The best way to understand computer vision is to build something that can see." 🔬