DCGAN Training on MNIST Dataset

Deep Convolutional Generative Adversarial Network - Digit-wise Training

Back to Home
Cell 1: Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import copy
Cell 2: Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
Cell 3: Load and Separate Dataset by Digit Class
# Dataset preparation - separate by digit class
transform = transforms.Compose([
    transforms.Resize(16),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

full_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)

# Separate data by digit class (0-9)
digit_datasets = {}
digit_loaders = {}

for digit in range(10):
    indices = [i for i, (_, label) in enumerate(full_dataset) if label == digit]
    digit_datasets[digit] = torch.utils.data.Subset(full_dataset, indices[:1500])
    digit_loaders[digit] = torch.utils.data.DataLoader(
        digit_datasets[digit],
        batch_size=64,
        shuffle=True
    )
    print(f"Digit {digit}: {len(digit_datasets[digit])} images")
Cell 4: Define Hyperparameters
# Hyperparameters
latent_dim = 100
epochs_per_digit = 25
lr = 0.0002
beta1 = 0.5
Cell 5: Define Generator Architecture
# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)
Cell 6: Define Discriminator Architecture
# Discriminator: 16x16 image -> probability real/fake
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 128, 4, 2, 1, bias=False),   # 16x16 -> 8x8
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False), # 8x8 -> 4x4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),   # 4x4 -> 1x1
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(-1, 1)
Cell 7: Define Weight Initialization Function
# Weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
Cell 8: Define Loss Criterion
criterion = nn.BCELoss()
Cell 9: Define Visualization Function
def show_real_vs_fake(real, fake, digit):
    real_grid = make_grid(real[:8], nrow=4, normalize=True)
    fake_grid = make_grid(fake[:8], nrow=4, normalize=True)

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(real_grid.permute(1, 2, 0))
    plt.title(f"Real Images - Digit {digit}")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(fake_grid.permute(1, 2, 0))
    plt.title(f"Generated Images - Digit {digit}")
    plt.axis("off")

    plt.tight_layout()
    plt.show()
Cell 10: Define Image Quality Metrics Function
def calculate_image_quality(images):
    images = images.detach().cpu().numpy()
    diversity = np.std(images)

    sharpness_values = []
    for img in images:
        img_2d = img.squeeze()
        grad_y, grad_x = np.gradient(img_2d)
        sharpness_values.append(np.mean(np.abs(grad_x)) + np.mean(np.abs(grad_y)))

    sharpness = np.mean(sharpness_values)
    contrast = np.mean([img.max() - img.min() for img in images])

    return {
        'diversity': diversity,
        'sharpness': sharpness,
        'contrast': contrast
    }
Cell 11: Initialize Training Storage
trained_generators = {}
metrics_history = {digit: {'g_loss': [], 'd_loss': [], 'quality': []} for digit in range(10)}
moving_avg_window = 3
Cell 12: Train GANs for All Digits (0-9) Sequentially
for digit in range(10):
    print(f"\n{'='*70}")
    print(f"TRAINING DIGIT {digit}")
    print(f"{'='*70}\n")

    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    generator.apply(weights_init)
    discriminator.apply(weights_init)

    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

    # Cache one real batch for visualization
    fixed_real_batch, _ = next(iter(digit_loaders[digit]))
    fixed_real_batch = fixed_real_batch.to(device)

    for epoch in range(epochs_per_digit):
        epoch_g_loss = 0
        epoch_d_loss = 0
        num_batches = 0

        for batch_idx, (real_images, _) in enumerate(digit_loaders[digit]):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            real_labels = torch.full((batch_size, 1), 0.9, device=device)
            fake_labels = torch.full((batch_size, 1), 0.1, device=device)

            if batch_idx % 2 == 0:
                optimizer_D.zero_grad()

                real_output = discriminator(real_images)
                d_real_loss = criterion(real_output, real_labels)

                noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
                fake_images = generator(noise)
                fake_output = discriminator(fake_images.detach())
                d_fake_loss = criterion(fake_output, fake_labels)

                d_loss = d_real_loss + d_fake_loss
                d_loss.backward()
                optimizer_D.step()

                epoch_d_loss += d_loss.item()
            else:
                noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
                fake_images = generator(noise)

            optimizer_G.zero_grad()
            fake_output = discriminator(fake_images)
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward()
            optimizer_G.step()

            epoch_g_loss += g_loss.item()
            num_batches += 1

        avg_g_loss = epoch_g_loss / num_batches
        avg_d_loss = epoch_d_loss / (num_batches // 2)

        with torch.no_grad():
            sample_noise = torch.randn(16, latent_dim, 1, 1, device=device)
            sample_images = generator(sample_noise)
            quality_metrics = calculate_image_quality(sample_images)

        metrics_history[digit]['g_loss'].append(avg_g_loss)
        metrics_history[digit]['d_loss'].append(avg_d_loss)
        metrics_history[digit]['quality'].append(quality_metrics)

        if len(metrics_history[digit]['g_loss']) >= moving_avg_window:
            smoothed_g_loss = sum(metrics_history[digit]['g_loss'][-moving_avg_window:]) / moving_avg_window
        else:
            smoothed_g_loss = avg_g_loss

        quality = "✅" if smoothed_g_loss < 2.5 else "⚠️" if smoothed_g_loss < 4.5 else "❌"

        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch [{epoch+1}/{epochs_per_digit}] | "
                  f"G Loss: {avg_g_loss:.3f} | D Loss: {avg_d_loss:.3f} | Quality: {quality}")

            # Real vs Generated visualization
            show_real_vs_fake(fixed_real_batch, sample_images, digit)

    trained_generators[digit] = copy.deepcopy(generator.state_dict())
    print(f"\n✓ Digit {digit} training complete!\n")

