DCGAN Training on MNIST Dataset
Deep Convolutional Generative Adversarial Network - Digit-wise Training
Back to Home
Cell 1: Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import copy
Cell 2: Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
Cell 3: Load and Separate Dataset by Digit Class
# Dataset preparation - separate by digit class
transform = transforms.Compose([
transforms.Resize(16),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
full_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
# Separate data by digit class (0-9)
digit_datasets = {}
digit_loaders = {}
for digit in range(10):
indices = [i for i, (_, label) in enumerate(full_dataset) if label == digit]
digit_datasets[digit] = torch.utils.data.Subset(full_dataset, indices[:1500])
digit_loaders[digit] = torch.utils.data.DataLoader(
digit_datasets[digit],
batch_size=64,
shuffle=True
)
print(f"Digit {digit}: {len(digit_datasets[digit])} images")
Cell 4: Define Hyperparameters
# Hyperparameters
latent_dim = 100
epochs_per_digit = 25
lr = 0.0002
beta1 = 0.5
Cell 5: Define Generator Architecture
# Generator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.net(z)
Cell 6: Define Discriminator Architecture
# Discriminator: 16x16 image -> probability real/fake
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 128, 4, 2, 1, bias=False), # 16x16 -> 8x8
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False), # 8x8 -> 4x4
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 1, 4, 1, 0, bias=False), # 4x4 -> 1x1
nn.Sigmoid()
)
def forward(self, x):
return self.net(x).view(-1, 1)
Cell 7: Define Weight Initialization Function
# Weight initialization
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
Cell 8: Define Loss Criterion
criterion = nn.BCELoss()
Cell 9: Define Visualization Function
def show_real_vs_fake(real, fake, digit):
real_grid = make_grid(real[:8], nrow=4, normalize=True)
fake_grid = make_grid(fake[:8], nrow=4, normalize=True)
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(real_grid.permute(1, 2, 0))
plt.title(f"Real Images - Digit {digit}")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(fake_grid.permute(1, 2, 0))
plt.title(f"Generated Images - Digit {digit}")
plt.axis("off")
plt.tight_layout()
plt.show()
Cell 10: Define Image Quality Metrics Function
def calculate_image_quality(images):
images = images.detach().cpu().numpy()
diversity = np.std(images)
sharpness_values = []
for img in images:
img_2d = img.squeeze()
grad_y, grad_x = np.gradient(img_2d)
sharpness_values.append(np.mean(np.abs(grad_x)) + np.mean(np.abs(grad_y)))
sharpness = np.mean(sharpness_values)
contrast = np.mean([img.max() - img.min() for img in images])
return {
'diversity': diversity,
'sharpness': sharpness,
'contrast': contrast
}
Cell 11: Initialize Training Storage
trained_generators = {}
metrics_history = {digit: {'g_loss': [], 'd_loss': [], 'quality': []} for digit in range(10)}
moving_avg_window = 3
Cell 12: Train GANs for All Digits (0-9) Sequentially
for digit in range(10):
print(f"\n{'='*70}")
print(f"TRAINING DIGIT {digit}")
print(f"{'='*70}\n")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# Cache one real batch for visualization
fixed_real_batch, _ = next(iter(digit_loaders[digit]))
fixed_real_batch = fixed_real_batch.to(device)
for epoch in range(epochs_per_digit):
epoch_g_loss = 0
epoch_d_loss = 0
num_batches = 0
for batch_idx, (real_images, _) in enumerate(digit_loaders[digit]):
batch_size = real_images.size(0)
real_images = real_images.to(device)
real_labels = torch.full((batch_size, 1), 0.9, device=device)
fake_labels = torch.full((batch_size, 1), 0.1, device=device)
if batch_idx % 2 == 0:
optimizer_D.zero_grad()
real_output = discriminator(real_images)
d_real_loss = criterion(real_output, real_labels)
noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
fake_images = generator(noise)
fake_output = discriminator(fake_images.detach())
d_fake_loss = criterion(fake_output, fake_labels)
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()
epoch_d_loss += d_loss.item()
else:
noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
fake_images = generator(noise)
optimizer_G.zero_grad()
fake_output = discriminator(fake_images)
g_loss = criterion(fake_output, real_labels)
g_loss.backward()
optimizer_G.step()
epoch_g_loss += g_loss.item()
num_batches += 1
avg_g_loss = epoch_g_loss / num_batches
avg_d_loss = epoch_d_loss / (num_batches // 2)
with torch.no_grad():
sample_noise = torch.randn(16, latent_dim, 1, 1, device=device)
sample_images = generator(sample_noise)
quality_metrics = calculate_image_quality(sample_images)
metrics_history[digit]['g_loss'].append(avg_g_loss)
metrics_history[digit]['d_loss'].append(avg_d_loss)
metrics_history[digit]['quality'].append(quality_metrics)
if len(metrics_history[digit]['g_loss']) >= moving_avg_window:
smoothed_g_loss = sum(metrics_history[digit]['g_loss'][-moving_avg_window:]) / moving_avg_window
else:
smoothed_g_loss = avg_g_loss
quality = "✅" if smoothed_g_loss < 2.5 else "⚠️" if smoothed_g_loss < 4.5 else "❌"
if (epoch + 1) % 5 == 0 or epoch == 0:
print(f"Epoch [{epoch+1}/{epochs_per_digit}] | "
f"G Loss: {avg_g_loss:.3f} | D Loss: {avg_d_loss:.3f} | Quality: {quality}")
# Real vs Generated visualization
show_real_vs_fake(fixed_real_batch, sample_images, digit)
trained_generators[digit] = copy.deepcopy(generator.state_dict())
print(f"\n✓ Digit {digit} training complete!\n")
print("\nALL DIGITS TRAINING COMPLETE!")
Cell 13: Generate All Trained Digits Grid (6 samples per digit)
# Load all trained generators and show all digits together
print("Loading all trained generators...\n")
generator = Generator().to(device)
all_samples = []
for digit in range(10):
generator.load_state_dict(trained_generators[digit])
generator.eval()
with torch.no_grad():
noise = torch.randn(6, latent_dim, 1, 1, device=device)
samples = generator(noise)
all_samples.append(samples)
final_grid = make_grid(torch.cat(all_samples), nrow=6, normalize=True)
plt.figure(figsize=(12, 18))
plt.imshow(final_grid.cpu().permute(1, 2, 0))
plt.title("All Trained Digits (6 samples per digit, rows 0-9)", fontsize=14, fontweight='bold')
plt.axis("off")
plt.tight_layout()
plt.show()
Cell 14: Plot Training Metrics by Digit
# Plot training metrics for all digits
fig, axes = plt.subplots(2, 5, figsize=(18, 8))
fig.suptitle('Training Metrics by Digit (Same GAN Architecture)', fontsize=16, fontweight='bold')
for digit in range(10):
row = digit // 5
col = digit % 5
ax = axes[row, col]
ax.plot(metrics_history[digit]['g_loss'], label='G Loss', color='blue', linewidth=2)
ax.plot(metrics_history[digit]['d_loss'], label='D Loss', color='red', linewidth=2)
ax.set_title(f'Digit {digit}', fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Cell 15: Plot Image Quality Metrics
# Image quality metrics visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metrics_names = ['diversity', 'sharpness', 'contrast']
titles = ['Pixel Diversity', 'Image Sharpness', 'Image Contrast']
for idx, (metric_name, title) in enumerate(zip(metrics_names, titles)):
ax = axes[idx]
for digit in range(10):
values = [q[metric_name] for q in metrics_history[digit]['quality']]
ax.plot(values, label=f'Digit {digit}', alpha=0.7, linewidth=1.5)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel(title)
ax.legend(loc='best', ncol=2, fontsize=8)
ax.grid(True, alpha=0.3)
plt.suptitle('Image Quality Metrics Across Training', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
Cell 16: Display Final Training Summary
# Final quality summary
print("\n" + "="*70)
print("FINAL TRAINING SUMMARY")
print("="*70)
for digit in range(10):
final_g_loss = metrics_history[digit]['g_loss'][-1]
final_d_loss = metrics_history[digit]['d_loss'][-1]
final_quality = metrics_history[digit]['quality'][-1]
if final_g_loss < 2.5:
status = "Excellent"
elif final_g_loss < 4.5:
status = "Good"
else:
status = "Needs improvement"
print(f"Digit {digit}: G Loss: {final_g_loss:.3f} | D Loss: {final_d_loss:.3f} | {status}")
print(f" Quality - Diversity: {final_quality['diversity']:.4f} | "
f"Sharpness: {final_quality['sharpness']:.4f} | "
f"Contrast: {final_quality['contrast']:.4f}")
print("="*70)
Cell 17: Interactive Digit Generation
# Interactive generation function
def generate_digit(digit, num_samples=5):
"""
Generate images for a specific digit using the trained generator.
Args:
digit (int): Digit to generate (0-9)
num_samples (int): Number of samples to generate
"""
if digit < 0 or digit > 9:
print("Error: Please enter a digit between 0 and 9")
return
print(f"Generating {num_samples} images for digit {digit}...")
# Load trained generator for this digit
generator = Generator().to(device)
generator.load_state_dict(trained_generators[digit])
generator.eval()
with torch.no_grad():
noise = torch.randn(num_samples, latent_dim, 1, 1, device=device)
generated_images = generator(noise)
# Display
grid = make_grid(generated_images, nrow=num_samples, normalize=True)
plt.figure(figsize=(num_samples * 2, 2))
plt.imshow(grid.cpu().permute(1, 2, 0))
plt.title(f"Generated Images for Digit: {digit}", fontsize=14, fontweight='bold')
plt.axis("off")
plt.tight_layout()
plt.show()
# Quality metrics
quality = calculate_image_quality(generated_images)
print(f"\nQuality Metrics:")
print(f" Diversity: {quality['diversity']:.4f}")
print(f" Sharpness: {quality['sharpness']:.4f}")
print(f" Contrast: {quality['contrast']:.4f}")
print("\n" + "="*70)
print("INTERACTIVE GENERATION READY")
print("="*70)
print("Usage: generate_digit(digit, num_samples=5)")
print("Example: generate_digit(3, 8) # Generate 8 images of digit 3")
print("="*70)