Blogs / From Local Optima to Complete Collapse: When Optimization Turns into Catastrophe

From Local Optima to Complete Collapse: When Optimization Turns into Catastrophe

از بهینه محلی تا فروپاشی کامل: وقتی بهینه‌سازی به فاجعه تبدیل می‌شود

Introduction

A smart architect decides to design the best building in the world. He begins the work, and everything goes perfectly, but suddenly:
  • They get stuck in an average design and have no more creativity (local optimum)
  • They're stuck at a point not knowing which direction to go (saddle point)
  • They repeat the same design for months without progress (Plateau)
  • Or worst of all: They suddenly forget everything and start from scratch!
  • Or more catastrophic: They only design one type of building and completely lose diversity!
  • Or more terrifying: Their calculations explode and numbers reach infinity!
These are exactly the challenges that deep learning models face. In previous articles, we examined three main optimization challenges, but the real world of AI is full of other catastrophes that can destroy a multi-million dollar project in seconds.
In this comprehensive article, we dive deep into these catastrophes and discover:
  • Why GANs suddenly produce only one image (Mode Collapse)
  • Why robots forget old things when learning new things (Catastrophic Forgetting)
  • Why sometimes loss turns into NaN and everything breaks (Gradient Explosion)
  • How the world's largest companies deal with these catastrophes

Classification of Optimization Catastrophes

Catastrophe Type Main Symptom Severity Vulnerable Architectures
Mode Collapse Identical repetitive outputs 🔴 Critical GANs
Catastrophic Forgetting Forgetting previous knowledge 🔴 Critical All (especially Continual Learning)
Gradient Explosion Loss → NaN or Inf 🔴 Critical RNNs, Deep networks
Gradient Vanishing Early layers don't learn 🟡 Medium Very deep networks, RNNs
Training Instability Severe loss oscillations 🟠 High GANs, Large Transformers
Dead Neurons Part of network inactive 🟡 Medium Networks with ReLU
Oscillation/Divergence Non-convergence 🟠 High High LR, inappropriate architecture

Catastrophe 1: Mode Collapse - When GAN Forgets What Diversity Is

Definition and Symptoms

Mode Collapse is one of the worst nightmares of Generative Adversarial Networks. In this state:
  • Generator produces only one or few limited outputs
  • Diversity completely disappears
  • Model found the "safest" way and won't leave it
Concrete Example: Imagine you have a GAN that should generate different faces. After a few epochs, you notice it only produces one face with minor variations - all blonde hair, all blue eyes! This is Mode Collapse.

Why Does Mode Collapse Happen?

Mathematical Reason: Generator and Discriminator are in a minimax game:
min_G max_D V(D, G) = E[log D(x)] + E[log(1 - D(G(z)))]
When Generator finds a way to fool Discriminator (like producing one specific type of image), it has no incentive to explore other modes.

Types of Mode Collapse

1. Complete Collapse
  • Generator produces only one output
  • Worst possible case
  • Project completely failed
2. Partial Collapse
  • Generator produces a few outputs (e.g., 5-10 types)
  • But should produce thousands of different types
  • Much more common than Complete Collapse
3. Mode Hopping
  • Generator jumps from one mode to another every few epochs
  • Never learns all modes simultaneously
  • Very confusing and hard to debug

Professional Solutions

1. Spectral Normalization

python
from torch.nn.utils import spectral_norm

class Generator(nn.Module):
def __init__(self):
# Instead of regular Conv2d
self.conv1 = spectral_norm(nn.Conv2d(128, 256, 3))
self.conv2 = spectral_norm(nn.Conv2d(256, 512, 3))
Amazing Impact:
  • Loss landscape becomes smoother
  • Training more stable
  • Mode Collapse drastically reduced
Real Application: StyleGAN, BigGAN and most modern GANs

2. Minibatch Discrimination

python
class MinibatchDiscrimination(nn.Module):
def forward(self, x):
# Calculate similarity between samples in batch
distances = compute_pairwise_distances(x)
# If all very similar → probably fake
diversity_score = distances.mean()
return torch.cat([x, diversity_score], dim=1)
How It Works:
  • If Generator produces all identical images
  • Discriminator understands (because diversity is low)
  • Generator is forced to create diversity

3. Progressive Growing

Idea: Start with small images, gradually get bigger.
python
# Epoch 1-10: 4x4 images
# Epoch 11-20: 8x8 images
# Epoch 21-30: 16x16 images
# ...
# Epoch 61-70: 1024x1024 images
Why It Works:
  • At low resolution, learning general structures
  • Gradually adding details
  • Mode Collapse happens less at low resolution
