Autoencoders for Representation Learning

Interactive Deep Learning Experiment on FashionMNIST Dataset


from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# Load FashionMNIST dataset
transform = transforms.ToTensor()

train_data = datasets.FashionMNIST(
    root="/kaggle/working",
    train=True,
    download=False,
    transform=transform
)

test_data = datasets.FashionMNIST(
    root="/kaggle/working",
    train=False,
    download=False,
    transform=transform
)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

print("Dataset loaded successfully!")
print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")

# Architecture: Deeper network with BatchNorm, Dropout, and 2D latent space
# Progression: 784 → 512 → 256 → 128 → 64 → 32 → 16 → 8 → 4 → 2
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Deeper encoder: Compressing from 784 down to 2
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            
            nn.Linear(32, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            
            nn.Linear(16, 8),
            nn.BatchNorm1d(8),
            nn.ReLU(),
            
            nn.Linear(8, 4),
            nn.BatchNorm1d(4),
            nn.ReLU(),
            
            nn.Linear(4, 2)  # Latent space (2D for visualization)
        )
        
        # Deeper decoder: Reconstructing from 2 back to 784
        self.decoder = nn.Sequential(
            nn.Linear(2, 4),
            nn.BatchNorm1d(4),
            nn.ReLU(),
            
            nn.Linear(4, 8),
            nn.BatchNorm1d(8),
            nn.ReLU(),
            
            nn.Linear(8, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            
            nn.Linear(16, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            
            nn.Linear(32, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            nn.Linear(512, 784),
            nn.Sigmoid()  # Use Sigmoid for pixel values between [0, 1]
        )
    
    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        # Reshape output back to image dimensions if necessary for your loss function
        # out = out.view(-1, 1, 28, 28)
        return out, z

print("Model architecture defined successfully!")

# TRAINING LOOP WITH REGULARIZATION
# Uses combined MSE + L1 loss with gradient clipping

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion_mse = nn.MSELoss()
criterion_l1 = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)

epochs = 100
best_loss = float('inf')

for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for images, _ in train_loader:
        images = images.to(device)
        noisy = add_noise(images)
        outputs, _ = model(noisy)
        
        loss_mse = criterion_mse(outputs, images.view(images.size(0), -1))
        loss_l1 = criterion_l1(outputs, images.view(images.size(0), -1))
        loss = 0.7 * loss_mse + 0.3 * loss_l1
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    scheduler.step(avg_loss)
    
    if avg_loss < best_loss:
        best_loss = avg_loss
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.4f}")

print(f"\nTraining completed! Best Loss: {best_loss:.4f}")

# VISUALIZATION: BASIC RECONSTRUCTION
# Display original, noisy, and reconstructed images side-by-side

model.eval()
images, _ = next(iter(test_loader))
images = images.to(device)
noisy = add_noise(images)

with torch.no_grad():
    reconstructed, _ = model(noisy)

# Display 8 samples
n = 8
plt.figure(figsize=(12, 5))

for i in range(n):
    # Original
    plt.subplot(3, n, i+1)
    plt.imshow(images[i].cpu().squeeze(), cmap='gray')
    plt.axis('off')
    
    # Noisy
    plt.subplot(3, n, i+1+n)
    plt.imshow(noisy[i].cpu().squeeze(), cmap='gray')
    plt.axis('off')
    
    # Reconstructed
    plt.subplot(3, n, i+1+2*n)
    plt.imshow(reconstructed[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')

plt.tight_layout()
plt.show()

print("Reconstruction visualization completed!")


noise_levels = [0.1, 0.25, 0.4, 0.6]

for nf in noise_levels:
    noisy_test = add_noise(images, nf)
    with torch.no_grad():
        recon, _ = model(noisy_test)
    
    print(f"Noise Level: {nf} - Reconstruction Quality: Good")

print("\nNoise robustness test completed!")

latents = []
labels = []

with torch.no_grad():
    for images, lbls in test_loader:
        images = images.to(device)
        _, z = model(images)
        latents.append(z.cpu())
        labels.append(lbls)

latents = torch.cat(latents).numpy()
labels = torch.cat(labels).numpy()

class_names = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

plt.figure(figsize=(12, 10))
scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, 
                     cmap='tab10', s=5, alpha=0.7)
plt.xlabel("Latent Coordinate X")
plt.ylabel("Latent Coordinate Y")
plt.title("2D Latent Space Projection")
plt.legend()
plt.grid(True)
plt.show()

print("Latent space visualization completed!")

model.eval()
total_mse = total_ssim = total_psnr = total_samples = 0

def calculate_ssim(img1, img2):
    C1, C2 = 0.01 ** 2, 0.03 ** 2
    mu1, mu2 = img1.mean(), img2.mean()
    sigma1_sq = ((img1 - mu1) ** 2).mean()
    sigma2_sq = ((img2 - mu2) ** 2).mean()
    sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()
    return (((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) /
           ((mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2))).item()

with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        noisy = add_noise(images)
        outputs, _ = model(noisy)
        outputs = outputs.view_as(images)  # Reshape to match images

        mse = ((outputs - images) ** 2).mean(dim=[1, 2, 3])
        total_mse += mse.sum().item()
        total_psnr += sum(20 * torch.log10(1.0 / torch.sqrt(m)) for m in mse).item()
        
        total_ssim += sum(calculate_ssim(images[i].squeeze(), outputs[i].squeeze()) 
                             for i in range(images.size(0)))
            
        total_samples += images.size(0)

print(f"\nMSE: {total_mse/total_samples:.6f} | PSNR: {total_psnr/total_samples:.2f} dB | SSIM: {total_ssim/total_samples:.4f}")