Generative Adversarial Networks (GANs)

Procedure

The objective of this procedure is to implement a Generative Adversarial Network (GAN) using a Deep Convolutional GAN (DCGAN) architecture on the MNIST dataset. The experiment involves training a generator and discriminator in an adversarial manner to learn the data distribution and generate digit-like images. The procedure focuses on understanding adversarial training dynamics, monitoring generator and discriminator losses, and visually comparing real and generated images across epochs using a "Real vs Generated" panel to evaluate model performance.


1. Environment Setup

Import the required Python libraries: PyTorch, Torchvision, NumPy, and Matplotlib.


2. Dataset Preparation

  • Load the MNIST handwritten digit dataset.
  • Apply pre-processing steps including resizing images to 16×16, converting them to tensors, and normalizing pixel values to the range [−1, 1].
  • Separate the dataset into individual subsets for each digit (0–9).
  • For each digit class, select a fixed number of samples and create separate DataLoaders to train the GAN independently on each digit.

3. Hyper-parameter Initialization

  • Define key hyper-parameters: latent vector size, number of training epochs per digit, learning rate, and Adam optimizer momentum term.
  • Set binary cross-entropy loss as the adversarial loss function.

4. Model Architecture Definition

  • Design the Generator using transposed convolutional layers, batch normalization, and ReLU activations to transform random noise into digit images.
  • Design the Discriminator using convolutional layers with LeakyReLU activations to classify images as real or fake.
  • Apply weight initialization to both networks to ensure stable training.

5. Training Strategy (Digit-wise GAN Training)

For each digit from 0 to 9:

  • Initialize a new Generator and Discriminator.
  • Assign Adam optimizers to both networks.
  • Cache a fixed batch of real images for visualization purposes.
  • Train the GAN for a fixed number of epochs using alternating optimization:
    • Update the Discriminator by minimizing loss on real and fake samples.
    • Update the Generator by minimizing its loss to fool the Discriminator.
  • Use label smoothing to improve training stability.

6. Epoch-wise Monitoring and Visualization

  • Compute average Generator and Discriminator losses for each epoch.
  • Generate sample images using fixed noise vectors.
  • Display a "Real vs Generated" image panel at regular intervals to visually evaluate learning progress.
  • Assign a simple quality indicator based on smoothed generator loss values.

7. Image Quality Evaluation

After each epoch, compute quantitative quality metrics:

  • Pixel diversity
  • Image sharpness
  • Image contrast

Store loss values and quality metrics for later analysis.


8. Result Visualization

  • Load all trained generators and generate sample images for every digit.
  • Display all generated digits together in a single grid for comparison.

9. Training Metrics Analysis

  • Plot Generator and Discriminator loss curves for all digits.
  • Visualize image quality metrics across epochs to analyse training behaviour.

10. Final Training Summary

  • Display final loss values and quality metrics for each digit.
  • Categorize training performance as Excellent, Good, or Needs Improvement.

11. Interactive Sample Generation

  • Implement an interactive function to generate images for any selected digit.
  • Allow the user to specify the number of samples.
  • Display generated images along with corresponding quality metrics.

12. Result Analysis

  • Generate samples for all digits to verify overall GAN performance.
  • Observe diversity and realism in the generated handwritten digits.