Big Success: StyleGAN could create photorealistic images with this technique

Catastrophe 2: Catastrophic Forgetting - The Disaster of Forgetting

Definition and Importance

Catastrophic Forgetting is one of the biggest challenges in continual learning:
Definition: When a model learns a new task, it catastrophically forgets its previous knowledge.
Human Example:
  • You're fluent in English
  • You start learning French
  • After 6 months, when you want to speak English, you've forgotten!
  • This is very rare in humans, but very common in AI!

Why Does It Happen in Neural Networks?

Mathematical Reason - Stability-Plasticity Dilemma:
Neural network must have two contradictory properties:
  1. Stability: Preserve old knowledge
  2. Plasticity: Learn new knowledge
Problem: These two usually contradict each other!
Network weights: W

Task 1: W moves toward Task 1 optimum → W₁
Task 2: W₁ moves toward Task 2 optimum → W₂

But: W₂ might be very bad for Task 1!

Professional Solutions

1. Elastic Weight Consolidation (EWC)

Idea: Some weights are very important for previous task - don't let them change much!
Mathematical Formula:
Loss = Loss_task_new + λ Σ F_i (θ_i - θ*_i)²

F_i = Fisher Information Matrix (importance of weight i for previous task)
θ*_i = optimal weight for previous task
λ = protection amount (usually 1000-10000)
Implementation Code:
python
class EWC:
def __init__(self, model, dataloader, lambda_=1000):
self.model = model
self.lambda_ = lambda_
self.fisher = {}
self.optimal_params = {}
# Compute Fisher Information
self._compute_fisher(dataloader)
def _compute_fisher(self, dataloader):
"""
Calculate importance of each weight for current task
"""
self.model.eval()
for name, param in self.model.named_parameters():
self.fisher[name] = torch.zeros_like(param)
self.optimal_params[name] = param.data.clone()
for data, target in dataloader:
self.model.zero_grad()
output = self.model(data)
loss = F.cross_entropy(output, target)
loss.backward()
# Fisher = gradient²
for name, param in self.model.named_parameters():
self.fisher[name] += param.grad.data ** 2
# Normalization
for name in self.fisher:
self.fisher[name] /= len(dataloader)
def penalty(self):
"""
Penalty for changing important weights
"""
loss = 0
for name, param in self.model.named_parameters():
fisher = self.fisher[name]
optimal = self.optimal_params[name]
loss += (fisher * (param - optimal) ** 2).sum()
return self.lambda_ * loss
Real Results:
  • In MNIST split: without EWC → 70% Forgetting
  • With EWC → only 15% Forgetting!

2. Progressive Neural Networks

Idea: For each new task, add a new column to the network!
python
class ProgressiveNN(nn.Module):
def __init__(self):
super().__init__()
self.columns = nn.ModuleList() # Each task one column
self.lateral_connections = nn.ModuleList()
def add_task(self, input_size, hidden_size, output_size):
"""
Add new column for new task
"""
new_column = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
# Lateral connections from previous columns
if len(self.columns) > 0:
lateral = nn.ModuleList([
nn.Linear(hidden_size, hidden_size)
for _ in range(len(self.columns))
])
self.lateral_connections.append(lateral)
self.columns.append(new_column)
# Freeze previous columns
for i in range(len(self.columns) - 1):
for param in self.columns[i].parameters():
param.requires_grad = False
Advantages:
  • No forgetting! Because previous weights are frozen
  • Each task can use knowledge from previous ones
Disadvantages:
  • Network becomes very large (for 10 tasks, 10 times!)
  • Impractical for many tasks

3. Gradient Episodic Memory (GEM)

Idea: Keep a few samples from previous tasks, when training new task, make sure gradient on previous tasks doesn't become negative!
python
class GEM:
def __init__(self, model, memory_size_per_task=100):
self.model = model
self.memory = {} # {task_id: (data, labels)}
self.memory_size = memory_size_per_task
def store_samples(self, task_id, dataloader):
"""
Store representative samples from task
"""
data_list, label_list = [], []
for data, labels in dataloader:
data_list.append(data)
label_list.append(labels)
if len(data_list) * data.size(0) >= self.memory_size:
break
self.memory[task_id] = (
torch.cat(data_list)[:self.memory_size],
torch.cat(label_list)[:self.memory_size]
)
def project_gradient(self, current_grad):
"""
If gradient on previous tasks is negative, project it
"""
for task_id in self.memory.keys():
mem_grad = self.compute_gradient(task_id)
# Calculate dot product
dot = sum((g1 * g2).sum() for g1, g2 in zip(current_grad, mem_grad))
# If negative (damages previous task)
if dot < 0:
# Project
mem_norm = sum((g ** 2).sum() for g in mem_grad)
for i, (g, m) in enumerate(zip(current_grad, mem_grad)):
current_grad[i] = g - (dot / mem_norm) * m
return current_grad
Advantages:
  • Very effective! Forgetting almost zero
  • Requires limited memory (just a few samples from each task)