print("\nALL DIGITS TRAINING COMPLETE!")
Cell 13: Generate All Trained Digits Grid (6 samples per digit)
# Load all trained generators and show all digits together
print("Loading all trained generators...\n")

generator = Generator().to(device)
all_samples = []

for digit in range(10):
    generator.load_state_dict(trained_generators[digit])
    generator.eval()
    
    with torch.no_grad():
        noise = torch.randn(6, latent_dim, 1, 1, device=device)
        samples = generator(noise)
        all_samples.append(samples)

final_grid = make_grid(torch.cat(all_samples), nrow=6, normalize=True)

plt.figure(figsize=(12, 18))
plt.imshow(final_grid.cpu().permute(1, 2, 0))
plt.title("All Trained Digits (6 samples per digit, rows 0-9)", fontsize=14, fontweight='bold')
plt.axis("off")
plt.tight_layout()
plt.show()
Cell 14: Plot Training Metrics by Digit
# Plot training metrics for all digits
fig, axes = plt.subplots(2, 5, figsize=(18, 8))
fig.suptitle('Training Metrics by Digit (Same GAN Architecture)', fontsize=16, fontweight='bold')

for digit in range(10):
    row = digit // 5
    col = digit % 5
    ax = axes[row, col]
    
    ax.plot(metrics_history[digit]['g_loss'], label='G Loss', color='blue', linewidth=2)
    ax.plot(metrics_history[digit]['d_loss'], label='D Loss', color='red', linewidth=2)
    ax.set_title(f'Digit {digit}', fontweight='bold')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Cell 15: Plot Image Quality Metrics
# Image quality metrics visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics_names = ['diversity', 'sharpness', 'contrast']
titles = ['Pixel Diversity', 'Image Sharpness', 'Image Contrast']

for idx, (metric_name, title) in enumerate(zip(metrics_names, titles)):
    ax = axes[idx]
    
    for digit in range(10):
        values = [q[metric_name] for q in metrics_history[digit]['quality']]
        ax.plot(values, label=f'Digit {digit}', alpha=0.7, linewidth=1.5)
    
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Epoch')
    ax.set_ylabel(title)
    ax.legend(loc='best', ncol=2, fontsize=8)
    ax.grid(True, alpha=0.3)

plt.suptitle('Image Quality Metrics Across Training', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
Cell 16: Display Final Training Summary
# Final quality summary
print("\n" + "="*70)
print("FINAL TRAINING SUMMARY")
print("="*70)

for digit in range(10):
    final_g_loss = metrics_history[digit]['g_loss'][-1]
    final_d_loss = metrics_history[digit]['d_loss'][-1]
    final_quality = metrics_history[digit]['quality'][-1]
    
    if final_g_loss < 2.5:
        status = "Excellent"
    elif final_g_loss < 4.5:
        status = "Good"
    else:
        status = "Needs improvement"
    
    print(f"Digit {digit}: G Loss: {final_g_loss:.3f} | D Loss: {final_d_loss:.3f} | {status}")
    print(f"  Quality - Diversity: {final_quality['diversity']:.4f} | "
          f"Sharpness: {final_quality['sharpness']:.4f} | "
          f"Contrast: {final_quality['contrast']:.4f}")

print("="*70)
Cell 17: Interactive Digit Generation
# Interactive generation function
def generate_digit(digit, num_samples=5):
    """
    Generate images for a specific digit using the trained generator.
    
    Args:
        digit (int): Digit to generate (0-9)
        num_samples (int): Number of samples to generate
    """
    if digit < 0 or digit > 9:
        print("Error: Please enter a digit between 0 and 9")
        return
    
    print(f"Generating {num_samples} images for digit {digit}...")
    
    # Load trained generator for this digit
    generator = Generator().to(device)
    generator.load_state_dict(trained_generators[digit])
    generator.eval()
    
    with torch.no_grad():
        noise = torch.randn(num_samples, latent_dim, 1, 1, device=device)
        generated_images = generator(noise)
    
    # Display
    grid = make_grid(generated_images, nrow=num_samples, normalize=True)
    
    plt.figure(figsize=(num_samples * 2, 2))
    plt.imshow(grid.cpu().permute(1, 2, 0))
    plt.title(f"Generated Images for Digit: {digit}", fontsize=14, fontweight='bold')
    plt.axis("off")
    plt.tight_layout()
    plt.show()
    
    # Quality metrics
    quality = calculate_image_quality(generated_images)
    print(f"\nQuality Metrics:")
    print(f"  Diversity: {quality['diversity']:.4f}")
    print(f"  Sharpness: {quality['sharpness']:.4f}")
    print(f"  Contrast: {quality['contrast']:.4f}")

print("\n" + "="*70)
print("INTERACTIVE GENERATION READY")
print("="*70)
print("Usage: generate_digit(digit, num_samples=5)")
print("Example: generate_digit(3, 8)  # Generate 8 images of digit 3")
print("="*70)