Real Application: Used in lifelong learning systems

Catastrophe 3: Gradient Explosion - When Numbers Explode

Definition and Symptoms

Gradient Explosion is one of the scariest moments in neural network training:
Symptoms:
  • Loss suddenly reaches NaN or Inf
  • Weights become very large numbers (10¹⁰ or more!)
  • Model completely breaks in a few iterations
  • Needs complete restart
Why Does It Happen?
In recurrent neural networks and deep networks, gradient backprops through multiple layers:
gradient_layer_1 = gradient_output × W_n × W_(n-1) × ... × W_2 × W_1

If each W > 1:
gradient becomes very large (e.g., 1.1^100 = 13780)

If each W < 1:
gradient becomes very small (e.g., 0.9^100 = 0.0000266)

Principled Solutions

1. Gradient Clipping (Most Powerful Solution)

python
# Method 1: Clip by norm (recommended)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Method 2: Clip by value
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
# Complete usage
for epoch in range(num_epochs):
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# 🔧 This line prevents catastrophe!
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
How It Works:
python
# Suppose gradients are:
gradients = [10.0, 50.0, 100.0, 5.0]
norm = sqrt(10² + 50² + 100² + 5²) = 112.36

# max_norm = 1.0
# If norm > max_norm:
scale = max_norm / norm = 1.0 / 112.36 = 0.0089
# New gradients:
clipped_gradients = [g * scale for g in gradients]
# = [0.089, 0.445, 0.89, 0.0445]
Result: Gradients are limited but their direction is preserved!

2. Proper Weight Initialization

python
def init_weights(m):
if isinstance(m, nn.Linear):
# Xavier initialization to prevent explosion
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LSTM):
# Orthogonal initialization for RNN
for name, param in m.named_parameters():
if 'weight_hh' in name:
nn.init.orthogonal_(param)
elif 'weight_ih' in name:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)

model.apply(init_weights)

3. Monitoring and Early Detection

python
class GradientMonitor:
def __init__(self, alert_threshold=10.0):
self.alert_threshold = alert_threshold
self.history = []
def check_gradients(self, model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.history.append(total_norm)
if total_norm > self.alert_threshold:
print(f"⚠️ WARNING: Gradient norm = {total_norm:.2f}")
return True
return False

Catastrophe 4: Gradient Vanishing - Silence of Deep Layers

Definition and Impact

Gradient Vanishing is the opposite of Explosion:
  • Gradients approach zero
  • Early layers of network don't learn
  • Model stays shallow even if it's deep!
Numerical Example:
Suppose network has 100 layers
Each layer: activation = sigmoid(Wx + b)

gradient at layer 100 = 1.0
gradient at layer 50 = 0.01
gradient at layer 10 = 0.0000001 ← almost zero!
gradient at layer 1 = 10^-20 ← completely zero!
Result: Layers 1 to 10 don't learn at all!

Effective Solutions

1. Using ReLU and Its Variants

python
# ❌ Bad: Sigmoid (small derivative)
activation = nn.Sigmoid()

# ✅ Good: ReLU (derivative 1 for x > 0)
activation = nn.ReLU()
# ✅ Better: Leaky ReLU (non-zero derivative for all x)
activation = nn.LeakyReLU(negative_slope=0.01)
# ✅ Best for Transformers: GELU
activation = nn.GELU()

2. Residual Connections (Skip Connections)

python
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.layer1 = nn.Linear(dim, dim)
self.layer2 = nn.Linear(dim, dim)
self.activation = nn.ReLU()
def forward(self, x):
residual = x # Save input
out = self.layer1(x)
out = self.activation(out)
out = self.layer2(out)
# Add residual
out = out + residual # 🔧 This line saves gradient flow!
out = self.activation(out)
return out
Why It Works:
Without skip connection:
gradient_layer1 = gradient × W_100 × W_99 × ... × W_2 → becomes zero

With skip connection:
gradient_layer1 = gradient × (1 + W_100 × W_99 × ... × W_2) → at least original gradient remains!
Big Success: ResNet could have 1000+ layers without vanishing gradient!

Catastrophe 5: Training Instability - Deadly Oscillations

Symptoms and Detection

Training Instability:
  • Loss irregularly goes up and down
  • Some epochs model gets worse
  • No smooth convergence
python
# Example oscillating loss:
Epoch 1: Loss = 2.5
Epoch 2: Loss = 2.1
Epoch 3: Loss = 1.8
Epoch 4: Loss = 3.2 ← 💥 Why worse?
Epoch 5: Loss = 1.5
Epoch 6: Loss = 2.9 ← 💥 Again!
Epoch 7: Loss = 1.2

Advanced Solutions

1. Learning Rate Warmup and Decay

python
class WarmupScheduler:
def __init__(self, optimizer, warmup_steps, total_steps):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.step_count = 0
self.base_lr = optimizer.param_groups[0]['lr']
def step(self):
self.step_count += 1
if self.step_count < self.warmup_steps:
# Warmup: gradual increase
lr = self.base_lr * (self.step_count / self.warmup_steps)
else:
# Decay: gradual decrease
progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr = self.base_lr * (1 - progress)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

2. Exponential Moving Average (EMA) of Weights

python
class EMA:
def __init__(self, model, decay=0.999):
self.model = model
self.decay = decay
self.shadow = {}
# Store copy of weights
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
"""
Update EMA weights
"""
for name, param in self.model.named_parameters():
if param.requires_grad:
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
Advantage: EMA weights are usually more stable and accurate!

Comprehensive Checklist: Preventing Optimization Catastrophes

Before Training ✅

Architecture:
  • Use Residual Connections in deep networks
  • Batch/Layer Normalization in appropriate places
  • ReLU/GELU instead of Sigmoid/Tanh
  • LSTM/GRU instead of simple RNN
Initialization:
  • He initialization for ReLU
  • Xavier/Glorot for Tanh
  • Orthogonal initialization for RNN
Hyperparameters:
  • Appropriate learning rate (usually 1e-4 to 1e-3 for Adam)
  • Reasonable batch size (32-256)
  • Gradient clipping enabled (max_norm=1.0)

During Training 📊

Monitoring:
  • Loss on train and validation
  • Gradient norm
  • Current learning rate
  • Model weights (check for NaN/Inf)
Early Warning Signs:
  • Loss → NaN: Gradient explosion
  • Oscillating loss: LR too high or batch size small
  • Validation worse than train: Overfitting
  • Identical outputs: Mode collapse

Emergency Actions 🚨

When Loss → NaN:
  1. Immediately stop
  2. Restore from previous checkpoint
  3. Reduce learning rate 10x
  4. Adjust gradient clipping (smaller max_norm)
  5. Restart
When Mode Collapse:
  1. Reduce learning rate
  2. Add more noise to input
  3. Strengthen Discriminator (in GANs)
  4. Use Minibatch Discrimination
When Catastrophic Forgetting:
  1. Implement EWC or GEM
  2. Reduce learning rate
  3. Use Knowledge Distillation

Conclusion: The Art of Survival in the Optimization World

Optimization catastrophes are an inseparable part of deep learning. In previous articles, we examined three main challenges (local optima, saddle points, Plateau), and in this article, we saw more critical catastrophes:
Key Points:
Mode Collapse: Don't sacrifice diversity for safety - use Minibatch Discrimination and Spectral Normalization
Catastrophic Forgetting: Protect past memory - EWC, Progressive Networks, or GEM
Gradient Explosion: Always have gradient clipping - it's vital insurance
Gradient Vanishing: ReLU + Residual Connections = success recipe
Training Instability: Warmup + Decay + EMA = stability
Key Industry Lessons:
  1. Google: Spectral Normalization in BigGAN → solved Mode Collapse
  2. Microsoft: Residual Connections in ResNet → solved Vanishing Gradient
  3. OpenAI: Gradient Clipping in GPT → prevented Explosion
  4. DeepMind: EWC in game AI → solved Catastrophic Forgetting
Final Recommendation:
In the real world, success belongs to those who:
  • Anticipate: See catastrophes before they happen
  • Are Prepared: Have backup and recovery systems
  • Learn: Take lessons from past failures
  • Monitor: Track everything
Remember: The best way to prevent catastrophe is to be prepared for it!